-rw-r--r-- 16584 cryptattacktester-20231020/isd2.cpp raw
#include <cassert>
#include <vector>
#include <random>
#include "decoding.h"
#include "bit.h"
#include "ram.h"
#include "util.h"
#include "subset.h"
#include "bit_vector.h"
#include "index.h"
#include "bit_matrix.h"
#include "bit_cube.h"
#include "column_swaps.h"
#include "parity.h"
#include "sorting.h"
#include "isd0.h"
using namespace std;
/*
Let the partial weights of a solution be (w0, w1, 0, w2)
CHECKPI = 1: ensures that w0 = w1 = PI
CHECKPI = 0: do not check w0 and w1
CHECKSUM = 1: ensures that w0 + w1 + w2 = T
CHECKSUM = 0: ensures that w2 = T - 2PI
PI == PIJ*2 | CHECKPI | CHECKSUM | partial weights of solutions
---------------------------------------------------------
      Y     |    1    |    1     | (2PIJ, 2PIJ, 0, T - 4PIJ)
---------------------------------------------------------
      Y     |    1    |    0     | (2PIJ, 2PIJ, 0, T - 4PIJ)
---------------------------------------------------------
      Y     |    0    |    1     | (2PIJ - 2x, 2PIJ - 2y, 0, T - 4PIJ + 2x + 2y)
---------------------------------------------------------
      Y     |    0    |    0     | (2PIJ - 2x, 2PIJ - 2y, 0, T - 4PIJ) <= can be useful if we know that there is no solution of weight < T !!!
---------------------------------------------------------
      N     |    1    |    1     | (PI, PI, 0, T - 2PI)
---------------------------------------------------------
      N     |    1    |    0     | (PI, PI, 0, T - 2PI)
---------------------------------------------------------
      N     |    0    |    1     | (2PIJ - 2x, 2PIJ - 2y, 0, T - 2PIJ + 2x + 2y) <= PI is ignored in this case
---------------------------------------------------------
      N     |    0    |    0     | (2PIJ - 2x, 2PIJ - 2y, 0, T - 2PI) <= totally useless?
Since CHECKSUM = 1 is more expensive than CHECKSUM = 0, the useful cases are
1. (Y, 1, 0)
2. (N, 1, 0)
3. (Y, 0, 1) = (N, 0, 1)
4. (Y, 0, 0)
*/
template<class AT,class BT>
static void shuffle(vector<AT> &A,vector<BT> &B)
{
  bigint n = A.size();
  assert(n == B.size());
  permutation pi(n);
  pi.permute(A);
  pi.permute(B);
}
template<class AT,class BT,class CT>
static void shuffle(vector<AT> &A,vector<BT> &B,vector<CT> &C)
{
  bigint n = A.size();
  assert(n == B.size());
  assert(n == C.size());
  permutation pi(n);
  pi.permute(A);
  pi.permute(B);
  pi.permute(C);
}
template<class AT,class BT,class CT,class DT>
static void shuffle(vector<AT> &A,vector<BT> &B,vector<CT> &C,vector<DT> &D)
{
  bigint n = A.size();
  assert(n == B.size());
  assert(n == C.size());
  assert(n == D.size());
  permutation pi(n);
  pi.permute(A);
  pi.permute(B);
  pi.permute(C);
  pi.permute(D);
}
vector<bit> isd2(
  const vector<bit> &bits,
  const vector<bigint> ¶ms,
  const vector<bigint> &attackparams
)
{
        bigint N = params.at(0);
        bigint K_orig = params.at(1);
        bigint T = params.at(2);
        bigint pos = 0;
        bigint ITERS = attackparams.at(pos++);
        bigint RESET = attackparams.at(pos++);
        bigint X = attackparams.at(pos++);
        bigint YX = attackparams.at(pos++); auto Y = X+YX;
        bigint PIJ = attackparams.at(pos++);
        bigint PI = attackparams.at(pos++);
        bigint L0 = attackparams.at(pos++);
        bigint L1 = attackparams.at(pos++);
        bigint CHECKPI = attackparams.at(pos++);
        bigint CHECKSUM = attackparams.at(pos++);
        bigint D = attackparams.at(pos++);
        bigint Z = attackparams.at(pos++);
        bigint QU0 = attackparams.at(pos++);
        bigint QF0 = attackparams.at(pos++); auto PE0 = QF0*QU0;
        bigint WI0 = attackparams.at(pos++);
        bigint QU1 = attackparams.at(pos++);
        bigint QF1 = attackparams.at(pos++); auto PE1 = QF1*QU1;
        bigint WI1 = attackparams.at(pos++);
        bigint FW = attackparams.at(pos++);
        bigint L = L0 + L1;
        assert(PI <= 2*PIJ);
        assert(PI%2 == 0);
        assert(D >= 1);
        assert(!((D-1)>>L0)); // D <= 2^L0
        auto inputs = decoding_deserialize(bits,params);
        auto pk = inputs.first;
        auto s = inputs.second;
        vector<vector<bit>> H = bit_matrix_transpose_and_identity(pk);
        vector<vector<bit>> column_map;
        for (bigint i = 0; i < N; i++)
                column_map.push_back(bit_vector_from_integer(i, nbits(N-1)));
        bit alwayssystematic = 1;
        bigint K = K_orig;
        if (FW) {
          alwayssystematic = parity_known(s,H,column_map,bit(T.bit(0)));
          K -= 1;
        }
        vector<vector<bit>> initial_H = H;
        vector<bit> initial_s = s;
        vector<vector<bit>> initial_column_map = column_map;
        bit initial_alwayssystematic = alwayssystematic;
        bigint R = N - K;
        bigint KK = K + L;
        const bigint idx_bits = nbits((KK-Z+1)/2-1);
	vector<bit> s_ret(N-K-L);
	vector<vector<bit>> set_ret = bit_matrix(PIJ*4, idx_bits);
	vector<vector<bit>> map_ret = bit_matrix(N, nbits(N-1));
        bigint untilreset = 0;
        for (bigint iter = 0; iter < ITERS; iter++)
        {
		// if alwayssystematic: H.at(i).at(j) == (i-KK == j-L) for KK <= i < N, 0 <= j < R
                if (untilreset > 0) {
                  alwayssystematic &= column_swaps(s, H, column_map, N, K, L, X, Y);
                } else {
                  untilreset = RESET;
                  H = initial_H;
                  s = initial_s;
                  column_map = initial_column_map;
                  if (iter == 0)
                    alwayssystematic = initial_alwayssystematic;
                  else {
                    alwayssystematic = bit_matrix_column_randompermutation(s,H,column_map);
                    if (FW) alwayssystematic &= initial_alwayssystematic;
                  }
                  bit_matrix_randomize_rows(H, s, L);
                }
                --untilreset;
		// partitioning s and H
		vector<bit> s01 = bit_vector_extract(s, 0,  L);
		vector<bit> s2  = bit_vector_extract(s, L,  R);
		vector<vector<vector<bit>>> Hs01(2); 
		vector<vector<vector<bit>>> Hs2(2); 
		for (bigint i = 0; i < KK-Z; i++)
		{
			Hs01.at( (i < (KK-Z)/2) ? 0 : 1 ).push_back(bit_vector_extract(H.at(i), 0, L));
			Hs2.at(  (i < (KK-Z)/2) ? 0 : 1 ).push_back(bit_vector_extract(H.at(i), L, R));
		}
			
		// search for solution
		bigint flip_idx;
		vector<bigint> q_gray(0);
		vector<vector<bit>> L0_sum(0), L1_sum[2];
		vector<vector<vector<bit>>> L0_set(0), L1_set(0);
		bigint lens[2] = {0,0};
		for (bigint d = 0; d < D; d++) // randomizing search tree
		{
			vector<bit> L_root_01(0);
			vector<bit> L_root_valid(0);
			vector<vector<bit>> L_root_sum(0);
			vector<vector<vector<bit>>> L_root_set(0);
	
			for (bigint t = 0; t < 2; t++)
			{
				vector<bit> L_01(0);
				vector<vector<bit>> L_sum(0);
				vector<vector<vector<bit>>> L_set(0);
	
				vector<bit> zz(L);
	
				if (d == 0 and t == 0)
				{
					subset(L0_sum,    L0_set, Hs01.at(0).size(), PIJ, idx_bits, zz, Hs01.at(0));
					subset(L1_sum[0], L1_set, Hs01.at(1).size(), PIJ, idx_bits, zz, Hs01.at(1));
			
					lens[0] = L0_sum.size();
					lens[1] = L1_sum[0].size();
					for (bigint i = 0; i < lens[1]; i++)		
						L1_sum[1].push_back(bit_vector_xor(L1_sum[0].at(i), s01));
				}
	
				if (d > 0) // making use of gray code
				{
					if (t == 0)
						flip_idx = gray_idx(q_gray);
					for (bigint i = 0; i < lens[1]; i++)
						L1_sum[t].at(i).at(flip_idx) = ~L1_sum[t].at(i).at(flip_idx);
				}
				for (bigint i = 0; i < lens[0]; i++)
				{
					L_01.push_back(bit(0));
					L_sum.push_back(L0_sum.at(i));
					L_set.push_back(L0_set.at(i));
				}
	
				for (bigint i = 0; i < lens[1]; i++)
				{
					L_01.push_back(bit(1));
					L_sum.push_back(L1_sum[t].at(i));
					L_set.push_back(L1_set.at(i));
				}
				shuffle(L_01, L_sum, L_set);
				sorting(L_01, L_sum, L_set, L0);
	
				//
		
				vector<bit> todo_check;
				vector<vector<bit>> todo_sum;
				vector<vector<vector<bit>>> todo_set;
				for (bigint i = 0; i < L_sum.size()-1; i++)
				{
		                        for (bigint offset = 1;offset <= WI0;++offset) {
                                                if (i+offset >= L_sum.size()) continue;
		
						bit check = L_01.at(i) ^ L_01.at(i+offset);
						check = check.andn(bit_vector_compare(bit_vector_extract(L_sum.at(i+0), 0, L0), 
	                                                 	bit_vector_extract(L_sum.at(i+offset), 0, L0))); 
						vector<bit> v0 = bit_vector_extract(L_sum.at(i+0), L0, L);
						vector<bit> v1 = bit_vector_extract(L_sum.at(i+offset), L0, L);
						vector<bit> v = bit_vector_xor(v0, v1);
						vector<vector<bit>> set(0);
						for (bigint j = 0; j < PIJ; j++) set.push_back(L_set.at(i+0).at(j));
						for (bigint j = 0; j < PIJ; j++) set.push_back(L_set.at(i+offset).at(j));
						todo_check.push_back(check);
						todo_sum.push_back(v);
						todo_set.push_back(set);
					}
				}
		                shuffle(todo_check,todo_sum,todo_set);
				vector<bit> queue_valid(QU0); 
				vector<vector<bit>> queue_sum(QU0, vector<bit>(L1)); 
				vector<vector<vector<bit>>> queue_set = bit_cube(QU0, PIJ*2, idx_bits); 
				
				bigint timer = 0;
				for (bigint z = 0;z < todo_check.size();++z) {
					timer = (timer + 1) % PE0;
					if (z == todo_check.size()-1)
						timer = 0;
					auto check = todo_check.at(z);
					auto sum = todo_sum.at(z);
					auto set = todo_set.at(z);
	
					bit_queue1_insert(queue_valid, check);
					bit_vector_queue_insert(queue_sum, sum, check);
					bit_matrix_queue_insert(queue_set, set, check);
	
					// processing elements in the queue
	
					if (timer == 0) //
					{
						for (bigint j = 0; j < QU0; j++)
						{
							L_root_01.push_back(bit(t));
							L_root_valid.push_back(queue_valid.at(j));
					
							L_root_sum.push_back(queue_sum.at(j));
							L_root_set.push_back(queue_set.at(j));
	
							// clear the queue elements
		
							queue_valid.at(j) = bit(0);
							bit_vector_clear(queue_sum.at(j));
							bit_matrix_clear(queue_set.at(j));
						}
					}
				}
			} // t
	
			shuffle(L_root_01, L_root_sum, L_root_set, L_root_valid);
			sorting(L_root_01, L_root_sum, L_root_set, L_root_valid);
	
			vector<bit> todo_check;
			vector<vector<vector<bit>>> todo_set;
			for (bigint i = 0; i < L_root_sum.size()-1; i++)
			{
	                        for (bigint offset = 1;offset <= WI1;++offset) {
                                        if (i+offset >= L_root_sum.size()) continue;
					bit check; 
	
					check = L_root_01.at(i) ^ L_root_01.at(i+offset);
					check &= L_root_valid.at(i) & L_root_valid.at(i+offset);
					check = check.andn(bit_vector_compare(L_root_sum.at(i+0), L_root_sum.at(i+offset)));
	
                                	// do weight check if CHECKPI
                                	if (CHECKPI)
					{
						for (bigint k = 0; k < 2; k++)
						{
							vector<vector<bit>> set_check(0);
							for (bigint j = PIJ*k; j < PIJ*(k+1); j++) set_check.push_back(L_root_set.at(i+0).at(j));
							for (bigint j = PIJ*k; j < PIJ*(k+1); j++) set_check.push_back(L_root_set.at(i+offset).at(j));
							check &= set_size_check(set_check, PI);
						}
					}
	
					vector<vector<bit>> set(0);
					for (bigint j = 0; j < PIJ*2; j++) set.push_back(L_root_set.at(i+0).at(j));
					for (bigint j = 0; j < PIJ*2; j++) set.push_back(L_root_set.at(i+offset).at(j));
					todo_check.push_back(check);
					todo_set.push_back(set);
				}
			}
	                shuffle(todo_check,todo_set);
			vector<bit> queue_valid(QU1); 
			vector<vector<vector<bit>>> queue_set = bit_cube(QU1, PIJ*4, idx_bits); 
	
			bigint timer = 0;
			for (bigint z = 0;z < todo_check.size();++z) {
				timer = (timer + 1) % PE1;
				if (z == todo_check.size()-1)
					timer = 0;
				// conditionally pushing pairs into the queue
				auto check = todo_check.at(z);
				auto set = todo_set.at(z);
				bit_queue1_insert(queue_valid, check);
				bit_matrix_queue_insert(queue_set, set, check);
				// processing elements in the queue
				if (timer == 0)
				{
					for (bigint j = 0; j < QU1; j++)
					{
						vector<bit> sum = s2;
						for (bigint b = 0; b < 4; b++)
						{
							vector<vector<bit>> set_p(0);
							for (bigint p = PIJ*b; p < PIJ*(b+1); p++) 
								set_p.push_back(queue_set.at(j).at(p));
							bit_vector_ixor(sum, bit_matrix_sum_of_cols(Hs2.at(b & 1), set_p));
						}
                                               	// final check
                                               	const vector<bit> tp4 = bit_vector_from_integer((CHECKSUM) ? T : T-PI*2);
                                               	vector<bit> w_sum = bit_vector_hamming_weight(sum);
                                               	bit check_w = alwayssystematic;
                                               	if (CHECKSUM == 0) // make sure that the solution has partial weights (*, *, 0, T-PI*2)
                                                       	check_w &= queue_valid.at(j).andn(bit_vector_integer_compare(w_sum, tp4));
                                               	else // make sure that the solution has partial weights (w0, w1, 0, w2) with w0 + w1 + w2 = T
                                               	{
                                                       	vector<vector<bit>> weight_list(0);
                                                       	for (bigint b = 0; b < 2; b++)
                                                       	{
                                                               	vector<vector<bit>> set(0);
                                                               	for (bigint p = PIJ*(b+0); p < PIJ*(b+1); p++) set.push_back(queue_set.at(j).at(p));
                                                               	for (bigint p = PIJ*(b+2); p < PIJ*(b+3); p++) set.push_back(queue_set.at(j).at(p));
                                                               	weight_list.push_back(set_size(set));
                                                       	}
                                                       	weight_list.push_back(w_sum);
                                                       	vector<bit> w_tmp(nbits(PIJ*4)), w_final(nbits(PIJ*4 + R-L));
                                                       	bit_vector_add(w_tmp, weight_list.at(0), weight_list.at(1));
                                                       	bit_vector_add(w_final, weight_list.at(2), w_tmp);
                                                       	check_w &= queue_valid.at(j).andn(bit_vector_integer_compare(w_final, tp4));
                                               	}
						// store solution
                        
						bit_vector_mux(s_ret, sum, check_w);
						bit_matrix_mux(set_ret, queue_set.at(j), check_w);
						bit_matrix_mux(map_ret, column_map, check_w);
						// clear the queue elements
			
						queue_valid.at(j) = bit(0);
						bit_matrix_clear(queue_set.at(j));
					}
				}
			} 
		}
	} // iter
        vector<bit> e_ret(N);
        vector<bit> e(0);
        vector<vector<bit>> indices0;
        vector<vector<bit>> indices1;
        vector<vector<bit>> indices2;
        vector<vector<bit>> indices3;
        for (bigint i = PIJ*0; i < PIJ*1; i++) indices0.push_back(set_ret.at(i));
        vector<bit> v0 = indices_to_vector(indices0, (KK-Z)/2);
        for (bigint i = PIJ*2; i < PIJ*3; i++) indices1.push_back(set_ret.at(i));
        vector<bit> v1 = indices_to_vector(indices1, (KK-Z)/2);
        vector<bit> v01 = bit_vector_xor(v0, v1);
        for (bigint i = PIJ*1; i < PIJ*2; i++) indices2.push_back(set_ret.at(i));
        vector<bit> v2 = indices_to_vector(indices2, (KK-Z+1)/2);
        for (bigint i = PIJ*3; i < PIJ*4; i++) indices3.push_back(set_ret.at(i));
        vector<bit> v3 = indices_to_vector(indices3, (KK-Z+1)/2);
        vector<bit> v23 = bit_vector_xor(v2, v3);
        for (bigint i = 0; i < v01.size(); i++) e.push_back(v01.at(i));
        for (bigint i = 0; i < v23.size(); i++) e.push_back(v23.at(i));
        for (bigint i = 0; i < Z; i++) e.push_back(bit(0));
        for (bigint i = 0; i < s_ret.size(); i++) e.push_back(s_ret.at(i));
        assert(e.size() == N);
        for (bigint i = 0; i < N; i++)
                ram_write(e_ret, map_ret.at(i), e.at(i));
        // pk has identity implicitly on left, H has it on right
        // so change convention for output ordering
        vector<bit> e_ret_swap;
        for (bigint i = K_orig;i < N;++i)
          e_ret_swap.push_back(e_ret.at(i));
        for (bigint i = 0;i < K_orig;++i)
          e_ret_swap.push_back(e_ret.at(i));
        return e_ret_swap;
}