-rw-r--r-- 6784 cryptattacktester-20231020/isd1.cpp raw
#include <cassert>
#include <vector>
#include "decoding.h"
#include "permutation.h"
#include "bit.h"
#include "ram.h"
#include "util.h"
#include "sorting.h"
#include "subset.h"
#include "bit_vector.h"
#include "index.h"
#include "bit_matrix.h"
#include "bit_cube.h"
#include "column_swaps.h"
#include "parity.h"
#include "isd1.h"
using namespace std;
template<class AT,class BT>
static void shuffle(vector<AT> &A,vector<BT> &B)
{
  bigint n = A.size();
  assert(n == B.size());
  permutation pi(n);
  pi.permute(A);
  pi.permute(B);
}
vector<bit> isd1(
  const vector<bit> &bits,
  const vector<bigint> ¶ms,
  const vector<bigint> &attackparams
)
{
        bigint N = params.at(0);
        bigint K_orig = params.at(1);
        bigint T = params.at(2);
        bigint pos = 0;
        bigint ITERS = attackparams.at(pos++);
        bigint RESET = attackparams.at(pos++);
        bigint X = attackparams.at(pos++);
        bigint YX = attackparams.at(pos++); auto Y = X+YX;
        bigint PI = attackparams.at(pos++);
        bigint L = attackparams.at(pos++);
        bigint Z = attackparams.at(pos++);
        bigint QUEUE_SIZE = attackparams.at(pos++);
        bigint QF = attackparams.at(pos++); auto PERIOD = QF*QUEUE_SIZE;
        bigint WINDOW = attackparams.at(pos++);
        bigint FW = attackparams.at(pos++);
        auto inputs = decoding_deserialize(bits,params);
        auto pk = inputs.first;
        auto s = inputs.second;
        vector<vector<bit>> H = bit_matrix_transpose_and_identity(pk);
        vector<vector<bit>> column_map;
        for (bigint i = 0; i < N; i++)
                column_map.push_back(bit_vector_from_integer(i, nbits(N-1)));
        bit alwayssystematic = 1;
        bigint K = K_orig;
        if (FW) {
          alwayssystematic = parity_known(s,H,column_map,bit(T.bit(0)));
          K -= 1;
        }
        vector<vector<bit>> initial_H = H;
        vector<bit> initial_s = s;
        vector<vector<bit>> initial_column_map = column_map;
        bit initial_alwayssystematic = alwayssystematic;
        bigint R = N - K;
        bigint KK = K + L;
        bigint RR = N - KK;
	const bigint idx_bits = nbits((KK-Z+1)/2-1);
	vector<bit> s_ret(N-K-L);
	vector<vector<bit>> set_ret = bit_matrix(PI*2, idx_bits);
	vector<vector<bit>> map_ret = bit_matrix(N, nbits(N-1));
	bigint untilreset = 0;
	for (bigint iter = 0; iter < ITERS; iter++) 
	{
		// if alwayssystematic: H.at(i).at(j) == (i-KK == j-L) for KK <= i < N, 0 <= j < R
                if (untilreset > 0) {
                  alwayssystematic &= column_swaps(s, H, column_map, N, K, L, X, Y);
                } else {
                  untilreset = RESET;
                  H = initial_H;
                  s = initial_s;
                  column_map = initial_column_map;
                  if (iter == 0)
                    alwayssystematic = initial_alwayssystematic;
                  else {
                    alwayssystematic = bit_matrix_column_randompermutation(s,H,column_map);
                    if (FW) alwayssystematic &= initial_alwayssystematic;
                  }
                  bit_matrix_randomize_rows(H, s, L);
                }
                --untilreset;
		// partitioning s and H
		vector<bit> s0 = bit_vector_extract(s, 0, L);
		vector<bit> s1 = bit_vector_extract(s, L, R);
		vector<vector<vector<bit>>> Hs0(2); 
		vector<vector<vector<bit>>> Hs1(2); 
		for (bigint i = 0; i < KK-Z; i++)
		{
			Hs0.at( (i < (KK-Z)/2) ? 0 : 1 ).push_back(bit_vector_extract(H.at(i), 0, L));
			Hs1.at( (i < (KK-Z)/2) ? 0 : 1 ).push_back(bit_vector_extract(H.at(i), L, R));
		}
			
		// search for solution
		vector<bit> L_01(0);
		vector<vector<bit>> L_sum(0);
		vector<vector<vector<bit>>> L_set(0);
		vector<bit> zz(L);
		subset(L_sum, L_set, Hs0.at(0).size(), PI, idx_bits, zz, Hs0.at(0));
		
		for (bigint i = 0; i < L_sum.size(); i++)
			L_01.push_back(bit(0));
		subset(L_sum, L_set, Hs0.at(1).size(), PI, idx_bits, s0, Hs0.at(1));
		for (bigint i = L_01.size(); i < L_sum.size(); i++)
			L_01.push_back(bit(1));
		
		sorting(L_01, L_sum, L_set);
		vector<bit> todo_check;
		vector<vector<vector<bit>>> todo_set;
		for (bigint i = 0; i < L_sum.size()-1; i++) {
			for (bigint offset = 1;offset <= WINDOW;++offset) {
				if (i+offset >= L_sum.size()) continue;
				bit check = L_01.at(i) ^ L_01.at(i+offset);
				check = check.andn(bit_vector_compare(L_sum.at(i), L_sum.at(i+offset)));
				vector<vector<bit>> set(0);
				for (bigint j = 0; j < PI; j++) set.push_back(L_set.at(i+0).at(j));
				for (bigint j = 0; j < PI; j++) set.push_back(L_set.at(i+offset).at(j));
				todo_check.push_back(check);
				todo_set.push_back(set);
			}
		}
		shuffle(todo_check,todo_set);
                vector<bit> queue_valid(QUEUE_SIZE);
		vector<vector<vector<bit>>> queue_set = bit_cube(QUEUE_SIZE, PI*2, idx_bits); 
		bigint timer = 0;
		for (bigint j = 0;j < todo_set.size();++j) {
		        auto check = todo_check.at(j);
		        auto set = todo_set.at(j);
                        bit_queue1_insert(queue_valid, check);
			bit_matrix_queue_insert(queue_set, set, check);
			// processing elements in the queue
			timer = (timer + 1) % PERIOD;
			
			if (j == todo_set.size()-1) timer = 0;
			if (timer == 0) //
			{
				for (bigint j = 0; j < QUEUE_SIZE; j++)
				{
					vector<bit> sum = s1;
					for (bigint b = 0; b < 2; b++)
					{
						vector<vector<bit>> set_p(0);
						for (bigint p = PI*b; p < PI*(b+1); p++) 
							set_p.push_back(queue_set.at(j).at(p));
						bit_vector_ixor(sum, bit_matrix_sum_of_cols(Hs1.at(b), set_p));
					}
					bit check_w = alwayssystematic.andn(bit_vector_hamming_weight_isnot(sum,T-PI*2)) & queue_valid.at(j);
					// store solution
                        
					bit_vector_mux(s_ret, sum, check_w);
					bit_matrix_mux(set_ret, queue_set.at(j), check_w);
					bit_matrix_mux(map_ret, column_map, check_w);
					// clear the queue elements
                                        queue_valid.at(j) = bit(0);
					bit_matrix_clear(queue_set.at(j));
				}
			}
		}
	}
	vector<bit> e_ret(N);
	vector<bit> e(N);
	for (bigint i = 0; i < RR; i++)
		e.at(i + KK) = s_ret.at(i);
	for (bigint i = 0; i < PI; i++)
		ram_write(e, 0, (KK-Z)/2, set_ret.at(i), bit(1));
	for (bigint i = PI; i < PI*2; i++)
		ram_write(e, (KK-Z)/2, KK-Z, set_ret.at(i), bit(1));
	for (bigint i = 0; i < N; i++)
		ram_write(e_ret, map_ret.at(i), e.at(i));
        // pk has identity implicitly on left, H has it on right
        // so change convention for output ordering
        vector<bit> e_ret_swap;
        for (bigint i = K_orig;i < N;++i)
          e_ret_swap.push_back(e_ret.at(i));
        for (bigint i = 0;i < K_orig;++i)
          e_ret_swap.push_back(e_ret.at(i));
        return e_ret_swap;
}