diff --git a/pil/src/pil_helpers/traces.rs b/pil/src/pil_helpers/traces.rs index 40dfc7fd..e9631826 100644 --- a/pil/src/pil_helpers/traces.rs +++ b/pil/src/pil_helpers/traces.rs @@ -12,11 +12,11 @@ trace!(RomRow, RomTrace { }); trace!(MemRow, MemTrace { - addr: F, step: F, sel: F, wr: F, value: [F; 2], addr_changes: F, same_value: F, first_addr_access_is_read: F, + addr: F, step: F, sel: F, wr: F, value: [F; 2], addr_changes: F, increment: F, same_value: F, first_addr_access_is_read: F, }); trace!(MemAlignRow, MemAlignTrace { - addr: F, offset: F, width: F, wr: F, pc: F, reset: F, sel_up_to_down: F, sel_down_to_up: F, reg: [F; 8], sel: [F; 8], step: F, sel_prove: F, + addr: F, offset: F, width: F, wr: F, pc: F, reset: F, sel_up_to_down: F, sel_down_to_up: F, reg: [F; 8], sel: [F; 8], step: F, delta_addr: F, sel_prove: F, value: [F; 2], }); trace!(MemAlignRomRow, MemAlignRomTrace { @@ -58,5 +58,3 @@ trace!(SpecifiedRangesRow, SpecifiedRangesTrace { trace!(U8AirRow, U8AirTrace { mul: F, }); - - diff --git a/state-machines/mem/pil/mem_align.pil b/state-machines/mem/pil/mem_align.pil index 0c3cce00..8a23ab2a 100644 --- a/state-machines/mem/pil/mem_align.pil +++ b/state-machines/mem/pil/mem_align.pil @@ -100,7 +100,7 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM col witness sel_down_to_up; // 1 if the next value is the previous value (e.g. W -> R) col witness reg[CHUNK_NUM]; // Register values, 1 byte each col witness sel[CHUNK_NUM]; // Selectors, 1 if the value is used, 0 otherwise - col witness step; // Step of memory + col witness step; // Memory step // 1] Ensure the MemAlign follows the program @@ -133,7 +133,12 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM } flags += wr * 2**CHUNK_NUM + reset * 2**(CHUNK_NUM + 1) + sel_up_to_down * 2**(CHUNK_NUM + 2) + sel_down_to_up * 2**(CHUNK_NUM + 3); - lookup_assumes(MEM_ALIGN_ROM_ID, [pc, pc'-pc, (addr-'addr)*(1-reset), offset, width, flags]); + // Perform the lookup against the program + expr delta_pc; + col witness delta_addr; // Auxiliary column + delta_pc = pc' - pc; + delta_addr === (addr - 'addr) * (1 - reset); + lookup_assumes(MEM_ALIGN_ROM_ID, [pc, delta_pc, delta_addr, offset, width, flags]); // 2] Assume aligned memory accesses against the Memory component const expr sel_assume = sel_up_to_down + sel_down_to_up; @@ -143,10 +148,12 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM // On assume steps, we reconstruct the value from the registers directly expr assume_val[RC]; - for (int i = 0; i < RC; i++) { - assume_val[i] = 0; - for (int j = 0; j < CHUNKS_BY_RC; j++) { - assume_val[i] += reg[j + i * CHUNKS_BY_RC] * 2**(j*8); + for (int rc_index = 0; rc_index < RC; rc_index++) { + assume_val[rc_index] = 0; + int base = 1; + for (int _offset = 0; _offset < CHUNKS_BY_RC; _offset++) { + assume_val[rc_index] += reg[_offset + rc_index * CHUNKS_BY_RC] * base; + base *= 256; } } @@ -157,7 +164,7 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM // On prove steps, we reconstruct the value in the correct manner chosen by the selectors expr prove_val[RC]; - for (int rc_index = 0; rc_index < RC; ++rc_index) { + for (int rc_index = 0; rc_index < RC; rc_index++) { prove_val[rc_index] = 0; } for (int _offset = 0; _offset < CHUNK_NUM; _offset++) { @@ -166,13 +173,16 @@ airtemplate MemAlign(const int N = 2**10, const int RC = 2, const int CHUNK_NUM int base = 1; for (int ichunk = 0; ichunk < CHUNKS_BY_RC; ichunk++) { _tmp += reg[(_offset + rc_index * CHUNKS_BY_RC + ichunk) % CHUNK_NUM] * base; - base = base * 256; + base *= 256; } prove_val[rc_index] += sel[_offset] * _tmp; } } // We prove and assume with the same permutation check but with disjoint and different sign selectors - permutation_proves(MEMORY_ID, cols: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * CHUNK_NUM + offset, step, width, ...prove_val], sel: sel_prove); - permutation_assumes(MEMORY_ID, cols: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * CHUNK_NUM + offset, step, width, ...assume_val], sel: sel_assume); + col witness value[RC]; // Auxiliary columns + for (int i = 0; i < RC; i++) { + value[i] === sel_prove * prove_val[i] + sel_assume * assume_val[i]; + } + permutation(MEMORY_ID, cols: [wr * (MEMORY_STORE_OP - MEMORY_LOAD_OP) + MEMORY_LOAD_OP, addr * CHUNK_NUM + offset, step, width, ...value], sel: sel_prove - sel_assume); } \ No newline at end of file diff --git a/state-machines/mem/src/mem_align_rom_sm.rs b/state-machines/mem/src/mem_align_rom_sm.rs index 61170135..df6081e9 100644 --- a/state-machines/mem/src/mem_align_rom_sm.rs +++ b/state-machines/mem/src/mem_align_rom_sm.rs @@ -79,22 +79,22 @@ impl MemAlignRomSM { pub fn calculate_next_pc(&self, opcode: MemOp, offset: usize, width: usize) -> u64 { // Get the table offset - let (table_offset, one_word) = match opcode { - MemOp::OneRead => { - (1, true) - } - - MemOp::OneWrite => { - (1 + ONE_WORD_COMBINATIONS * OP_SIZES[0], true) - } - - MemOp::TwoReads => { - (1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + ONE_WORD_COMBINATIONS * OP_SIZES[1], false) - } - - MemOp::TwoWrites => { - (1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + ONE_WORD_COMBINATIONS * OP_SIZES[1] + TWO_WORD_COMBINATIONS * OP_SIZES[2], false) - } + let (table_offset, one_word) = match opcode { + MemOp::OneRead => (1, true), + + MemOp::OneWrite => (1 + ONE_WORD_COMBINATIONS * OP_SIZES[0], true), + + MemOp::TwoReads => ( + 1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + ONE_WORD_COMBINATIONS * OP_SIZES[1], + false, + ), + + MemOp::TwoWrites => ( + 1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + + ONE_WORD_COMBINATIONS * OP_SIZES[1] + + TWO_WORD_COMBINATIONS * OP_SIZES[2], + false, + ), }; // Get the first row index @@ -114,7 +114,13 @@ impl MemAlignRomSM { first_row_idx } - fn get_first_row_idx(opcode: MemOp, offset: usize, width: usize, table_offset: u64, one_word: bool) -> u64 { + fn get_first_row_idx( + opcode: MemOp, + offset: usize, + width: usize, + table_offset: u64, + one_word: bool, + ) -> u64 { let opcode_idx = opcode as usize; let op_size = OP_SIZES[opcode_idx]; @@ -203,11 +209,7 @@ impl MemAlignRomSM { } } - info!( - "{}: ··· Creating Mem Align ROM instance [{} rows filled 100%]", - Self::MY_NAME, - self.num_rows, - ); + info!("{}: ··· Creating Mem Align Rom instance", Self::MY_NAME,); let air_instance = AirInstance::new(sctx, ZISK_AIRGROUP_ID, MEM_ALIGN_ROM_AIR_IDS[0], None, prover_buffer); diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index eafbd80d..40e1e148 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -17,8 +17,12 @@ use zisk_pil::{MemAlignRow, MemAlignTrace, MEM_ALIGN_AIR_IDS, ZISK_AIRGROUP_ID}; use crate::{MemAlignInput, MemAlignRomSM, MemOp}; +const RC: usize = 2; const CHUNK_NUM: usize = 8; +const CHUNKS_BY_RC: usize = CHUNK_NUM / RC; const CHUNK_BITS: usize = 8; +const RC_BITS: u64 = (CHUNKS_BY_RC * CHUNK_BITS) as u64; +const RC_MASK: u64 = (1 << RC_BITS) - 1; const OFFSET_MASK: u32 = 0x07; const OFFSET_BITS: u32 = 3; const CHUNK_BITS_MASK: u64 = (1 << CHUNK_BITS) - 1; @@ -131,13 +135,6 @@ impl MemAlignSM { #[inline(always)] pub fn get_mem_op(&self, input: &MemAlignInput, phase: usize) -> MemAlignResponse { - debug_assert!( - input.mem_values.len() == phase + 1, - "The number of mem_values {} is not equal to phase + 1 {}", - input.mem_values.len(), - phase + 1 - ); - let addr = input.address; let width = input.width; @@ -192,6 +189,7 @@ impl MemAlignSM { let mut read_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), offset: F::from_canonical_u64(DEFAULT_OFFSET), width: F::from_canonical_u64(DEFAULT_WIDTH), // wr: F::from_bool(false), @@ -204,6 +202,7 @@ impl MemAlignSM { let mut value_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), offset: F::from_canonical_usize(offset), width: F::from_canonical_usize(width), // wr: F::from_bool(false), @@ -226,6 +225,15 @@ impl MemAlignSM { } } + let mut _value_read = value_read; + let mut _value = value; + for i in 0..RC { + read_row.value[i] = F::from_canonical_u64(_value_read & RC_MASK); + value_row.value[i] = F::from_canonical_u64(_value & RC_MASK); + _value_read >>= RC_BITS; + _value >>= RC_BITS; + } + #[rustfmt::skip] debug_info!( "\nOne Word Read\n\ @@ -309,6 +317,7 @@ impl MemAlignSM { let mut read_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), offset: F::from_canonical_u64(DEFAULT_OFFSET), width: F::from_canonical_u64(DEFAULT_WIDTH), // wr: F::from_bool(false), @@ -321,6 +330,7 @@ impl MemAlignSM { let mut write_row = MemAlignRow:: { step: F::from_canonical_u64(step + 1), addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), offset: F::from_canonical_u64(DEFAULT_OFFSET), width: F::from_canonical_u64(DEFAULT_WIDTH), wr: F::from_bool(true), @@ -333,6 +343,7 @@ impl MemAlignSM { let mut value_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u32(addr_read), + // delta_addr: F::zero(), offset: F::from_canonical_usize(offset), width: F::from_canonical_usize(width), wr: F::from_bool(true), @@ -365,6 +376,18 @@ impl MemAlignSM { } } + let mut _value_read = value_read; + let mut _value_write = value_write; + let mut _value = value; + for i in 0..RC { + read_row.value[i] = F::from_canonical_u64(_value_read & RC_MASK); + write_row.value[i] = F::from_canonical_u64(_value_write & RC_MASK); + value_row.value[i] = F::from_canonical_u64(_value & RC_MASK); + _value_read >>= RC_BITS; + _value_write >>= RC_BITS; + _value >>= RC_BITS; + } + #[rustfmt::skip] debug_info!( "\nOne Word Write\n\ @@ -454,6 +477,7 @@ impl MemAlignSM { let mut first_read_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u32(addr_first_read), + // delta_addr: F::zero(), offset: F::from_canonical_u64(DEFAULT_OFFSET), width: F::from_canonical_u64(DEFAULT_WIDTH), // wr: F::from_bool(false), @@ -466,6 +490,7 @@ impl MemAlignSM { let mut value_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u32(addr_first_read), + // delta_addr: F::zero(), offset: F::from_canonical_usize(offset), width: F::from_canonical_usize(width), // wr: F::from_bool(false), @@ -478,6 +503,7 @@ impl MemAlignSM { let mut second_read_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u32(addr_second_read), + delta_addr: F::one(), offset: F::from_canonical_u64(DEFAULT_OFFSET), width: F::from_canonical_u64(DEFAULT_WIDTH), // wr: F::from_bool(false), @@ -508,6 +534,20 @@ impl MemAlignSM { } } + let mut _value_first_read = value_first_read; + let mut _value = value; + let mut _value_second_read = value_second_read; + for i in 0..RC { + first_read_row.value[i] = + F::from_canonical_u64(_value_first_read & RC_MASK); + value_row.value[i] = F::from_canonical_u64(_value & RC_MASK); + second_read_row.value[i] = + F::from_canonical_u64(_value_second_read & RC_MASK); + _value_first_read >>= RC_BITS; + _value >>= RC_BITS; + _value_second_read >>= RC_BITS; + } + #[rustfmt::skip] debug_info!( "\nTwo Words Read\n\ @@ -677,6 +717,7 @@ impl MemAlignSM { let mut first_read_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u32(addr_first_read_write), + // delta_addr: F::zero(), offset: F::from_canonical_u64(DEFAULT_OFFSET), width: F::from_canonical_u64(DEFAULT_WIDTH), // wr: F::from_bool(false), @@ -689,6 +730,7 @@ impl MemAlignSM { let mut first_write_row = MemAlignRow:: { step: F::from_canonical_u64(step + 1), addr: F::from_canonical_u32(addr_first_read_write), + // delta_addr: F::zero(), offset: F::from_canonical_u64(DEFAULT_OFFSET), width: F::from_canonical_u64(DEFAULT_WIDTH), wr: F::from_bool(true), @@ -701,6 +743,7 @@ impl MemAlignSM { let mut value_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u32(addr_first_read_write), + // delta_addr: F::zero(), offset: F::from_canonical_usize(offset), width: F::from_canonical_usize(width), wr: F::from_bool(true), @@ -713,6 +756,7 @@ impl MemAlignSM { let mut second_write_row = MemAlignRow:: { step: F::from_canonical_u64(step + 1), addr: F::from_canonical_u32(addr_second_read_write), + delta_addr: F::one(), offset: F::from_canonical_u64(DEFAULT_OFFSET), width: F::from_canonical_u64(DEFAULT_WIDTH), wr: F::from_bool(true), @@ -725,6 +769,7 @@ impl MemAlignSM { let mut second_read_row = MemAlignRow:: { step: F::from_canonical_u64(step), addr: F::from_canonical_u32(addr_second_read_write), + // delta_addr: F::zero(), offset: F::from_canonical_u64(DEFAULT_OFFSET), width: F::from_canonical_u64(DEFAULT_WIDTH), // wr: F::from_bool(false), @@ -777,6 +822,28 @@ impl MemAlignSM { } } + let mut _value_first_read = value_first_read; + let mut _value_first_write = value_first_write; + let mut _value = value; + let mut _value_second_write = value_second_write; + let mut _value_second_read = value_second_read; + for i in 0..RC { + first_read_row.value[i] = + F::from_canonical_u64(_value_first_read & RC_MASK); + first_write_row.value[i] = + F::from_canonical_u64(_value_first_write & RC_MASK); + value_row.value[i] = F::from_canonical_u64(_value & RC_MASK); + second_write_row.value[i] = + F::from_canonical_u64(_value_second_write & RC_MASK); + second_read_row.value[i] = + F::from_canonical_u64(_value_second_read & RC_MASK); + _value_first_read >>= RC_BITS; + _value_first_write >>= RC_BITS; + _value >>= RC_BITS; + _value_second_write >>= RC_BITS; + _value_second_read >>= RC_BITS; + } + #[rustfmt::skip] debug_info!( "\nTwo Words Write\n\ @@ -945,7 +1012,7 @@ impl MemAlignSM { ); } - // Compute the padding multiplicity + // Compute the program multiplicity let mem_align_rom_sm = self.mem_align_rom_sm.clone(); mem_align_rom_sm.update_padding_row(padding_size as u64); @@ -953,8 +1020,8 @@ impl MemAlignSM { "{}: ··· Creating Mem Align instance [{} / {} rows filled {:.2}%]", Self::MY_NAME, rows_len, - air_mem_align.num_rows(), - rows_len as f64 / air_mem_align.num_rows() as f64 * 100.0 + air_mem_align_rows, + rows_len as f64 / air_mem_align_rows as f64 * 100.0 ); // Add a new Mem Align instance