-rw-r--r-- 5109 cryptattacktester-20230614/subsettest.cpp raw
#include <cassert>
#include <iostream>
#include "random.h"
#include "index.h"
#include "subset.h"
#include "subset_cost.h"
#include "bit_matrix.h"
#include "bit_matrix_cost.h"
#include "sorting_cost.h"
using namespace std;
int main()
{
for (bigint p = 0;p <= 5;++p) {
for (bigint m = p;m <= 10;++m) {
if (m == 0) continue;
bigint expected = binomial(bigint(m),bigint(p));
cout << "subsettest " << p << " out of " << m << " expected " << expected << "\n" << flush;
for (bigint wordsize = 0;wordsize < 50;++wordsize) {
if (p > 0)
cout << "subsettest " << p << " out of " << m
<< " wordsize " << wordsize
<< " straightforward " << bit_matrix_sum_of_cols_straightforward_cost(wordsize,m,p)
<< " viasorting " << bit_matrix_sum_of_cols_viasorting_cost(wordsize,m,p)
<< " default " << bit_matrix_sum_of_cols_cost(wordsize,m,p)
<< "\n" << flush;
vector<bit> s;
for (bigint j = 0;j < wordsize;++j)
s.push_back(bit(random_bool()));
vector<vector<bit>> H;
for (bigint i = 0;i < m;++i) {
vector<bit> Hi;
for (bigint j = 0;j < wordsize;++j)
Hi.push_back(bit(random_bool()));
H.push_back(Hi);
}
vector<vector<bit>> L_sum;
vector<vector<vector<bit>>> L_set;
bigint idx_bits = nbits(m-1);
bigint ops = -bit::ops();
subset(L_sum,L_set,m,p,idx_bits,s,H);
assert(ops+bit::ops() == subset_cost(m,p,wordsize));
assert(L_sum.size() == expected);
assert(L_set.size() == expected);
vector<bool> sint;
for (bigint j = 0;j < wordsize;++j)
sint.push_back(s.at(j).value());
vector<vector<bool>> Hint;
for (bigint i = 0;i < m;++i) {
vector<bool> Hinti;
for (bigint j = 0;j < wordsize;++j)
Hinti.push_back(H.at(i).at(j).value());
Hint.push_back(Hinti);
}
vector<vector<bigint>> Lint;
for (bigint j = 0;j < L_set.size();++j) {
assert(L_set.at(j).size() == p);
assert(L_sum.at(j).size() == wordsize);
vector<bigint> Lintj;
for (bigint u = 0;u < p;++u)
Lintj.push_back(index_value(L_set.at(j).at(u)));
for (bigint u = 1;u < p;++u)
assert(Lintj.at(u) < Lintj.at(u-1));
for (bigint w = 0;w < wordsize;++w) {
bool check = L_sum.at(j).at(w).value();
check ^= sint.at(w);
for (bigint u = 0;u < p;++u)
check ^= Hint.at(Lintj.at(u)).at(w);
assert(!check);
}
if (wordsize > 0 && p > 0) {
bigint sumops = -bit::ops();
auto HLj = bit_matrix_sum_of_cols_straightforward(H,L_set.at(j));
assert(sumops+bit::ops() == bit_matrix_sum_of_cols_straightforward_cost(wordsize,m,p));
for (bigint w = 0;w < wordsize;++w) {
bool check = L_sum.at(j).at(w).value();
check ^= sint.at(w);
check ^= HLj.at(w).value();
assert(!check);
}
}
if (wordsize > 0 && p > 0) {
bigint sumops = -bit::ops();
auto HLj = bit_matrix_sum_of_cols_viasorting(H,L_set.at(j));
assert(sumops+bit::ops() == bit_matrix_sum_of_cols_viasorting_cost(wordsize,m,p));
for (bigint w = 0;w < wordsize;++w) {
bool check = L_sum.at(j).at(w).value();
check ^= sint.at(w);
check ^= HLj.at(w).value();
assert(!check);
}
}
if (wordsize > 0 && p > 0) {
bigint sumops = -bit::ops();
auto HLj = bit_matrix_sum_of_cols(H,L_set.at(j));
assert(sumops+bit::ops() == bit_matrix_sum_of_cols_cost(wordsize,m,p));
for (bigint w = 0;w < wordsize;++w) {
bool check = L_sum.at(j).at(w).value();
check ^= sint.at(w);
check ^= HLj.at(w).value();
assert(!check);
}
}
Lint.push_back(Lintj);
}
for (bigint j = 1;j < L_set.size();++j) {
bool ok = 0;
for (bigint u = 0;u < L_set.at(j).size();++u) {
assert(Lint.at(j).at(u) >= Lint.at(j-1).at(u));
if (Lint.at(j).at(u) > Lint.at(j-1).at(u)) {
ok = 1;
break;
}
}
assert(ok);
}
cout << flush;
}
}
}
for (bigint p = 8;p <= 128;p *= 2) {
for (bigint m = p;m <= 8192;++m) {
if (m&(m-1)) continue;
for (bigint wordsize = 1;wordsize <= 1024;wordsize += wordsize)
cout << "subsettest " << p << " out of " << m
<< " wordsize " << wordsize
<< " straightforward " << bit_matrix_sum_of_cols_straightforward_cost(wordsize,m,p)
<< " viasorting " << bit_matrix_sum_of_cols_viasorting_cost(wordsize,m,p)
<< " default " << bit_matrix_sum_of_cols_cost(wordsize,m,p)
<< "\n" << flush;
}
}
return 0;
}