-rw-r--r-- 5033 cryptattacktester-20231020/circuitprob.cpp raw
#include <iostream>
#include "problem.h"
#include "attack.h"
#include "collision_prob.h"
#include "random.h"
using namespace std;
static bool cacheinit = 0;
static problem Ecached;
static vector<bigint> Pcached;
static vector<vector<bool>> publist;
static vector<vector<bool>> seclist;
static bigfloat numinputs;
static bigfloat numoutputs;
int attack_handle(const problem &E,const vector<bigint> &P,const attack &A,const vector<bigint> &Q)
{
  bigint maxcost = 1073741824;
  bigint maxnonbatchcost = 1024;
  bigint trialfactor = 1000;
  bigint probfactor = 10000;
  selection_constrain(attack_selection,"maxcost",maxcost,maxcost);
  selection_constrain(attack_selection,"maxnonbatchcost",maxnonbatchcost,maxnonbatchcost);
  selection_constrain(attack_selection,"trialfactor",trialfactor,trialfactor);
  selection_constrain(attack_selection,"probfactor",probfactor,probfactor);
  cout << "circuitprob";
  cout << " problem=" << E.name;
  for (bigint j = 0;j < P.size();++j)
    cout << (j ? ',' : ' ') << E.paramnames.at(j) << "=" << P.at(j);
  cout << " attack=";
  cout << A.name;
  for (bigint j = 0;j < Q.size();++j)
    cout << (j ? ',' : ' ') << A.paramnames.at(j) << "=" << Q.at(j);
  bigint predictedcost = A.cost(P,Q);
  cout << " cost " << predictedcost;
  if (predictedcost > maxcost) {
    cout << " skipping\n" << flush;
    return 1;
  }
  bigfloat predictedprob = A.prob(P,Q);
  cout << " prob " << predictedprob;
  if (probfactor)
    if (ceil_as_bigint(predictedprob*bigfloat(probfactor)) <= 1) {
      cout << " skipping\n" << flush;
      return 1;
    }
  bit::clear_all();
  if (cacheinit) {
    if (E.psgen != Ecached.psgen) cacheinit = 0;
    if (E.paramnames != Ecached.paramnames) cacheinit = 0;
    if (P != Pcached) cacheinit = 0;
  }
  if (!cacheinit) {
    Ecached = E;
    Pcached = P;
    publist.clear();
    seclist.clear();
    numinputs = E.numinputs(P);
    numoutputs = E.numoutputs(P);
    cacheinit = 1;
  }
  bigfloat predictedprob2 = collision_lastmatch_prob(predictedprob*numinputs,numinputs,numoutputs);
  cout << " prob2 " << predictedprob2;
  vector<bigint> Pbigint;
  for (bigint j = 0;j < P.size();++j)
    Pbigint.push_back((bigint) (P.at(j)));
  vector<bigint> Qbigint;
  for (bigint j = 0;j < Q.size();++j)
    Qbigint.push_back((bigint) (Q.at(j)));
  bigint bigtrials;
  
  // predicted average successes: trials*predictedprob2
  // predicted deviation: sqrt(trials*predictedprob2*(1-predictedprob2))
  // want deviation/successes <= X
  // i.e. trials*predictedprob2*(1-predictedprob2) <= X^2 trials^2 predictedprob2^2
  // i.e. 1-predictedprob2 <= X^2 trials predictedprob2
  if (predictedprob2 > 0.5)
    bigtrials = trialfactor;
  else {
    bigfloat floattrials = trialfactor*(1-predictedprob2)/predictedprob2;
    bigtrials = ceil_as_bigint(floattrials);
  }
  if (bigtrials < 1) bigtrials = 1;
  if (bigtrials > "1000000000000000000") bigtrials = "1000000000000000000";
  bigint trials = bigtrials;
  while (publist.size() < trials) {
    random_seed(publist.size());
    auto ps = E.psgen(P);
    publist.push_back(ps.first);
    seclist.push_back(ps.second);
  }
  bool checknonbatch = (predictedcost < maxnonbatchcost);
  bigint nonbatchsuccesses = 0;
  bigint successes = 0;
  if (checknonbatch) {
    for (bigint t = 0;t < trials;++t) {
      auto pub = publist.at(t);
      auto sec = seclist.at(t);
      vector<bit> pubbit;
      for (bigint j = 0;j < pub.size();++j)
        pubbit.push_back(bit(pub.at(j)));
      random_seed();
      vector<bit> attackoutput = A.circuit(pubbit,Pbigint,Qbigint);
      bool success = 1;
      for (bigint j = 0;j < sec.size();++j)
        if (sec.at(j) != attackoutput.at(j).value())
          success = 0;
      nonbatchsuccesses += success;
    }
  }
  if (1) { // always do batch
    for (bigint batch = 0;batch < trials;batch += bit_slicing) {
      bigint jbound = publist.at(batch).size();
      vector<bit> pubbit;
      for (bigint j = 0;j < jbound;++j) {
        bitset<bit_slicing> pubj;
        for (bigint t = 0;t < bit_slicing && batch+t < trials;++t)
          pubj[t] = publist.at(batch+t).at(j);
        pubbit.push_back(bit(pubj));
      }
      random_seed();
      vector<bit> attackoutput = A.circuit(pubbit,Pbigint,Qbigint);
      for (bigint t = 0;t < bit_slicing && batch+t < trials;++t) {
        auto sec = seclist.at(batch+t);
        bool success = 1;
        for (bigint j = 0;j < sec.size();++j)
          if (sec.at(j) != attackoutput.at(j).value_vector()[t])
            success = 0;
        successes += success;
      }
    }
    if (checknonbatch) assert(successes == nonbatchsuccesses);
  }
  bigfloat observed = bigfloat(successes)/bigfloat(trials);
  bigfloat ratio2 = observed/predictedprob2;
  cout << " trials " << trials;
  cout << " slicedops " << bit::ops();
  cout << " succ " << observed;
  cout << " ratio2 " << ratio2;
  if (ratio2 > 1.1) cout << " [1;31mALERT[0m";
  if (ratio2 < 0.9) cout << " [1;31mALERT[0m";
  cout << '\n' << flush;
  return 1;
}