-rw-r--r-- 7527 cryptattacktester-20231020/ram.cpp raw
#include <cassert>
#include "ram.h"
#include "bit_vector.h"
using namespace std;
// input: vector x of equal-length bit vectors x[0],...,x[N-1]
// input: integer L between 0 and N-1
// input: integer H between L+1 and N
// input: bit vector i specifying I = i[0]+2*i[1]+...+2^(ibits-1)*i[ibits-1]
// output: x[L+J] for some J with 0 <= J < H-L
// satisfying J=I if 0 <= I < H-L; no constraints if I >= H-L
const vector<bit> ram_read(
const vector<std::vector<bit>> &x,
bigint L,
bigint H,
const vector<bit> &i,
bigint ibits)
{
if (H <= L) return vector<bit>{};
if (H == L+1) return x.at(L);
bigint splitpos = 0;
bigint split = 1;
while (L+split < H-split) {
splitpos += 1;
split *= 2;
}
// now H-L <= split+split
// and split == 2^splitpos
// conventional RAM circuit:
// use i[0:splitpos] to look up entry in x[L:L+split]
// result0 = x[L+(I mod split)]
// use i[0:splitpos] to look up entry in x[L+split:H], called result1
// result1 = x[L+split+(I mod split)] if L+split+(I mod split) < H
// multiplex according to i[splitpos]
// why this works:
// assume I < H-L; then I < split+split
// case 0: i[splitpos] is 0
// then I < split
// so result0 = x[L+I], and we'll select result0
// case 1: i[splitpos] is 1
// then split <= I < H-L <= split+split
// so split+(I mod split) = I <= H-L
// so result1 = x[L+I], and we'll select result1
if (ibits <= splitpos) {
// different case: definitely want result from x[L:L+split], no multiplexing
return ram_read(x,L,L+split,i,ibits);
}
vector<bit> result0 = ram_read(x,L,L+split,i,splitpos);
vector<bit> result1 = ram_read(x,L+split,H,i,splitpos);
assert(result0.size() == result1.size());
bit isplit = i.at(splitpos);
vector<bit> result{};
for (bigint r = 0;r < result0.size();++r) {
bit x0 = result0.at(r);
bit x1 = result1.at(r);
result.push_back(isplit.mux(x0,x1));
}
return result;
}
const vector<bit> ram_read(
const vector<std::vector<bit>> &x,
bigint L,
bigint H,
const vector<bit> &i)
{
return ram_read(x, L, H, i, i.size());
}
const vector<bit> ram_read(
const vector<std::vector<bit>> &x,
const vector<bit> &i)
{
return ram_read(x, 0, x.size(), i);
}
// same as ram_read above but only returns the jth bit
const bit ram_read(
const vector<std::vector<bit>> &x,
bigint L,
bigint H,
const vector<bit> &i,
bigint ibits,
bigint j)
{
if (H <= L) return bit(0);
if (H == L+1) return x.at(L).at(j);
bigint splitpos = 0;
bigint split = 1;
while (L+split < H-split) {
splitpos += 1;
split *= 2;
}
if (ibits <= splitpos) {
return ram_read(x,L,L+split,i,ibits,j);
}
bit result0 = ram_read(x,L,L+split,i,splitpos,j);
bit result1 = ram_read(x,L+split,H,i,splitpos,j);
bit isplit = i.at(splitpos);
bit result = isplit.mux(result0,result1);
return result;
}
// same as ram_read above but x is a vector of bits
const bit ram_read(
const vector<bit> &x,
bigint L,
bigint H,
const vector<bit> &i,
bigint ibits)
{
if (H <= L) return bit(0);
if (H == L+1) return x.at(L);
bigint splitpos = 0;
bigint split = 1;
while (L+split < H-split) {
splitpos += 1;
split *= 2;
}
if (ibits <= splitpos) {
return ram_read(x,L,L+split,i,ibits);
}
bit result0 = ram_read(x,L,L+split,i,splitpos);
bit result1 = ram_read(x,L+split,H,i,splitpos);
bit isplit = i.at(splitpos);
bit result = isplit.mux(result0,result1);
return result;
}
const bit ram_read(
const vector<bit> &x,
const vector<bit> &i)
{
return ram_read(x, 0, x.size(), i, i.size());
}
// input: vector x of equal-length bit vectors x[0],...,x[N-1]
// input: integer L between 0 and N-1
// input: integer H between L+1 and N
// input: bit vector i specifying I = i[0]+2*i[1]+...+2^(ibits-1)*i[ibits-1]
// input: bit vector data of length the same as any x[*]
// input: bit b with default value 0; only used for recursive calls.
// write data to x[L+J] for some J with 0 <= J < H-L
// satisfying J=I if 0 <= I < H-L; no constraints if I >= H-L
void ram_write(
vector<std::vector<bit>> &x,
bigint L,
bigint H,
const vector<bit> &i,
bigint ibits,
const vector<bit> &data,
bit b, bool top)
{
assert (x.at(0).size() == data.size());
if (H <= L) return;
if (H == L+1) {
if (top)
x.at(L) = data;
else
for (bigint r = 0;r < data.size();++r)
x.at(L).at(r) = b.mux(data.at(r), x.at(L).at(r));
return;
}
bigint splitpos = 0;
bigint split = 1;
while (L+split < H-split) {
splitpos += 1;
split *= 2;
}
if (ibits <= splitpos) {
return ram_write(x,L,L+split,i,ibits,data,b,top);
}
bit isplit = i.at(splitpos);
ram_write(x,L,L+split,i,splitpos,data, top ? isplit : (b | isplit), 0);
ram_write(x,L+split,H,i,splitpos,data, top ? ~isplit : b.orn(isplit), 0);
}
void ram_write(
vector<std::vector<bit>> &x,
bigint L,
bigint H,
const vector<bit> &i,
const vector<bit> &data)
{
ram_write(x, L, H, i, i.size(), data);
}
void ram_write(
vector<std::vector<bit>> &x,
const vector<bit> &i,
const vector<bit> &data)
{
ram_write(x, 0, x.size(), i, data);
}
const vector<bit> ram_read_write(
vector<std::vector<bit>> &x,
bigint L,
bigint H,
const vector<bit> &i,
bigint ibits,
vector<bit> &data,
bit b, bool top)
{
assert (x.at(0).size() == data.size());
if (H <= L) return vector<bit>{};
if (H == L+1) {
vector<bit> v = x.at(L);
if (top)
x.at(L) = data;
else
bit_vector_mux(x.at(L), data, x.at(L), b);
return v;
}
bigint splitpos = 0;
bigint split = 1;
while (L+split < H-split) {
splitpos += 1;
split *= 2;
}
if (ibits <= splitpos) {
return ram_read_write(x,L,L+split,i,ibits,data,b, top);
}
bit isplit = i.at(splitpos);
vector<bit> result0 = ram_read_write(x,L,L+split,i,splitpos,data, top ? isplit : (b | isplit), 0);
vector<bit> result1 = ram_read_write(x,L+split,H,i,splitpos,data, top ? ~isplit : b.orn(isplit), 0);
assert(result0.size() == result1.size());
vector<bit> result{};
for (bigint r = 0;r < result0.size();++r) {
bit x0 = result0.at(r);
bit x1 = result1.at(r);
result.push_back(isplit.mux(x0,x1));
}
return result;
}
const vector<bit> ram_read_write(
vector<std::vector<bit>> &x,
bigint L,
bigint H,
const vector<bit> &i,
vector<bit> &data)
{
return ram_read_write(x, L, H, i, i.size(), data);
}
const vector<bit> ram_read_write(
vector<std::vector<bit>> &x,
const vector<bit> &i,
vector<bit> &data)
{
return ram_read_write(x, 0, x.size(), i, data);
}
void ram_write(
vector<bit> &x,
bigint L,
bigint H,
const vector<bit> &i,
bigint ibits,
const bit data,
bit b,
bool top)
{
if (H <= L) return;
if (H == L+1)
{
x.at(L) = b.mux(data, x.at(L));
return;
}
bigint splitpos = 0;
bigint split = 1;
while (L+split < H-split) {
splitpos += 1;
split *= 2;
}
if (ibits <= splitpos) {
return ram_write(x,L,L+split,i,ibits,data,b,top);
}
bit isplit = i.at(splitpos);
ram_write(x,L,L+split,i,splitpos,data, top ? isplit : (b | isplit), 0);
ram_write(x,L+split,H,i,splitpos,data, top ? ~isplit : b.orn(isplit), 0);
return;
}
void ram_write(
vector<bit> &x,
bigint L,
bigint H,
const vector<bit> &i,
const bit data)
{
ram_write(x, L, H, i, i.size(), data);
}
void ram_write(
vector<bit> &x,
const vector<bit> &i,
const bit data)
{
ram_write(x, 0, x.size(), i, data);
}