-rw-r--r-- 2718 cryptattacktester-20231020/echelontest.cpp raw
#include <cassert>
#include <iostream>
#include "random.h"
#include "index.h"
#include "bit_matrix.h"
#include "bit_matrix_cost.h"
using namespace std;
int main()
{
  for (bigint rows = 1;rows < 20;++rows)
    for (bigint cols = rows;cols < 20;++cols)
      for (bigint bound = rows;bound <= cols;++bound) {
        cout << "echelontest rows " << rows << " cols " << cols << " bound " << bound << "\n" << flush;
        for (bigint loop = 0;loop < 10;++loop) {
          vector<vector<bit>> M;
          for (bigint c = 0;c < cols;++c) {
            vector<bit> Mc;
            for (bigint r = 0;r < rows;++r)
              Mc.push_back(bit(random_bool()));
            M.push_back(Mc);
          }
          vector<vector<bit>> Morig = M;
          vector<vector<bit>> pivots;
  
          bigint ops = bit::ops();
          bit ok = bit_matrix_reduced_echelon(pivots,M,bound);
          assert(bit::ops()-ops == bit_matrix_reduced_echelon_cost(rows,cols,bound));
  
          vector<vector<bool>> Mcheck;
          for (bigint c = 0;c < cols;++c) {
            vector<bool> Mcheckc;
            for (bigint r = 0;r < rows;++r)
              Mcheckc.push_back(Morig.at(c).at(r).value());
            Mcheck.push_back(Mcheckc);
          }
  
          vector<bigint> pivotcheck;
          for (bigint c = 0;c < cols;++c) {
            bigint r = pivotcheck.size();
            if (r >= rows) break;
            for (bigint r2 = r;r2 < rows;++r2)
              if (Mcheck.at(c).at(r2)) {
                for (bigint c2 = 0;c2 < cols;++c2) {
                  bool t = Mcheck.at(c2).at(r);
                  Mcheck.at(c2).at(r) = Mcheck.at(c2).at(r2);
                  Mcheck.at(c2).at(r2) = t;
                }
                for (bigint r3 = 0;r3 < rows;++r3)
                  if (r3 != r)
                    if (Mcheck.at(c).at(r3))
                      for (bigint c2 = 0;c2 < cols;++c2)
                        Mcheck.at(c2).at(r3) = Mcheck.at(c2).at(r3) ^ Mcheck.at(c2).at(r);
                pivotcheck.push_back(c);
                break;
              }
          }
  
          bool okint = ok.value();
  
          vector<bigint> pivotsint;
          for (bigint r = 0;r < rows;++r)
            pivotsint.push_back(r+index_value(pivots.at(r)));
  
          if (okint) {
            for (bigint c = 0;c < cols;++c)
              for (bigint r = 0;r < rows;++r)
                assert(M.at(c).at(r).value() == Mcheck.at(c).at(r));
            for (bigint r = 0;r < rows;++r)
              assert(pivotsint.at(r) == pivotcheck.at(r));
          }
          else {
            assert(pivotcheck.size() < rows || pivotcheck.at(pivotcheck.size()-1) >= bound);
          }
        }
      }
  return 0;
}