-rw-r--r-- 4243 cryptattacktester-20231020/bit_matrix_cost.cpp raw
#include <cassert>
#include <map>
#include "bit_cost.h"
#include "bit_vector_cost.h"
#include "bit_matrix_cost.h"
#include "sorting_cost.h"
#include "ram_cost.h"
using namespace std;
static map<tuple<bigint,bigint,bigint>,bigint> bit_matrix_reduced_echelon_cost_cache;
bigint bit_matrix_reduced_echelon_cost(bigint rows,bigint columns,bigint bound)
{
  assert(rows <= bound);
  assert(bound <= columns);
  tuple<bigint,bigint,bigint> key = make_tuple(rows,columns,bound);
  if (bit_matrix_reduced_echelon_cost_cache.count(key) > 0)
    return bit_matrix_reduced_echelon_cost_cache[key];
  bigint result = 0;
  for (bigint r = 0;r < rows;++r) {
    // m.size() is columns
    // m.at(0).size() is rows
    // v.size() is bound-r
    bigint idxbits = nbits(bound-r);
    result += (bound-r)*(rows-1-r); // bit_vector_or_bits(bit_vector_extract(m.at(r+c), r, nrows))
    result += bit_vector_first_one_cost(bound-r); // bit_vector_first_one(v);
    result += (rows-r-1)*ram_read_cost(bound-r,idxbits,1); // t = ram_read(m, r, bound, idx, idx.size(), r);
    result += (rows-r-1)*(columns-r)*2; // m.at(c).at(r) ^= m.at(c).at(i).andn(t);
    result += ram_read_cost(bound-r,idxbits,rows); // vector<bit> u = ram_read(m, r, bound, idx, idx.size());
    result += (rows-1)*(columns-r)*2; // m.at(c).at(i) ^= (u.at(i) & m.at(c).at(r));
  }
  result += rows-1; // bit_vector_and_bits(pivot_bits);
  bit_matrix_reduced_echelon_cost_cache[key] = result;
  return result;
}
bigint column_swaps_cost(bigint N,bigint K,bigint L,bigint X,bigint Y)
{
  bigint R = N - K;
  bigint KK = K + L;
  bigint RR = N - KK;
  bigint result = 0;
  result += bit_matrix_reduced_echelon_cost(X,KK+X+1,Y); // bit success = bit_matrix_reduced_echelon(pivots, m, Y);
  for (bigint i = 0;i < X;++i) {
    // index size: vector<bit> v(bound - r); ... vector<bit> idx(nbits(v.size()-1));
    // where bound is Y
    result += ram_read_write_cost(Y-i,nbits(Y-i-1),N-K); // ram_read_write(H,    i, Y, pivots.at(i), H.at(KK + i));
    result += ram_read_write_cost(Y-i,nbits(Y-i-1),nbits(N-1)); // ram_read_write(column_map, i, Y, pivots.at(i), column_map.at(KK + i));
  }
  result += X*KK*(R-X)*2; // H.at(i).at(j) ^= H.at(i).at(x + L) & H.at(x + KK).at(j);
  result += X*(R-X)*2; // s.at(j) ^= s.at(x + L) & H.at(x + KK).at(j);
  return result;
}
bigint bit_matrix_column_randompermutation_cost(bigint N,bigint K)
{
  bigint result = 0;
  result += bit_matrix_reduced_echelon_cost(N-K,N+1,N); // bit success = bit_matrix_reduced_echelon(pivots, m, Y);
  for (bigint i = 0;i < N-K;++i) {
    result += ram_read_write_cost(N-i,nbits(N-i-1),N-K); // ram_read_write(H, i, cols, pivots.at(i), H.at(i));
    result += ram_read_write_cost(N-i,nbits(N-i-1),nbits(N-1)); // column_map.at(i) = ram_read_write(column_map, i, cols, pivots.at(i), column_map.at(i));
  }
  return result;
}
bigint bit_matrix_vector_mul_cost(bigint rows,bigint cols)
{
  bigint result = rows*cols*bit_and_cost; // w.at(j) = m.at(i).at(j) & v.at(i);
  result += rows*(cols-1)*bit_xor_cost; // bit_vector_ixor(ret, w);
  return result;
}
// sum of p columns
bigint bit_matrix_sum_of_cols_straightforward_cost(bigint rows,bigint cols,bigint p)
{
  bigint result = 0;
  assert(p > 0);
  result += p*ram_read_cost(cols,nbits(cols-1),rows);
  result += (p-1)*rows; // ixor
  return result;
}
bigint indices_to_vector_cost(bigint cols,bigint p)
{
  bigint idx_size = cols ? nbits(cols-1) : bigint(0);
  bigint result = 0;
  for (bigint i = 0;i < p;++i) {
    result += idx_size*5; // 5 full adders
    // XXX: can speed up circuit for subtracting p-1-i
  }
  result += sorting_cost(cols+p,idx_size+1); // sorting(L)
  return result;
}
bigint bit_matrix_sum_of_cols_viasorting_cost(bigint rows,bigint cols,bigint p)
{
  bigint result = 0;
  result = indices_to_vector_cost(cols,p);
  result += bit_matrix_vector_mul_cost(rows,cols);
  return result;
}
bigint bit_matrix_sum_of_cols_cost(bigint rows,bigint cols,bigint p)
{
  bigint cost_viasorting = bit_matrix_sum_of_cols_viasorting_cost(rows,cols,p);
  bigint cost_straightforward = bit_matrix_sum_of_cols_straightforward_cost(rows,cols,p);
  if (cost_straightforward < cost_viasorting)
    return cost_straightforward;
  return cost_viasorting;
}