Skip to content

Commit

Permalink
Mem align optimized and working
Browse files Browse the repository at this point in the history
  • Loading branch information
hecmas committed Nov 25, 2024
1 parent 42287a2 commit b5d00b1
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 46 deletions.
6 changes: 2 additions & 4 deletions pil/src/pil_helpers/traces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ trace!(RomRow, RomTrace<F> {
});

trace!(MemRow, MemTrace<F> {
addr: F, step: F, sel: F, wr: F, value: [F; 2], addr_changes: F, same_value: F, first_addr_access_is_read: F,
addr: F, step: F, sel: F, wr: F, value: [F; 2], addr_changes: F, increment: F, same_value: F, first_addr_access_is_read: F,
});

trace!(MemAlignRow, MemAlignTrace<F> {
addr: F, offset: F, width: F, wr: F, pc: F, reset: F, sel_up_to_down: F, sel_down_to_up: F, reg: [F; 8], sel: [F; 8], step: F, sel_prove: F,
addr: F, offset: F, width: F, wr: F, pc: F, reset: F, sel_up_to_down: F, sel_down_to_up: F, reg: [F; 8], sel: [F; 8], step: F, delta_addr: F, sel_prove: F, value: [F; 2],
});

trace!(MemAlignRomRow, MemAlignRomTrace<F> {
Expand Down Expand Up @@ -58,5 +58,3 @@ trace!(SpecifiedRangesRow, SpecifiedRangesTrace<F> {
trace!(U8AirRow, U8AirTrace<F> {
mul: F,
});


30 changes: 20 additions & 10 deletions state-machines/mem/pil/mem_align.pil
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM
col witness sel_down_to_up; // 1 if the next value is the previous value (e.g. W -> R)
col witness reg[CHUNK_NUM]; // Register values, 1 byte each
col witness sel[CHUNK_NUM]; // Selectors, 1 if the value is used, 0 otherwise
col witness step; // Step of memory
col witness step; // Memory step

// 1] Ensure the MemAlign follows the program

Expand Down Expand Up @@ -133,7 +133,12 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM
}
flags += wr * 2**CHUNK_NUM + reset * 2**(CHUNK_NUM + 1) + sel_up_to_down * 2**(CHUNK_NUM + 2) + sel_down_to_up * 2**(CHUNK_NUM + 3);

lookup_assumes(MEM_ALIGN_ROM_ID, [pc, pc'-pc, (addr-'addr)*(1-reset), offset, width, flags]);
// Perform the lookup against the program
expr delta_pc;
col witness delta_addr; // Auxiliary column
delta_pc = pc' - pc;
delta_addr === (addr - 'addr) * (1 - reset);
lookup_assumes(MEM_ALIGN_ROM_ID, [pc, delta_pc, delta_addr, offset, width, flags]);

// 2] Assume aligned memory accesses against the Memory component
const expr sel_assume = sel_up_to_down + sel_down_to_up;
Expand All @@ -143,10 +148,12 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM

// On assume steps, we reconstruct the value from the registers directly
expr assume_val[RC];
for (int i = 0; i < RC; i++) {
assume_val[i] = 0;
for (int j = 0; j < CHUNKS_BY_RC; j++) {
assume_val[i] += reg[j + i * CHUNKS_BY_RC] * 2**(j*8);
for (int rc_index = 0; rc_index < RC; rc_index++) {
assume_val[rc_index] = 0;
int base = 1;
for (int _offset = 0; _offset < CHUNKS_BY_RC; _offset++) {
assume_val[rc_index] += reg[_offset + rc_index * CHUNKS_BY_RC] * base;
base *= 256;
}
}

Expand All @@ -157,7 +164,7 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM

// On prove steps, we reconstruct the value in the correct manner chosen by the selectors
expr prove_val[RC];
for (int rc_index = 0; rc_index < RC; ++rc_index) {
for (int rc_index = 0; rc_index < RC; rc_index++) {
prove_val[rc_index] = 0;
}
for (int _offset = 0; _offset < CHUNK_NUM; _offset++) {
Expand All @@ -166,13 +173,16 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM
int base = 1;
for (int ichunk = 0; ichunk < CHUNKS_BY_RC; ichunk++) {
_tmp += reg[(_offset + rc_index * CHUNKS_BY_RC + ichunk) % CHUNK_NUM] * base;
base = base * 256;
base *= 256;
}
prove_val[rc_index] += sel[_offset] * _tmp;
}
}

// We prove and assume with the same permutation check but with disjoint and different sign selectors
permutation_proves(MEMORY_ID, cols: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * CHUNK_NUM + offset, step, width, ...prove_val], sel: sel_prove);
permutation_assumes(MEMORY_ID, cols: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * CHUNK_NUM + offset, step, width, ...assume_val], sel: sel_assume);
col witness value[RC]; // Auxiliary columns
for (int i = 0; i < RC; i++) {
value[i] === sel_prove * prove_val[i] + sel_assume * assume_val[i];
}
permutation(MEMORY_ID, cols: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * CHUNK_NUM + offset, step, width, ...value], sel: sel_prove - sel_assume);
}
46 changes: 24 additions & 22 deletions state-machines/mem/src/mem_align_rom_sm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,22 +79,22 @@ impl<F: PrimeField> MemAlignRomSM<F> {

pub fn calculate_next_pc(&self, opcode: MemOp, offset: usize, width: usize) -> u64 {
// Get the table offset
let (table_offset, one_word) = match opcode {
MemOp::OneRead => {
(1, true)
}

MemOp::OneWrite => {
(1 + ONE_WORD_COMBINATIONS * OP_SIZES[0], true)
}

MemOp::TwoReads => {
(1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + ONE_WORD_COMBINATIONS * OP_SIZES[1], false)
}

MemOp::TwoWrites => {
(1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + ONE_WORD_COMBINATIONS * OP_SIZES[1] + TWO_WORD_COMBINATIONS * OP_SIZES[2], false)
}
let (table_offset, one_word) = match opcode {
MemOp::OneRead => (1, true),

MemOp::OneWrite => (1 + ONE_WORD_COMBINATIONS * OP_SIZES[0], true),

MemOp::TwoReads => (
1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + ONE_WORD_COMBINATIONS * OP_SIZES[1],
false,
),

MemOp::TwoWrites => (
1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] +
ONE_WORD_COMBINATIONS * OP_SIZES[1] +
TWO_WORD_COMBINATIONS * OP_SIZES[2],
false,
),
};

// Get the first row index
Expand All @@ -114,7 +114,13 @@ impl<F: PrimeField> MemAlignRomSM<F> {
first_row_idx
}

fn get_first_row_idx(opcode: MemOp, offset: usize, width: usize, table_offset: u64, one_word: bool) -> u64 {
fn get_first_row_idx(
opcode: MemOp,
offset: usize,
width: usize,
table_offset: u64,
one_word: bool,
) -> u64 {
let opcode_idx = opcode as usize;
let op_size = OP_SIZES[opcode_idx];

Expand Down Expand Up @@ -203,11 +209,7 @@ impl<F: PrimeField> MemAlignRomSM<F> {
}
}

info!(
"{}: ··· Creating Mem Align ROM instance [{} rows filled 100%]",
Self::MY_NAME,
self.num_rows,
);
info!("{}: ··· Creating Mem Align Rom instance", Self::MY_NAME,);

let air_instance =
AirInstance::new(sctx, ZISK_AIRGROUP_ID, MEM_ALIGN_ROM_AIR_IDS[0], None, prover_buffer);
Expand Down
87 changes: 77 additions & 10 deletions state-machines/mem/src/mem_align_sm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@ use zisk_pil::{MemAlignRow, MemAlignTrace, MEM_ALIGN_AIR_IDS, ZISK_AIRGROUP_ID};

use crate::{MemAlignInput, MemAlignRomSM, MemOp};

const RC: usize = 2;
const CHUNK_NUM: usize = 8;
const CHUNKS_BY_RC: usize = CHUNK_NUM / RC;
const CHUNK_BITS: usize = 8;
const RC_BITS: u64 = (CHUNKS_BY_RC * CHUNK_BITS) as u64;
const RC_MASK: u64 = (1 << RC_BITS) - 1;
const OFFSET_MASK: u32 = 0x07;
const OFFSET_BITS: u32 = 3;
const CHUNK_BITS_MASK: u64 = (1 << CHUNK_BITS) - 1;
Expand Down Expand Up @@ -131,13 +135,6 @@ impl<F: PrimeField> MemAlignSM<F> {

#[inline(always)]
pub fn get_mem_op(&self, input: &MemAlignInput, phase: usize) -> MemAlignResponse {
debug_assert!(
input.mem_values.len() == phase + 1,
"The number of mem_values {} is not equal to phase + 1 {}",
input.mem_values.len(),
phase + 1
);

let addr = input.address;
let width = input.width;

Expand Down Expand Up @@ -192,6 +189,7 @@ impl<F: PrimeField> MemAlignSM<F> {
let mut read_row = MemAlignRow::<F> {
step: F::from_canonical_u64(step),
addr: F::from_canonical_u32(addr_read),
// delta_addr: F::zero(),
offset: F::from_canonical_u64(DEFAULT_OFFSET),
width: F::from_canonical_u64(DEFAULT_WIDTH),
// wr: F::from_bool(false),
Expand All @@ -204,6 +202,7 @@ impl<F: PrimeField> MemAlignSM<F> {
let mut value_row = MemAlignRow::<F> {
step: F::from_canonical_u64(step),
addr: F::from_canonical_u32(addr_read),
// delta_addr: F::zero(),
offset: F::from_canonical_usize(offset),
width: F::from_canonical_usize(width),
// wr: F::from_bool(false),
Expand All @@ -226,6 +225,15 @@ impl<F: PrimeField> MemAlignSM<F> {
}
}

let mut _value_read = value_read;
let mut _value = value;
for i in 0..RC {
read_row.value[i] = F::from_canonical_u64(_value_read & RC_MASK);
value_row.value[i] = F::from_canonical_u64(_value & RC_MASK);
_value_read >>= RC_BITS;
_value >>= RC_BITS;
}

#[rustfmt::skip]
debug_info!(
"\nOne Word Read\n\
Expand Down Expand Up @@ -309,6 +317,7 @@ impl<F: PrimeField> MemAlignSM<F> {
let mut read_row = MemAlignRow::<F> {
step: F::from_canonical_u64(step),
addr: F::from_canonical_u32(addr_read),
// delta_addr: F::zero(),
offset: F::from_canonical_u64(DEFAULT_OFFSET),
width: F::from_canonical_u64(DEFAULT_WIDTH),
// wr: F::from_bool(false),
Expand All @@ -321,6 +330,7 @@ impl<F: PrimeField> MemAlignSM<F> {
let mut write_row = MemAlignRow::<F> {
step: F::from_canonical_u64(step + 1),
addr: F::from_canonical_u32(addr_read),
// delta_addr: F::zero(),
offset: F::from_canonical_u64(DEFAULT_OFFSET),
width: F::from_canonical_u64(DEFAULT_WIDTH),
wr: F::from_bool(true),
Expand All @@ -333,6 +343,7 @@ impl<F: PrimeField> MemAlignSM<F> {
let mut value_row = MemAlignRow::<F> {
step: F::from_canonical_u64(step),
addr: F::from_canonical_u32(addr_read),
// delta_addr: F::zero(),
offset: F::from_canonical_usize(offset),
width: F::from_canonical_usize(width),
wr: F::from_bool(true),
Expand Down Expand Up @@ -365,6 +376,18 @@ impl<F: PrimeField> MemAlignSM<F> {
}
}

let mut _value_read = value_read;
let mut _value_write = value_write;
let mut _value = value;
for i in 0..RC {
read_row.value[i] = F::from_canonical_u64(_value_read & RC_MASK);
write_row.value[i] = F::from_canonical_u64(_value_write & RC_MASK);
value_row.value[i] = F::from_canonical_u64(_value & RC_MASK);
_value_read >>= RC_BITS;
_value_write >>= RC_BITS;
_value >>= RC_BITS;
}

#[rustfmt::skip]
debug_info!(
"\nOne Word Write\n\
Expand Down Expand Up @@ -454,6 +477,7 @@ impl<F: PrimeField> MemAlignSM<F> {
let mut first_read_row = MemAlignRow::<F> {
step: F::from_canonical_u64(step),
addr: F::from_canonical_u32(addr_first_read),
// delta_addr: F::zero(),
offset: F::from_canonical_u64(DEFAULT_OFFSET),
width: F::from_canonical_u64(DEFAULT_WIDTH),
// wr: F::from_bool(false),
Expand All @@ -466,6 +490,7 @@ impl<F: PrimeField> MemAlignSM<F> {
let mut value_row = MemAlignRow::<F> {
step: F::from_canonical_u64(step),
addr: F::from_canonical_u32(addr_first_read),
// delta_addr: F::zero(),
offset: F::from_canonical_usize(offset),
width: F::from_canonical_usize(width),
// wr: F::from_bool(false),
Expand All @@ -478,6 +503,7 @@ impl<F: PrimeField> MemAlignSM<F> {
let mut second_read_row = MemAlignRow::<F> {
step: F::from_canonical_u64(step),
addr: F::from_canonical_u32(addr_second_read),
delta_addr: F::one(),
offset: F::from_canonical_u64(DEFAULT_OFFSET),
width: F::from_canonical_u64(DEFAULT_WIDTH),
// wr: F::from_bool(false),
Expand Down Expand Up @@ -508,6 +534,20 @@ impl<F: PrimeField> MemAlignSM<F> {
}
}

let mut _value_first_read = value_first_read;
let mut _value = value;
let mut _value_second_read = value_second_read;
for i in 0..RC {
first_read_row.value[i] =
F::from_canonical_u64(_value_first_read & RC_MASK);
value_row.value[i] = F::from_canonical_u64(_value & RC_MASK);
second_read_row.value[i] =
F::from_canonical_u64(_value_second_read & RC_MASK);
_value_first_read >>= RC_BITS;
_value >>= RC_BITS;
_value_second_read >>= RC_BITS;
}

#[rustfmt::skip]
debug_info!(
"\nTwo Words Read\n\
Expand Down Expand Up @@ -677,6 +717,7 @@ impl<F: PrimeField> MemAlignSM<F> {
let mut first_read_row = MemAlignRow::<F> {
step: F::from_canonical_u64(step),
addr: F::from_canonical_u32(addr_first_read_write),
// delta_addr: F::zero(),
offset: F::from_canonical_u64(DEFAULT_OFFSET),
width: F::from_canonical_u64(DEFAULT_WIDTH),
// wr: F::from_bool(false),
Expand All @@ -689,6 +730,7 @@ impl<F: PrimeField> MemAlignSM<F> {
let mut first_write_row = MemAlignRow::<F> {
step: F::from_canonical_u64(step + 1),
addr: F::from_canonical_u32(addr_first_read_write),
// delta_addr: F::zero(),
offset: F::from_canonical_u64(DEFAULT_OFFSET),
width: F::from_canonical_u64(DEFAULT_WIDTH),
wr: F::from_bool(true),
Expand All @@ -701,6 +743,7 @@ impl<F: PrimeField> MemAlignSM<F> {
let mut value_row = MemAlignRow::<F> {
step: F::from_canonical_u64(step),
addr: F::from_canonical_u32(addr_first_read_write),
// delta_addr: F::zero(),
offset: F::from_canonical_usize(offset),
width: F::from_canonical_usize(width),
wr: F::from_bool(true),
Expand All @@ -713,6 +756,7 @@ impl<F: PrimeField> MemAlignSM<F> {
let mut second_write_row = MemAlignRow::<F> {
step: F::from_canonical_u64(step + 1),
addr: F::from_canonical_u32(addr_second_read_write),
delta_addr: F::one(),
offset: F::from_canonical_u64(DEFAULT_OFFSET),
width: F::from_canonical_u64(DEFAULT_WIDTH),
wr: F::from_bool(true),
Expand All @@ -725,6 +769,7 @@ impl<F: PrimeField> MemAlignSM<F> {
let mut second_read_row = MemAlignRow::<F> {
step: F::from_canonical_u64(step),
addr: F::from_canonical_u32(addr_second_read_write),
// delta_addr: F::zero(),
offset: F::from_canonical_u64(DEFAULT_OFFSET),
width: F::from_canonical_u64(DEFAULT_WIDTH),
// wr: F::from_bool(false),
Expand Down Expand Up @@ -777,6 +822,28 @@ impl<F: PrimeField> MemAlignSM<F> {
}
}

let mut _value_first_read = value_first_read;
let mut _value_first_write = value_first_write;
let mut _value = value;
let mut _value_second_write = value_second_write;
let mut _value_second_read = value_second_read;
for i in 0..RC {
first_read_row.value[i] =
F::from_canonical_u64(_value_first_read & RC_MASK);
first_write_row.value[i] =
F::from_canonical_u64(_value_first_write & RC_MASK);
value_row.value[i] = F::from_canonical_u64(_value & RC_MASK);
second_write_row.value[i] =
F::from_canonical_u64(_value_second_write & RC_MASK);
second_read_row.value[i] =
F::from_canonical_u64(_value_second_read & RC_MASK);
_value_first_read >>= RC_BITS;
_value_first_write >>= RC_BITS;
_value >>= RC_BITS;
_value_second_write >>= RC_BITS;
_value_second_read >>= RC_BITS;
}

#[rustfmt::skip]
debug_info!(
"\nTwo Words Write\n\
Expand Down Expand Up @@ -945,16 +1012,16 @@ impl<F: PrimeField> MemAlignSM<F> {
);
}

// Compute the padding multiplicity
// Compute the program multiplicity
let mem_align_rom_sm = self.mem_align_rom_sm.clone();
mem_align_rom_sm.update_padding_row(padding_size as u64);

info!(
"{}: ··· Creating Mem Align instance [{} / {} rows filled {:.2}%]",
Self::MY_NAME,
rows_len,
air_mem_align.num_rows(),
rows_len as f64 / air_mem_align.num_rows() as f64 * 100.0
air_mem_align_rows,
rows_len as f64 / air_mem_align_rows as f64 * 100.0
);

// Add a new Mem Align instance
Expand Down

0 comments on commit b5d00b1

Please sign in to comment.