-rw-r--r-- 5160 cryptattacktester-20231020/bit_matrix.h raw
#ifndef BIT_MATRIX_H
#define BIT_MATRIX_H
#include "bigint.h"
#include "bit.h"
#include "ram.h"
#include "util.h"
#include "random.h"
#include "bit_vector.h"
using namespace std;
vector<vector<bit>> bit_matrix_transpose_and_identity(const vector<vector<bit>> &);
bit bit_matrix_column_randompermutation(vector<bit> &,vector<vector<bit>> &,vector<vector<bit>> &);
static inline const vector<vector<bit>> bit_matrix(bigint n, bigint m)
{
       return vector<vector<bit>> (n, vector<bit>(m));
}
static inline void bit_matrix_clear(vector<vector<bit>> &m)
{
	for (bigint i = 0; i < m.size(); i++)
		bit_vector_clear(m.at(i));
}
static inline void bit_matrix_mux(vector<vector<bit>> &dest, 
                             vector<vector<bit>> &src,
                             bit b)
{
	for (bigint i = 0; i < dest.size(); i++)
		for (bigint j = 0; j < dest.at(0).size(); j++)
			dest.at(i).at(j) = b.mux(dest.at(i).at(j), src.at(i).at(j));
}
static inline void bit_matrix_mux(vector<vector<bit>> &dest, 
                             vector<vector<bit>> &src0, 
                             vector<vector<bit>> &src1,
                             bit b)
{
	for (bigint i = 0; i < dest.size(); i++)
		for (bigint j = 0; j < dest.at(0).size(); j++)
			dest.at(i).at(j) = b.mux(src0.at(i).at(j), src1.at(i).at(j));
}
static inline void bit_matrix_queue_insert(vector<vector<vector<bit>>> &q, 
                                           vector<vector<bit>> &m,
                                           bit b)
{
	bigint i = q.size();
	i -= 2;
	for (; i >= 0; i--)
		bit_matrix_mux(q.at(i+1), q.at(i), b);
	bit_matrix_mux(q.at(0), m, b);
}
static inline bit bit_matrix_reduced_echelon(vector<vector<bit>> &pivots, vector<vector<bit>> &m, bigint bound)
{
        bit t;
        vector<bit> pivot_bits(0);
        bigint ncols = m.size();
        bigint nrows = m.at(0).size();
        for (bigint r = 0; r < min(nrows, ncols); r++)
        {
                vector<bit> v(bound - r);
                for (bigint c = 0; c < bound - r; c++)
                        v.at(c) = bit_vector_or_bits(bit_vector_extract(m.at(r+c), r, nrows));
                vector<bit> idx = bit_vector_first_one(v);
                pivots.push_back(idx);
                for (bigint i = r+1; i < nrows; i++)
                {
                        t = ram_read(m, r, bound, idx, idx.size(), r);
                        for (bigint c = r; c < ncols; c++)
                                m.at(c).at(r) ^= m.at(c).at(i).andn(t);
                }
                vector<bit> u = ram_read(m, r, bound, idx, idx.size());
                pivot_bits.push_back(u.at(r));
                for (bigint i = 0; i < nrows; i++)
                {
                        if (i == r) continue;
                        for (bigint c = r; c < ncols; c++)
                                m.at(c).at(i) ^= (u.at(i) & m.at(c).at(r));
                }
        }
        return bit_vector_and_bits(pivot_bits);
}
static inline vector<bit> bit_matrix_vector_mul(vector<vector<bit>> &m, vector<bit> &v, bool flip = 0)
{
        assert(v.size() == m.size());
        vector<bit> ret(m.at(0).size());
        for (bigint i = 0; i < v.size(); i++)
        {
                vector<bit> w(m.at(0).size());
                bit vi = v.at(i);
                for (bigint j = 0;j < w.size();++j)
                  if (flip)
                    w.at(j) = m.at(i).at(j).andn(vi);
                  else
                    w.at(j) = m.at(i).at(j) & vi;
                if (i == 0)
                  ret = w;
                else
                  bit_vector_ixor(ret, w);
        }
        return ret;
}
static inline vector<vector<bit>> bit_matrix_transpose(vector<vector<bit>> &m)
{
	vector<vector<bit>> ret(m.at(0).size(), vector<bit> (m.size()));
	for (bigint i = 0; i < ret.size(); i++)
		for (bigint j = 0; j < ret.at(0).size(); j++)
			ret.at(i).at(j) = m.at(j).at(i);
	return ret;
}
static inline void bit_matrix_cswap(const bit c, vector<vector<bit>> &m0, vector<vector<bit>> &m1)
{
	assert(m0.size() == m1.size());
	assert(m0[0].size() == m1[0].size());
	for (bigint i = 0; i < m0.size(); i++)
	for (bigint j = 0; j < m0[0].size(); j++)
		c.cswap(m0.at(i).at(j), m1.at(i).at(j));
}
static inline void bit_matrix_randomize_rows(vector<vector<bit>> &m, vector<bit> &s, bigint L)
{
        for (bigint r = 0; r < L; ++r)
                for (bigint i = 0; i < m.at(0).size(); ++i)
                {
                        if (i == r) continue;
                        bit b = random_bool();
                        // XXX: can also do b as a per-circuit constant, saving about half of the operations
                        for (bigint c = 0; c < m.size(); ++c)
                                m.at(c).at(i) ^= b & m.at(c).at(r);
                        s.at(i) ^= b & s.at(r);
                }
}
vector<bit> bit_matrix_sum_of_cols_straightforward(vector<vector<bit>> &, vector<vector<bit>> &);
vector<bit> bit_matrix_sum_of_cols_viasorting(vector<vector<bit>> &, vector<vector<bit>> &);
vector<bit> bit_matrix_sum_of_cols(vector<vector<bit>> &, vector<vector<bit>> &);
#endif