-rw-r--r-- 6554 cryptattacktester-20231020/isd2_cost.cpp raw
#include <cassert>
#include "ram_cost.h"
#include "bit_cost.h"
#include "bit_vector_cost.h"
#include "bit_matrix_cost.h"
#include "subset_cost.h"
#include "index_cost.h"
#include "sorting_cost.h"
#include "parity_cost.h"
#include "isd2_cost.h"
using namespace std;
bigint isd2_cost(const vector<bigint> ¶ms,const vector<bigint> &attackparams)
{
  bigint N = params.at(0);
  bigint K = params.at(1);
  bigint W = params.at(2);
  bigint pos = 0;
  bigint ITERS = attackparams.at(pos++);
  bigint RESET = attackparams.at(pos++);
  bigint X = attackparams.at(pos++);
  bigint YX = attackparams.at(pos++); auto Y = X+YX;
  bigint PIJ = attackparams.at(pos++);
  bigint PI = attackparams.at(pos++);
  bigint L0 = attackparams.at(pos++);
  bigint L1 = attackparams.at(pos++);
  bigint CHECKPI = attackparams.at(pos++);
  bigint CHECKSUM = attackparams.at(pos++);
  bigint D = attackparams.at(pos++);
  bigint Z = attackparams.at(pos++);
  bigint QU0 = attackparams.at(pos++);
  bigint QF0 = attackparams.at(pos++); auto PE0 = QF0*QU0;
  bigint WI0 = attackparams.at(pos++);
  bigint QU1 = attackparams.at(pos++);
  bigint QF1 = attackparams.at(pos++); auto PE1 = QF1*QU1;
  bigint WI1 = attackparams.at(pos++);
  bigint FW = attackparams.at(pos++);
  bigint fwcost = 0;
  if (FW) {
    fwcost = parity_known_cost(N,K);
    --K;
  }
  bigint L = L0+L1;
  bigint R = N - K;
  bigint KK = K + L;
  bigint RR = N - KK;
  bigint left = (KK-Z)/2;
  bigint right = KK-Z-left;
  bigint idx_bits = nbits(right-1);
  bigint result = 0;
  bigint listsize0 = binomial(left,PIJ);
  bigint listsize1 = binomial(right,PIJ);
  bigint listsize = listsize0+listsize1;
  result += 2*sorting_cost(listsize,L0+1,L1+PIJ*idx_bits); // sorting(L_01, L_sum, L_set, L0);
  WI0 = min(WI0,bigint(listsize-1));
  bigint pool = (2*listsize-WI0-1)*WI0/2;
  bigint persum = 0;
  persum += 1; // bit check = L_01.at(i) ^ L_01.at(i+offset);
  persum += 1+bit_vector_compare_cost(L0); // check = check.andn(bit_vector_compare(bit_vector_extract(L_sum.at(i+0), 0, L0), bit_vector_extract(L_sum.at(i+offset), 0, L0)));
  persum += bit_queue1_insert_cost(QU0); // bit_queue1_insert(queue_valid, check);
  persum += L1; // v = bit_vector_xor(v0, v1);
  persum += QU0*L1*bit_mux_cost; // bit_vector_queue_insert(queue_sum, v, check);
  persum += QU0*2*PIJ*idx_bits*bit_mux_cost; // bit_matrix_queue_insert(queue_set, set, check);
  result += 2*pool*persum;
  bigint queue_clears = (pool+PE0-1)/PE0;
  bigint rootlistsize = 2*queue_clears*QU0;
  WI1 = min(WI1,bigint(rootlistsize-1));
  result += sorting_cost(rootlistsize,L1+2,2*PIJ*idx_bits); // sorting(L_root_01, L_root_sum, L_root_set, L_root_valid);
  bigint rootpool = (2*rootlistsize-WI1-1)*WI1/2;
  bigint perrootsum = 0;
  perrootsum += 1; // check = L_root_01.at(i) ^ L_root_01.at(i+1);
  perrootsum += 2; // check &= L_root_valid.at(i) & L_root_valid.at(i+1);
  perrootsum += 1+bit_vector_compare_cost(L1); // check.andn(bit_vector_compare(L_root_sum.at(i+0), L_root_sum.at(i+offset)));
  if (CHECKPI)
    perrootsum += 2*(1+set_size_check_cost(2*PIJ,idx_bits,PI)); // check &= set_size_check(set_check, PI);
  perrootsum += bit_queue1_insert_cost(QU1); // bit_queue1_insert(queue_valid, check);
  perrootsum += QU1*4*PIJ*idx_bits*bit_mux_cost; // bit_matrix_queue_insert(queue_set, set, check);
  result += rootpool*perrootsum;
  bigint postrootqueue = 0;
  postrootqueue += 2*(R-L+bit_matrix_sum_of_cols_cost(R-L,left,PIJ)); // for b=0, b=2: bit_vector_ixor(sum, bit_matrix_sum_of_cols(Hs2.at(b & 1), set_p));
  postrootqueue += 2*(R-L+bit_matrix_sum_of_cols_cost(R-L,right,PIJ)); // for b=1, b=3: bit_vector_ixor(sum, bit_matrix_sum_of_cols(Hs2.at(b & 1), set_p));
  postrootqueue += bit_vector_hamming_weight_cost(R-L); // bit_vector_hamming_weight(sum);
  if (CHECKSUM == 0) {
    postrootqueue += 2+bit_vector_integer_compare_cost(nbits(R-L),nbits(W-PI*2)); // check_w &= queue_valid.at(j).andn(bit_vector_integer_compare(w_sum, tp4));
  } else {
    postrootqueue += 2*set_size_cost(2*PIJ,idx_bits); // weight_list.push_back(set_size(set));
    postrootqueue += bit_vector_add_cost(nbits(2*PIJ),nbits(2*PIJ)); // bit_vector_add(w_tmp, weight_list.at(0), weight_list.at(1));
    postrootqueue += bit_vector_add_cost(nbits(R-L),nbits(4*PIJ)); // bit_vector_add(w_final, weight_list.at(2), w_tmp);
    postrootqueue += 2+bit_vector_integer_compare_cost(nbits(R-L+4*PIJ),nbits(W)); // check_w &= queue_valid.at(j).andn(bit_vector_integer_compare(w_final, tp4));
  }
  postrootqueue += (R-L)*bit_mux_cost; // bit_vector_mux(s_ret, sum, check_w);
  postrootqueue += 4*PIJ*idx_bits*bit_mux_cost; // bit_matrix_mux(set_ret, queue_set.at(j), check_w);
  postrootqueue += N*nbits(N-1)*bit_mux_cost; // bit_matrix_mux(map_ret, column_map, check_w);
  bigint root_queue_clears = (rootpool+PE1-1)/PE1;
  result += root_queue_clears*QU1*postrootqueue;
  result *= D;
  result += subset_cost(left,PIJ,L); // subset(L0_sum,    L0_set, Hs01.at(0).size(), PIJ, idx_bits, zz, Hs01.at(0));
  result += subset_cost(right,PIJ,L); // subset(L1_sum[0], L1_set, Hs01.at(1).size(), PIJ, idx_bits, zz, Hs01.at(1));
  result += listsize1*L; // bit_vector_xor(L1_sum[0].at(i), s01));
  result += (D-1)*2*listsize1; // L1_sum[t].at(i).at(flip_idx) = ~L1_sum[t].at(i).at(flip_idx);
  bigint column_swaps = column_swaps_cost(N,K,L,X,Y); // column_swaps(s, H, column_map, N, K, L, X, Y);
  result += column_swaps;
  result += 1; // alwayssystematic &= swapssucceeded;
  result *= ITERS;
  bigint perresetexceptfirst = 0;
  perresetexceptfirst += bit_matrix_column_randompermutation_cost(N,K);
  if (FW) perresetexceptfirst += 1; // alwayssystematic &= initial_alwayssystematic;
  bigint perreset = perresetexceptfirst;
  perreset -= column_swaps; // skipped on reset
  perreset -= 1; // skipped on reset
  perreset += 2*L*(N-K-1)*(N+1); // bit_matrix_randomize_rows
  result += perreset*(ITERS/RESET);
  result -= perresetexceptfirst; // skipped on iter == 0
  result += indices_to_vector_cost(left,PIJ); // indices_to_vector(indices0, (KK-Z)/2);
  result += indices_to_vector_cost(left,PIJ); // indices_to_vector(indices1, (KK-Z)/2);
  result += left; // bit_vector_xor(v0, v1);
  result += indices_to_vector_cost(right,PIJ); // indices_to_vector(indices2, (KK-Z+1)/2);
  result += indices_to_vector_cost(right,PIJ); // indices_to_vector(indices3, (KK-Z+1)/2);
  result += right; // bit_vector_xor(v2, v3);
  result += N*ram_write_cost(N,nbits(N-1),1); // ram_write(e_ret, map_ret.at(i), e.at(i));
  result += fwcost;
  return result;
}