-rw-r--r-- 5160 cryptattacktester-20230614/bit_matrix.h raw
#ifndef BIT_MATRIX_H
#define BIT_MATRIX_H
#include "bigint.h"
#include "bit.h"
#include "ram.h"
#include "util.h"
#include "random.h"
#include "bit_vector.h"
using namespace std;
vector<vector<bit>> bit_matrix_transpose_and_identity(const vector<vector<bit>> &);
bit bit_matrix_column_randompermutation(vector<bit> &,vector<vector<bit>> &,vector<vector<bit>> &);
static inline const vector<vector<bit>> bit_matrix(bigint n, bigint m)
{
return vector<vector<bit>> (n, vector<bit>(m));
}
static inline void bit_matrix_clear(vector<vector<bit>> &m)
{
for (bigint i = 0; i < m.size(); i++)
bit_vector_clear(m.at(i));
}
static inline void bit_matrix_mux(vector<vector<bit>> &dest,
vector<vector<bit>> &src,
bit b)
{
for (bigint i = 0; i < dest.size(); i++)
for (bigint j = 0; j < dest.at(0).size(); j++)
dest.at(i).at(j) = b.mux(dest.at(i).at(j), src.at(i).at(j));
}
static inline void bit_matrix_mux(vector<vector<bit>> &dest,
vector<vector<bit>> &src0,
vector<vector<bit>> &src1,
bit b)
{
for (bigint i = 0; i < dest.size(); i++)
for (bigint j = 0; j < dest.at(0).size(); j++)
dest.at(i).at(j) = b.mux(src0.at(i).at(j), src1.at(i).at(j));
}
static inline void bit_matrix_queue_insert(vector<vector<vector<bit>>> &q,
vector<vector<bit>> &m,
bit b)
{
bigint i = q.size();
i -= 2;
for (; i >= 0; i--)
bit_matrix_mux(q.at(i+1), q.at(i), b);
bit_matrix_mux(q.at(0), m, b);
}
static inline bit bit_matrix_reduced_echelon(vector<vector<bit>> &pivots, vector<vector<bit>> &m, bigint bound)
{
bit t;
vector<bit> pivot_bits(0);
bigint ncols = m.size();
bigint nrows = m.at(0).size();
for (bigint r = 0; r < min(nrows, ncols); r++)
{
vector<bit> v(bound - r);
for (bigint c = 0; c < bound - r; c++)
v.at(c) = bit_vector_or_bits(bit_vector_extract(m.at(r+c), r, nrows));
vector<bit> idx = bit_vector_first_one(v);
pivots.push_back(idx);
for (bigint i = r+1; i < nrows; i++)
{
t = ram_read(m, r, bound, idx, idx.size(), r);
for (bigint c = r; c < ncols; c++)
m.at(c).at(r) ^= m.at(c).at(i).andn(t);
}
vector<bit> u = ram_read(m, r, bound, idx, idx.size());
pivot_bits.push_back(u.at(r));
for (bigint i = 0; i < nrows; i++)
{
if (i == r) continue;
for (bigint c = r; c < ncols; c++)
m.at(c).at(i) ^= (u.at(i) & m.at(c).at(r));
}
}
return bit_vector_and_bits(pivot_bits);
}
static inline vector<bit> bit_matrix_vector_mul(vector<vector<bit>> &m, vector<bit> &v, bool flip = 0)
{
assert(v.size() == m.size());
vector<bit> ret(m.at(0).size());
for (bigint i = 0; i < v.size(); i++)
{
vector<bit> w(m.at(0).size());
bit vi = v.at(i);
for (bigint j = 0;j < w.size();++j)
if (flip)
w.at(j) = m.at(i).at(j).andn(vi);
else
w.at(j) = m.at(i).at(j) & vi;
if (i == 0)
ret = w;
else
bit_vector_ixor(ret, w);
}
return ret;
}
static inline vector<vector<bit>> bit_matrix_transpose(vector<vector<bit>> &m)
{
vector<vector<bit>> ret(m.at(0).size(), vector<bit> (m.size()));
for (bigint i = 0; i < ret.size(); i++)
for (bigint j = 0; j < ret.at(0).size(); j++)
ret.at(i).at(j) = m.at(j).at(i);
return ret;
}
static inline void bit_matrix_cswap(const bit c, vector<vector<bit>> &m0, vector<vector<bit>> &m1)
{
assert(m0.size() == m1.size());
assert(m0[0].size() == m1[0].size());
for (bigint i = 0; i < m0.size(); i++)
for (bigint j = 0; j < m0[0].size(); j++)
c.cswap(m0.at(i).at(j), m1.at(i).at(j));
}
static inline void bit_matrix_randomize_rows(vector<vector<bit>> &m, vector<bit> &s, bigint L)
{
for (bigint r = 0; r < L; ++r)
for (bigint i = 0; i < m.at(0).size(); ++i)
{
if (i == r) continue;
bit b = random_bool();
// XXX: can also do b as a per-circuit constant, saving about half of the operations
for (bigint c = 0; c < m.size(); ++c)
m.at(c).at(i) ^= b & m.at(c).at(r);
s.at(i) ^= b & s.at(r);
}
}
vector<bit> bit_matrix_sum_of_cols_straightforward(vector<vector<bit>> &, vector<vector<bit>> &);
vector<bit> bit_matrix_sum_of_cols_viasorting(vector<vector<bit>> &, vector<vector<bit>> &);
vector<bit> bit_matrix_sum_of_cols(vector<vector<bit>> &, vector<vector<bit>> &);
#endif