-rw-r--r-- 5683 cryptattacktester-20231020/isd0.cpp raw
#include <cassert>
#include <vector>
#include "decoding.h"
#include "bit.h"
#include "ram.h"
#include "util.h"
#include "subset.h"
#include "bit_vector.h"
#include "index.h"
#include "bit_matrix.h"
#include "bit_cube.h"
#include "column_swaps.h"
#include "parity.h"
#include "isd0.h"
using namespace std;
vector<bit> isd0(
  const vector<bit> &bits, 
  const vector<bigint> ¶ms,
  const vector<bigint> &attackparams
)
{
	bigint N = params.at(0);
	bigint K_orig = params.at(1);
	bigint T = params.at(2);
	bigint pos = 0;
	bigint ITERS = attackparams.at(pos++);
	bigint RESET = attackparams.at(pos++);
	bigint X = attackparams.at(pos++);
	bigint YX = attackparams.at(pos++); auto Y = X+YX;
	bigint P = attackparams.at(pos++);
	bigint L = attackparams.at(pos++);
	bigint Z = attackparams.at(pos++);
	bigint QUEUE_SIZE = attackparams.at(pos++);
	bigint QF = attackparams.at(pos++); auto PERIOD = QF*QUEUE_SIZE;
	bigint FW = attackparams.at(pos++);
        auto inputs = decoding_deserialize(bits,params);
        auto pk = inputs.first;
        auto s = inputs.second;
	vector<vector<bit>> H = bit_matrix_transpose_and_identity(pk);
	vector<vector<bit>> column_map;
	for (bigint i = 0; i < N; i++)
		column_map.push_back(bit_vector_from_integer(i, nbits(N-1)));
	bit alwayssystematic = 1;
	bigint K = K_orig;
	if (FW) {
	  alwayssystematic = parity_known(s,H,column_map,bit(T.bit(0)));
	  K -= 1;
	}
	vector<vector<bit>> initial_H = H;
	vector<bit> initial_s = s;
	vector<vector<bit>> initial_column_map = column_map;
	bit initial_alwayssystematic = alwayssystematic;
	bigint R = N - K;
	bigint KK = K + L;
	bigint RR = N - KK;
        const bigint idx_bits = nbits(KK-Z-1);
	vector<bit> s_ret(N-K-L);
        vector<vector<bit>> set_ret = bit_matrix(P, idx_bits);
        vector<vector<bit>> map_ret = bit_matrix(N, nbits(N-1));
	bigint untilreset = 0;
	for (bigint iter = 0; iter < ITERS; iter++) 
	{
		// if alwayssystematic: H.at(i).at(j) == (i-KK == j-L) for KK <= i < N, 0 <= j < R
		if (untilreset > 0) {
		  alwayssystematic &= column_swaps(s, H, column_map, N, K, L, X, Y);
		} else {
		  untilreset = RESET;
	          H = initial_H;
          	  s = initial_s;
          	  column_map = initial_column_map;
          	  if (iter == 0)
		    alwayssystematic = initial_alwayssystematic;
          	  else {
		    alwayssystematic = bit_matrix_column_randompermutation(s,H,column_map);
		    if (FW) alwayssystematic &= initial_alwayssystematic;
          	  }
                  bit_matrix_randomize_rows(H, s, L);
		}
		--untilreset;
		// partitioning s and H
		vector<vector<bit>> ss(0);
		ss.push_back(bit_vector_extract(s, 0, L));
		ss.push_back(bit_vector_extract(s, L, R));
		vector<vector<vector<bit>>> Hs(2); 
		for (bigint i = 0; i < KK-Z; i++)
		{
			Hs.at(0).push_back(bit_vector_extract(H.at(i), 0, L));
			Hs.at(1).push_back(bit_vector_extract(H.at(i), L, R));
		}
			
		// search for solution
		vector<vector<bit>> L_sum(0);
		vector<vector<vector<bit>>> L_set(0);
                subset(L_sum, L_set, KK-Z, P, idx_bits, ss.at(L == 0), Hs.at(L == 0));
		if (P > 0 and L > 0)
		{
                        vector<bit> queue_valid(QUEUE_SIZE);
                        vector<vector<vector<bit>>> queue_set = bit_cube(QUEUE_SIZE, P, idx_bits);
			bigint timer = 0;
			for (bigint i = 0; i < L_sum.size(); i++)	
			{
				timer = (timer + 1) % PERIOD;
				bit zero_check = bit_vector_iszero(L_sum.at(i));
                                bit_queue1_insert(queue_valid, zero_check);
				bit_matrix_queue_insert(queue_set, L_set.at(i), zero_check);
				if (i == L_sum.size()-1) // last set
					timer = 0;		
				if (timer == 0) // clear the queue periodically
				{	
					for (bigint j = 0; j < QUEUE_SIZE; j++)
					{
                                                vector<bit> sum = ss.at(1);
                                                bit_vector_ixor(sum, bit_matrix_sum_of_cols(Hs.at(1), queue_set.at(j)));
						bit weight_check = alwayssystematic.andn(bit_vector_hamming_weight_isnot(sum,T-P)) & queue_valid.at(j);
	
						bit_vector_mux(s_ret, sum, weight_check);
						bit_matrix_mux(set_ret, queue_set.at(j), weight_check);
						bit_matrix_mux(map_ret, column_map, weight_check);
 
						// clear the queue element
				
                                                queue_valid.at(j) = bit(0);
						bit_matrix_clear(queue_set.at(j));
					}
				}
			}
		}
		if (P > 0 and L == 0)
		{
			for (bigint i = 0; i < L_sum.size(); i++)	
			{
				bit weight_check = alwayssystematic.andn(bit_vector_hamming_weight_isnot(L_sum.at(i),T-P));
				bit_vector_mux(s_ret, L_sum.at(i), weight_check);
				bit_matrix_mux(set_ret, L_set.at(i), weight_check);
				bit_matrix_mux(map_ret, column_map, weight_check);
			}
		}
		if (P == 0)
		{
			bit zero_check = bit_vector_iszero(ss.at(0));
			bit weight_check = alwayssystematic.andn(bit_vector_hamming_weight_isnot(ss.at(1),T-P));
			bit check = weight_check;
			
			if (L > 0) check &= zero_check;
			bit_vector_mux(s_ret, ss.at(1), check);
			bit_matrix_mux(map_ret, column_map, check);
		}
	}
	vector<bit> e_ret(N);
	vector<bit> e(N);
	for (bigint i = 0; i < RR; i++)
		e.at(i + KK) = s_ret.at(i);
	if (P > 0)
		for (bigint i = 0; i < set_ret.size(); i++)
			ram_write(e, 0, KK-Z, set_ret.at(i), bit(1));
	for (bigint i = 0; i < N; i++)
		ram_write(e_ret, map_ret.at(i), e.at(i));
	// pk has identity implicitly on left, H has it on right
	// so change convention for output ordering
	vector<bit> e_ret_swap;
	for (bigint i = K_orig;i < N;++i)
	  e_ret_swap.push_back(e_ret.at(i));
	for (bigint i = 0;i < K_orig;++i)
	  e_ret_swap.push_back(e_ret.at(i));
	return e_ret_swap;
}