-rw-r--r-- 8888 cryptattacktester-20230614/aes128_enum.cpp raw
#include <cassert>
#include "bit_vector.h"
#include "aes128_enum.h"
using namespace std;
typedef vector<bit> byte;
static byte byte_xor(byte c,byte d)
{
byte result;
for (bigint i = 0;i < 8;++i)
result.push_back(c.at(i)^d.at(i));
return result;
}
static vector<bit> two{0,1,0,0,0,0,0,0};
static byte xtime(byte c)
{
bit c0 = c.at(0);
bit c1 = c.at(1);
bit c2 = c.at(2);
bit c3 = c.at(3);
bit c4 = c.at(4);
bit c5 = c.at(5);
bit c6 = c.at(6);
bit c7 = c.at(7);
bit h0 = c7;
bit h1 = c0^c7;
bit h2 = c1;
bit h3 = c2^c7;
bit h4 = c3^c7;
bit h5 = c4;
bit h6 = c5;
bit h7 = c6;
byte result;
result.push_back(h0);
result.push_back(h1);
result.push_back(h2);
result.push_back(h3);
result.push_back(h4);
result.push_back(h5);
result.push_back(h6);
result.push_back(h7);
return result;
}
static byte byte_sub(byte c)
{
bit U0 = c.at(7);
bit U1 = c.at(6);
bit U2 = c.at(5);
bit U3 = c.at(4);
bit U4 = c.at(3);
bit U5 = c.at(2);
bit U6 = c.at(1);
bit U7 = c.at(0);
bit y14 = U3 ^ U5;
bit y13 = U0 ^ U6;
bit y9 = U0 ^ U3;
bit y8 = U0 ^ U5;
bit t0 = U1 ^ U2;
bit y1 = t0 ^ U7;
bit y4 = y1 ^ U3;
bit y12 = y13 ^ y14;
bit y2 = y1 ^ U0;
bit y5 = y1 ^ U6;
bit y3 = y5 ^ y8;
bit t1 = U4 ^ y12;
bit y15 = t1 ^ U5;
bit y20 = t1 ^ U1;
bit y6 = y15 ^ U7;
bit y10 = y15 ^ t0;
bit y11 = y20 ^ y9;
bit y7 = U7 ^ y11;
bit y17 = y10 ^ y11;
bit y19 = y10 ^ y8;
bit y16 = t0 ^ y11;
bit y21 = y13 ^ y16;
bit y18 = U0 ^ y16;
bit t2 = y12 & y15;
bit t3 = y3 & y6;
bit t4 = t3 ^ t2;
bit t5 = y4 & U7;
bit t6 = t5 ^ t2;
bit t7 = y13 & y16;
bit t8 = y5 & y1;
bit t9 = t8 ^ t7;
bit t10 = y2 & y7;
bit t11 = t10 ^ t7;
bit t12 = y9 & y11;
bit t13 = y14 & y17;
bit t14 = t13 ^ t12;
bit t15 = y8 & y10;
bit t16 = t15 ^ t12;
bit t17 = t4 ^ y20;
bit t18 = t6 ^ t16;
bit t19 = t9 ^ t14;
bit t20 = t11 ^ t16;
bit t21 = t17 ^ t14;
bit t22 = t18 ^ y19;
bit t23 = t19 ^ y21;
bit t24 = t20 ^ y18;
bit t25 = t21 ^ t22;
bit t26 = t21 & t23;
bit t27 = t24 ^ t26;
bit t28 = t25 & t27;
bit t29 = t28 ^ t22;
bit t30 = t23 ^ t24;
bit t31 = t22 ^ t26;
bit t32 = t31 & t30;
bit t33 = t32 ^ t24;
bit t34 = t23 ^ t33;
bit t35 = t27 ^ t33;
bit t36 = t24 & t35;
bit t37 = t36 ^ t34;
bit t38 = t27 ^ t36;
bit t39 = t29 & t38;
bit t40 = t25 ^ t39;
bit t41 = t40 ^ t37;
bit t42 = t29 ^ t33;
bit t43 = t29 ^ t40;
bit t44 = t33 ^ t37;
bit t45 = t42 ^ t41;
bit z0 = t44 & y15;
bit z1 = t37 & y6;
bit z2 = t33 & U7;
bit z3 = t43 & y16;
bit z4 = t40 & y1;
bit z5 = t29 & y7;
bit z6 = t42 & y11;
bit z7 = t45 & y17;
bit z8 = t41 & y10;
bit z9 = t44 & y12;
bit z10 = t37 & y3;
bit z11 = t33 & y4;
bit z12 = t43 & y13;
bit z13 = t40 & y5;
bit z14 = t29 & y2;
bit z15 = t42 & y9;
bit z16 = t45 & y14;
bit z17 = t41 & y8;
bit tc1 = z15 ^ z16;
bit tc2 = z10 ^ tc1;
bit tc3 = z9 ^ tc2;
bit tc4 = z0 ^ z2;
bit tc5 = z1 ^ z0;
bit tc6 = z3 ^ z4;
bit tc7 = z12 ^ tc4;
bit tc8 = z7 ^ tc6;
bit tc9 = z8 ^ tc7;
bit tc10 = tc8 ^ tc9;
bit tc11 = tc6 ^ tc5;
bit tc12 = z3 ^ z5;
bit tc13 = z13 ^ tc1;
bit tc14 = tc4 ^ tc12;
bit S3 = tc3 ^ tc11;
bit tc16 = z6 ^ tc8;
bit tc17 = z14 ^ tc10;
bit tc18 = tc13 ^ tc14;
bit S7 = z12.xnor(tc18);
bit tc20 = z15 ^ tc16;
bit tc21 = tc2 ^ z11;
bit S0 = tc3 ^ tc16;
bit S6 = tc10.xnor(tc18);
bit S4 = tc14 ^ S3;
bit S1 = S3.xnor(tc16);
bit tc26 = tc17 ^ tc20;
bit S2 = tc26.xnor(z17);
bit S5 = tc21 ^ tc17;
byte result;
result.push_back(S7);
result.push_back(S6);
result.push_back(S5);
result.push_back(S4);
result.push_back(S3);
result.push_back(S2);
result.push_back(S1);
result.push_back(S0);
return result;
}
static vector<bit> initialroundconstant{1,0,0,0,0,0,0,0};
static vector<bit> encrypt(const vector<bit> &in,const vector<bit> &k)
{
vector<vector<byte>> expanded(4,vector<byte> (44));
vector<vector<byte>> state(4,vector<byte> (4));
vector<vector<byte>> newstate(4,vector<byte> (4));
byte roundconstant;
bigint i;
bigint j;
bigint r;
bigint bitpos;
for (j = 0;j < 4;++j)
for (i = 0;i < 4;++i)
for (bitpos = 0;bitpos < 8;++bitpos)
expanded.at(i).at(j).push_back(k.at((j*4+i)*8+bitpos));
roundconstant = initialroundconstant;
for (j = 4;j < 44;++j) {
vector<byte> temp(4);
if (j % 4)
for (i = 0;i < 4;++i) temp.at(i) = expanded.at(i).at(j - 1);
else {
for (i = 0;i < 4;++i) temp.at(i) = byte_sub(expanded.at((i + 1) % 4).at(j - 1));
temp.at(0) = byte_xor(temp.at(0),roundconstant);
roundconstant = xtime(roundconstant);
}
for (i = 0;i < 4;++i)
expanded.at(i).at(j) = byte_xor(temp.at(i),expanded.at(i).at(j - 4));
}
for (j = 0;j < 4;++j)
for (i = 0;i < 4;++i)
for (bitpos = 0;bitpos < 8;++bitpos)
state.at(i).at(j).push_back(in.at((j*4+i)*8+bitpos));
for (j = 0;j < 4;++j)
for (i = 0;i < 4;++i)
state.at(i).at(j) = byte_xor(state.at(i).at(j),expanded.at(i).at(j));
for (r = 0;r < 10;++r) {
for (i = 0;i < 4;++i)
for (j = 0;j < 4;++j)
newstate.at(i).at(j) = byte_sub(state.at(i).at(j));
for (i = 0;i < 4;++i)
for (j = 0;j < 4;++j)
state.at(i).at(j) = newstate.at(i).at((j + i) % 4);
if (r < 9)
for (j = 0;j < 4;++j) {
byte a0 = state.at(0).at(j);
byte a1 = state.at(1).at(j);
byte a2 = state.at(2).at(j);
byte a3 = state.at(3).at(j);
byte a01 = byte_xor(a0,a1);
byte a12 = byte_xor(a1,a2);
byte a23 = byte_xor(a2,a3);
byte a30 = byte_xor(a3,a0);
state.at(0).at(j) = byte_xor(xtime(a01),byte_xor(a1,a23));
state.at(1).at(j) = byte_xor(xtime(a12),byte_xor(a2,a30));
state.at(2).at(j) = byte_xor(xtime(a23),byte_xor(a3,a01));
state.at(3).at(j) = byte_xor(xtime(a30),byte_xor(a0,a12));
}
for (i = 0;i < 4;++i)
for (j = 0;j < 4;++j)
state.at(i).at(j) = byte_xor(state.at(i).at(j),expanded.at(i).at(r * 4 + 4 + j));
}
vector<bit> result;
for (j = 0;j < 4;++j)
for (i = 0;i < 4;++i)
for (bitpos = 0;bitpos < 8;++bitpos)
result.push_back(state.at(i).at(j).at(bitpos));
return result;
}
vector<bit> aes128_enum(
const vector<bit> &bits,
const vector<bigint> ¶ms,
const vector<bigint> &attackparams
)
{
bigint K = params.at(0);
bigint C = params.at(1);
bigint pos = 0;
bigint I = attackparams.at(pos++);
bigint QX = attackparams.at(pos++);
bigint QUEUE_SIZE = attackparams.at(pos++);
bigint QF = attackparams.at(pos++); auto PERIOD = QF*QUEUE_SIZE;
vector<bit> plaintext0(128);
vector<bit> ciphertext0(C);
vector<bit> plaintext1(128);
vector<bit> ciphertext1(C);
{
bigint pos = 0;
for (bigint j = 0;j < 128;++j) plaintext0.at(j) = bits.at(pos++);
for (bigint j = 0;j < C;++j) ciphertext0.at(j) = bits.at(pos++);
for (bigint j = 0;j < 128;++j) plaintext1.at(j) = bits.at(pos++);
for (bigint j = 0;j < C;++j) ciphertext1.at(j) = bits.at(pos++);
assert(pos == bits.size());
}
vector<bit> result(K,bit(1));
// note that queue is not used if QX == 0
vector<bit> queue_valid(QUEUE_SIZE);
vector<vector<bit>> queue(QUEUE_SIZE,vector<bit>(K));
bigint timer = 0;
vector<bit> guess(128);
for (bigint iter = 0;iter < I;++iter) {
auto guessct0 = encrypt(plaintext0,guess);
bit mismatch;
for (bigint j = 0;j < C;++j)
mismatch |= guessct0.at(j)^ciphertext0.at(j);
if (QX == 0) {
auto guessct1 = encrypt(plaintext1,guess);
for (bigint j = 0;j < C;++j)
mismatch |= guessct1.at(j)^ciphertext1.at(j);
for (bigint j = 0;j < K;++j)
result.at(j) = mismatch.mux(guess.at(j),result.at(j));
} else {
vector<bit> guessprefix(K);
for (bigint j = 0;j < K;++j)
guessprefix.at(j) = guess.at(j);
bit match = ~mismatch;
bit_queue1_insert(queue_valid,match);
bit_vector_queue_insert(queue,guessprefix,match);
++timer;
if (timer == PERIOD || iter == I-1) {
timer = 0;
for (bigint q = 0;q < QUEUE_SIZE;++q) {
vector<bit> queueguess(128);
for (bigint j = 0;j < K;++j)
queueguess.at(j) = queue.at(q).at(j);
auto guessct1 = encrypt(plaintext1,queueguess);
bit mismatch1 = ~queue_valid.at(q);
for (bigint j = 0;j < C;++j)
mismatch1 |= guessct1.at(j)^ciphertext1.at(j);
for (bigint j = 0;j < K;++j)
result.at(j) = mismatch1.mux(queueguess.at(j),result.at(j));
queue_valid.at(q) = bit(0);
}
}
}
bit incrementing(1);
for (bigint j = 0;j < K;++j) {
bit old = guess.at(j);
guess.at(j) = old^incrementing;
incrementing &= old;
}
}
return result;
}