-rw-r--r-- 3186 cryptattacktester-20230614/bit_matrix.cpp raw
#include <cassert>
#include "ram.h"
#include "sorting.h"
#include "bit_vector.h"
#include "index.h"
#include "permutation.h"
#include "bit_matrix.h"
#include "bit_matrix_cost.h"
using namespace std;
// transposed input matrix on left, identity matrix on right
vector<vector<bit>> bit_matrix_transpose_and_identity(const vector<vector<bit>> &m)
{
bigint rows = m.size();
assert(rows > 0);
bigint cols = m.at(0).size();
assert(cols > 0);
bigint bigcols = cols+rows;
vector<vector<bit>> result;
for (bigint i = 0;i < bigcols;++i) {
vector<bit> column;
if (i < cols)
for (bigint j = 0;j < rows;++j)
column.push_back(m.at(j).at(i));
else
for (bigint j = 0;j < rows;++j)
column.push_back(bit(j == i-cols));
result.push_back(column);
}
return result;
}
vector<bit> bit_matrix_sum_of_cols_straightforward(vector<vector<bit>> &m, vector<vector<bit>> &indices)
{
vector<bit> result = ram_read(m,indices.at(0));
for (bigint j = 1;j < indices.size();++j)
bit_vector_ixor(result,ram_read(m,indices.at(j)));
return result;
}
vector<bit> bit_matrix_sum_of_cols_viasorting(vector<vector<bit>> &m, vector<vector<bit>> &indices)
{
vector<bit> e(0);
bigint idx_size = indices.at(0).size();
vector<vector<bit>> L;
for (bigint i = 0; i < indices.size(); i++)
{
bigint j = indices.size()-1-i;
vector<bit> v_j = bit_vector_from_integer(j, idx_size, 1);
vector<bit> idx = indices.at(i);
vector<bit> v(idx_size);
bit_vector_add(v, v_j, idx, bit(1));
v.insert(v.begin(), bit(0));
L.push_back(v);
}
for (bigint i = 0; i < m.size(); i++)
{
vector<bit> v = bit_vector_from_integer(i, idx_size);
v.insert(v.begin(), bit(1));
L.push_back(v);
}
sorting(L);
for (bigint i = 0; i < m.size(); i++)
e.push_back(L.at(i).at(0));
vector<bit> ret = bit_matrix_vector_mul(m, e, 1);
return ret;
}
vector<bit> bit_matrix_sum_of_cols(vector<vector<bit>> &m, vector<vector<bit>> &indices)
{
bigint cols = m.size();
assert(cols > 0);
bigint rows = m.at(0).size();
bigint p = indices.size();
bigint cost_viasorting = bit_matrix_sum_of_cols_viasorting_cost(rows,cols,p);
bigint cost_straightforward = bit_matrix_sum_of_cols_straightforward_cost(rows,cols,p);
if (cost_straightforward < cost_viasorting)
return bit_matrix_sum_of_cols_straightforward(m,indices);
return bit_matrix_sum_of_cols_viasorting(m,indices);
}
bit bit_matrix_column_randompermutation(
vector<bit> &s,
vector<vector<bit>> &H,
vector<vector<bit>> &column_map
)
{
bigint cols = H.size();
assert(cols > 0);
bigint rows = H.at(0).size();
assert(s.size() == rows);
assert(column_map.size() == cols);
permutation pi(cols);
vector<vector<bit>> pivots;
H.push_back(s);
pi.permute(H);
pi.permute(column_map);
bit success = bit_matrix_reduced_echelon(pivots,H,cols);
s = H.back(); H.pop_back();
for (bigint i = 0;i < rows;++i) {
H.at(i) = ram_read_write(H, i, cols, pivots.at(i), H.at(i));
column_map.at(i) = ram_read_write(column_map, i, cols, pivots.at(i), column_map.at(i));
}
permutation pi_N(cols,rows);
pi_N.permute(H);
pi_N.permute(column_map);
return success;
}