-rw-r--r-- 8888 cryptattacktester-20231020/aes128_enum.cpp raw
#include <cassert>
#include "bit_vector.h"
#include "aes128_enum.h"
using namespace std;
typedef vector<bit> byte;
static byte byte_xor(byte c,byte d)
{
  byte result;
  for (bigint i = 0;i < 8;++i)
    result.push_back(c.at(i)^d.at(i));
  return result;
}
static vector<bit> two{0,1,0,0,0,0,0,0};
static byte xtime(byte c)
{
  bit c0 = c.at(0);
  bit c1 = c.at(1);
  bit c2 = c.at(2);
  bit c3 = c.at(3);
  bit c4 = c.at(4);
  bit c5 = c.at(5);
  bit c6 = c.at(6);
  bit c7 = c.at(7);
  bit h0 = c7;
  bit h1 = c0^c7;
  bit h2 = c1;
  bit h3 = c2^c7;
  bit h4 = c3^c7;
  bit h5 = c4;
  bit h6 = c5;
  bit h7 = c6;
  
  byte result;
  result.push_back(h0);
  result.push_back(h1);
  result.push_back(h2);
  result.push_back(h3);
  result.push_back(h4);
  result.push_back(h5);
  result.push_back(h6);
  result.push_back(h7);
  return result;
}
static byte byte_sub(byte c)
{
  bit U0 = c.at(7);
  bit U1 = c.at(6);
  bit U2 = c.at(5);
  bit U3 = c.at(4);
  bit U4 = c.at(3);
  bit U5 = c.at(2);
  bit U6 = c.at(1);
  bit U7 = c.at(0);
  bit y14 = U3 ^ U5;
  bit y13 = U0 ^ U6;
  bit y9 = U0 ^ U3;
  bit y8 = U0 ^ U5;
  bit t0 = U1 ^ U2;
  bit y1 = t0 ^ U7;
  bit y4 = y1 ^ U3;
  bit y12 = y13 ^ y14;
  bit y2 = y1 ^ U0;
  bit y5 = y1 ^ U6;
  bit y3 = y5 ^ y8;
  bit t1 = U4 ^ y12;
  bit y15 = t1 ^ U5;
  bit y20 = t1 ^ U1;
  bit y6 = y15 ^ U7;
  bit y10 = y15 ^ t0;
  bit y11 = y20 ^ y9;
  bit y7 = U7 ^ y11;
  bit y17 = y10 ^ y11;
  bit y19 = y10 ^ y8;
  bit y16 = t0 ^ y11;
  bit y21 = y13 ^ y16;
  bit y18 = U0 ^ y16;
  bit t2 = y12 & y15;
  bit t3 = y3 & y6;
  bit t4 = t3 ^ t2;
  bit t5 = y4 & U7;
  bit t6 = t5 ^ t2;
  bit t7 = y13 & y16;
  bit t8 = y5 & y1;
  bit t9 = t8 ^ t7;
  bit t10 = y2 & y7;
  bit t11 = t10 ^ t7;
  bit t12 = y9 & y11;
  bit t13 = y14 & y17;
  bit t14 = t13 ^ t12;
  bit t15 = y8 & y10;
  bit t16 = t15 ^ t12;
  bit t17 = t4 ^ y20;
  bit t18 = t6 ^ t16;
  bit t19 = t9 ^ t14;
  bit t20 = t11 ^ t16;
  bit t21 = t17 ^ t14;
  bit t22 = t18 ^ y19;
  bit t23 = t19 ^ y21;
  bit t24 = t20 ^ y18;
  bit t25 = t21 ^ t22;
  bit t26 = t21 & t23;
  bit t27 = t24 ^ t26;
  bit t28 = t25 & t27;
  bit t29 = t28 ^ t22;
  bit t30 = t23 ^ t24;
  bit t31 = t22 ^ t26;
  bit t32 = t31 & t30;
  bit t33 = t32 ^ t24;
  bit t34 = t23 ^ t33;
  bit t35 = t27 ^ t33;
  bit t36 = t24 & t35;
  bit t37 = t36 ^ t34;
  bit t38 = t27 ^ t36;
  bit t39 = t29 & t38;
  bit t40 = t25 ^ t39;
  bit t41 = t40 ^ t37;
  bit t42 = t29 ^ t33;
  bit t43 = t29 ^ t40;
  bit t44 = t33 ^ t37;
  bit t45 = t42 ^ t41;
  bit z0 = t44 & y15;
  bit z1 = t37 & y6;
  bit z2 = t33 & U7;
  bit z3 = t43 & y16;
  bit z4 = t40 & y1;
  bit z5 = t29 & y7;
  bit z6 = t42 & y11;
  bit z7 = t45 & y17;
  bit z8 = t41 & y10;
  bit z9 = t44 & y12;
  bit z10 = t37 & y3;
  bit z11 = t33 & y4;
  bit z12 = t43 & y13;
  bit z13 = t40 & y5;
  bit z14 = t29 & y2;
  bit z15 = t42 & y9;
  bit z16 = t45 & y14;
  bit z17 = t41 & y8;
  bit tc1 = z15 ^ z16;
  bit tc2 = z10 ^ tc1;
  bit tc3 = z9 ^ tc2;
  bit tc4 = z0 ^ z2;
  bit tc5 = z1 ^ z0;
  bit tc6 = z3 ^ z4;
  bit tc7 = z12 ^ tc4;
  bit tc8 = z7 ^ tc6;
  bit tc9 = z8 ^ tc7;
  bit tc10 = tc8 ^ tc9;
  bit tc11 = tc6 ^ tc5;
  bit tc12 = z3 ^ z5;
  bit tc13 = z13 ^ tc1;
  bit tc14 = tc4 ^ tc12;
  bit S3 = tc3 ^ tc11;
  bit tc16 = z6 ^ tc8;
  bit tc17 = z14 ^ tc10;
  bit tc18 = tc13 ^ tc14;
  bit S7 = z12.xnor(tc18);
  bit tc20 = z15 ^ tc16;
  bit tc21 = tc2 ^ z11;
  bit S0 = tc3 ^ tc16;
  bit S6 = tc10.xnor(tc18);
  bit S4 = tc14 ^ S3;
  bit S1 = S3.xnor(tc16);
  bit tc26 = tc17 ^ tc20;
  bit S2 = tc26.xnor(z17);
  bit S5 = tc21 ^ tc17;
  byte result;
  result.push_back(S7);
  result.push_back(S6);
  result.push_back(S5);
  result.push_back(S4);
  result.push_back(S3);
  result.push_back(S2);
  result.push_back(S1);
  result.push_back(S0);
  return result;
}
static vector<bit> initialroundconstant{1,0,0,0,0,0,0,0};
static vector<bit> encrypt(const vector<bit> &in,const vector<bit> &k)
{
  vector<vector<byte>> expanded(4,vector<byte> (44));
  vector<vector<byte>> state(4,vector<byte> (4));
  vector<vector<byte>> newstate(4,vector<byte> (4));
  byte roundconstant;
  bigint i;
  bigint j;
  bigint r;
  bigint bitpos;
  for (j = 0;j < 4;++j)
    for (i = 0;i < 4;++i)
      for (bitpos = 0;bitpos < 8;++bitpos)
        expanded.at(i).at(j).push_back(k.at((j*4+i)*8+bitpos));
  roundconstant = initialroundconstant;
  for (j = 4;j < 44;++j) {
    vector<byte> temp(4);
    if (j % 4)
      for (i = 0;i < 4;++i) temp.at(i) = expanded.at(i).at(j - 1);
    else {
      for (i = 0;i < 4;++i) temp.at(i) = byte_sub(expanded.at((i + 1) % 4).at(j - 1));
      temp.at(0) = byte_xor(temp.at(0),roundconstant);
      roundconstant = xtime(roundconstant);
    }
    for (i = 0;i < 4;++i)
      expanded.at(i).at(j) = byte_xor(temp.at(i),expanded.at(i).at(j - 4));
  }
  for (j = 0;j < 4;++j)
    for (i = 0;i < 4;++i)
      for (bitpos = 0;bitpos < 8;++bitpos)
        state.at(i).at(j).push_back(in.at((j*4+i)*8+bitpos));
  for (j = 0;j < 4;++j)
    for (i = 0;i < 4;++i)
      state.at(i).at(j) = byte_xor(state.at(i).at(j),expanded.at(i).at(j));
  for (r = 0;r < 10;++r) {
    for (i = 0;i < 4;++i)
      for (j = 0;j < 4;++j)
        newstate.at(i).at(j) = byte_sub(state.at(i).at(j));
    for (i = 0;i < 4;++i)
      for (j = 0;j < 4;++j)
        state.at(i).at(j) = newstate.at(i).at((j + i) % 4);
    if (r < 9)
      for (j = 0;j < 4;++j) {
        byte a0 = state.at(0).at(j);
        byte a1 = state.at(1).at(j);
        byte a2 = state.at(2).at(j);
        byte a3 = state.at(3).at(j);
        byte a01 = byte_xor(a0,a1);
        byte a12 = byte_xor(a1,a2);
        byte a23 = byte_xor(a2,a3);
        byte a30 = byte_xor(a3,a0);
        state.at(0).at(j) = byte_xor(xtime(a01),byte_xor(a1,a23));
        state.at(1).at(j) = byte_xor(xtime(a12),byte_xor(a2,a30));
        state.at(2).at(j) = byte_xor(xtime(a23),byte_xor(a3,a01));
        state.at(3).at(j) = byte_xor(xtime(a30),byte_xor(a0,a12));
      }
    for (i = 0;i < 4;++i)
      for (j = 0;j < 4;++j)
        state.at(i).at(j) = byte_xor(state.at(i).at(j),expanded.at(i).at(r * 4 + 4 + j));
  }
  vector<bit> result;
  for (j = 0;j < 4;++j)
    for (i = 0;i < 4;++i)
      for (bitpos = 0;bitpos < 8;++bitpos)
        result.push_back(state.at(i).at(j).at(bitpos));
  return result;
}
vector<bit> aes128_enum(
  const vector<bit> &bits,
  const vector<bigint> ¶ms,
  const vector<bigint> &attackparams
)
{
  bigint K = params.at(0);
  bigint C = params.at(1);
  bigint pos = 0;
  bigint I = attackparams.at(pos++);
  bigint QX = attackparams.at(pos++);
  bigint QUEUE_SIZE = attackparams.at(pos++);
  bigint QF = attackparams.at(pos++); auto PERIOD = QF*QUEUE_SIZE;
  vector<bit> plaintext0(128);
  vector<bit> ciphertext0(C);
  vector<bit> plaintext1(128);
  vector<bit> ciphertext1(C);
  {
    bigint pos = 0;
    for (bigint j = 0;j < 128;++j) plaintext0.at(j) = bits.at(pos++);
    for (bigint j = 0;j < C;++j) ciphertext0.at(j) = bits.at(pos++);
    for (bigint j = 0;j < 128;++j) plaintext1.at(j) = bits.at(pos++);
    for (bigint j = 0;j < C;++j) ciphertext1.at(j) = bits.at(pos++);
    assert(pos == bits.size());
  }
  vector<bit> result(K,bit(1));
  // note that queue is not used if QX == 0
  vector<bit> queue_valid(QUEUE_SIZE);
  vector<vector<bit>> queue(QUEUE_SIZE,vector<bit>(K));
  bigint timer = 0;
  vector<bit> guess(128);
  for (bigint iter = 0;iter < I;++iter) {
    auto guessct0 = encrypt(plaintext0,guess);
    bit mismatch;
    for (bigint j = 0;j < C;++j)
      mismatch |= guessct0.at(j)^ciphertext0.at(j);
    if (QX == 0) {
      auto guessct1 = encrypt(plaintext1,guess);
      for (bigint j = 0;j < C;++j)
        mismatch |= guessct1.at(j)^ciphertext1.at(j);
      for (bigint j = 0;j < K;++j)
        result.at(j) = mismatch.mux(guess.at(j),result.at(j));
    } else {
      vector<bit> guessprefix(K);
      for (bigint j = 0;j < K;++j)
        guessprefix.at(j) = guess.at(j);
      bit match = ~mismatch;
      bit_queue1_insert(queue_valid,match);
      bit_vector_queue_insert(queue,guessprefix,match);
      ++timer;
      if (timer == PERIOD || iter == I-1) {
        timer = 0;
        for (bigint q = 0;q < QUEUE_SIZE;++q) {
          vector<bit> queueguess(128);
          for (bigint j = 0;j < K;++j)
            queueguess.at(j) = queue.at(q).at(j);
          auto guessct1 = encrypt(plaintext1,queueguess);
          bit mismatch1 = ~queue_valid.at(q);
          for (bigint j = 0;j < C;++j)
            mismatch1 |= guessct1.at(j)^ciphertext1.at(j);
          for (bigint j = 0;j < K;++j)
            result.at(j) = mismatch1.mux(queueguess.at(j),result.at(j));
          queue_valid.at(q) = bit(0);
        }
      }
    }
    bit incrementing(1);
    for (bigint j = 0;j < K;++j) {
      bit old = guess.at(j);
      guess.at(j) = old^incrementing;
      incrementing &= old;
    }
  }
  
  return result;
}