-rw-r--r-- 2790 cryptattacktester-20231020/ram_cost.cpp raw
#include <map>
#include "bit_cost.h"
#include "ram_cost.h"
using namespace std;
static map<tuple<bigint,bigint,bigint>,bigint> ram_read_cost_cache;
static map<tuple<bigint,bigint,bigint,bool>,bigint> ram_write_cost_cache;
static map<tuple<bigint,bigint,bigint,bool>,bigint> ram_read_write_cost_cache;
bigint ram_read_cost(bigint N,bigint ibits,bigint length)
{
if (N <= 1) return 0;
tuple<bigint,bigint,bigint> key = make_tuple(N,ibits,length);
if (ram_read_cost_cache.count(key))
return ram_read_cost_cache[key];
bigint splitpos = 0;
bigint split = 1;
while (2*split < N) {
++splitpos;
split *= 2;
}
if (ibits <= splitpos)
return ram_read_cost(split,ibits,length);
bigint result = 0;
result += ram_read_cost(split,splitpos,length);
result += ram_read_cost(N-split,splitpos,length);
result += length*bit_mux_cost; // result.push_back(isplit.mux(x0,x1));
ram_read_cost_cache[key] = result;
return result;
}
bigint ram_write_cost(bigint N,bigint ibits,bigint length,bool top)
{
if (N <= 0) return 0;
if (N == 1) {
if (top)
return 0;
else
return length*bit_mux_cost; // b.mux(data.at(r), x.at(L).at(r));
}
tuple<bigint,bigint,bigint,bool> key = make_tuple(N,ibits,length,top);
if (ram_write_cost_cache.count(key))
return ram_write_cost_cache[key];
bigint splitpos = 0;
bigint split = 1;
while (2*split < N) {
++splitpos;
split *= 2;
}
if (ibits <= splitpos)
return ram_write_cost(split,ibits,length,top);
bigint result = 0;
if (top)
result += 1; // ~isplit // XXX: suppress ~
else
result += 2; // b|isplit ... b.orn(isplit)
result += ram_write_cost(split,splitpos,length,0);
result += ram_write_cost(N-split,splitpos,length,0);
ram_write_cost_cache[key] = result;
return result;
}
bigint ram_read_write_cost(bigint N,bigint ibits,bigint length,bool top)
{
if (N <= 0) return 0;
if (N == 1) {
if (top)
return 0;
else
return length*bit_mux_cost; // bit_vector_mux(x.at(L), data, x.at(L), b);
}
tuple<bigint,bigint,bigint,bool> key = make_tuple(N,ibits,length,top);
if (ram_read_write_cost_cache.count(key))
return ram_read_write_cost_cache[key];
bigint splitpos = 0;
bigint split = 1;
while (2*split < N) {
++splitpos;
split *= 2;
}
if (ibits <= splitpos)
return ram_read_write_cost(split,ibits,length,top);
bigint result = 0;
if (top)
result += 1; // ~isplit // XXX: suppress ~
else
result += 2; // b|isplit ... b.orn(isplit)
result += ram_read_write_cost(split,splitpos,length,0);
result += ram_read_write_cost(N-split,splitpos,length,0);
result += length*bit_mux_cost; // result.push_back(isplit.mux(x0,x1));
ram_read_write_cost_cache[key] = result;
return result;
}