-rw-r--r-- 4243 cryptattacktester-20230614/bit_matrix_cost.cpp raw
#include <cassert>
#include <map>
#include "bit_cost.h"
#include "bit_vector_cost.h"
#include "bit_matrix_cost.h"
#include "sorting_cost.h"
#include "ram_cost.h"
using namespace std;
static map<tuple<bigint,bigint,bigint>,bigint> bit_matrix_reduced_echelon_cost_cache;
bigint bit_matrix_reduced_echelon_cost(bigint rows,bigint columns,bigint bound)
{
assert(rows <= bound);
assert(bound <= columns);
tuple<bigint,bigint,bigint> key = make_tuple(rows,columns,bound);
if (bit_matrix_reduced_echelon_cost_cache.count(key) > 0)
return bit_matrix_reduced_echelon_cost_cache[key];
bigint result = 0;
for (bigint r = 0;r < rows;++r) {
// m.size() is columns
// m.at(0).size() is rows
// v.size() is bound-r
bigint idxbits = nbits(bound-r);
result += (bound-r)*(rows-1-r); // bit_vector_or_bits(bit_vector_extract(m.at(r+c), r, nrows))
result += bit_vector_first_one_cost(bound-r); // bit_vector_first_one(v);
result += (rows-r-1)*ram_read_cost(bound-r,idxbits,1); // t = ram_read(m, r, bound, idx, idx.size(), r);
result += (rows-r-1)*(columns-r)*2; // m.at(c).at(r) ^= m.at(c).at(i).andn(t);
result += ram_read_cost(bound-r,idxbits,rows); // vector<bit> u = ram_read(m, r, bound, idx, idx.size());
result += (rows-1)*(columns-r)*2; // m.at(c).at(i) ^= (u.at(i) & m.at(c).at(r));
}
result += rows-1; // bit_vector_and_bits(pivot_bits);
bit_matrix_reduced_echelon_cost_cache[key] = result;
return result;
}
bigint column_swaps_cost(bigint N,bigint K,bigint L,bigint X,bigint Y)
{
bigint R = N - K;
bigint KK = K + L;
bigint RR = N - KK;
bigint result = 0;
result += bit_matrix_reduced_echelon_cost(X,KK+X+1,Y); // bit success = bit_matrix_reduced_echelon(pivots, m, Y);
for (bigint i = 0;i < X;++i) {
// index size: vector<bit> v(bound - r); ... vector<bit> idx(nbits(v.size()-1));
// where bound is Y
result += ram_read_write_cost(Y-i,nbits(Y-i-1),N-K); // ram_read_write(H, i, Y, pivots.at(i), H.at(KK + i));
result += ram_read_write_cost(Y-i,nbits(Y-i-1),nbits(N-1)); // ram_read_write(column_map, i, Y, pivots.at(i), column_map.at(KK + i));
}
result += X*KK*(R-X)*2; // H.at(i).at(j) ^= H.at(i).at(x + L) & H.at(x + KK).at(j);
result += X*(R-X)*2; // s.at(j) ^= s.at(x + L) & H.at(x + KK).at(j);
return result;
}
bigint bit_matrix_column_randompermutation_cost(bigint N,bigint K)
{
bigint result = 0;
result += bit_matrix_reduced_echelon_cost(N-K,N+1,N); // bit success = bit_matrix_reduced_echelon(pivots, m, Y);
for (bigint i = 0;i < N-K;++i) {
result += ram_read_write_cost(N-i,nbits(N-i-1),N-K); // ram_read_write(H, i, cols, pivots.at(i), H.at(i));
result += ram_read_write_cost(N-i,nbits(N-i-1),nbits(N-1)); // column_map.at(i) = ram_read_write(column_map, i, cols, pivots.at(i), column_map.at(i));
}
return result;
}
bigint bit_matrix_vector_mul_cost(bigint rows,bigint cols)
{
bigint result = rows*cols*bit_and_cost; // w.at(j) = m.at(i).at(j) & v.at(i);
result += rows*(cols-1)*bit_xor_cost; // bit_vector_ixor(ret, w);
return result;
}
// sum of p columns
bigint bit_matrix_sum_of_cols_straightforward_cost(bigint rows,bigint cols,bigint p)
{
bigint result = 0;
assert(p > 0);
result += p*ram_read_cost(cols,nbits(cols-1),rows);
result += (p-1)*rows; // ixor
return result;
}
bigint indices_to_vector_cost(bigint cols,bigint p)
{
bigint idx_size = cols ? nbits(cols-1) : bigint(0);
bigint result = 0;
for (bigint i = 0;i < p;++i) {
result += idx_size*5; // 5 full adders
// XXX: can speed up circuit for subtracting p-1-i
}
result += sorting_cost(cols+p,idx_size+1); // sorting(L)
return result;
}
bigint bit_matrix_sum_of_cols_viasorting_cost(bigint rows,bigint cols,bigint p)
{
bigint result = 0;
result = indices_to_vector_cost(cols,p);
result += bit_matrix_vector_mul_cost(rows,cols);
return result;
}
bigint bit_matrix_sum_of_cols_cost(bigint rows,bigint cols,bigint p)
{
bigint cost_viasorting = bit_matrix_sum_of_cols_viasorting_cost(rows,cols,p);
bigint cost_straightforward = bit_matrix_sum_of_cols_straightforward_cost(rows,cols,p);
if (cost_straightforward < cost_viasorting)
return cost_straightforward;
return cost_viasorting;
}