From e0c0b38d1275845976f80215cfe5de09609fa144 Mon Sep 17 00:00:00 2001 From: fractasy Date: Wed, 18 Dec 2024 19:22:59 +0000 Subject: [PATCH 01/10] Collect and consume mem reads in store c. Send memory data bus operation payloads. --- emulator/src/emu.rs | 680 ++++++++++++++++++++++++++++++++-- pil/src/pil_helpers/traces.rs | 14 +- 2 files changed, 649 insertions(+), 45 deletions(-) diff --git a/emulator/src/emu.rs b/emulator/src/emu.rs index 002a92fa..99c51ec5 100644 --- a/emulator/src/emu.rs +++ b/emulator/src/emu.rs @@ -15,6 +15,59 @@ use zisk_core::{ SRC_IND, SRC_MEM, SRC_STEP, STORE_IND, STORE_MEM, STORE_NONE, }; +struct MemBusHelpers {} + +const MEMORY_LOAD_OP: u64 = 1; +const MEMORY_STORE_OP: u64 = 2; +const MEM_STEP_BASE: u64 = 3; +const MAX_MEM_OPS_BY_MAIN_STEP: u64 = 4; +const MAX_MEM_OPS_BY_STEP_OFFSET: u64 = 5; + +impl MemBusHelpers { + // function mem_load(expr addr, expr step, expr step_offset = 0, expr bytes = 8, expr value[]) { + // function mem_store(expr addr, expr step, expr step_offset = 0, expr bytes = 8, expr value[]) + // { + pub fn mem_load( + addr: u32, + step: u64, + step_offset: u8, + bytes: u8, + mem_values: [u64; 2], + ) -> [u64; 7] { + [ + MEMORY_LOAD_OP, + addr as u64, + MEM_STEP_BASE + + MAX_MEM_OPS_BY_MAIN_STEP * step + + MAX_MEM_OPS_BY_STEP_OFFSET * step_offset as u64, + bytes as u64, + mem_values[0], + mem_values[1], + 0, + ] + } + pub fn mem_write( + addr: u32, + step: u64, + step_offset: u8, + bytes: u8, + value: u64, + mem_values: [u64; 2], + ) -> [u64; 7] { + [ + MEMORY_STORE_OP, + addr as u64, + MEM_STEP_BASE + + MAX_MEM_OPS_BY_MAIN_STEP * step + + MAX_MEM_OPS_BY_STEP_OFFSET * step_offset as u64, + bytes as u64, + mem_values[0], + mem_values[1], + value, + ] + } +} + /// ZisK emulator structure, containing the ZisK rom, the list of ZisK operations, and the /// execution context pub struct Emu<'a> { @@ -113,7 +166,6 @@ impl<'a> Emu<'a> { &mut self, instruction: &ZiskInst, mem_reads: &mut Vec, - _mem_reads_index: &mut usize, ) { match instruction.a_src { SRC_C => self.ctx.inst_ctx.a = self.ctx.inst_ctx.c, @@ -226,6 +278,86 @@ impl<'a> Emu<'a> { } } + /// Calculate the 'a' register value based on the source specified by the current instruction, + /// using formerly generated memory reads from a previous emulation + #[inline(always)] + pub fn source_a_mem_reads_consume_databus>( + &mut self, + instruction: &ZiskInst, + mem_reads: &[u64], + mem_reads_index: &mut usize, + data_bus: &mut DataBus, + ) { + match instruction.a_src { + SRC_C => self.ctx.inst_ctx.a = self.ctx.inst_ctx.c, + SRC_MEM => { + // Calculate memory address + let mut address = instruction.a_offset_imm0; + if instruction.a_use_sp_imm1 != 0 { + address += self.ctx.inst_ctx.sp; + } + + // If the operation is a register operation, get it from the context registers + if Mem::address_is_register(address) { + self.ctx.inst_ctx.a = self.get_reg(Mem::address_to_register_index(address)); + } + // Otherwise, get it from memory + else if Mem::is_full_aligned(address, 8) { + assert!(*mem_reads_index < mem_reads.len()); + self.ctx.inst_ctx.a = mem_reads[*mem_reads_index]; + *mem_reads_index += 1; + let payload = MemBusHelpers::mem_load( + address as u32, + self.ctx.inst_ctx.step, + 0, + 8, + [self.ctx.inst_ctx.a, 0], + ); + data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + } else { + let (required_address_1, required_address_2) = + Mem::required_addresses(address, 8); + debug_assert!(required_address_1 != required_address_2); + assert!(*mem_reads_index < mem_reads.len()); + let raw_data_1 = mem_reads[*mem_reads_index]; + *mem_reads_index += 1; + assert!(*mem_reads_index < mem_reads.len()); + let raw_data_2 = mem_reads[*mem_reads_index]; + *mem_reads_index += 1; + self.ctx.inst_ctx.a = + Mem::get_double_not_aligned_data(address, 8, raw_data_1, raw_data_2); + let payload = MemBusHelpers::mem_load( + address as u32, + self.ctx.inst_ctx.step, + 0, + 8, + [raw_data_1, raw_data_2], + ); + data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + } + /*println!( + "Emu::source_a_mem_reads_consume() mem_leads_index={} value={:x}", + *mem_reads_index, self.ctx.inst_ctx.a + );*/ + + // Feed the stats + if self.ctx.do_stats { + self.ctx.stats.on_memory_read(address, 8); + } + } + SRC_IMM => { + self.ctx.inst_ctx.a = instruction.a_offset_imm0 | (instruction.a_use_sp_imm1 << 32) + } + SRC_STEP => self.ctx.inst_ctx.a = self.ctx.inst_ctx.step, + // #[cfg(feature = "sp")] + // SRC_SP => self.ctx.inst_ctx.a = self.ctx.inst_ctx.sp, + _ => panic!( + "Emu::source_a_mem_reads_consume_databus() Invalid a_src={} pc={}", + instruction.a_src, self.ctx.inst_ctx.pc + ), + } + } + /// Calculate the 'b' register value based on the source specified by the current instruction #[inline(always)] pub fn source_b(&mut self, instruction: &ZiskInst) { @@ -288,7 +420,6 @@ impl<'a> Emu<'a> { &mut self, instruction: &ZiskInst, mem_reads: &mut Vec, - _mem_reads_index: &mut usize, ) { match instruction.b_src { SRC_C => self.ctx.inst_ctx.b = self.ctx.inst_ctx.c, @@ -337,6 +468,7 @@ impl<'a> Emu<'a> { // If the operation is a register operation, get it from the context registers if Mem::address_is_register(address) { + assert!(instruction.ind_width == 8); self.ctx.inst_ctx.b = self.get_reg(Mem::address_to_register_index(address)); } // Otherwise, get it from memory @@ -437,6 +569,7 @@ impl<'a> Emu<'a> { // If the operation is a register operation, get it from the context registers if Mem::address_is_register(address) { + assert!(instruction.ind_width == 8); self.ctx.inst_ctx.b = self.get_reg(Mem::address_to_register_index(address)); } // Otherwise, get it from memory @@ -487,6 +620,191 @@ impl<'a> Emu<'a> { } } + /// Calculate the 'b' register value based on the source specified by the current instruction, + /// using formerly generated memory reads from a previous emulation + #[inline(always)] + pub fn source_b_mem_reads_consume_databus>( + &mut self, + instruction: &ZiskInst, + mem_reads: &[u64], + mem_reads_index: &mut usize, + data_bus: &mut DataBus, + ) { + match instruction.b_src { + SRC_C => self.ctx.inst_ctx.b = self.ctx.inst_ctx.c, + SRC_MEM => { + // Calculate memory address + let mut address = instruction.b_offset_imm0; + if instruction.b_use_sp_imm1 != 0 { + address += self.ctx.inst_ctx.sp; + } + + // If the operation is a register operation, get it from the context registers + if Mem::address_is_register(address) { + self.ctx.inst_ctx.b = self.get_reg(Mem::address_to_register_index(address)); + let payload = MemBusHelpers::mem_load( + address as u32, + self.ctx.inst_ctx.step, + 1, + 8, + [self.ctx.inst_ctx.b, 0], + ); + data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + } + // Otherwise, get it from memory + else if Mem::is_full_aligned(address, 8) { + assert!(*mem_reads_index < mem_reads.len()); + self.ctx.inst_ctx.b = mem_reads[*mem_reads_index]; + *mem_reads_index += 1; + let payload = MemBusHelpers::mem_load( + address as u32, + self.ctx.inst_ctx.step, + 1, + 8, + [self.ctx.inst_ctx.b, 0], + ); + data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + } else { + let (required_address_1, required_address_2) = + Mem::required_addresses(address, 8); + if required_address_1 == required_address_2 { + assert!(*mem_reads_index < mem_reads.len()); + let raw_data = mem_reads[*mem_reads_index]; + *mem_reads_index += 1; + self.ctx.inst_ctx.b = + Mem::get_single_not_aligned_data(address, 8, raw_data); + let payload = MemBusHelpers::mem_load( + address as u32, + self.ctx.inst_ctx.step, + 1, + 8, + [self.ctx.inst_ctx.b, 0], + ); + data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + } else { + assert!(*mem_reads_index < mem_reads.len()); + let raw_data_1 = mem_reads[*mem_reads_index]; + *mem_reads_index += 1; + assert!(*mem_reads_index < mem_reads.len()); + let raw_data_2 = mem_reads[*mem_reads_index]; + *mem_reads_index += 1; + self.ctx.inst_ctx.b = + Mem::get_double_not_aligned_data(address, 8, raw_data_1, raw_data_2); + let payload = MemBusHelpers::mem_load( + address as u32, + self.ctx.inst_ctx.step, + 1, + 8, + [raw_data_1, raw_data_2], + ); + data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + } + } + /*println!( + "Emu::source_b_mem_reads_consume() mem_leads_index={} value={:x}", + *mem_reads_index, self.ctx.inst_ctx.b + );*/ + + if self.ctx.do_stats { + self.ctx.stats.on_memory_read(address, 8); + } + } + SRC_IMM => { + self.ctx.inst_ctx.b = instruction.b_offset_imm0 | (instruction.b_use_sp_imm1 << 32) + } + SRC_IND => { + // Calculate memory address + let mut address = + (self.ctx.inst_ctx.a as i64 + instruction.b_offset_imm0 as i64) as u64; + if instruction.b_use_sp_imm1 != 0 { + address += self.ctx.inst_ctx.sp; + } + + // If the operation is a register operation, get it from the context registers + if Mem::address_is_register(address) { + assert!(instruction.ind_width == 8); + self.ctx.inst_ctx.b = self.get_reg(Mem::address_to_register_index(address)); + let payload = MemBusHelpers::mem_load( + address as u32, + self.ctx.inst_ctx.step, + 1, + 8, + [self.ctx.inst_ctx.b, 0], + ); + data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + } + // Otherwise, get it from memory + else if Mem::is_full_aligned(address, instruction.ind_width) { + assert!(*mem_reads_index < mem_reads.len()); + self.ctx.inst_ctx.b = mem_reads[*mem_reads_index]; + *mem_reads_index += 1; + let payload = MemBusHelpers::mem_load( + address as u32, + self.ctx.inst_ctx.step, + 1, + 8, + [self.ctx.inst_ctx.b, 0], + ); + data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + } else { + let (required_address_1, required_address_2) = + Mem::required_addresses(address, instruction.ind_width); + if required_address_1 == required_address_2 { + assert!(*mem_reads_index < mem_reads.len()); + let raw_data = mem_reads[*mem_reads_index]; + *mem_reads_index += 1; + self.ctx.inst_ctx.b = Mem::get_single_not_aligned_data( + address, + instruction.ind_width, + raw_data, + ); + let payload = MemBusHelpers::mem_load( + address as u32, + self.ctx.inst_ctx.step, + 1, + instruction.ind_width as u8, + [raw_data, 0], + ); + data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + } else { + assert!(*mem_reads_index < mem_reads.len()); + let raw_data_1 = mem_reads[*mem_reads_index]; + *mem_reads_index += 1; + assert!(*mem_reads_index < mem_reads.len()); + let raw_data_2 = mem_reads[*mem_reads_index]; + *mem_reads_index += 1; + self.ctx.inst_ctx.b = Mem::get_double_not_aligned_data( + address, + instruction.ind_width, + raw_data_1, + raw_data_2, + ); + let payload = MemBusHelpers::mem_load( + address as u32, + self.ctx.inst_ctx.step, + 1, + 8, + [raw_data_1, raw_data_2], + ); + data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + } + } + /*println!( + "Emu::source_b_mem_reads_consume() mem_leads_index={} value={:x}", + *mem_reads_index, self.ctx.inst_ctx.b + );*/ + + if self.ctx.do_stats { + self.ctx.stats.on_memory_read(address, instruction.ind_width); + } + } + _ => panic!( + "Emu::source_b_mem_reads_consume_databus() Invalid b_src={} pc={}", + instruction.b_src, self.ctx.inst_ctx.pc + ), + } + } + /// Store the 'c' register value based on the storage specified by the current instruction #[inline(always)] pub fn store_c(&mut self, instruction: &ZiskInst) { @@ -561,56 +879,332 @@ impl<'a> Emu<'a> { /// Store the 'c' register value based on the storage specified by the current instruction and /// log memory access if required #[inline(always)] - pub fn store_c_slice(&mut self, instruction: &ZiskInst) { + pub fn store_c_mem_reads_generate(&mut self, instruction: &ZiskInst, mem_reads: &mut Vec) { match instruction.store { STORE_NONE => {} STORE_MEM => { // Calculate the value - let val: i64 = if instruction.store_ra { + let value: i64 = if instruction.store_ra { self.ctx.inst_ctx.pc as i64 + instruction.jmp_offset2 } else { self.ctx.inst_ctx.c as i64 }; - let val = val as u64; + let value: u64 = value as u64; // Calculate the memory address - let mut addr: i64 = instruction.store_offset; + let mut address: i64 = instruction.store_offset; if instruction.store_use_sp { - addr += self.ctx.inst_ctx.sp as i64; + address += self.ctx.inst_ctx.sp as i64; } - debug_assert!(addr >= 0); - let addr = addr as u64; + debug_assert!(address >= 0); + let address = address as u64; // If the operation is a register operation, write it to the context registers - if Mem::address_is_register(addr) { - self.set_reg(Mem::address_to_register_index(addr), val); + if Mem::address_is_register(address) { + self.set_reg(Mem::address_to_register_index(address), value); + } + // Otherwise, if not aligned, get old raw data from memory, then write it + else if Mem::is_full_aligned(address, 8) { + self.ctx.inst_ctx.mem.write(address, value, 8); + } else { + let mut additional_data: Vec; + (self.ctx.inst_ctx.b, additional_data) = + self.ctx.inst_ctx.mem.read_required(address, 8); + debug_assert!(!additional_data.is_empty()); + mem_reads.append(&mut additional_data); + + self.ctx.inst_ctx.mem.write(address, value, 8); } } STORE_IND => { // Calculate the value - let val: i64 = if instruction.store_ra { + let value: i64 = if instruction.store_ra { self.ctx.inst_ctx.pc as i64 + instruction.jmp_offset2 } else { self.ctx.inst_ctx.c as i64 }; - let val = val as u64; + let value = value as u64; // Calculate the memory address - let mut addr = instruction.store_offset; + let mut address = instruction.store_offset; if instruction.store_use_sp { - addr += self.ctx.inst_ctx.sp as i64; + address += self.ctx.inst_ctx.sp as i64; } - addr += self.ctx.inst_ctx.a as i64; - debug_assert!(addr >= 0); - let addr = addr as u64; + address += self.ctx.inst_ctx.a as i64; + debug_assert!(address >= 0); + let address = address as u64; // If the operation is a register operation, write it to the context registers - if (instruction.ind_width == 8) && Mem::address_is_register(addr) { - self.set_reg(Mem::address_to_register_index(addr), val); + if Mem::address_is_register(address) { + assert!(instruction.ind_width == 8); + self.set_reg(Mem::address_to_register_index(address), value); + } + // Otherwise, if not aligned, get old raw data from memory, then write it + else if Mem::is_full_aligned(address, instruction.ind_width) { + self.ctx.inst_ctx.mem.write(address, value, instruction.ind_width); + } else { + let mut additional_data: Vec; + (self.ctx.inst_ctx.b, additional_data) = + self.ctx.inst_ctx.mem.read_required(address, instruction.ind_width); + debug_assert!(!additional_data.is_empty()); + mem_reads.append(&mut additional_data); + + self.ctx.inst_ctx.mem.write(address, value, instruction.ind_width); + } + } + _ => panic!( + "Emu::store_c_mem_reads_generate() Invalid store={} pc={}", + instruction.store, self.ctx.inst_ctx.pc + ), + } + } + + /// Store the 'c' register value based on the storage specified by the current instruction and + /// log memory access if required + #[inline(always)] + pub fn store_c_mem_reads_consume( + &mut self, + instruction: &ZiskInst, + mem_reads: &[u64], + mem_reads_index: &mut usize, + ) { + match instruction.store { + STORE_NONE => {} + STORE_MEM => { + // Calculate the memory address + let mut address: i64 = instruction.store_offset; + if instruction.store_use_sp { + address += self.ctx.inst_ctx.sp as i64; + } + debug_assert!(address >= 0); + let address = address as u64; + + // If the operation is a register operation, write it to the context registers + if Mem::address_is_register(address) { + // Calculate the value + let value: i64 = if instruction.store_ra { + self.ctx.inst_ctx.pc as i64 + instruction.jmp_offset2 + } else { + self.ctx.inst_ctx.c as i64 + }; + let value = value as u64; + + self.set_reg(Mem::address_to_register_index(address), value); + } + // Otherwise, if not aligned, get old raw data from memory, then write it + else if !Mem::is_full_aligned(address, 8) { + let (required_address_1, required_address_2) = + Mem::required_addresses(address, 8); + if required_address_1 == required_address_2 { + assert!(*mem_reads_index < mem_reads.len()); + *mem_reads_index += 1; + } else { + assert!(*mem_reads_index < mem_reads.len()); + *mem_reads_index += 1; + assert!(*mem_reads_index < mem_reads.len()); + *mem_reads_index += 1; + } + } + } + STORE_IND => { + // Calculate the memory address + let mut address = instruction.store_offset; + if instruction.store_use_sp { + address += self.ctx.inst_ctx.sp as i64; + } + address += self.ctx.inst_ctx.a as i64; + debug_assert!(address >= 0); + let address = address as u64; + + // If the operation is a register operation, write it to the context registers + if Mem::address_is_register(address) { + // Calculate the value + let value: i64 = if instruction.store_ra { + self.ctx.inst_ctx.pc as i64 + instruction.jmp_offset2 + } else { + self.ctx.inst_ctx.c as i64 + }; + let value = value as u64; + + assert!(instruction.ind_width == 8); + self.set_reg(Mem::address_to_register_index(address), value); + } + // Otherwise, if not aligned, get old raw data from memory, then write it + else if !Mem::is_full_aligned(address, instruction.ind_width) { + let (required_address_1, required_address_2) = + Mem::required_addresses(address, instruction.ind_width); + if required_address_1 == required_address_2 { + assert!(*mem_reads_index < mem_reads.len()); + *mem_reads_index += 1; + } else { + assert!(*mem_reads_index < mem_reads.len()); + *mem_reads_index += 1; + assert!(*mem_reads_index < mem_reads.len()); + *mem_reads_index += 1; + } + } + } + _ => panic!( + "Emu::store_c_mem_reads_consume() Invalid store={} pc={}", + instruction.store, self.ctx.inst_ctx.pc + ), + } + } + + /// Store the 'c' register value based on the storage specified by the current instruction and + /// log memory access if required + #[inline(always)] + pub fn store_c_mem_reads_consume_databus>( + &mut self, + instruction: &ZiskInst, + mem_reads: &[u64], + mem_reads_index: &mut usize, + data_bus: &mut DataBus, + ) { + match instruction.store { + STORE_NONE => {} + STORE_MEM => { + // Calculate the value + let value: i64 = if instruction.store_ra { + self.ctx.inst_ctx.pc as i64 + instruction.jmp_offset2 + } else { + self.ctx.inst_ctx.c as i64 + }; + let value = value as u64; + + // Calculate the memory address + let mut address: i64 = instruction.store_offset; + if instruction.store_use_sp { + address += self.ctx.inst_ctx.sp as i64; + } + debug_assert!(address >= 0); + let address = address as u64; + + // If the operation is a register operation, write it to the context registers + if Mem::address_is_register(address) { + self.set_reg(Mem::address_to_register_index(address), value); + + let payload = MemBusHelpers::mem_write( + address as u32, + self.ctx.inst_ctx.step, + 2, + 8, + value, + [value, 0], + ); + data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + } + // Otherwise, if not aligned, get old raw data from memory, then write it + else if !Mem::is_full_aligned(address, 8) { + let (required_address_1, required_address_2) = + Mem::required_addresses(address, 8); + if required_address_1 == required_address_2 { + assert!(*mem_reads_index < mem_reads.len()); + let raw_data = mem_reads[*mem_reads_index]; + *mem_reads_index += 1; + + let payload = MemBusHelpers::mem_write( + address as u32, + self.ctx.inst_ctx.step, + 2, + 8, + value, + [raw_data, 0], + ); + data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + } else { + assert!(*mem_reads_index < mem_reads.len()); + let raw_data_1 = mem_reads[*mem_reads_index]; + *mem_reads_index += 1; + assert!(*mem_reads_index < mem_reads.len()); + let raw_data_2 = mem_reads[*mem_reads_index]; + *mem_reads_index += 1; + + let payload = MemBusHelpers::mem_write( + address as u32, + self.ctx.inst_ctx.step, + 2, + 8, + value, + [raw_data_1, raw_data_2], + ); + data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + } + } + } + STORE_IND => { + // Calculate the value + let value: i64 = if instruction.store_ra { + self.ctx.inst_ctx.pc as i64 + instruction.jmp_offset2 + } else { + self.ctx.inst_ctx.c as i64 + }; + let value = value as u64; + + // Calculate the memory address + let mut address = instruction.store_offset; + if instruction.store_use_sp { + address += self.ctx.inst_ctx.sp as i64; + } + address += self.ctx.inst_ctx.a as i64; + debug_assert!(address >= 0); + let address = address as u64; + + // If the operation is a register operation, write it to the context registers + if Mem::address_is_register(address) { + assert!(instruction.ind_width == 8); + self.set_reg(Mem::address_to_register_index(address), value); + + let payload = MemBusHelpers::mem_write( + address as u32, + self.ctx.inst_ctx.step, + 2, + 8, + value, + [value, 0], + ); + data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + } + // Otherwise, if not aligned, get old raw data from memory, then write it + else if !Mem::is_full_aligned(address, instruction.ind_width) { + let (required_address_1, required_address_2) = + Mem::required_addresses(address, instruction.ind_width); + if required_address_1 == required_address_2 { + assert!(*mem_reads_index < mem_reads.len()); + let raw_data = mem_reads[*mem_reads_index]; + *mem_reads_index += 1; + + let payload = MemBusHelpers::mem_write( + address as u32, + self.ctx.inst_ctx.step, + 2, + instruction.ind_width as u8, + value, + [raw_data, 0], + ); + data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + } else { + assert!(*mem_reads_index < mem_reads.len()); + let raw_data_1 = mem_reads[*mem_reads_index]; + *mem_reads_index += 1; + assert!(*mem_reads_index < mem_reads.len()); + let raw_data_2 = mem_reads[*mem_reads_index]; + *mem_reads_index += 1; + + let payload = MemBusHelpers::mem_write( + address as u32, + self.ctx.inst_ctx.step, + 2, + instruction.ind_width as u8, + value, + [raw_data_1, raw_data_2], + ); + data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + } } } _ => panic!( - "Emu::store_c_slice() Invalid store={} pc={}", + "Emu::store_c_mem_reads_consume_databus() Invalid store={} pc={}", instruction.store, self.ctx.inst_ctx.pc ), } @@ -929,37 +1523,28 @@ impl<'a> Emu<'a> { /// Performs one single step of the emulation #[inline(always)] pub fn par_step(&mut self, emu_full_trace_vec: &mut EmuTrace) { - let mut mem_reads_index = emu_full_trace_vec.steps.mem_reads.len(); emu_full_trace_vec.last_state = EmuTraceStart { pc: self.ctx.inst_ctx.pc, sp: self.ctx.inst_ctx.sp, c: self.ctx.inst_ctx.c, step: self.ctx.inst_ctx.step, regs: self.ctx.inst_ctx.regs, - mem_reads_index, + mem_reads_index: emu_full_trace_vec.steps.mem_reads.len(), }; let instruction = self.rom.get_instruction(self.ctx.inst_ctx.pc); // Build the 'a' register value based on the source specified by the current instruction - self.source_a_mem_reads_generate( - instruction, - &mut emu_full_trace_vec.steps.mem_reads, - &mut mem_reads_index, - ); + self.source_a_mem_reads_generate(instruction, &mut emu_full_trace_vec.steps.mem_reads); // Build the 'b' register value based on the source specified by the current instruction - self.source_b_mem_reads_generate( - instruction, - &mut emu_full_trace_vec.steps.mem_reads, - &mut mem_reads_index, - ); + self.source_b_mem_reads_generate(instruction, &mut emu_full_trace_vec.steps.mem_reads); // Call the operation (instruction.func)(&mut self.ctx.inst_ctx); // Store the 'c' register value based on the storage specified by the current instruction - self.store_c(instruction); + self.store_c_mem_reads_generate(instruction, &mut emu_full_trace_vec.steps.mem_reads); // Set SP, if specified by the current instruction // #[cfg(feature = "sp")] @@ -1019,7 +1604,7 @@ impl<'a> Emu<'a> { self.source_a_mem_reads_consume(instruction, &trace_step.mem_reads, mem_reads_index); self.source_b_mem_reads_consume(instruction, &trace_step.mem_reads, mem_reads_index); (instruction.func)(&mut self.ctx.inst_ctx); - self.store_c_slice(instruction); + self.store_c_mem_reads_consume(instruction, &trace_step.mem_reads, mem_reads_index); let finished = inst_observer.on_instruction(instruction, &self.ctx.inst_ctx); @@ -1043,10 +1628,25 @@ impl<'a> Emu<'a> { data_bus: &mut DataBus, ) -> bool { let instruction = self.rom.get_instruction(self.ctx.inst_ctx.pc); - self.source_a_mem_reads_consume(instruction, &trace_step.mem_reads, mem_reads_index); - self.source_b_mem_reads_consume(instruction, &trace_step.mem_reads, mem_reads_index); + self.source_a_mem_reads_consume_databus( + instruction, + &trace_step.mem_reads, + mem_reads_index, + data_bus, + ); + self.source_b_mem_reads_consume_databus( + instruction, + &trace_step.mem_reads, + mem_reads_index, + data_bus, + ); (instruction.func)(&mut self.ctx.inst_ctx); - self.store_c_slice(instruction); + self.store_c_mem_reads_consume_databus( + instruction, + &trace_step.mem_reads, + mem_reads_index, + data_bus, + ); let operation_payload = OperationBusData::new_payload(instruction, &self.ctx.inst_ctx); data_bus.write_to_bus(OPERATION_BUS_ID, operation_payload.to_vec()); @@ -1141,7 +1741,7 @@ impl<'a> Emu<'a> { self.source_a_mem_reads_consume(instruction, &trace_step.mem_reads, mem_reads_index); self.source_b_mem_reads_consume(instruction, &trace_step.mem_reads, mem_reads_index); (instruction.func)(&mut self.ctx.inst_ctx); - self.store_c_slice(instruction); + self.store_c_mem_reads_consume(instruction, &trace_step.mem_reads, mem_reads_index); let finished = inst_observer.on_instruction(instruction, &self.ctx.inst_ctx); @@ -1169,7 +1769,7 @@ impl<'a> Emu<'a> { self.source_a_mem_reads_consume(instruction, &trace_step.mem_reads, trace_step_index); self.source_b_mem_reads_consume(instruction, &trace_step.mem_reads, trace_step_index); (instruction.func)(&mut self.ctx.inst_ctx); - self.store_c_slice(instruction); + self.store_c_mem_reads_consume(instruction, &trace_step.mem_reads, trace_step_index); // #[cfg(feature = "sp")] // self.set_sp(instruction); self.set_pc(instruction); diff --git a/pil/src/pil_helpers/traces.rs b/pil/src/pil_helpers/traces.rs index 8e05774e..d149312e 100644 --- a/pil/src/pil_helpers/traces.rs +++ b/pil/src/pil_helpers/traces.rs @@ -1,7 +1,8 @@ // WARNING: This file has been autogenerated from the PILOUT file. // Manual modifications are not recommended and may be overwritten. use proofman_common as common; -pub use proofman_macros::{trace, values}; +pub use proofman_macros::trace; +pub use proofman_macros::values; #[allow(dead_code)] type FieldExtension = [F; 3]; @@ -48,22 +49,25 @@ pub const U_8_AIR_AIR_IDS: &[usize] = &[15]; pub const U_16_AIR_AIR_IDS: &[usize] = &[16]; + //PUBLICS -use serde::{Deserialize, Serialize}; +use serde::Deserialize; +use serde::Serialize; #[derive(Default, Debug, Serialize, Deserialize)] pub struct ZiskPublics { - #[serde(default)] + #[serde(default)] pub rom_root: [u64; 4], + } values!(ZiskPublicValues { rom_root: [F; 4], }); - + values!(ZiskProofValues { enable_input_data: FieldExtension, }); - + trace!(MainTrace { a: [F; 2], b: [F; 2], c: [F; 2], flag: F, pc: F, a_src_imm: F, a_src_mem: F, a_offset_imm0: F, a_imm1: F, a_src_step: F, b_src_imm: F, b_src_mem: F, b_offset_imm0: F, b_imm1: F, b_src_ind: F, ind_width: F, is_external_op: F, op: F, store_ra: F, store_mem: F, store_ind: F, store_offset: F, set_pc: F, jmp_offset1: F, jmp_offset2: F, m32: F, addr1: F, __debug_operation_bus_enabled: F, }, 0, 0, 2097152 ); From 5b1179db6d5b7674334b0ec2582b71f393a38df8 Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Wed, 18 Dec 2024 11:20:35 +0000 Subject: [PATCH 02/10] add mem-helpers --- state-machines/mem/Cargo.toml | 2 +- state-machines/mem/src/lib.rs | 2 + state-machines/mem/src/mem_bus_helpers.rs | 51 +++++++++++++++++++++++ state-machines/mem/src/mem_constants.rs | 3 ++ 4 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 state-machines/mem/src/mem_bus_helpers.rs diff --git a/state-machines/mem/Cargo.toml b/state-machines/mem/Cargo.toml index b914c5da..dba8b5c8 100644 --- a/state-machines/mem/Cargo.toml +++ b/state-machines/mem/Cargo.toml @@ -24,4 +24,4 @@ num-bigint = { workspace = true } default = [] no_lib_link = ["proofman-common/no_lib_link"] debug_mem_proxy_engine = [] -debug_mem_align = [] \ No newline at end of file +debug_mem_align = [] diff --git a/state-machines/mem/src/lib.rs b/state-machines/mem/src/lib.rs index 3c42869b..b69eea43 100644 --- a/state-machines/mem/src/lib.rs +++ b/state-machines/mem/src/lib.rs @@ -1,6 +1,7 @@ mod input_data_sm; mod mem_align_rom_sm; mod mem_align_sm; +mod mem_bus_helpers; mod mem_constants; mod mem_helpers; mod mem_module; @@ -13,6 +14,7 @@ mod rom_data; pub use input_data_sm::*; pub use mem_align_rom_sm::*; pub use mem_align_sm::*; +pub use mem_bus_helpers::*; pub use mem_constants::*; pub use mem_helpers::*; pub use mem_module::*; diff --git a/state-machines/mem/src/mem_bus_helpers.rs b/state-machines/mem/src/mem_bus_helpers.rs new file mode 100644 index 00000000..37adfe8e --- /dev/null +++ b/state-machines/mem/src/mem_bus_helpers.rs @@ -0,0 +1,51 @@ +use crate::{ + MAX_MEM_OPS_BY_MAIN_STEP, MAX_MEM_OPS_BY_STEP_OFFSET, MEMORY_LOAD_OP, MEMORY_STORE_OP, + MEM_STEP_BASE, +}; + +pub struct MemBusHelpers {} + +impl MemBusHelpers { + // function mem_load(expr addr, expr step, expr step_offset = 0, expr bytes = 8, expr value[]) { + // function mem_store(expr addr, expr step, expr step_offset = 0, expr bytes = 8, expr value[]) + // { + pub fn mem_load( + addr: u32, + step: u64, + step_offset: u8, + bytes: u8, + mem_values: [u64; 2], + ) -> [u64; 7] { + [ + MEMORY_LOAD_OP, + addr as u64, + MEM_STEP_BASE + + MAX_MEM_OPS_BY_MAIN_STEP * step + + MAX_MEM_OPS_BY_STEP_OFFSET * step_offset as u64, + bytes as u64, + mem_values[0] as u64, + mem_values[1] as u64, + 0, + ] + } + pub fn mem_write( + addr: u32, + step: u64, + step_offset: u8, + bytes: u8, + value: u64, + mem_values: [u64; 2], + ) -> [u64; 7] { + [ + MEMORY_STORE_OP, + addr as u64, + MEM_STEP_BASE + + MAX_MEM_OPS_BY_MAIN_STEP * step + + MAX_MEM_OPS_BY_STEP_OFFSET * step_offset as u64, + bytes as u64, + mem_values[0] as u64, + mem_values[1] as u64, + value, + ] + } +} diff --git a/state-machines/mem/src/mem_constants.rs b/state-machines/mem/src/mem_constants.rs index 9165edd1..39e885e6 100644 --- a/state-machines/mem/src/mem_constants.rs +++ b/state-machines/mem/src/mem_constants.rs @@ -7,6 +7,9 @@ pub const MAX_MEM_STEP_OFFSET: u64 = 2; pub const MAX_MEM_OPS_BY_STEP_OFFSET: u64 = 2; pub const MAX_MEM_OPS_BY_MAIN_STEP: u64 = (MAX_MEM_STEP_OFFSET + 1) * MAX_MEM_OPS_BY_STEP_OFFSET; +pub const MEMORY_LOAD_OP: u64 = 1; +pub const MEMORY_STORE_OP: u64 = 2; + pub const MAX_MAIN_STEP: u64 = 0x1FFF_FFFF_FFFF_FFFF; pub const MAX_MEM_STEP: u64 = MEM_STEP_BASE + MAX_MEM_OPS_BY_MAIN_STEP * MAX_MAIN_STEP + From 88b7944927cbb7eaaf748b67ca0d5e098a84c847 Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Thu, 19 Dec 2024 21:50:22 +0000 Subject: [PATCH 03/10] adding mem changes with architecture changes --- Cargo.lock | 1 + state-machines/mem/Cargo.toml | 1 + state-machines/mem/src/lib.rs | 2 + state-machines/mem/src/mem_bus_helpers.rs | 4 +- state-machines/mem/src/mem_constants.rs | 9 +- state-machines/mem/src/mem_counters.rs | 123 ++++++++++++++++++++++ state-machines/mem/src/mem_proxy.rs | 20 +++- 7 files changed, 155 insertions(+), 5 deletions(-) create mode 100644 state-machines/mem/src/mem_counters.rs diff --git a/Cargo.lock b/Cargo.lock index 8813a640..ef8cf224 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2240,6 +2240,7 @@ dependencies = [ "rayon", "sm-common", "witness", + "zisk-common", "zisk-core", "zisk-pil", ] diff --git a/state-machines/mem/Cargo.toml b/state-machines/mem/Cargo.toml index dba8b5c8..1f28f4a7 100644 --- a/state-machines/mem/Cargo.toml +++ b/state-machines/mem/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] sm-common = { path = "../common" } zisk-core = { path = "../../core" } +zisk-common = { path = "../../common" } zisk-pil = { path = "../../pil" } num-traits = "0.2" diff --git a/state-machines/mem/src/lib.rs b/state-machines/mem/src/lib.rs index b69eea43..05491fbc 100644 --- a/state-machines/mem/src/lib.rs +++ b/state-machines/mem/src/lib.rs @@ -3,6 +3,7 @@ mod mem_align_rom_sm; mod mem_align_sm; mod mem_bus_helpers; mod mem_constants; +mod mem_counters; mod mem_helpers; mod mem_module; mod mem_proxy; @@ -16,6 +17,7 @@ pub use mem_align_rom_sm::*; pub use mem_align_sm::*; pub use mem_bus_helpers::*; pub use mem_constants::*; +pub use mem_counters::*; pub use mem_helpers::*; pub use mem_module::*; pub use mem_proxy::*; diff --git a/state-machines/mem/src/mem_bus_helpers.rs b/state-machines/mem/src/mem_bus_helpers.rs index 37adfe8e..3394c495 100644 --- a/state-machines/mem/src/mem_bus_helpers.rs +++ b/state-machines/mem/src/mem_bus_helpers.rs @@ -17,7 +17,7 @@ impl MemBusHelpers { mem_values: [u64; 2], ) -> [u64; 7] { [ - MEMORY_LOAD_OP, + MEMORY_LOAD_OP as u64, addr as u64, MEM_STEP_BASE + MAX_MEM_OPS_BY_MAIN_STEP * step + @@ -37,7 +37,7 @@ impl MemBusHelpers { mem_values: [u64; 2], ) -> [u64; 7] { [ - MEMORY_STORE_OP, + MEMORY_STORE_OP as u64, addr as u64, MEM_STEP_BASE + MAX_MEM_OPS_BY_MAIN_STEP * step + diff --git a/state-machines/mem/src/mem_constants.rs b/state-machines/mem/src/mem_constants.rs index 39e885e6..67f6f391 100644 --- a/state-machines/mem/src/mem_constants.rs +++ b/state-machines/mem/src/mem_constants.rs @@ -7,8 +7,13 @@ pub const MAX_MEM_STEP_OFFSET: u64 = 2; pub const MAX_MEM_OPS_BY_STEP_OFFSET: u64 = 2; pub const MAX_MEM_OPS_BY_MAIN_STEP: u64 = (MAX_MEM_STEP_OFFSET + 1) * MAX_MEM_OPS_BY_STEP_OFFSET; -pub const MEMORY_LOAD_OP: u64 = 1; -pub const MEMORY_STORE_OP: u64 = 2; +pub const MEMORY_LOAD_OP: u8 = 1; +pub const MEMORY_STORE_OP: u8 = 2; + +pub const MEM_REGS_MASK: u32 = 0xFFFF_FF00; +pub const MEM_REGS_ADDR: u32 = 0xA000_0000; + +pub const MEM_BUS_ID: u16 = 1000; pub const MAX_MAIN_STEP: u64 = 0x1FFF_FFFF_FFFF_FFFF; pub const MAX_MEM_STEP: u64 = MEM_STEP_BASE + diff --git a/state-machines/mem/src/mem_counters.rs b/state-machines/mem/src/mem_counters.rs new file mode 100644 index 00000000..81954a37 --- /dev/null +++ b/state-machines/mem/src/mem_counters.rs @@ -0,0 +1,123 @@ +use std::collections::HashMap; + +use sm_common::Metrics; +use zisk_common::{BusDevice, BusId}; +use zisk_core::ZiskOperationType; + +use crate::{ + MEMORY_MAX_DIFF, MEMORY_STORE_OP, MEM_BUS_ID, MEM_BYTES_BITS, MEM_REGS_ADDR, MEM_REGS_MASK, +}; + +#[derive(Debug, Clone, Copy, Default)] +pub struct UsesCounter { + pub first_step: u64, + pub last_step: u64, + pub count: u64, +} + +pub struct MemCounters { + registers: [UsesCounter; 32], + addr: HashMap, + mem_align: Vec, + mem_align_rows: u32, +} + +impl MemCounters { + pub fn new() -> Self { + let empty_counter = UsesCounter::default(); + Self { + registers: [empty_counter; 32], + addr: HashMap::new(), + mem_align: Vec::new(), + mem_align_rows: 0, + } + } + pub fn count_extra_internal_reads(previous_step: u64, step: u64) -> u64 { + let diff = step - previous_step; + if diff > MEMORY_MAX_DIFF { + (diff - 1) / MEMORY_MAX_DIFF + } else { + 0 + } + } +} + +impl Metrics for MemCounters { + fn measure(&mut self, _: &BusId, data: &[u64]) -> Vec<(BusId, Vec)> { + let op = data[0] as u8; + let addr = data[1] as u32; + let mut addr_w = addr >> MEM_BYTES_BITS; + let step = data[2]; + let bytes = data[3] as u8; + + if (addr & MEM_REGS_MASK) == MEM_REGS_ADDR { + let reg_index = ((addr >> 3) & 0x1F) as usize; + if self.registers[reg_index].count == 0 { + self.registers[reg_index] = + UsesCounter { first_step: step, last_step: step, count: 1 }; + } else { + self.registers[reg_index].count += + 1 + Self::count_extra_internal_reads(self.registers[reg_index].last_step, step); + self.registers[reg_index].last_step = step; + } + } else { + let aligned = addr & 0x7 == 0 && bytes == 8; + if aligned { + self.addr + .entry(addr_w) + .and_modify(|value| { + value.count += 1; + value.last_step = step; + }) + .or_insert(UsesCounter { first_step: step, last_step: step, count: 1 }); + } else { + // TODO: use mem_align helpers + + let addr_count = + if ((addr + bytes as u32) >> MEM_BYTES_BITS) != addr_w { 2 } else { 1 }; + let ops_by_addr = if op == MEMORY_STORE_OP { 2 } else { 1 }; + + let last_step = step + ops_by_addr - 1; + for index in 0..addr_count { + self.addr + .entry(addr_w + index) + .and_modify(|value| { + value.count += ops_by_addr + + Self::count_extra_internal_reads(value.last_step, step); + value.last_step = last_step; + }) + .or_insert(UsesCounter { first_step: step, last_step, count: ops_by_addr }); + addr_w += 1; + } + let mem_align_op_rows = 1 + addr_count * ops_by_addr as u32; + self.mem_align.push(mem_align_op_rows as u8); + self.mem_align_rows += mem_align_op_rows; + } + } + + vec![] + } + + fn add(&mut self, _other: &dyn Metrics) {} + + fn op_type(&self) -> Vec { + vec![zisk_core::ZiskOperationType::Arith] + } + + fn bus_id(&self) -> Vec { + vec![MEM_BUS_ID] + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } +} + +impl BusDevice for MemCounters { + #[inline] + fn process_data(&mut self, bus_id: &BusId, data: &[u64]) -> Vec<(BusId, Vec)> { + self.measure(bus_id, data); + + vec![] + } +} diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index 00f46a32..e8c3ef3f 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -1,8 +1,11 @@ use std::sync::Arc; -use crate::{InputDataSM, MemAlignRomSM, MemAlignSM, MemProxyEngine, MemSM, RomDataSM}; +use crate::{ + InputDataSM, MemAlignRomSM, MemAlignSM, MemCounters, MemProxyEngine, MemSM, RomDataSM, +}; use p3_field::PrimeField; use pil_std_lib::Std; +use sm_common::{BusDeviceMetrics, ComponentProvider, Instance, InstanceExpanderCtx, Planner}; use zisk_core::ZiskRequiredMemory; pub struct MemProxy { @@ -42,3 +45,18 @@ impl MemProxy { engine.prove(mem_operations) } } + +impl ComponentProvider for MemProxy { + fn get_counter(&self) -> Box { + Box::new(MemCounters::new()) + // Box::new(MemCounters::new(OPERATION_BUS_ID, vec![zisk_core::ZiskOperationType::Arith])) + } + + fn get_planner(&self) -> Box { + unimplemented!("get_planner for MemProxy"); + } + + fn get_instance(&self, _iectx: InstanceExpanderCtx) -> Box> { + unimplemented!("get_instance for MemProxy"); + } +} From 47d9daffceb44cece96385ec0ca26b9864e84c88 Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Thu, 19 Dec 2024 22:00:46 +0000 Subject: [PATCH 04/10] update to new interfaces --- emulator/src/emu.rs | 2 +- state-machines/mem/src/mem_counters.rs | 9 ++------- state-machines/mem/src/mem_proxy.rs | 9 +++++++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/emulator/src/emu.rs b/emulator/src/emu.rs index e24e0456..1b540a2b 100644 --- a/emulator/src/emu.rs +++ b/emulator/src/emu.rs @@ -1810,7 +1810,7 @@ impl<'a> Emu<'a> { self.source_a_mem_reads_consume(instruction, &trace_step.mem_reads, mem_reads_index); self.source_b_mem_reads_consume(instruction, &trace_step.mem_reads, mem_reads_index); (instruction.func)(&mut self.ctx.inst_ctx); - self.store_c_slice(instruction); + self.store_c_mem_reads_consume(instruction, &trace_step.mem_reads, mem_reads_index); let operation_payload = OperationBusData::from_instruction(instruction, &self.ctx.inst_ctx); if data_bus.write_to_bus(OPERATION_BUS_ID, operation_payload.to_vec()) { diff --git a/state-machines/mem/src/mem_counters.rs b/state-machines/mem/src/mem_counters.rs index 81954a37..ea64af56 100644 --- a/state-machines/mem/src/mem_counters.rs +++ b/state-machines/mem/src/mem_counters.rs @@ -2,7 +2,6 @@ use std::collections::HashMap; use sm_common::Metrics; use zisk_common::{BusDevice, BusId}; -use zisk_core::ZiskOperationType; use crate::{ MEMORY_MAX_DIFF, MEMORY_STORE_OP, MEM_BUS_ID, MEM_BYTES_BITS, MEM_REGS_ADDR, MEM_REGS_MASK, @@ -100,10 +99,6 @@ impl Metrics for MemCounters { fn add(&mut self, _other: &dyn Metrics) {} - fn op_type(&self) -> Vec { - vec![zisk_core::ZiskOperationType::Arith] - } - fn bus_id(&self) -> Vec { vec![MEM_BUS_ID] } @@ -115,9 +110,9 @@ impl Metrics for MemCounters { impl BusDevice for MemCounters { #[inline] - fn process_data(&mut self, bus_id: &BusId, data: &[u64]) -> Vec<(BusId, Vec)> { + fn process_data(&mut self, bus_id: &BusId, data: &[u64]) -> (bool, Vec<(BusId, Vec)>) { self.measure(bus_id, data); - vec![] + (true, vec![]) } } diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index e8c3ef3f..f6ee23fd 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -5,7 +5,9 @@ use crate::{ }; use p3_field::PrimeField; use pil_std_lib::Std; -use sm_common::{BusDeviceMetrics, ComponentProvider, Instance, InstanceExpanderCtx, Planner}; +use sm_common::{ + BusDeviceInstance, BusDeviceMetrics, ComponentProvider, InstanceExpanderCtx, Planner, +}; use zisk_core::ZiskRequiredMemory; pub struct MemProxy { @@ -56,7 +58,10 @@ impl ComponentProvider for MemProxy { unimplemented!("get_planner for MemProxy"); } - fn get_instance(&self, _iectx: InstanceExpanderCtx) -> Box> { + fn get_instance(&self, iectx: InstanceExpanderCtx) -> Box> { + unimplemented!("get_instance for MemProxy"); + } + fn get_inputs_generator(&self) -> Option>> { unimplemented!("get_instance for MemProxy"); } } From ad0f4374d533628990973a0a0a4fb4262622eda1 Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Tue, 7 Jan 2025 10:02:38 +0000 Subject: [PATCH 05/10] implemented counters and planners of memories --- cli/src/commands/install_toolchain.rs | 8 +- core/src/elf2rom.rs | 6 +- core/src/zisk_inst.rs | 26 +- emulator/src/emu.rs | 34 +-- emulator/src/emu_options.rs | 12 +- emulator/src/stats.rs | 28 +- state-machines/arith/src/arith_full.rs | 16 +- state-machines/arith/src/arith_operation.rs | 110 ++++---- .../arith/src/arith_operation_test.rs | 42 +-- .../arith/src/arith_range_table_helpers.rs | 12 +- .../arith/src/arith_table_helpers.rs | 40 +-- state-machines/binary/src/binary_basic.rs | 30 +-- .../binary/src/binary_basic_table.rs | 110 ++++---- state-machines/binary/src/binary_extension.rs | 24 +- .../src/instance_observer/inputs_collector.rs | 4 +- state-machines/mem/src/input_data_sm.rs | 16 +- state-machines/mem/src/lib.rs | 6 + state-machines/mem/src/mem_align_planner.rs | 125 +++++++++ state-machines/mem/src/mem_align_rom_sm.rs | 6 +- state-machines/mem/src/mem_bus_helpers.rs | 8 +- state-machines/mem/src/mem_constants.rs | 6 +- state-machines/mem/src/mem_counters.rs | 57 +++- state-machines/mem/src/mem_helpers.rs | 6 +- state-machines/mem/src/mem_module_planner.rs | 249 ++++++++++++++++++ state-machines/mem/src/mem_planner.rs | 69 +++++ state-machines/mem/src/mem_proxy.rs | 7 +- state-machines/mem/src/mem_sm.rs | 10 +- state-machines/mem/src/rom_data.rs | 16 +- 28 files changed, 794 insertions(+), 289 deletions(-) create mode 100644 state-machines/mem/src/mem_align_planner.rs create mode 100644 state-machines/mem/src/mem_module_planner.rs create mode 100644 state-machines/mem/src/mem_planner.rs diff --git a/cli/src/commands/install_toolchain.rs b/cli/src/commands/install_toolchain.rs index 2b8d6e7e..9ad8799a 100644 --- a/cli/src/commands/install_toolchain.rs +++ b/cli/src/commands/install_toolchain.rs @@ -36,10 +36,10 @@ impl InstallToolchainCmd { if let Ok(entry) = entry { let entry_path = entry.path(); let entry_name = entry_path.file_name().unwrap(); - if entry_path.is_dir() && - entry_name != "bin" && - entry_name != "circuits" && - entry_name != "toolchains" + if entry_path.is_dir() + && entry_name != "bin" + && entry_name != "circuits" + && entry_name != "toolchains" { if let Err(err) = fs::remove_dir_all(&entry_path) { println!("Failed to remove directory {:?}: {}", entry_path, err); diff --git a/core/src/elf2rom.rs b/core/src/elf2rom.rs index 0209aa27..217b7753 100644 --- a/core/src/elf2rom.rs +++ b/core/src/elf2rom.rs @@ -53,9 +53,9 @@ pub fn elf2rom(elf_file: String) -> Result> { // Add init data as a read/write memory section, initialized by code // If the data is a writable memory section, add it to the ROM memory using Zisk // copy instructions - if (section_header.sh_flags & SHF_WRITE as u64) != 0 && - addr >= RAM_ADDR && - addr + data.len() as u64 <= RAM_ADDR + RAM_SIZE + if (section_header.sh_flags & SHF_WRITE as u64) != 0 + && addr >= RAM_ADDR + && addr + data.len() as u64 <= RAM_ADDR + RAM_SIZE { //println! {"elf2rom() new RW from={:x} length={:x}={}", addr, data.len(), //data.len()}; diff --git a/core/src/zisk_inst.rs b/core/src/zisk_inst.rs index be6e575c..989d6d50 100644 --- a/core/src/zisk_inst.rs +++ b/core/src/zisk_inst.rs @@ -228,19 +228,19 @@ impl ZiskInst { /// Constructs a `flags`` bitmap made of combinations of fields of the Zisk instruction. This /// field is used by the PIL to proof some of the operations. pub fn get_flags(&self) -> u64 { - let flags: u64 = 1 | - (((self.a_src == SRC_IMM) as u64) << 1) | - (((self.a_src == SRC_MEM) as u64) << 2) | - (((self.a_src == SRC_STEP) as u64) << 3) | - (((self.b_src == SRC_IMM) as u64) << 4) | - (((self.b_src == SRC_MEM) as u64) << 5) | - ((self.is_external_op as u64) << 6) | - ((self.store_ra as u64) << 7) | - (((self.store == STORE_MEM) as u64) << 8) | - (((self.store == STORE_IND) as u64) << 9) | - ((self.set_pc as u64) << 10) | - ((self.m32 as u64) << 11) | - (((self.b_src == SRC_IND) as u64) << 12); + let flags: u64 = 1 + | (((self.a_src == SRC_IMM) as u64) << 1) + | (((self.a_src == SRC_MEM) as u64) << 2) + | (((self.a_src == SRC_STEP) as u64) << 3) + | (((self.b_src == SRC_IMM) as u64) << 4) + | (((self.b_src == SRC_MEM) as u64) << 5) + | ((self.is_external_op as u64) << 6) + | ((self.store_ra as u64) << 7) + | (((self.store == STORE_MEM) as u64) << 8) + | (((self.store == STORE_IND) as u64) << 9) + | ((self.set_pc as u64) << 10) + | ((self.m32 as u64) << 11) + | (((self.b_src == SRC_IND) as u64) << 12); flags } diff --git a/emulator/src/emu.rs b/emulator/src/emu.rs index 1b540a2b..2e84d040 100644 --- a/emulator/src/emu.rs +++ b/emulator/src/emu.rs @@ -39,9 +39,9 @@ impl MemBusHelpers { [ MEMORY_LOAD_OP, addr as u64, - MEM_STEP_BASE + - MAX_MEM_OPS_BY_MAIN_STEP * step + - MAX_MEM_OPS_BY_STEP_OFFSET * step_offset as u64, + MEM_STEP_BASE + + MAX_MEM_OPS_BY_MAIN_STEP * step + + MAX_MEM_OPS_BY_STEP_OFFSET * step_offset as u64, bytes as u64, mem_values[0], mem_values[1], @@ -59,9 +59,9 @@ impl MemBusHelpers { [ MEMORY_STORE_OP, addr as u64, - MEM_STEP_BASE + - MAX_MEM_OPS_BY_MAIN_STEP * step + - MAX_MEM_OPS_BY_STEP_OFFSET * step_offset as u64, + MEM_STEP_BASE + + MAX_MEM_OPS_BY_MAIN_STEP * step + + MAX_MEM_OPS_BY_STEP_OFFSET * step_offset as u64, bytes as u64, mem_values[0], mem_values[1], @@ -1313,9 +1313,9 @@ impl<'a> Emu<'a> { } // Log emulation step, if requested - if options.print_step.is_some() && - (options.print_step.unwrap() != 0) && - ((self.ctx.inst_ctx.step % options.print_step.unwrap()) == 0) + if options.print_step.is_some() + && (options.print_step.unwrap() != 0) + && ((self.ctx.inst_ctx.step % options.print_step.unwrap()) == 0) { println!("step={}", self.ctx.inst_ctx.step); } @@ -1490,9 +1490,9 @@ impl<'a> Emu<'a> { // Increment step counter self.ctx.inst_ctx.step += 1; - if self.ctx.inst_ctx.end || - ((self.ctx.inst_ctx.step - self.ctx.last_callback_step) == - self.ctx.callback_steps) + if self.ctx.inst_ctx.end + || ((self.ctx.inst_ctx.step - self.ctx.last_callback_step) + == self.ctx.callback_steps) { // In run() we have checked the callback consistency with ctx.do_callback let callback = callback.as_ref().unwrap(); @@ -1870,8 +1870,8 @@ impl<'a> Emu<'a> { let b = [inst_ctx.b & 0xFFFFFFFF, (inst_ctx.b >> 32) & 0xFFFFFFFF]; let c = [inst_ctx.c & 0xFFFFFFFF, (inst_ctx.c >> 32) & 0xFFFFFFFF]; - let addr1 = (inst.b_offset_imm0 as i64 + - if inst.b_src == SRC_IND { inst_ctx.a as i64 } else { 0 }) as u64; + let addr1 = (inst.b_offset_imm0 as i64 + + if inst.b_src == SRC_IND { inst_ctx.a as i64 } else { 0 }) as u64; let jmp_offset1 = if inst.jmp_offset1 >= 0 { F::from_canonical_u64(inst.jmp_offset1 as u64) @@ -1949,9 +1949,9 @@ impl<'a> Emu<'a> { m32: F::from_bool(inst.m32), addr1: F::from_canonical_u64(addr1), __debug_operation_bus_enabled: F::from_bool( - inst.op_type == ZiskOperationType::Arith || - inst.op_type == ZiskOperationType::Binary || - inst.op_type == ZiskOperationType::BinaryE, + inst.op_type == ZiskOperationType::Arith + || inst.op_type == ZiskOperationType::Binary + || inst.op_type == ZiskOperationType::BinaryE, ), } } diff --git a/emulator/src/emu_options.rs b/emulator/src/emu_options.rs index acd7bae1..0703ab80 100644 --- a/emulator/src/emu_options.rs +++ b/emulator/src/emu_options.rs @@ -98,11 +98,11 @@ impl fmt::Display for EmuOptions { impl EmuOptions { /// Returns true if the configuration allows to emulate in fast mode, maximizing the performance pub fn is_fast(&self) -> bool { - self.trace_steps.is_none() && - (self.print_step.is_none() || (self.print_step.unwrap() == 0)) && - self.trace.is_none() && - !self.log_step && - !self.verbose && - !self.tracerv + self.trace_steps.is_none() + && (self.print_step.is_none() || (self.print_step.unwrap() == 0)) + && self.trace.is_none() + && !self.log_step + && !self.verbose + && !self.tracerv } } diff --git a/emulator/src/stats.rs b/emulator/src/stats.rs index c6a16fee..80b0ffec 100644 --- a/emulator/src/stats.rs +++ b/emulator/src/stats.rs @@ -185,22 +185,22 @@ impl Stats { output += &format!(" COST_STEP: {:02} sec\n", COST_STEP); // Calculate some aggregated counters to be used in the logs - let total_mem_ops = self.mops.mread_na1 + - self.mops.mread_na2 + - self.mops.mread_a + - self.mops.mwrite_na1 + - self.mops.mwrite_na2 + - self.mops.mwrite_a; - let total_mem_align_steps = self.mops.mread_na1 + - self.mops.mread_na2 * 2 + - self.mops.mwrite_na1 * 2 + - self.mops.mwrite_na2 * 4; + let total_mem_ops = self.mops.mread_na1 + + self.mops.mread_na2 + + self.mops.mread_a + + self.mops.mwrite_na1 + + self.mops.mwrite_na2 + + self.mops.mwrite_a; + let total_mem_align_steps = self.mops.mread_na1 + + self.mops.mread_na2 * 2 + + self.mops.mwrite_na1 * 2 + + self.mops.mwrite_na2 * 4; let cost_mem = total_mem_ops as f64 * COST_MEM; - let cost_mem_align = self.mops.mread_na1 as f64 * COST_MEMA_R1 + - self.mops.mread_na2 as f64 * COST_MEMA_R2 + - self.mops.mwrite_na1 as f64 * COST_MEMA_W1 + - self.mops.mwrite_na2 as f64 * COST_MEMA_W2; + let cost_mem_align = self.mops.mread_na1 as f64 * COST_MEMA_R1 + + self.mops.mread_na2 as f64 * COST_MEMA_R2 + + self.mops.mwrite_na1 as f64 * COST_MEMA_W1 + + self.mops.mwrite_na2 as f64 * COST_MEMA_W2; // Declare some total counters for the opcodes let mut total_opcodes: u64 = 0; diff --git a/state-machines/arith/src/arith_full.rs b/state-machines/arith/src/arith_full.rs index c6b18ce7..802eb5b6 100644 --- a/state-machines/arith/src/arith_full.rs +++ b/state-machines/arith/src/arith_full.rs @@ -232,14 +232,14 @@ impl ArithFullSM { step, opcode, ZiskOperationType::Binary as u64, - aop.d[0] + - CHUNK_SIZE * aop.d[1] + - CHUNK_SIZE.pow(2) * (aop.d[2] + extension.0) + - CHUNK_SIZE.pow(3) * aop.d[3], - aop.b[0] + - CHUNK_SIZE * aop.b[1] + - CHUNK_SIZE.pow(2) * (aop.b[2] + extension.1) + - CHUNK_SIZE.pow(3) * aop.b[3], + aop.d[0] + + CHUNK_SIZE * aop.d[1] + + CHUNK_SIZE.pow(2) * (aop.d[2] + extension.0) + + CHUNK_SIZE.pow(3) * aop.d[3], + aop.b[0] + + CHUNK_SIZE * aop.b[1] + + CHUNK_SIZE.pow(2) * (aop.b[2] + extension.1) + + CHUNK_SIZE.pow(3) * aop.b[3], ) .to_vec()] } else { diff --git a/state-machines/arith/src/arith_operation.rs b/state-machines/arith/src/arith_operation.rs index c6b8b48a..8ce0274c 100644 --- a/state-machines/arith/src/arith_operation.rs +++ b/state-machines/arith/src/arith_operation.rs @@ -133,22 +133,22 @@ impl ArithOperation { self.op = op; self.input_a = input_a; self.input_b = input_b; - self.div_by_zero = input_b == 0 && - (op == ZiskOp::Div.code() || - op == ZiskOp::Rem.code() || - op == ZiskOp::DivW.code() || - op == ZiskOp::RemW.code() || - op == ZiskOp::Divu.code() || - op == ZiskOp::Remu.code() || - op == ZiskOp::DivuW.code() || - op == ZiskOp::RemuW.code()); - - self.div_overflow = ((op == ZiskOp::Div.code() || op == ZiskOp::Rem.code()) && - input_a == 0x8000_0000_0000_0000 && - input_b == 0xFFFF_FFFF_FFFF_FFFF) || - ((op == ZiskOp::DivW.code() || op == ZiskOp::RemW.code()) && - input_a == 0x8000_0000 && - input_b == 0xFFFF_FFFF); + self.div_by_zero = input_b == 0 + && (op == ZiskOp::Div.code() + || op == ZiskOp::Rem.code() + || op == ZiskOp::DivW.code() + || op == ZiskOp::RemW.code() + || op == ZiskOp::Divu.code() + || op == ZiskOp::Remu.code() + || op == ZiskOp::DivuW.code() + || op == ZiskOp::RemuW.code()); + + self.div_overflow = ((op == ZiskOp::Div.code() || op == ZiskOp::Rem.code()) + && input_a == 0x8000_0000_0000_0000 + && input_b == 0xFFFF_FFFF_FFFF_FFFF) + || ((op == ZiskOp::DivW.code() || op == ZiskOp::RemW.code()) + && input_a == 0x8000_0000 + && input_b == 0xFFFF_FFFF); let [a, b, c, d] = Self::calculate_abcd_from_ab(op, input_a, input_b); self.a = Self::u64_to_chunks(a); @@ -578,15 +578,15 @@ impl ArithOperation { assert!(range_c1 == 0 || range_c3 == 0, "range_c1:{} range_c3:{}", range_c1, range_c3); assert!(range_d1 == 0 || range_d3 == 0, "range_d1:{} range_d3:{}", range_d1, range_d3); - self.range_ab = (range_a3 + range_a1) * 3 + - range_b3 + - range_b1 + - if (range_a1 + range_b1) > 0 { 8 } else { 0 }; + self.range_ab = (range_a3 + range_a1) * 3 + + range_b3 + + range_b1 + + if (range_a1 + range_b1) > 0 { 8 } else { 0 }; - self.range_cd = (range_c3 + range_c1) * 3 + - range_d3 + - range_d1 + - if (range_c1 + range_d1) > 0 { 8 } else { 0 }; + self.range_cd = (range_c3 + range_c1) * 3 + + range_d3 + + range_d1 + + if (range_c1 + range_d1) > 0 { 8 } else { 0 }; } pub fn calculate_chunks(&self) -> [i64; 8] { @@ -614,40 +614,40 @@ impl ArithOperation { let nb_fa = nb * (1 - 2 * na); chunks[0] = fab * a[0] * b[0] // chunk0 - - c[0] + - 2 * np * c[0] + - div * d[0] - - 2 * nr * d[0]; + - c[0] + + 2 * np * c[0] + + div * d[0] + - 2 * nr * d[0]; chunks[1] = fab * a[1] * b[0] // chunk1 - + fab * a[0] * b[1] - - c[1] + - 2 * np * c[1] + - div * d[1] - - 2 * nr * d[1]; + + fab * a[0] * b[1] + - c[1] + + 2 * np * c[1] + + div * d[1] + - 2 * nr * d[1]; chunks[2] = fab * a[2] * b[0] // chunk2 + fab * a[1] * b[1] + fab * a[0] * b[2] + a[0] * nb_fa * m32 - + b[0] * na_fb * m32 - - c[2] + - 2 * np * c[2] + - div * d[2] - - 2 * nr * d[2] - - np * div * m32 + - nr * m32; // div == 0 ==> nr = 0 + + b[0] * na_fb * m32 + - c[2] + + 2 * np * c[2] + + div * d[2] + - 2 * nr * d[2] + - np * div * m32 + + nr * m32; // div == 0 ==> nr = 0 chunks[3] = fab * a[3] * b[0] // chunk3 + fab * a[2] * b[1] + fab * a[1] * b[2] + fab * a[0] * b[3] + a[1] * nb_fa * m32 - + b[1] * na_fb * m32 - - c[3] + - 2 * np * c[3] + - div * d[3] - - 2 * nr * d[3]; + + b[1] * na_fb * m32 + - c[3] + + 2 * np * c[3] + + div * d[3] + - 2 * nr * d[3]; chunks[4] = fab * a[3] * b[1] // chunk4 + fab * a[2] * b[2] @@ -671,23 +671,23 @@ impl ArithOperation { chunks[5] = fab * a[3] * b[2] // chunk5 + fab * a[2] * b[3] + a[1] * nb_fa * (1 - m32) - + b[1] * na_fb * (1 - m32) - - d[1] * (1 - div) + - d[1] * 2 * np * (1 - div); + + b[1] * na_fb * (1 - m32) + - d[1] * (1 - div) + + d[1] * 2 * np * (1 - div); chunks[6] = fab * a[3] * b[3] // chunk6 + a[2] * nb_fa * (1 - m32) - + b[2] * na_fb * (1 - m32) - - d[2] * (1 - div) + - d[2] * 2 * np * (1 - div); + + b[2] * na_fb * (1 - m32) + - d[2] * (1 - div) + + d[2] * 2 * np * (1 - div); // 0x4000_0000_0000_0000__8000_0000_0000_0000 chunks[7] = 0x10000 * na * nb * (1 - m32) // chunk7 + a[3] * nb_fa * (1 - m32) - + b[3] * na_fb * (1 - m32) - - 0x10000 * np * (1 - div) * (1 - m32) - - d[3] * (1 - div) + - d[3] * 2 * np * (1 - div); + + b[3] * na_fb * (1 - m32) + - 0x10000 * np * (1 - div) * (1 - m32) + - d[3] * (1 - div) + + d[3] * 2 * np * (1 - div); chunks } diff --git a/state-machines/arith/src/arith_operation_test.rs b/state-machines/arith/src/arith_operation_test.rs index 37e0a4c7..67d37120 100644 --- a/state-machines/arith/src/arith_operation_test.rs +++ b/state-machines/arith/src/arith_operation_test.rs @@ -101,15 +101,15 @@ impl ArithOperationTest { fn is_m32_op(op: u8) -> bool { let zisk_op = ZiskOp::try_from_code(op).unwrap(); match zisk_op { - ZiskOp::Mul | - ZiskOp::Mulh | - ZiskOp::Mulsuh | - ZiskOp::Mulu | - ZiskOp::Muluh | - ZiskOp::Divu | - ZiskOp::Remu | - ZiskOp::Div | - ZiskOp::Rem => false, + ZiskOp::Mul + | ZiskOp::Mulh + | ZiskOp::Mulsuh + | ZiskOp::Mulu + | ZiskOp::Muluh + | ZiskOp::Divu + | ZiskOp::Remu + | ZiskOp::Div + | ZiskOp::Rem => false, ZiskOp::MulW | ZiskOp::DivuW | ZiskOp::RemuW | ZiskOp::DivW | ZiskOp::RemW => true, _ => panic!("ArithOperationTest::is_m32_op() Invalid opcode={}", op), } @@ -162,26 +162,26 @@ impl ArithOperationTest { println!("{:#?}", aop); const CHUNK_SIZE: u64 = 0x10000; - let bus_a_low: u64 = aop.div as u64 * (aop.c[0] + aop.c[1] * CHUNK_SIZE) + - (1 - aop.div as u64) * (aop.a[0] + aop.a[1] * CHUNK_SIZE); - let bus_a_high: u64 = aop.div as u64 * (aop.c[2] + aop.c[3] * CHUNK_SIZE) + - (1 - aop.div as u64) * (aop.a[2] + aop.a[3] * CHUNK_SIZE); + let bus_a_low: u64 = aop.div as u64 * (aop.c[0] + aop.c[1] * CHUNK_SIZE) + + (1 - aop.div as u64) * (aop.a[0] + aop.a[1] * CHUNK_SIZE); + let bus_a_high: u64 = aop.div as u64 * (aop.c[2] + aop.c[3] * CHUNK_SIZE) + + (1 - aop.div as u64) * (aop.a[2] + aop.a[3] * CHUNK_SIZE); let bus_b_low: u64 = aop.b[0] + CHUNK_SIZE * aop.b[1]; let bus_b_high: u64 = aop.b[2] + CHUNK_SIZE * aop.b[3]; let secondary_res: u64 = if aop.main_mul || aop.main_div { 0 } else { 1 }; - let bus_res_low = secondary_res * (aop.d[0] + aop.d[1] * CHUNK_SIZE) + - aop.main_mul as u64 * (aop.c[0] + aop.c[1] * CHUNK_SIZE) + - aop.main_div as u64 * (aop.a[0] + aop.a[1] * CHUNK_SIZE); + let bus_res_low = secondary_res * (aop.d[0] + aop.d[1] * CHUNK_SIZE) + + aop.main_mul as u64 * (aop.c[0] + aop.c[1] * CHUNK_SIZE) + + aop.main_div as u64 * (aop.a[0] + aop.a[1] * CHUNK_SIZE); - let bus_res_high_64 = secondary_res * (aop.d[2] + aop.d[3] * CHUNK_SIZE) + - aop.main_mul as u64 * (aop.c[2] + aop.c[3] * CHUNK_SIZE) + - aop.main_div as u64 * (aop.a[2] + aop.a[3] * CHUNK_SIZE); + let bus_res_high_64 = secondary_res * (aop.d[2] + aop.d[3] * CHUNK_SIZE) + + aop.main_mul as u64 * (aop.c[2] + aop.c[3] * CHUNK_SIZE) + + aop.main_div as u64 * (aop.a[2] + aop.a[3] * CHUNK_SIZE); - let bus_res_high = if aop.sext && !aop.div_overflow { 0xFFFF_FFFF } else { 0 } + - (1 - aop.m32 as u64) * bus_res_high_64; + let bus_res_high = if aop.sext && !aop.div_overflow { 0xFFFF_FFFF } else { 0 } + + (1 - aop.m32 as u64) * bus_res_high_64; let expected_a_low = a & 0xFFFF_FFFF; let expected_a_high = (a >> 32) & 0xFFFF_FFFF; diff --git a/state-machines/arith/src/arith_range_table_helpers.rs b/state-machines/arith/src/arith_range_table_helpers.rs index 6fdd3cab..0685b09e 100644 --- a/state-machines/arith/src/arith_range_table_helpers.rs +++ b/state-machines/arith/src/arith_range_table_helpers.rs @@ -45,16 +45,16 @@ impl ArithRangeTableHelpers { assert!(range_index < 43); assert!(value >= if range_type == NEG { 0x8000 } else { 0 }); assert!( - value <= - match range_type { + value + <= match range_type { FULL => 0xFFFF, POS => 0x7FFF, NEG => 0xFFFF, _ => panic!("Invalid range type"), } ); - OFFSETS[range_index as usize] * 0x8000 + - if range_type == NEG { value - 0x8000 } else { value } as usize + OFFSETS[range_index as usize] * 0x8000 + + if range_type == NEG { value - 0x8000 } else { value } as usize } pub fn get_row_carry_range_check(value: i64) -> usize { assert!(value >= -0xEFFFF); @@ -158,8 +158,8 @@ impl Iterator for ArithRangeTableInputsIterator<'_> { fn next(&mut self) -> Option { if !self.iter_hash { - while self.iter_row < ROWS as u32 && - self.inputs.multiplicity[self.iter_row as usize] == 0 + while self.iter_row < ROWS as u32 + && self.inputs.multiplicity[self.iter_row as usize] == 0 { self.iter_row += 1; } diff --git a/state-machines/arith/src/arith_table_helpers.rs b/state-machines/arith/src/arith_table_helpers.rs index 67f2a730..2c545a98 100644 --- a/state-machines/arith/src/arith_table_helpers.rs +++ b/state-machines/arith/src/arith_table_helpers.rs @@ -17,14 +17,14 @@ impl ArithTableHelpers { div_by_zero: bool, div_overflow: bool, ) -> usize { - let index = (op - FIRST_OP) as u64 * 128 + - na as u64 + - nb as u64 * 2 + - np as u64 * 4 + - nr as u64 * 8 + - sext as u64 * 16 + - div_by_zero as u64 * 32 + - div_overflow as u64 * 64; + let index = (op - FIRST_OP) as u64 * 128 + + na as u64 + + nb as u64 * 2 + + np as u64 * 4 + + nr as u64 * 8 + + sext as u64 * 16 + + div_by_zero as u64 * 32 + + div_overflow as u64 * 64; debug_assert!(index < ARITH_TABLE_ROWS.len() as u64); let row = ARITH_TABLE_ROWS[index as usize]; debug_assert!( @@ -75,18 +75,18 @@ impl ArithTableHelpers { range_ab: u16, range_cd: u16, ) -> usize { - let flags = if m32 { 1 } else { 0 } + - if div { 2 } else { 0 } + - if na { 4 } else { 0 } + - if nb { 8 } else { 0 } + - if np { 16 } else { 0 } + - if nr { 32 } else { 0 } + - if sext { 64 } else { 0 } + - if div_by_zero { 128 } else { 0 } + - if div_overflow { 256 } else { 0 } + - if main_mul { 512 } else { 0 } + - if main_div { 1024 } else { 0 } + - if signed { 2048 } else { 0 }; + let flags = if m32 { 1 } else { 0 } + + if div { 2 } else { 0 } + + if na { 4 } else { 0 } + + if nb { 8 } else { 0 } + + if np { 16 } else { 0 } + + if nr { 32 } else { 0 } + + if sext { 64 } else { 0 } + + if div_by_zero { 128 } else { 0 } + + if div_overflow { 256 } else { 0 } + + if main_mul { 512 } else { 0 } + + if main_div { 1024 } else { 0 } + + if signed { 2048 } else { 0 }; let row = Self::direct_get_row(op, na, nb, np, nr, sext, div_by_zero, div_overflow); assert_eq!( op as u16, ARITH_TABLE[row][0], diff --git a/state-machines/binary/src/binary_basic.rs b/state-machines/binary/src/binary_basic.rs index f0b4ac26..8fcbd7c9 100644 --- a/state-machines/binary/src/binary_basic.rs +++ b/state-machines/binary/src/binary_basic.rs @@ -251,9 +251,9 @@ impl BinaryBasicSM { } // If the chunk is signed, then the result is the sign of a - if (binary_basic_table_op == BinaryBasicTableOp::Min) && - (plast[i] == 1) && - (a_bytes[i] & 0x80) != (b_bytes[i] & 0x80) + if (binary_basic_table_op == BinaryBasicTableOp::Min) + && (plast[i] == 1) + && (a_bytes[i] & 0x80) != (b_bytes[i] & 0x80) { cout = if (a_bytes[i] & 0x80) != 0 { 1 } else { 0 }; } @@ -318,9 +318,9 @@ impl BinaryBasicSM { } // If the chunk is signed, then the result is the sign of a - if (binary_basic_table_op == BinaryBasicTableOp::Max) && - (plast[i] == 1) && - (a_bytes[i] & 0x80) != (b_bytes[i] & 0x80) + if (binary_basic_table_op == BinaryBasicTableOp::Max) + && (plast[i] == 1) + && (a_bytes[i] & 0x80) != (b_bytes[i] & 0x80) { cout = if (a_bytes[i] & 0x80) != 0 { 0 } else { 1 }; } @@ -495,9 +495,9 @@ impl BinaryBasicSM { } // If the chunk is signed, then the result is the sign of a - if (binary_basic_table_op.eq(&BinaryBasicTableOp::Lt)) && - (plast[i] == 1) && - (a_bytes[i] & 0x80) != (b_bytes[i] & 0x80) + if (binary_basic_table_op.eq(&BinaryBasicTableOp::Lt)) + && (plast[i] == 1) + && (a_bytes[i] & 0x80) != (b_bytes[i] & 0x80) { cout = if a_bytes[i] & 0x80 != 0 { 1 } else { 0 }; } @@ -737,9 +737,9 @@ impl BinaryBasicSM { if a_bytes[i] <= b_bytes[i] { cout = 1; } - if (binary_basic_table_op == BinaryBasicTableOp::Le) && - (plast[i] == 1) && - (a_bytes[i] & 0x80) != (b_bytes[i] & 0x80) + if (binary_basic_table_op == BinaryBasicTableOp::Le) + && (plast[i] == 1) + && (a_bytes[i] & 0x80) != (b_bytes[i] & 0x80) { cout = c; } @@ -883,9 +883,9 @@ impl BinaryBasicSM { // Set free_in_a_or_c and free_in_b_or_zero for i in 0..HALF_BYTES { - row.free_in_a_or_c[i] = mode64 * - (row.free_in_a[i + HALF_BYTES] - row.free_in_c[HALF_BYTES - 1]) + - row.free_in_c[HALF_BYTES - 1]; + row.free_in_a_or_c[i] = mode64 + * (row.free_in_a[i + HALF_BYTES] - row.free_in_c[HALF_BYTES - 1]) + + row.free_in_c[HALF_BYTES - 1]; row.free_in_b_or_zero[i] = mode64 * row.free_in_b[i + HALF_BYTES]; } diff --git a/state-machines/binary/src/binary_basic_table.rs b/state-machines/binary/src/binary_basic_table.rs index 82f4c60f..e4a85456 100644 --- a/state-machines/binary/src/binary_basic_table.rs +++ b/state-machines/binary/src/binary_basic_table.rs @@ -97,72 +97,72 @@ impl BinaryBasicTableSM { fn opcode_has_last(opcode: BinaryBasicTableOp) -> bool { match opcode { - BinaryBasicTableOp::Minu | - BinaryBasicTableOp::Min | - BinaryBasicTableOp::Maxu | - BinaryBasicTableOp::Max | - BinaryBasicTableOp::LtAbsNP | - BinaryBasicTableOp::LtAbsPN | - BinaryBasicTableOp::Ltu | - BinaryBasicTableOp::Lt | - BinaryBasicTableOp::Gt | - BinaryBasicTableOp::Eq | - BinaryBasicTableOp::Add | - BinaryBasicTableOp::Sub | - BinaryBasicTableOp::Leu | - BinaryBasicTableOp::Le | - BinaryBasicTableOp::And | - BinaryBasicTableOp::Or | - BinaryBasicTableOp::Xor => true, + BinaryBasicTableOp::Minu + | BinaryBasicTableOp::Min + | BinaryBasicTableOp::Maxu + | BinaryBasicTableOp::Max + | BinaryBasicTableOp::LtAbsNP + | BinaryBasicTableOp::LtAbsPN + | BinaryBasicTableOp::Ltu + | BinaryBasicTableOp::Lt + | BinaryBasicTableOp::Gt + | BinaryBasicTableOp::Eq + | BinaryBasicTableOp::Add + | BinaryBasicTableOp::Sub + | BinaryBasicTableOp::Leu + | BinaryBasicTableOp::Le + | BinaryBasicTableOp::And + | BinaryBasicTableOp::Or + | BinaryBasicTableOp::Xor => true, BinaryBasicTableOp::Ext32 => false, } } fn opcode_has_cin(opcode: BinaryBasicTableOp) -> bool { match opcode { - BinaryBasicTableOp::Minu | - BinaryBasicTableOp::Min | - BinaryBasicTableOp::Maxu | - BinaryBasicTableOp::Max | - BinaryBasicTableOp::LtAbsNP | - BinaryBasicTableOp::LtAbsPN | - BinaryBasicTableOp::Ltu | - BinaryBasicTableOp::Lt | - BinaryBasicTableOp::Gt | - BinaryBasicTableOp::Eq | - BinaryBasicTableOp::Add | - BinaryBasicTableOp::Sub => true, - - BinaryBasicTableOp::Leu | - BinaryBasicTableOp::Le | - BinaryBasicTableOp::And | - BinaryBasicTableOp::Or | - BinaryBasicTableOp::Xor | - BinaryBasicTableOp::Ext32 => false, + BinaryBasicTableOp::Minu + | BinaryBasicTableOp::Min + | BinaryBasicTableOp::Maxu + | BinaryBasicTableOp::Max + | BinaryBasicTableOp::LtAbsNP + | BinaryBasicTableOp::LtAbsPN + | BinaryBasicTableOp::Ltu + | BinaryBasicTableOp::Lt + | BinaryBasicTableOp::Gt + | BinaryBasicTableOp::Eq + | BinaryBasicTableOp::Add + | BinaryBasicTableOp::Sub => true, + + BinaryBasicTableOp::Leu + | BinaryBasicTableOp::Le + | BinaryBasicTableOp::And + | BinaryBasicTableOp::Or + | BinaryBasicTableOp::Xor + | BinaryBasicTableOp::Ext32 => false, } } fn opcode_result_is_a(opcode: BinaryBasicTableOp) -> bool { match opcode { - BinaryBasicTableOp::Minu | - BinaryBasicTableOp::Min | - BinaryBasicTableOp::Maxu | - BinaryBasicTableOp::Max => true, - - BinaryBasicTableOp::LtAbsNP | - BinaryBasicTableOp::LtAbsPN | - BinaryBasicTableOp::Ltu | - BinaryBasicTableOp::Lt | - BinaryBasicTableOp::Gt | - BinaryBasicTableOp::Eq | - BinaryBasicTableOp::Add | - BinaryBasicTableOp::Sub | - BinaryBasicTableOp::Leu | - BinaryBasicTableOp::Le | - BinaryBasicTableOp::And | - BinaryBasicTableOp::Or | - BinaryBasicTableOp::Xor | - BinaryBasicTableOp::Ext32 => false, + BinaryBasicTableOp::Minu + | BinaryBasicTableOp::Min + | BinaryBasicTableOp::Maxu + | BinaryBasicTableOp::Max => true, + + BinaryBasicTableOp::LtAbsNP + | BinaryBasicTableOp::LtAbsPN + | BinaryBasicTableOp::Ltu + | BinaryBasicTableOp::Lt + | BinaryBasicTableOp::Gt + | BinaryBasicTableOp::Eq + | BinaryBasicTableOp::Add + | BinaryBasicTableOp::Sub + | BinaryBasicTableOp::Leu + | BinaryBasicTableOp::Le + | BinaryBasicTableOp::And + | BinaryBasicTableOp::Or + | BinaryBasicTableOp::Xor + | BinaryBasicTableOp::Ext32 => false, } } diff --git a/state-machines/binary/src/binary_extension.rs b/state-machines/binary/src/binary_extension.rs index d1a8f173..35d46dd8 100644 --- a/state-machines/binary/src/binary_extension.rs +++ b/state-machines/binary/src/binary_extension.rs @@ -60,12 +60,12 @@ impl BinaryExtensionSM { fn opcode_is_shift(opcode: ZiskOp) -> bool { match opcode { - ZiskOp::Sll | - ZiskOp::Srl | - ZiskOp::Sra | - ZiskOp::SllW | - ZiskOp::SrlW | - ZiskOp::SraW => true, + ZiskOp::Sll + | ZiskOp::Srl + | ZiskOp::Sra + | ZiskOp::SllW + | ZiskOp::SrlW + | ZiskOp::SraW => true, ZiskOp::SignExtendB | ZiskOp::SignExtendH | ZiskOp::SignExtendW => false, @@ -77,12 +77,12 @@ impl BinaryExtensionSM { match opcode { ZiskOp::SllW | ZiskOp::SrlW | ZiskOp::SraW => true, - ZiskOp::Sll | - ZiskOp::Srl | - ZiskOp::Sra | - ZiskOp::SignExtendB | - ZiskOp::SignExtendH | - ZiskOp::SignExtendW => false, + ZiskOp::Sll + | ZiskOp::Srl + | ZiskOp::Sra + | ZiskOp::SignExtendB + | ZiskOp::SignExtendH + | ZiskOp::SignExtendW => false, _ => panic!("BinaryExtensionSM::opcode_is_shift() got invalid opcode={:?}", opcode), } diff --git a/state-machines/common/src/instance_observer/inputs_collector.rs b/state-machines/common/src/instance_observer/inputs_collector.rs index a73f90bf..b8200290 100644 --- a/state-machines/common/src/instance_observer/inputs_collector.rs +++ b/state-machines/common/src/instance_observer/inputs_collector.rs @@ -77,8 +77,8 @@ impl InstObserver for InputsCollector { } if self.skipping { - if self.check_point.collect_info.skip == 0 || - self.skipped == self.check_point.collect_info.skip + if self.check_point.collect_info.skip == 0 + || self.skipped == self.check_point.collect_info.skip { self.skipping = false; } else { diff --git a/state-machines/mem/src/input_data_sm.rs b/state-machines/mem/src/input_data_sm.rs index acb3345b..40348469 100644 --- a/state-machines/mem/src/input_data_sm.rs +++ b/state-machines/mem/src/input_data_sm.rs @@ -10,8 +10,8 @@ use proofman_common::{AirInstance, FromTrace}; use zisk_core::{INPUT_ADDR, MAX_INPUT_SIZE}; use zisk_pil::{InputDataAirValues, InputDataTrace, ZiskProofValues}; -const INPUT_W_ADDR_INIT: u32 = INPUT_ADDR as u32 >> MEM_BYTES_BITS; -const INPUT_W_ADDR_END: u32 = (INPUT_ADDR + MAX_INPUT_SIZE - 1) as u32 >> MEM_BYTES_BITS; +pub const INPUT_DATA_W_ADDR_INIT: u32 = INPUT_ADDR as u32 >> MEM_BYTES_BITS; +pub const INPUT_DATA_W_ADDR_END: u32 = (INPUT_ADDR + MAX_INPUT_SIZE - 1) as u32 >> MEM_BYTES_BITS; #[allow(clippy::assertions_on_constants)] const _: () = { @@ -70,7 +70,7 @@ impl InputDataSM { let is_last_segment = segment_id == num_segments - 1; let input_offset = segment_id * air_rows; let previous_segment = if (segment_id == 0) { - MemPreviousSegment { addr: INPUT_W_ADDR_INIT, step: 0, value: 0 } + MemPreviousSegment { addr: INPUT_DATA_W_ADDR_INIT, step: 0, value: 0 } } else { MemPreviousSegment { addr: inputs[input_offset - 1].addr, @@ -149,7 +149,7 @@ impl InputDataSM { // range of instance let range_id = self.std.get_range(BigInt::from(1), BigInt::from(MEMORY_MAX_DIFF), None); self.std.range_check( - F::from_canonical_u32(previous_segment.addr - INPUT_W_ADDR_INIT + 1), + F::from_canonical_u32(previous_segment.addr - INPUT_DATA_W_ADDR_INIT + 1), F::one(), range_id, ); @@ -209,7 +209,7 @@ impl InputDataSM { air_values_mem.segment_last_value[1] = (last_value >> 32) as u32; self.std.range_check( - F::from_canonical_u32(INPUT_W_ADDR_END - last_addr + 1), + F::from_canonical_u32(INPUT_DATA_W_ADDR_END - last_addr + 1), F::one(), range_id, ); @@ -255,6 +255,12 @@ impl InputDataSM { fn get_u16_values(&self, value: u64) -> [u16; 4] { [value as u16, (value >> 16) as u16, (value >> 32) as u16, (value >> 48) as u16] } + pub fn get_from_addr() -> u32 { + INPUT_ADDR as u32 + } + pub fn get_to_addr() -> u32 { + (INPUT_ADDR + MAX_INPUT_SIZE - 1) as u32 + } } impl MemModule for InputDataSM { diff --git a/state-machines/mem/src/lib.rs b/state-machines/mem/src/lib.rs index 05491fbc..369b1e6d 100644 --- a/state-machines/mem/src/lib.rs +++ b/state-machines/mem/src/lib.rs @@ -1,4 +1,5 @@ mod input_data_sm; +mod mem_align_planner; mod mem_align_rom_sm; mod mem_align_sm; mod mem_bus_helpers; @@ -6,6 +7,8 @@ mod mem_constants; mod mem_counters; mod mem_helpers; mod mem_module; +mod mem_module_planner; +mod mem_planner; mod mem_proxy; mod mem_proxy_engine; mod mem_sm; @@ -13,6 +16,7 @@ mod mem_unmapped; mod rom_data; pub use input_data_sm::*; +pub use mem_align_planner::*; pub use mem_align_rom_sm::*; pub use mem_align_sm::*; pub use mem_bus_helpers::*; @@ -20,6 +24,8 @@ pub use mem_constants::*; pub use mem_counters::*; pub use mem_helpers::*; pub use mem_module::*; +pub use mem_module_planner::*; +pub use mem_planner::*; pub use mem_proxy::*; pub use mem_proxy_engine::*; pub use mem_sm::*; diff --git a/state-machines/mem/src/mem_align_planner.rs b/state-machines/mem/src/mem_align_planner.rs new file mode 100644 index 00000000..2bdeef94 --- /dev/null +++ b/state-machines/mem/src/mem_align_planner.rs @@ -0,0 +1,125 @@ +use std::sync::Arc; + +use crate::{MemCounters, MemPlanCalculator}; +use sm_common::{CheckPoint, ChunkId, InstanceType, Plan}; +use zisk_pil::{MEM_ALIGN_AIR_IDS, ZISK_AIRGROUP_ID}; + +pub struct MemAlignPlanner<'a> { + instances: Vec, + num_rows: u32, + current_skip: u32, + current_chunk_id: Option, + current_chunks: Vec, + current_rows_available: u32, + counters: Arc>, +} + +// TODO: dynamic +const MEM_ALIGN_ROWS: usize = 1 << 21; + +impl<'a> MemAlignPlanner<'a> { + pub fn new(counters: Arc>) -> Self { + Self { + instances: Vec::new(), + num_rows: MEM_ALIGN_ROWS as u32, + current_skip: 0, + current_chunk_id: None, + current_chunks: Vec::new(), + current_rows_available: MEM_ALIGN_ROWS as u32, + counters, + } + } + pub fn align_plan(&mut self) -> Vec { + if self.counters.is_empty() { + panic!("MemPlanner::plan() No metrics found"); + } + + let count = self.counters.len(); + for index in 0..count { + let chunk_id = self.counters[index].0; + let counter = self.counters[index].1; + self.set_current_chunk_id(chunk_id); + self.add_to_current_instance(counter.mem_align_rows, &counter.mem_align); + } + self.close_current_instance(); + vec![] + } + fn set_current_chunk_id(&mut self, chunk_id: ChunkId) { + if self.current_chunk_id == Some(chunk_id) && !self.current_chunks.is_empty() { + return; + } + self.current_chunk_id = Some(chunk_id); + if let Err(pos) = self.current_chunks.binary_search(&chunk_id) { + self.current_chunks.insert(pos, chunk_id); + } + } + fn add_to_current_instance(&mut self, total_rows: u32, operations_rows: &[u8]) { + let mut pending_rows = total_rows; + let mut operations_rows_offset: u32 = 0; + loop { + // check if has available rows to add all inside this chunks. + let (count, rows_fit) = if self.current_rows_available >= pending_rows { + // self.current_rows_available -= pending_rows; + (0, pending_rows) + } else { + self.calculate_how_many_operations_fit(operations_rows_offset, operations_rows) + }; + self.current_rows_available -= rows_fit; + pending_rows -= rows_fit; + if self.current_rows_available == 0 { + self.close_current_instance(); + } + operations_rows_offset += count; + self.open_new_instance(operations_rows_offset, pending_rows > 0); + } + } + fn close_current_instance(&mut self) { + // TODO: add instance + if self.current_chunks.is_empty() { + return; + } + // TODO: add multi chunk_id, with skip + let instance = Plan::new( + ZISK_AIRGROUP_ID, + MEM_ALIGN_AIR_IDS[0], + Some(self.instances.len()), + InstanceType::Instance, + Some(CheckPoint::new(self.current_chunks[0], self.current_skip as u64)), + None, + ); + self.instances.push(instance); + self.current_chunks.clear(); + } + fn open_new_instance(&mut self, next_instance_skip: u32, use_current_chunk_id: bool) { + self.current_skip = next_instance_skip; + self.current_rows_available = self.num_rows; + if use_current_chunk_id { + self.current_chunks.push(self.current_chunk_id.unwrap()); + } + } + fn calculate_how_many_operations_fit( + &self, + operations_offset: u32, + operations_rows: &[u8], + ) -> (u32, u32) { + let mut count = 0; + let mut rows = 0; + for row in operations_rows.iter().skip(operations_offset as usize) { + if (rows + *row as u32) > self.current_rows_available { + break; + } + count += 1; + rows += *row as u32; + } + (count, rows) + } +} + +impl MemPlanCalculator for MemAlignPlanner<'_> { + fn plan(&mut self) { + self.align_plan(); + } + fn collect_plans(&mut self) -> Vec { + std::mem::take(&mut self.instances) + } +} diff --git a/state-machines/mem/src/mem_align_rom_sm.rs b/state-machines/mem/src/mem_align_rom_sm.rs index cd92743e..b6cd71d0 100644 --- a/state-machines/mem/src/mem_align_rom_sm.rs +++ b/state-machines/mem/src/mem_align_rom_sm.rs @@ -48,9 +48,9 @@ impl MemAlignRomSM { ), MemOp::TwoWrites => ( - 1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + - ONE_WORD_COMBINATIONS * OP_SIZES[1] + - TWO_WORD_COMBINATIONS * OP_SIZES[2], + 1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + + ONE_WORD_COMBINATIONS * OP_SIZES[1] + + TWO_WORD_COMBINATIONS * OP_SIZES[2], false, ), }; diff --git a/state-machines/mem/src/mem_bus_helpers.rs b/state-machines/mem/src/mem_bus_helpers.rs index 3394c495..4cd4b396 100644 --- a/state-machines/mem/src/mem_bus_helpers.rs +++ b/state-machines/mem/src/mem_bus_helpers.rs @@ -23,8 +23,8 @@ impl MemBusHelpers { MAX_MEM_OPS_BY_MAIN_STEP * step + MAX_MEM_OPS_BY_STEP_OFFSET * step_offset as u64, bytes as u64, - mem_values[0] as u64, - mem_values[1] as u64, + mem_values[0], + mem_values[1], 0, ] } @@ -43,8 +43,8 @@ impl MemBusHelpers { MAX_MEM_OPS_BY_MAIN_STEP * step + MAX_MEM_OPS_BY_STEP_OFFSET * step_offset as u64, bytes as u64, - mem_values[0] as u64, - mem_values[1] as u64, + mem_values[0], + mem_values[1], value, ] } diff --git a/state-machines/mem/src/mem_constants.rs b/state-machines/mem/src/mem_constants.rs index 67f6f391..91e73a34 100644 --- a/state-machines/mem/src/mem_constants.rs +++ b/state-machines/mem/src/mem_constants.rs @@ -16,9 +16,9 @@ pub const MEM_REGS_ADDR: u32 = 0xA000_0000; pub const MEM_BUS_ID: u16 = 1000; pub const MAX_MAIN_STEP: u64 = 0x1FFF_FFFF_FFFF_FFFF; -pub const MAX_MEM_STEP: u64 = MEM_STEP_BASE + - MAX_MEM_OPS_BY_MAIN_STEP * MAX_MAIN_STEP + - MAX_MEM_OPS_BY_STEP_OFFSET * MAX_MEM_STEP_OFFSET; +pub const MAX_MEM_STEP: u64 = MEM_STEP_BASE + + MAX_MEM_OPS_BY_MAIN_STEP * MAX_MAIN_STEP + + MAX_MEM_OPS_BY_STEP_OFFSET * MAX_MEM_STEP_OFFSET; pub const MAX_MEM_ADDR: u64 = 0xFFFF_FFFF; diff --git a/state-machines/mem/src/mem_counters.rs b/state-machines/mem/src/mem_counters.rs index ea64af56..828b8aee 100644 --- a/state-machines/mem/src/mem_counters.rs +++ b/state-machines/mem/src/mem_counters.rs @@ -7,18 +7,23 @@ use crate::{ MEMORY_MAX_DIFF, MEMORY_STORE_OP, MEM_BUS_ID, MEM_BYTES_BITS, MEM_REGS_ADDR, MEM_REGS_MASK, }; +use log::info; + #[derive(Debug, Clone, Copy, Default)] pub struct UsesCounter { pub first_step: u64, pub last_step: u64, pub count: u64, + pub last_value: u64, } +#[derive(Default)] pub struct MemCounters { - registers: [UsesCounter; 32], - addr: HashMap, - mem_align: Vec, - mem_align_rows: u32, + pub registers: [UsesCounter; 32], + pub addr: HashMap, + pub addr_sorted: Vec<(u32, UsesCounter)>, + pub mem_align: Vec, + pub mem_align_rows: u32, } impl MemCounters { @@ -27,6 +32,7 @@ impl MemCounters { Self { registers: [empty_counter; 32], addr: HashMap::new(), + addr_sorted: Vec::new(), mem_align: Vec::new(), mem_align_rows: 0, } @@ -52,26 +58,43 @@ impl Metrics for MemCounters { if (addr & MEM_REGS_MASK) == MEM_REGS_ADDR { let reg_index = ((addr >> 3) & 0x1F) as usize; if self.registers[reg_index].count == 0 { - self.registers[reg_index] = - UsesCounter { first_step: step, last_step: step, count: 1 }; + self.registers[reg_index] = UsesCounter { + first_step: step, + last_step: step, + count: 1, + last_value: data[4], + }; } else { self.registers[reg_index].count += 1 + Self::count_extra_internal_reads(self.registers[reg_index].last_step, step); self.registers[reg_index].last_step = step; + self.registers[reg_index].last_value = data[4]; } } else { let aligned = addr & 0x7 == 0 && bytes == 8; + // TODO: last value must be calculated as last value operation + // R: value[4] + // RR: + let last_value = data[4]; if aligned { + // TODO: read, write self.addr .entry(addr_w) .and_modify(|value| { value.count += 1; value.last_step = step; + value.last_value = last_value; }) - .or_insert(UsesCounter { first_step: step, last_step: step, count: 1 }); + .or_insert(UsesCounter { + first_step: step, + last_step: step, + count: 1, + last_value, + }); } else { // TODO: use mem_align helpers - + // TODO: last value must be calculated as last value operation + let last_value = 0; let addr_count = if ((addr + bytes as u32) >> MEM_BYTES_BITS) != addr_w { 2 } else { 1 }; let ops_by_addr = if op == MEMORY_STORE_OP { 2 } else { 1 }; @@ -84,8 +107,16 @@ impl Metrics for MemCounters { value.count += ops_by_addr + Self::count_extra_internal_reads(value.last_step, step); value.last_step = last_step; + value.last_value = last_value }) - .or_insert(UsesCounter { first_step: step, last_step, count: ops_by_addr }); + .or_insert(UsesCounter { + first_step: step, + last_step, + count: ops_by_addr, + last_value, + }); + // if addr_count > 1, then addr_w must be the next (addr_w is expressed in + // MEM_BYTES) addr_w += 1; } let mem_align_op_rows = 1 + addr_count * ops_by_addr as u32; @@ -102,7 +133,13 @@ impl Metrics for MemCounters { fn bus_id(&self) -> Vec { vec![MEM_BUS_ID] } - + fn on_close(&mut self) { + // address must be ordered + info!("[Mem] Closing...."); + let addr_hashmap = std::mem::take(&mut self.addr); + self.addr_sorted = addr_hashmap.into_iter().collect(); + self.addr_sorted.sort_by(|a, b| a.0.cmp(&b.0)); + } fn as_any(&self) -> &dyn std::any::Any { self } diff --git a/state-machines/mem/src/mem_helpers.rs b/state-machines/mem/src/mem_helpers.rs index 8e70b537..b6cb8eb6 100644 --- a/state-machines/mem/src/mem_helpers.rs +++ b/state-machines/mem/src/mem_helpers.rs @@ -68,9 +68,9 @@ pub struct MemHelpers {} impl MemHelpers { pub fn main_step_to_address_step(step: u64, step_offset: u8) -> u64 { - MEM_STEP_BASE + - MAX_MEM_OPS_BY_MAIN_STEP * step + - MAX_MEM_OPS_BY_STEP_OFFSET * step_offset as u64 + MEM_STEP_BASE + + MAX_MEM_OPS_BY_MAIN_STEP * step + + MAX_MEM_OPS_BY_STEP_OFFSET * step_offset as u64 } } diff --git a/state-machines/mem/src/mem_module_planner.rs b/state-machines/mem/src/mem_module_planner.rs new file mode 100644 index 00000000..782bb718 --- /dev/null +++ b/state-machines/mem/src/mem_module_planner.rs @@ -0,0 +1,249 @@ +use std::sync::Arc; + +use crate::{MemCounters, MemPlanCalculator, UsesCounter, MEMORY_MAX_DIFF}; +use sm_common::{ChunkId, InstanceType, Plan}; + +pub struct MemModulePlanner<'a> { + airgroup_id: usize, + air_id: usize, + from_addr: u32, + to_addr: u32, + rows_available: u32, + instance_rows: u32, + last_step: u64, + last_addr: u32, // addr of last addr uses + last_value: u64, // value of last addr uses + cursors: Vec<(usize, usize)>, + pub instances: Vec, + first_instance: bool, + current_checkpoint_chunks: Vec, + current_checkpoint: MemInstanceCheckPoint, + current_chunk_id: Option, + counters: Arc>, +} + +#[derive(Debug, Default)] +struct MemInstanceCheckPoint { + prev_addr: u32, + skip_internal: u32, + prev_step: u64, + prev_value: u64, +} + +impl<'a> MemModulePlanner<'a> { + pub fn new( + airgroup_id: usize, + air_id: usize, + from_addr: u32, + to_addr: u32, + counters: Arc>, + ) -> Self { + Self { + airgroup_id, + air_id, + from_addr, + to_addr, + last_addr: 0, + last_step: 0, + last_value: 0, + rows_available: 0, + instance_rows: 1 << 21, + cursors: Vec::new(), + instances: Vec::new(), + counters, + first_instance: true, + current_chunk_id: None, + current_checkpoint_chunks: Vec::new(), + current_checkpoint: MemInstanceCheckPoint { + prev_addr: from_addr, + skip_internal: 0, + prev_step: 0, + prev_value: 0, + }, + } + } + pub fn module_plan(&mut self) { + if self.counters.is_empty() { + panic!("MemPlanner::plan() No metrics found"); + } + + // create a list of cursors, this list has the non-empty indexs of metric (couters) and his + // cursor init to first position + self.init_cursors(); + + while !self.cursors.is_empty() { + // searches for the first smallest element in the vector and returns its index. + let (cursor_index, cursor_pos) = self.get_next_cursor_index_and_pos(); + + let chunk_id = self.counters[cursor_index].0; + let addr = self.counters[cursor_index].1.addr_sorted[cursor_pos].0; + let addr_uses = self.counters[cursor_index].1.addr_sorted[cursor_pos].1; + + self.add_to_current_instance(chunk_id, addr, &addr_uses); + } + } + fn init_cursors(&mut self) { + // for each chunk-counter that has addr_sorted element add a cursor to the first element + self.cursors = Vec::new(); + for (index, counter) in self.counters.iter().enumerate() { + if counter.1.addr_sorted.is_empty() { + continue; + } + match counter.1.addr_sorted.binary_search_by(|(key, _)| key.cmp(&self.from_addr)) { + Ok(pos) => self.cursors.push((index, pos)), + Err(pos) => { + if pos < counter.1.addr_sorted.len() && + counter.1.addr_sorted[pos].0 <= self.to_addr + { + self.cursors.push((index, pos)); + } + } + } + } + } + fn get_next_cursor_index_and_pos(&mut self) -> (usize, usize) { + let (min_index, _) = self + .cursors + .iter() + .enumerate() + .min_by_key(|&(_, &(index, cursor))| self.counters[index].1.addr_sorted[cursor].0) + .unwrap(); + let cursor_index = self.cursors[min_index].0; + let cursor_pos = self.cursors[min_index].1; + + // if it's last position, we must remove for list of open_cursors, if not we increment + if cursor_pos + 1 >= self.counters[cursor_index].1.addr_sorted.len() || + self.counters[cursor_index].1.addr_sorted[cursor_pos + 1].0 > self.to_addr + { + self.cursors.remove(min_index); + } else { + self.cursors[min_index].1 += 1; + } + (cursor_index, cursor_pos) + } + /// Add "counter-address" to the current instance + /// If the chunk_id is not in the list, it will be added. This method need to verify the + /// distance between the last addr-step and the current addr-step, if the distance is + /// greater than MEMORY_MAX_DIFF we need to add extra intermediate "steps". + fn add_to_current_instance(&mut self, chunk_id: ChunkId, addr: u32, addr_uses: &UsesCounter) { + self.set_current_chunk_id(chunk_id); + self.add_internal_reads_to_current_instance(addr, addr_uses); + + let mut pending_rows = addr_uses.count; + while pending_rows > 0 { + if self.rows_available as u64 > pending_rows { + self.rows_available -= pending_rows as u32; + break; + } + pending_rows -= self.rows_available as u64; + let skip_internal = self.rows_available; + self.rows_available = 0; + self.close_instance(); + self.open_instance( + addr, + 0, // unknown this intermediate value + 0, // unknown this intermediate step + skip_internal, + ); + + self.rows_available = self.instance_rows; + } + // update last_xxx + self.last_value = addr_uses.last_value; + self.last_step = addr_uses.last_step; + self.last_addr = addr; + } + fn close_instance(&mut self) { + if self.current_checkpoint_chunks.is_empty() { + return; + } + // TODO: add chunks + // for chunk_id in self.current_checkpoint_chunks.iter() { + // instance.add_chunk_id(chunk_id.clone()); + // } + + let checkpoint = std::mem::take(&mut self.current_checkpoint); + let instance = Plan::new( + self.airgroup_id, + self.air_id, + Some(self.instances.len()), + InstanceType::Instance, + None, + Some(Box::new(checkpoint)), + ); + self.instances.push(instance); + self.current_checkpoint_chunks.clear(); + } + fn set_current_chunk_id(&mut self, chunk_id: ChunkId) { + if self.current_chunk_id == Some(chunk_id) && !self.current_checkpoint_chunks.is_empty() { + return; + } + self.current_chunk_id = Some(chunk_id); + if let Err(pos) = self.current_checkpoint_chunks.binary_search(&chunk_id) { + self.current_checkpoint_chunks.insert(pos, chunk_id); + } + } + fn add_internal_reads_to_current_instance(&mut self, addr: u32, addr_uses: &UsesCounter) { + // check internal reads (update last_xxx) + // reopen instance if need and set his chunk_id + if self.last_addr != addr { + return; + } + + let step_diff = addr_uses.first_step - self.last_step; + if step_diff <= MEMORY_MAX_DIFF { + return; + } + + // at this point we need to add internal reads, we calculate how many internal reads we need + let mut internal_rows = (step_diff - 1) / MEMORY_MAX_DIFF; + assert!(internal_rows < self.instance_rows as u64); + + // check if all internal reads fit in the current instance + if internal_rows < self.rows_available as u64 { + self.rows_available -= internal_rows as u32; + } else { + internal_rows -= self.rows_available as u64; + let skip_internal = self.rows_available; + self.rows_available = 0; + self.close_instance(); + self.open_instance( + addr, + self.last_value, + self.last_step + MEMORY_MAX_DIFF * skip_internal as u64, + skip_internal, + ); + + // rows_available is the number of rows after substract "pending" internal rows + self.rows_available = self.instance_rows - internal_rows as u32; + } + } + fn open_instance( + &mut self, + prev_addr: u32, + prev_value: u64, + prev_step: u64, + skip_internal: u32, + ) { + // TODO: add current chunk_id to new instance + self.first_instance = false; + self.current_checkpoint.prev_addr = prev_addr; + self.current_checkpoint.skip_internal = skip_internal; + self.current_checkpoint.prev_step = prev_step; + + // TODO: IMPORTANT review, when change of instance we need to known the previous value on + // write (on read previous value and current must be the same) + self.current_checkpoint.prev_value = prev_value; + + // TODO: add current chunk id + } +} + +impl MemPlanCalculator for MemModulePlanner<'_> { + fn plan(&mut self) { + self.module_plan(); + } + fn collect_plans(&mut self) -> Vec { + std::mem::take(&mut self.instances) + } +} diff --git a/state-machines/mem/src/mem_planner.rs b/state-machines/mem/src/mem_planner.rs new file mode 100644 index 00000000..282d3002 --- /dev/null +++ b/state-machines/mem/src/mem_planner.rs @@ -0,0 +1,69 @@ +use std::sync::Arc; + +use sm_common::{BusDeviceMetrics, ChunkId, Plan, Planner}; +use zisk_pil::{INPUT_DATA_AIR_IDS, MEM_AIR_IDS, ROM_DATA_AIR_IDS, ZISK_AIRGROUP_ID}; + +use crate::{ + MemAlignPlanner, MemCounters, MemModulePlanner, INPUT_DATA_W_ADDR_END, INPUT_DATA_W_ADDR_INIT, + RAM_W_ADDR_END, RAM_W_ADDR_INIT, ROM_DATA_W_ADDR_END, ROM_DATA_W_ADDR_INIT, +}; + +pub trait MemPlanCalculator { + fn plan(&mut self); + fn collect_plans(&mut self) -> Vec; +} + +#[derive(Default)] +pub struct MemPlanner {} + +impl MemPlanner { + pub fn new() -> Self { + Self {} + } +} + +impl Planner for MemPlanner { + fn plan(&self, metrics: Vec<(ChunkId, Box)>) -> Vec { + // convert generic information to specific information + let _counters: Vec<(ChunkId, &MemCounters)> = metrics + .iter() + .map(|(chunk_id, metric)| { + (*chunk_id, metric.as_any().downcast_ref::().unwrap()) + }) + .collect(); + + let counters = Arc::new(_counters); + let mut planners: Vec> = vec![ + Box::new(MemModulePlanner::new( + ZISK_AIRGROUP_ID, + MEM_AIR_IDS[0], + RAM_W_ADDR_INIT, + RAM_W_ADDR_END, + counters.clone(), + )), + Box::new(MemModulePlanner::new( + ZISK_AIRGROUP_ID, + ROM_DATA_AIR_IDS[0], + ROM_DATA_W_ADDR_INIT, + ROM_DATA_W_ADDR_END, + counters.clone(), + )), + Box::new(MemModulePlanner::new( + ZISK_AIRGROUP_ID, + INPUT_DATA_AIR_IDS[0], + INPUT_DATA_W_ADDR_INIT, + INPUT_DATA_W_ADDR_END, + counters.clone(), + )), + Box::new(MemAlignPlanner::new(counters.clone())), + ]; + for item in &mut planners { + item.plan(); + } + let mut plans: Vec = Vec::new(); + for item in &mut planners { + plans.append(&mut item.collect_plans()); + } + plans + } +} diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index f6ee23fd..43142893 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -1,7 +1,8 @@ use std::sync::Arc; use crate::{ - InputDataSM, MemAlignRomSM, MemAlignSM, MemCounters, MemProxyEngine, MemSM, RomDataSM, + InputDataSM, MemAlignRomSM, MemAlignSM, MemCounters, MemPlanner, MemProxyEngine, MemSM, + RomDataSM, }; use p3_field::PrimeField; use pil_std_lib::Std; @@ -55,10 +56,10 @@ impl ComponentProvider for MemProxy { } fn get_planner(&self) -> Box { - unimplemented!("get_planner for MemProxy"); + Box::new(MemPlanner::new()) } - fn get_instance(&self, iectx: InstanceExpanderCtx) -> Box> { + fn get_instance(&self, _iectx: InstanceExpanderCtx) -> Box> { unimplemented!("get_instance for MemProxy"); } fn get_inputs_generator(&self) -> Option>> { diff --git a/state-machines/mem/src/mem_sm.rs b/state-machines/mem/src/mem_sm.rs index ab731726..3df4fb7b 100644 --- a/state-machines/mem/src/mem_sm.rs +++ b/state-machines/mem/src/mem_sm.rs @@ -9,8 +9,8 @@ use proofman_common::{AirInstance, FromTrace}; use zisk_core::{RAM_ADDR, RAM_SIZE}; use zisk_pil::{MemAirValues, MemTrace, MEM_AIR_IDS, ZISK_AIRGROUP_ID}; -const RAM_W_ADDR_INIT: u32 = RAM_ADDR as u32 >> MEM_BYTES_BITS; -const RAM_W_ADDR_END: u32 = (RAM_ADDR + RAM_SIZE - 1) as u32 >> MEM_BYTES_BITS; +pub const RAM_W_ADDR_INIT: u32 = RAM_ADDR as u32 >> MEM_BYTES_BITS; +pub const RAM_W_ADDR_END: u32 = (RAM_ADDR + RAM_SIZE - 1) as u32 >> MEM_BYTES_BITS; const _: () = { // assert!((RAM_SIZE - 1) >> MEM_BYTES_BITS <= MEMORY_MAX_DIFF, "RAM is too large"); @@ -262,6 +262,12 @@ impl MemSM { AirInstance::new_from_trace(FromTrace::new(&mut trace).with_air_values(&mut air_values)) } + pub fn get_from_addr() -> u32 { + RAM_ADDR as u32 + } + pub fn get_to_addr() -> u32 { + (RAM_ADDR + RAM_SIZE - 1) as u32 + } } impl MemModule for MemSM { diff --git a/state-machines/mem/src/rom_data.rs b/state-machines/mem/src/rom_data.rs index 4f93dd6d..e9742b24 100644 --- a/state-machines/mem/src/rom_data.rs +++ b/state-machines/mem/src/rom_data.rs @@ -10,8 +10,8 @@ use proofman_common::{AirInstance, FromTrace}; use zisk_core::{ROM_ADDR, ROM_ADDR_MAX}; use zisk_pil::{RomDataAirValues, RomDataTrace}; -const ROM_W_ADDR: u32 = ROM_ADDR as u32 >> MEM_BYTES_BITS; -const ROM_W_ADDR_END: u32 = ROM_ADDR_MAX as u32 >> MEM_BYTES_BITS; +pub const ROM_DATA_W_ADDR_INIT: u32 = ROM_ADDR as u32 >> MEM_BYTES_BITS; +pub const ROM_DATA_W_ADDR_END: u32 = ROM_ADDR_MAX as u32 >> MEM_BYTES_BITS; const _: () = { // assert!( @@ -63,7 +63,7 @@ impl RomDataSM { let is_last_segment = segment_id == num_segments - 1; let input_offset = segment_id * air_rows; let previous_segment = if (segment_id == 0) { - MemPreviousSegment { addr: ROM_W_ADDR, step: 0, value: 0 } + MemPreviousSegment { addr: ROM_DATA_W_ADDR_INIT, step: 0, value: 0 } } else { MemPreviousSegment { addr: inputs[input_offset - 1].addr, @@ -137,7 +137,7 @@ impl RomDataSM { // range of instance let range_id = self.std.get_range(BigInt::from(1), BigInt::from(MEMORY_MAX_DIFF), None); self.std.range_check( - F::from_canonical_u32(previous_segment.addr - ROM_W_ADDR + 1), + F::from_canonical_u32(previous_segment.addr - ROM_DATA_W_ADDR_INIT + 1), F::one(), range_id, ); @@ -189,7 +189,7 @@ impl RomDataSM { air_values_mem.segment_last_value[1] = (last_value >> 32) as u32; self.std.range_check( - F::from_canonical_u32(ROM_W_ADDR_END - last_addr + 1), + F::from_canonical_u32(ROM_DATA_W_ADDR_END - last_addr + 1), F::one(), range_id, ); @@ -218,6 +218,12 @@ impl RomDataSM { fn get_u32_values(&self, value: u64) -> (u32, u32) { (value as u32, (value >> 32) as u32) } + pub fn get_from_addr() -> u32 { + ROM_DATA_W_ADDR_INIT + } + pub fn get_to_addr() -> u32 { + ROM_DATA_W_ADDR_END + } } impl MemModule for RomDataSM { From 0a27aa5c9d9cac6145d97f92c7db0ecd51faf4df Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Tue, 7 Jan 2025 10:48:39 +0000 Subject: [PATCH 06/10] update branch new-arch-mem with new-arch changes --- state-machines/mem/src/mem_align_planner.rs | 7 ++++--- state-machines/mem/src/mem_module_planner.rs | 5 +++-- state-machines/mem/src/mem_proxy.rs | 14 ++++++-------- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/state-machines/mem/src/mem_align_planner.rs b/state-machines/mem/src/mem_align_planner.rs index 2bdeef94..2f432744 100644 --- a/state-machines/mem/src/mem_align_planner.rs +++ b/state-machines/mem/src/mem_align_planner.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use crate::{MemCounters, MemPlanCalculator}; -use sm_common::{CheckPoint, ChunkId, InstanceType, Plan}; +use sm_common::{CheckPoint, ChunkId, CollectInfoSkip, InstanceType, Plan}; use zisk_pil::{MEM_ALIGN_AIR_IDS, ZISK_AIRGROUP_ID}; pub struct MemAlignPlanner<'a> { @@ -79,16 +79,17 @@ impl<'a> MemAlignPlanner<'a> { return; } // TODO: add multi chunk_id, with skip + let chunks = std::mem::take(&mut self.current_chunks); let instance = Plan::new( ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0], Some(self.instances.len()), InstanceType::Instance, - Some(CheckPoint::new(self.current_chunks[0], self.current_skip as u64)), + CheckPoint::Multiple(chunks), + Some(Box::new(CollectInfoSkip::new(self.current_skip as u64))), None, ); self.instances.push(instance); - self.current_chunks.clear(); } fn open_new_instance(&mut self, next_instance_skip: u32, use_current_chunk_id: bool) { self.current_skip = next_instance_skip; diff --git a/state-machines/mem/src/mem_module_planner.rs b/state-machines/mem/src/mem_module_planner.rs index 782bb718..1a90f203 100644 --- a/state-machines/mem/src/mem_module_planner.rs +++ b/state-machines/mem/src/mem_module_planner.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use crate::{MemCounters, MemPlanCalculator, UsesCounter, MEMORY_MAX_DIFF}; -use sm_common::{ChunkId, InstanceType, Plan}; +use sm_common::{CheckPoint, ChunkId, InstanceType, Plan}; pub struct MemModulePlanner<'a> { airgroup_id: usize, @@ -163,16 +163,17 @@ impl<'a> MemModulePlanner<'a> { // } let checkpoint = std::mem::take(&mut self.current_checkpoint); + let chunks = std::mem::take(&mut self.current_checkpoint_chunks); let instance = Plan::new( self.airgroup_id, self.air_id, Some(self.instances.len()), InstanceType::Instance, + CheckPoint::Multiple(chunks), None, Some(Box::new(checkpoint)), ); self.instances.push(instance); - self.current_checkpoint_chunks.clear(); } fn set_current_chunk_id(&mut self, chunk_id: ChunkId) { if self.current_chunk_id == Some(chunk_id) && !self.current_checkpoint_chunks.is_empty() { diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index 43142893..c276148e 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -6,9 +6,7 @@ use crate::{ }; use p3_field::PrimeField; use pil_std_lib::Std; -use sm_common::{ - BusDeviceInstance, BusDeviceMetrics, ComponentProvider, InstanceExpanderCtx, Planner, -}; +use sm_common::{BusDeviceInstance, BusDeviceMetrics, ComponentBuilder, InstanceCtx, Planner}; use zisk_core::ZiskRequiredMemory; pub struct MemProxy { @@ -49,20 +47,20 @@ impl MemProxy { } } -impl ComponentProvider for MemProxy { - fn get_counter(&self) -> Box { +impl ComponentBuilder for MemProxy { + fn build_counter(&self) -> Box { Box::new(MemCounters::new()) // Box::new(MemCounters::new(OPERATION_BUS_ID, vec![zisk_core::ZiskOperationType::Arith])) } - fn get_planner(&self) -> Box { + fn build_planner(&self) -> Box { Box::new(MemPlanner::new()) } - fn get_instance(&self, _iectx: InstanceExpanderCtx) -> Box> { + fn build_inputs_collector(&self, _iectx: InstanceCtx) -> Box> { unimplemented!("get_instance for MemProxy"); } - fn get_inputs_generator(&self) -> Option>> { + fn build_inputs_generator(&self) -> Option>> { unimplemented!("get_instance for MemProxy"); } } From 2f36a0e027d9e242d685d510563fced5b8e9060e Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Thu, 9 Jan 2025 21:40:39 +0000 Subject: [PATCH 07/10] memories with new arch --- common/src/data_bus_mem.rs | 45 ++ common/src/lib.rs | 2 + emulator/src/emu.rs | 36 +- state-machines/binary/src/binary_basic.rs | 30 +- state-machines/mem/pil/mem.pil | 21 +- state-machines/mem/src/input_data_sm.rs | 143 ++-- state-machines/mem/src/lib.rs | 12 +- state-machines/mem/src/mem_align_instance.rs | 77 ++ state-machines/mem/src/mem_align_planner.rs | 15 +- state-machines/mem/src/mem_align_sm.rs | 729 ++++++++---------- state-machines/mem/src/mem_constants.rs | 9 +- state-machines/mem/src/mem_counters.rs | 30 +- state-machines/mem/src/mem_helpers.rs | 100 ++- state-machines/mem/src/mem_module.rs | 17 +- state-machines/mem/src/mem_module_instance.rs | 174 +++++ state-machines/mem/src/mem_module_planner.rs | 30 +- state-machines/mem/src/mem_proxy.rs | 44 +- state-machines/mem/src/mem_proxy_engine.rs | 34 - state-machines/mem/src/mem_sm.rs | 181 ++--- state-machines/mem/src/rom_data.rs | 135 ++-- 20 files changed, 1046 insertions(+), 818 deletions(-) create mode 100644 common/src/data_bus_mem.rs create mode 100644 state-machines/mem/src/mem_align_instance.rs create mode 100644 state-machines/mem/src/mem_module_instance.rs diff --git a/common/src/data_bus_mem.rs b/common/src/data_bus_mem.rs new file mode 100644 index 00000000..3a1c7cb7 --- /dev/null +++ b/common/src/data_bus_mem.rs @@ -0,0 +1,45 @@ +pub const MEM_BUS_ID: u16 = 10; + +pub const MEM_BUS_DATA_SIZE: usize = 5; + +const OP: usize = 0; +const ADDR: usize = 1; +const STEP: usize = 2; +const BYTES: usize = 3; +const MEM_VALUE_0: usize = 4; +const MEM_VALUE_1: usize = 5; +const VALUE: usize = 6; + +pub struct MemBusData; + +impl MemBusData { + #[inline(always)] + pub fn get_addr(data: &[u64]) -> u32 { + data[ADDR] as u32 + } + + #[inline(always)] + pub fn get_op(data: &[u64]) -> u8 { + data[OP] as u8 + } + + #[inline(always)] + pub fn get_step(data: &[u64]) -> u64 { + data[STEP] + } + + #[inline(always)] + pub fn get_bytes(data: &[u64]) -> u8 { + data[BYTES] as u8 + } + + #[inline(always)] + pub fn get_value(data: &[u64]) -> u64 { + data[VALUE] + } + + #[inline(always)] + pub fn get_mem_values(data: &[u64]) -> [u64; 2] { + [data[MEM_VALUE_0], data[MEM_VALUE_1]] + } +} diff --git a/common/src/lib.rs b/common/src/lib.rs index d5a23b38..efe38a85 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -1,7 +1,9 @@ mod data_bus; +mod data_bus_mem; mod data_bus_operation; mod data_bus_rom; pub use data_bus::*; +pub use data_bus_mem::*; pub use data_bus_operation::*; pub use data_bus_rom::*; diff --git a/emulator/src/emu.rs b/emulator/src/emu.rs index 52fddf70..fe8a0f79 100644 --- a/emulator/src/emu.rs +++ b/emulator/src/emu.rs @@ -6,7 +6,9 @@ use crate::{ }; use p3_field::{AbstractField, PrimeField}; use riscv::RiscVRegisters; -use zisk_common::{BusDevice, OperationBusData, RomBusData, OPERATION_BUS_ID, ROM_BUS_ID}; +use zisk_common::{ + BusDevice, OperationBusData, RomBusData, MEM_BUS_ID, OPERATION_BUS_ID, ROM_BUS_ID, +}; // #[cfg(feature = "sp")] // use zisk_core::SRC_SP; use zisk_common::DataBus; @@ -313,7 +315,7 @@ impl<'a> Emu<'a> { 8, [self.ctx.inst_ctx.a, 0], ); - data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + data_bus.write_to_bus(MEM_BUS_ID, payload.to_vec()); } else { let (required_address_1, required_address_2) = Mem::required_addresses(address, 8); @@ -333,7 +335,7 @@ impl<'a> Emu<'a> { 8, [raw_data_1, raw_data_2], ); - data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + data_bus.write_to_bus(MEM_BUS_ID, payload.to_vec()); } /*println!( "Emu::source_a_mem_reads_consume() mem_leads_index={} value={:x}", @@ -649,7 +651,7 @@ impl<'a> Emu<'a> { 8, [self.ctx.inst_ctx.b, 0], ); - data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + data_bus.write_to_bus(MEM_BUS_ID, payload.to_vec()); } // Otherwise, get it from memory else if Mem::is_full_aligned(address, 8) { @@ -663,7 +665,7 @@ impl<'a> Emu<'a> { 8, [self.ctx.inst_ctx.b, 0], ); - data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + data_bus.write_to_bus(MEM_BUS_ID, payload.to_vec()); } else { let (required_address_1, required_address_2) = Mem::required_addresses(address, 8); @@ -680,7 +682,7 @@ impl<'a> Emu<'a> { 8, [self.ctx.inst_ctx.b, 0], ); - data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + data_bus.write_to_bus(MEM_BUS_ID, payload.to_vec()); } else { assert!(*mem_reads_index < mem_reads.len()); let raw_data_1 = mem_reads[*mem_reads_index]; @@ -697,7 +699,7 @@ impl<'a> Emu<'a> { 8, [raw_data_1, raw_data_2], ); - data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + data_bus.write_to_bus(MEM_BUS_ID, payload.to_vec()); } } /*println!( @@ -731,7 +733,7 @@ impl<'a> Emu<'a> { 8, [self.ctx.inst_ctx.b, 0], ); - data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + data_bus.write_to_bus(MEM_BUS_ID, payload.to_vec()); } // Otherwise, get it from memory else if Mem::is_full_aligned(address, instruction.ind_width) { @@ -745,7 +747,7 @@ impl<'a> Emu<'a> { 8, [self.ctx.inst_ctx.b, 0], ); - data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + data_bus.write_to_bus(MEM_BUS_ID, payload.to_vec()); } else { let (required_address_1, required_address_2) = Mem::required_addresses(address, instruction.ind_width); @@ -765,7 +767,7 @@ impl<'a> Emu<'a> { instruction.ind_width as u8, [raw_data, 0], ); - data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + data_bus.write_to_bus(MEM_BUS_ID, payload.to_vec()); } else { assert!(*mem_reads_index < mem_reads.len()); let raw_data_1 = mem_reads[*mem_reads_index]; @@ -786,7 +788,7 @@ impl<'a> Emu<'a> { 8, [raw_data_1, raw_data_2], ); - data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + data_bus.write_to_bus(MEM_BUS_ID, payload.to_vec()); } } /*println!( @@ -1092,7 +1094,7 @@ impl<'a> Emu<'a> { value, [value, 0], ); - data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + data_bus.write_to_bus(MEM_BUS_ID, payload.to_vec()); } // Otherwise, if not aligned, get old raw data from memory, then write it else if !Mem::is_full_aligned(address, 8) { @@ -1111,7 +1113,7 @@ impl<'a> Emu<'a> { value, [raw_data, 0], ); - data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + data_bus.write_to_bus(MEM_BUS_ID, payload.to_vec()); } else { assert!(*mem_reads_index < mem_reads.len()); let raw_data_1 = mem_reads[*mem_reads_index]; @@ -1128,7 +1130,7 @@ impl<'a> Emu<'a> { value, [raw_data_1, raw_data_2], ); - data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + data_bus.write_to_bus(MEM_BUS_ID, payload.to_vec()); } } } @@ -1163,7 +1165,7 @@ impl<'a> Emu<'a> { value, [value, 0], ); - data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + data_bus.write_to_bus(MEM_BUS_ID, payload.to_vec()); } // Otherwise, if not aligned, get old raw data from memory, then write it else if !Mem::is_full_aligned(address, instruction.ind_width) { @@ -1182,7 +1184,7 @@ impl<'a> Emu<'a> { value, [raw_data, 0], ); - data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + data_bus.write_to_bus(MEM_BUS_ID, payload.to_vec()); } else { assert!(*mem_reads_index < mem_reads.len()); let raw_data_1 = mem_reads[*mem_reads_index]; @@ -1199,7 +1201,7 @@ impl<'a> Emu<'a> { value, [raw_data_1, raw_data_2], ); - data_bus.write_to_bus(OPERATION_BUS_ID, payload.to_vec()); + data_bus.write_to_bus(MEM_BUS_ID, payload.to_vec()); } } } diff --git a/state-machines/binary/src/binary_basic.rs b/state-machines/binary/src/binary_basic.rs index 15c891c8..67b5eee5 100644 --- a/state-machines/binary/src/binary_basic.rs +++ b/state-machines/binary/src/binary_basic.rs @@ -181,9 +181,9 @@ impl BinaryBasicSM { } // If the chunk is signed, then the result is the sign of a - if (binary_basic_table_op == BinaryBasicTableOp::Min) - && (plast[i] == 1) - && (a_bytes[i] & 0x80) != (b_bytes[i] & 0x80) + if (binary_basic_table_op == BinaryBasicTableOp::Min) && + (plast[i] == 1) && + (a_bytes[i] & 0x80) != (b_bytes[i] & 0x80) { cout = if (a_bytes[i] & 0x80) != 0 { 1 } else { 0 }; } @@ -248,9 +248,9 @@ impl BinaryBasicSM { } // If the chunk is signed, then the result is the sign of a - if (binary_basic_table_op == BinaryBasicTableOp::Max) - && (plast[i] == 1) - && (a_bytes[i] & 0x80) != (b_bytes[i] & 0x80) + if (binary_basic_table_op == BinaryBasicTableOp::Max) && + (plast[i] == 1) && + (a_bytes[i] & 0x80) != (b_bytes[i] & 0x80) { cout = if (a_bytes[i] & 0x80) != 0 { 0 } else { 1 }; } @@ -425,9 +425,9 @@ impl BinaryBasicSM { } // If the chunk is signed, then the result is the sign of a - if (binary_basic_table_op.eq(&BinaryBasicTableOp::Lt)) - && (plast[i] == 1) - && (a_bytes[i] & 0x80) != (b_bytes[i] & 0x80) + if (binary_basic_table_op.eq(&BinaryBasicTableOp::Lt)) && + (plast[i] == 1) && + (a_bytes[i] & 0x80) != (b_bytes[i] & 0x80) { cout = if a_bytes[i] & 0x80 != 0 { 1 } else { 0 }; } @@ -667,9 +667,9 @@ impl BinaryBasicSM { if a_bytes[i] <= b_bytes[i] { cout = 1; } - if (binary_basic_table_op == BinaryBasicTableOp::Le) - && (plast[i] == 1) - && (a_bytes[i] & 0x80) != (b_bytes[i] & 0x80) + if (binary_basic_table_op == BinaryBasicTableOp::Le) && + (plast[i] == 1) && + (a_bytes[i] & 0x80) != (b_bytes[i] & 0x80) { cout = c; } @@ -813,9 +813,9 @@ impl BinaryBasicSM { // Set free_in_a_or_c and free_in_b_or_zero for i in 0..HALF_BYTES { - row.free_in_a_or_c[i] = mode64 - * (row.free_in_a[i + HALF_BYTES] - row.free_in_c[HALF_BYTES - 1]) - + row.free_in_c[HALF_BYTES - 1]; + row.free_in_a_or_c[i] = mode64 * + (row.free_in_a[i + HALF_BYTES] - row.free_in_c[HALF_BYTES - 1]) + + row.free_in_c[HALF_BYTES - 1]; row.free_in_b_or_zero[i] = mode64 * row.free_in_b[i + HALF_BYTES]; } diff --git a/state-machines/mem/pil/mem.pil b/state-machines/mem/pil/mem.pil index 27052d91..c8be2c9c 100644 --- a/state-machines/mem/pil/mem.pil +++ b/state-machines/mem/pil/mem.pil @@ -90,17 +90,24 @@ airtemplate Mem(const int N = 2**21, const int id = MEMORY_ID, const int RC = 2, } } - if (!immutable) { + + if (free_input_mem) { + // free input memory must be read-only + const int air.wr = 0; + } else if (inmutable) { + // immutable memory in address change, must be write first or a read of zero. + const expr air.wr = addr_changes * sel; + for (int i = 0; i < RC; i++) { + // addr_changes * (1 - wr) * value[i] === 0; + addr_changes * (1 - sel) * value[i] === 0; + } + } else { col witness air.wr; const expr air.rd = 1 - wr; wr * (1 - wr) === 0; - } else { - // a free input memory must be read-only, an immutable memory must be write - // on first row of new address (addr_changes = 1) - const expr air.wr = free_input_mem ? 0 : addr_changes; + // if wr is 1, sel must be 1 (not allowed writes) + wr * (1 - sel) === 0; } - // if wr is 1, sel must be 1 (not allowed writes) - wr * (1 - sel) === 0; sel * (1 - sel) === 0; diff --git a/state-machines/mem/src/input_data_sm.rs b/state-machines/mem/src/input_data_sm.rs index b525f1bf..1c7b8fdc 100644 --- a/state-machines/mem/src/input_data_sm.rs +++ b/state-machines/mem/src/input_data_sm.rs @@ -8,7 +8,7 @@ use p3_field::PrimeField; use pil_std_lib::Std; use proofman_common::{AirInstance, FromTrace}; use zisk_core::{INPUT_ADDR, MAX_INPUT_SIZE}; -use zisk_pil::{InputDataAirValues, InputDataTrace, ZiskProofValues}; +use zisk_pil::{InputDataAirValues, InputDataTrace}; pub const INPUT_DATA_W_ADDR_INIT: u32 = INPUT_ADDR as u32 >> MEM_BYTES_BITS; pub const INPUT_DATA_W_ADDR_END: u32 = (INPUT_ADDR + MAX_INPUT_SIZE - 1) as u32 >> MEM_BYTES_BITS; @@ -36,62 +36,21 @@ impl InputDataSM { pub fn new(std: Arc>) -> Arc { Arc::new(Self { std: std.clone() }) } - - pub fn prove(&self, inputs: &[MemInput]) { - let mut proof_values = ZiskProofValues::from_vec_guard(self.std.pctx.get_proof_values()); - proof_values.enable_input_data = if inputs.is_empty() { F::zero() } else { F::one() }; - - // PRE: proxy calculate if exists jmp on step out-of-range, adding internal inputs - // memory only need to process these special inputs, but inputs no change. At end of - // inputs proxy add an extra internal input to jump to last address - - let airgroup_id = InputDataTrace::::AIRGROUP_ID; - let air_id = InputDataTrace::::AIR_ID; - let air_rows = InputDataTrace::::NUM_ROWS; - - // at least one row to go - let count = inputs.len(); - let count_rem = count % air_rows; - let num_segments = (count / air_rows) + if count_rem > 0 { 1 } else { 0 }; - - let mut global_idxs = vec![0; num_segments]; - - #[allow(clippy::needless_range_loop)] - for i in 0..num_segments { - // TODO: Review - if let (true, global_idx) = - self.std.pctx.dctx.write().unwrap().add_instance(airgroup_id, air_id, 1) - { - global_idxs[i] = global_idx; - } - } - - #[allow(clippy::needless_range_loop)] - for segment_id in 0..num_segments { - let is_last_segment = segment_id == num_segments - 1; - let input_offset = segment_id * air_rows; - let previous_segment = if (segment_id == 0) { - MemPreviousSegment { addr: INPUT_DATA_W_ADDR_INIT, step: 0, value: 0 } - } else { - MemPreviousSegment { - addr: inputs[input_offset - 1].addr, - step: inputs[input_offset - 1].step, - value: inputs[input_offset - 1].value, - } - }; - let input_end = - if (input_offset + air_rows) > count { count } else { input_offset + air_rows }; - let mem_ops = &inputs[input_offset..input_end]; - - let air_instance = - self.prove_instance(mem_ops, segment_id, is_last_segment, &previous_segment); - - self.std - .pctx - .air_instance_repo - .add_air_instance(air_instance, Some(global_idxs[segment_id])); - } + fn get_u16_values(&self, value: u64) -> [u16; 4] { + [value as u16, (value >> 16) as u16, (value >> 32) as u16, (value >> 48) as u16] + } + pub fn get_from_addr() -> u32 { + INPUT_ADDR as u32 } + pub fn get_to_addr() -> u32 { + (INPUT_ADDR + MAX_INPUT_SIZE - 1) as u32 + } +} + +impl MemModule for InputDataSM { + // TODO PRE: proxy calculate if exists jmp on step out-of-range, adding internal inputs + // memory only need to process these special inputs, but inputs no change. At end of + // inputs proxy add an extra internal input to jump to last address /// Finalizes the witness accumulation process and triggers the proof generation. /// @@ -100,7 +59,7 @@ impl InputDataSM { /// # Parameters /// /// - `mem_inputs`: A slice of all `ZiskRequiredMemory` inputs - pub fn prove_instance( + fn prove_instance( &self, mem_ops: &[MemInput], segment_id: usize, @@ -155,15 +114,50 @@ impl InputDataSM { range_id, ); - // Fill the remaining rows let mut last_addr: u32 = previous_segment.addr; let mut last_step: u64 = previous_segment.step; let mut last_value: u64 = previous_segment.value; - for (i, mem_op) in mem_ops.iter().enumerate() { + let mut i = 0; + for mem_op in mem_ops.iter() { + let mut internal_reads = (mem_op.addr - last_addr) - 1; + if internal_reads > 1 { + // check if has enough rows to complete the internal reads + regular memory + let incomplete = (i + internal_reads as usize) >= trace.num_rows; + if incomplete { + internal_reads = (trace.num_rows - i) as u32; + } + + trace[i].addr_changes = F::one(); + last_addr += 1; + trace[i].addr = F::from_canonical_u32(last_addr); + + // the step, value of internal reads isn't relevant + last_step = 0; + trace[i].step = F::zero(); + trace[i].sel = F::zero(); + + // setting value to zero, is not relevant for internal reads + last_value = 0; + for j in 0..4 { + trace[i].value_word[j] = F::zero(); + } + i += 1; + + for _j in 1..internal_reads { + trace[i] = trace[i - 1]; + last_addr += 1; + trace[i].addr = F::from_canonical_u32(last_addr); + i += 1; + } + range_check_data[0] += 4 * internal_reads as u64; + if incomplete { + break; + } + } trace[i].addr = F::from_canonical_u32(mem_op.addr); trace[i].step = F::from_canonical_u64(mem_op.step); - trace[i].sel = F::from_bool(!mem_op.is_internal); + trace[i].sel = F::one(); let value = mem_op.value; let value_words = self.get_u16_values(value); @@ -179,16 +173,18 @@ impl InputDataSM { last_addr = mem_op.addr; last_step = mem_op.step; last_value = mem_op.value; + i += 1; } + let count = i; // STEP3. Add dummy rows to the output vector to fill the remaining rows //PADDING: At end of memory fill with same addr, incrementing step, same value, sel = 0 - let last_row_idx = mem_ops.len() - 1; + let last_row_idx = count - 1; let addr = trace[last_row_idx].addr; let value = trace[last_row_idx].value_word; - let padding_size = trace.num_rows() - mem_ops.len(); - for i in mem_ops.len()..trace.num_rows() { + let padding_size = trace.num_rows() - count; + for i in count..trace.num_rows() { last_step += 1; // TODO CHECK @@ -218,7 +214,7 @@ impl InputDataSM { // range of chunks let range_id = self.std.get_range(BigInt::from(0), BigInt::from((1 << 16) - 1), None); for (value, &multiplicity) in range_check_data.iter().enumerate() { - if (multiplicity == 0) { + if multiplicity == 0 { continue; } @@ -253,26 +249,7 @@ impl InputDataSM { AirInstance::new_from_trace(FromTrace::new(&mut trace).with_air_values(&mut air_values)) } - fn get_u16_values(&self, value: u64) -> [u16; 4] { - [value as u16, (value >> 16) as u16, (value >> 32) as u16, (value >> 48) as u16] - } - pub fn get_from_addr() -> u32 { - INPUT_ADDR as u32 - } - pub fn get_to_addr() -> u32 { - (INPUT_ADDR + MAX_INPUT_SIZE - 1) as u32 - } -} - -impl MemModule for InputDataSM { - fn send_inputs(&self, mem_op: &[MemInput]) { - self.prove(mem_op); - } fn get_addr_ranges(&self) -> Vec<(u32, u32)> { vec![(INPUT_ADDR as u32, (INPUT_ADDR + MAX_INPUT_SIZE - 1) as u32)] } - fn get_flush_input_size(&self) -> u32 { - // self.num_rows as u32 - 0 - } } diff --git a/state-machines/mem/src/lib.rs b/state-machines/mem/src/lib.rs index 369b1e6d..a44e9034 100644 --- a/state-machines/mem/src/lib.rs +++ b/state-machines/mem/src/lib.rs @@ -1,4 +1,5 @@ mod input_data_sm; +mod mem_align_instance; mod mem_align_planner; mod mem_align_rom_sm; mod mem_align_sm; @@ -7,15 +8,17 @@ mod mem_constants; mod mem_counters; mod mem_helpers; mod mem_module; +mod mem_module_instance; mod mem_module_planner; mod mem_planner; mod mem_proxy; -mod mem_proxy_engine; +// mod mem_proxy_engine; mod mem_sm; -mod mem_unmapped; +// mod mem_unmapped; mod rom_data; pub use input_data_sm::*; +pub use mem_align_instance::*; pub use mem_align_planner::*; pub use mem_align_rom_sm::*; pub use mem_align_sm::*; @@ -24,10 +27,11 @@ pub use mem_constants::*; pub use mem_counters::*; pub use mem_helpers::*; pub use mem_module::*; +pub use mem_module_instance::*; pub use mem_module_planner::*; pub use mem_planner::*; pub use mem_proxy::*; -pub use mem_proxy_engine::*; +// pub use mem_proxy_engine::*; pub use mem_sm::*; -pub use mem_unmapped::*; +//pub use mem_unmapped::*; pub use rom_data::*; diff --git a/state-machines/mem/src/mem_align_instance.rs b/state-machines/mem/src/mem_align_instance.rs new file mode 100644 index 00000000..815696a9 --- /dev/null +++ b/state-machines/mem/src/mem_align_instance.rs @@ -0,0 +1,77 @@ +use crate::{MemAlignCheckPoint, MemAlignInput, MemAlignSM, MemHelpers}; +use p3_field::PrimeField; +use proofman_common::{AirInstance, ProofCtx}; +use sm_common::{CheckPoint, Instance, InstanceCtx, InstanceType}; +use std::sync::Arc; +use zisk_common::{BusDevice, BusId, MemBusData}; + +pub struct MemAlignInstance { + checkpoint: MemAlignCheckPoint, + /// Instance context + ictx: InstanceCtx, + + /// Collected inputs + inputs: Vec, + mem_align_sm: Arc>, + pending_count: u32, + skip_pending: u32, +} + +impl MemAlignInstance { + pub fn new(mem_align_sm: Arc>, ictx: InstanceCtx) -> Self { + let checkpoint = + ictx.plan.meta.as_ref().unwrap().downcast_ref::().unwrap().clone(); + + Self { + ictx, + inputs: Vec::new(), + mem_align_sm, + checkpoint: checkpoint.clone(), + skip_pending: checkpoint.skip, + pending_count: checkpoint.count, + } + } +} + +impl Instance for MemAlignInstance { + fn compute_witness(&mut self, _pctx: &ProofCtx) -> Option> { + Some(self.mem_align_sm.prove_instance(&self.inputs, self.checkpoint.rows)) + } + + fn check_point(&self) -> CheckPoint { + self.ictx.plan.check_point.clone() + } + + fn instance_type(&self) -> InstanceType { + InstanceType::Instance + } +} + +impl BusDevice for MemAlignInstance { + fn process_data(&mut self, _bus_id: &BusId, data: &[u64]) -> (bool, Vec<(BusId, Vec)>) { + let addr = MemBusData::get_addr(data); + let bytes = MemBusData::get_bytes(data); + if !MemHelpers::is_aligned(addr, bytes) { + return (false, vec![]) + } + if self.skip_pending > 0 { + self.skip_pending -= 1; + return (false, vec![]) + } + + if self.pending_count == 0 { + return (true, vec![]) + } + self.pending_count -= 1; + self.inputs.push(MemAlignInput { + addr: MemBusData::get_addr(data), + is_write: MemHelpers::is_write(MemBusData::get_op(data)), + width: MemBusData::get_bytes(data), + step: MemBusData::get_step(data), + value: MemBusData::get_value(data), + mem_values: MemBusData::get_mem_values(data), + }); + + (false, vec![]) + } +} diff --git a/state-machines/mem/src/mem_align_planner.rs b/state-machines/mem/src/mem_align_planner.rs index 2f432744..2b9f90ab 100644 --- a/state-machines/mem/src/mem_align_planner.rs +++ b/state-machines/mem/src/mem_align_planner.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use crate::{MemCounters, MemPlanCalculator}; -use sm_common::{CheckPoint, ChunkId, CollectInfoSkip, InstanceType, Plan}; +use sm_common::{CheckPoint, ChunkId, InstanceType, Plan}; use zisk_pil::{MEM_ALIGN_AIR_IDS, ZISK_AIRGROUP_ID}; pub struct MemAlignPlanner<'a> { @@ -14,6 +14,13 @@ pub struct MemAlignPlanner<'a> { counters: Arc>, } +#[derive(Clone)] +pub struct MemAlignCheckPoint { + pub skip: u32, + pub count: u32, + pub rows: u32, +} + // TODO: dynamic const MEM_ALIGN_ROWS: usize = 1 << 21; @@ -86,7 +93,11 @@ impl<'a> MemAlignPlanner<'a> { Some(self.instances.len()), InstanceType::Instance, CheckPoint::Multiple(chunks), - Some(Box::new(CollectInfoSkip::new(self.current_skip as u64))), + Some(Box::new(MemAlignCheckPoint { + skip: self.current_skip, + count: 0, + rows: self.num_rows - self.current_rows_available, + })), None, ); self.instances.push(instance); diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index bf7e522a..d6b62403 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -1,5 +1,4 @@ -use core::panic; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use log::info; use num_bigint::BigInt; @@ -7,6 +6,7 @@ use num_traits::cast::ToPrimitive; use p3_field::PrimeField; use pil_std_lib::Std; +use proofman_common::{AirInstance, FromTrace}; use zisk_pil::{MemAlignTrace, MemAlignTraceRow}; use crate::{MemAlignInput, MemAlignRomSM, MemOp}; @@ -43,10 +43,7 @@ pub struct MemAlignResponse { } pub struct MemAlignSM { /// PIL2 standard library - std: Arc>, - - // Computed row information - rows: Mutex>>, + _std: Arc>, #[cfg(feature = "debug_mem_align")] num_computed_rows: Mutex, @@ -69,16 +66,19 @@ impl MemAlignSM { pub fn new(std: Arc>, mem_align_rom_sm: Arc) -> Arc { Arc::new(Self { - std: std.clone(), - rows: Mutex::new(Vec::new()), + _std: std.clone(), #[cfg(feature = "debug_mem_align")] num_computed_rows: Mutex::new(0), mem_align_rom_sm, }) } - #[inline(always)] - pub fn get_mem_op(&self, input: &MemAlignInput, phase: usize) -> MemAlignResponse { + pub fn prove_mem_align_op( + &self, + input: &MemAlignInput, + trace: &mut MemAlignTrace, + index: usize, + ) -> usize { let addr = input.addr; let width = input.width; @@ -114,8 +114,6 @@ impl MemAlignSM { | V6 | V7 | V0 | V1 | V2 | V3 | V4 | V5 | +----+----+====+====+====+====+----+----+ */ - debug_assert!(phase == 0); - // Unaligned memory op information thrown into the bus let step = input.step; let value = input.value; @@ -124,7 +122,7 @@ impl MemAlignSM { let addr_read = addr >> OFFSET_BITS; // Get the aligned value - let value_read = input.mem_values[phase]; + let value_read = input.mem_values[0]; // Get the next pc let next_pc = @@ -209,9 +207,9 @@ impl MemAlignSM { drop(num_rows); // Prove the generated rows - self.prove(&[read_row, value_row]); - - MemAlignResponse { more_addr: false, step, value: None } + trace[index] = read_row; + trace[index + 1] = value_row; + 2 } (true, false) => { /* RWV with offset=3, width=4 @@ -227,7 +225,6 @@ impl MemAlignSM { | V5 | V6 | V7 | V0 | V1 | V2 | V3 | V4 | +----+----+----+====+====+====+====+----+ */ - debug_assert!(phase == 0); // Unaligned memory op information thrown into the bus let step = input.step; @@ -237,7 +234,7 @@ impl MemAlignSM { let addr_read = addr >> OFFSET_BITS; // Get the aligned value - let value_read = input.mem_values[phase]; + let value_read = input.mem_values[0]; // Get the next pc let next_pc = @@ -371,9 +368,10 @@ impl MemAlignSM { drop(num_rows); // Prove the generated rows - self.prove(&[read_row, write_row, value_row]); - - MemAlignResponse { more_addr: false, step, value: Some(value_write) } + trace[index] = read_row; + trace[index + 1] = write_row; + trace[index + 2] = value_row; + 3 } (false, true) => { /* RVR with offset=5, width=8 @@ -389,108 +387,99 @@ impl MemAlignSM { | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | +====+====+====+====+====+----+----+----+ */ - debug_assert!(phase == 0 || phase == 1); - - match phase { - // If phase == 0, do nothing, just ask for more - 0 => MemAlignResponse { more_addr: true, step: input.step, value: None }, - - // Otherwise, do the RVR - 1 => { - // Unaligned memory op information thrown into the bus - let step = input.step; - let value = input.value; - - // Compute the remaining bytes - let rem_bytes = (offset + width) % CHUNK_NUM; - - // Get the aligned address - let addr_first_read = addr >> OFFSET_BITS; - let addr_second_read = addr_first_read + 1; - - // Get the aligned value - let value_first_read = input.mem_values[0]; - let value_second_read = input.mem_values[1]; - - // Get the next pc - let next_pc = - self.mem_align_rom_sm.calculate_next_pc(MemOp::TwoReads, offset, width); - - let mut first_read_row = MemAlignTraceRow:: { - step: F::from_canonical_u64(step), - addr: F::from_canonical_u32(addr_first_read), - // delta_addr: F::zero(), - offset: F::from_canonical_u64(DEFAULT_OFFSET), - width: F::from_canonical_u64(DEFAULT_WIDTH), - // wr: F::from_bool(false), - // pc: F::from_canonical_u64(0), - reset: F::from_bool(true), - sel_up_to_down: F::from_bool(true), - ..Default::default() - }; - - let mut value_row = MemAlignTraceRow:: { - step: F::from_canonical_u64(step), - addr: F::from_canonical_u32(addr_first_read), - // delta_addr: F::zero(), - offset: F::from_canonical_usize(offset), - width: F::from_canonical_usize(width), - // wr: F::from_bool(false), - pc: F::from_canonical_u64(next_pc), - // reset: F::from_bool(false), - sel_prove: F::from_bool(true), - ..Default::default() - }; - - let mut second_read_row = MemAlignTraceRow:: { - step: F::from_canonical_u64(step), - addr: F::from_canonical_u32(addr_second_read), - delta_addr: F::one(), - offset: F::from_canonical_u64(DEFAULT_OFFSET), - width: F::from_canonical_u64(DEFAULT_WIDTH), - // wr: F::from_bool(false), - pc: F::from_canonical_u64(next_pc + 1), - // reset: F::from_bool(false), - sel_down_to_up: F::from_bool(true), - ..Default::default() - }; - - for i in 0..CHUNK_NUM { - first_read_row.reg[i] = - F::from_canonical_u64(Self::get_byte(value_first_read, i, 0)); - if i >= offset { - first_read_row.sel[i] = F::from_bool(true); - } - - value_row.reg[i] = - F::from_canonical_u64(Self::get_byte(value, i, CHUNK_NUM - offset)); - - if i == offset { - value_row.sel[i] = F::from_bool(true); - } - - second_read_row.reg[i] = - F::from_canonical_u64(Self::get_byte(value_second_read, i, 0)); - if i < rem_bytes { - second_read_row.sel[i] = F::from_bool(true); - } - } - let mut _value_first_read = value_first_read; - let mut _value = value; - let mut _value_second_read = value_second_read; - for i in 0..RC { - first_read_row.value[i] = - F::from_canonical_u64(_value_first_read & RC_MASK); - value_row.value[i] = F::from_canonical_u64(_value & RC_MASK); - second_read_row.value[i] = - F::from_canonical_u64(_value_second_read & RC_MASK); - _value_first_read >>= RC_BITS; - _value >>= RC_BITS; - _value_second_read >>= RC_BITS; - } + // Unaligned memory op information thrown into the bus + let step = input.step; + let value = input.value; + + // Compute the remaining bytes + let rem_bytes = (offset + width) % CHUNK_NUM; + + // Get the aligned address + let addr_first_read = addr >> OFFSET_BITS; + let addr_second_read = addr_first_read + 1; + + // Get the aligned value + let value_first_read = input.mem_values[0]; + let value_second_read = input.mem_values[1]; + + // Get the next pc + let next_pc = + self.mem_align_rom_sm.calculate_next_pc(MemOp::TwoReads, offset, width); + + let mut first_read_row = MemAlignTraceRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_first_read), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + // wr: F::from_bool(false), + // pc: F::from_canonical_u64(0), + reset: F::from_bool(true), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; - #[rustfmt::skip] + let mut value_row = MemAlignTraceRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_first_read), + // delta_addr: F::zero(), + offset: F::from_canonical_usize(offset), + width: F::from_canonical_usize(width), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc), + // reset: F::from_bool(false), + sel_prove: F::from_bool(true), + ..Default::default() + }; + + let mut second_read_row = MemAlignTraceRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_second_read), + delta_addr: F::one(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc + 1), + // reset: F::from_bool(false), + sel_down_to_up: F::from_bool(true), + ..Default::default() + }; + + for i in 0..CHUNK_NUM { + first_read_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_first_read, i, 0)); + if i >= offset { + first_read_row.sel[i] = F::from_bool(true); + } + + value_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value, i, CHUNK_NUM - offset)); + + if i == offset { + value_row.sel[i] = F::from_bool(true); + } + + second_read_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_second_read, i, 0)); + if i < rem_bytes { + second_read_row.sel[i] = F::from_bool(true); + } + } + + let mut _value_first_read = value_first_read; + let mut _value = value; + let mut _value_second_read = value_second_read; + for i in 0..RC { + first_read_row.value[i] = F::from_canonical_u64(_value_first_read & RC_MASK); + value_row.value[i] = F::from_canonical_u64(_value & RC_MASK); + second_read_row.value[i] = F::from_canonical_u64(_value_second_read & RC_MASK); + _value_first_read >>= RC_BITS; + _value >>= RC_BITS; + _value_second_read >>= RC_BITS; + } + + #[rustfmt::skip] debug_info!( "\nTwo Words Read\n\ Num Rows: {:?}\n\ @@ -525,16 +514,14 @@ impl MemAlignSM { ] ); - #[cfg(feature = "debug_mem_align")] - drop(num_rows); - - // Prove the generated rows - self.prove(&[first_read_row, value_row, second_read_row]); + #[cfg(feature = "debug_mem_align")] + drop(num_rows); - MemAlignResponse { more_addr: false, step, value: None } - } - _ => panic!("Invalid phase={}", phase), - } + // Prove the generated rows + trace[index] = first_read_row; + trace[index + 1] = value_row; + trace[index + 2] = second_read_row; + 3 } (true, true) => { /* RWVWR with offset=6, width=4 @@ -558,227 +545,184 @@ impl MemAlignSM { | R0 | R1 | R2 | R3 | R4 | R5 | R6 | R7 | +====+====+----+----+----+----+----+----+ */ - debug_assert!(phase == 0 || phase == 1); - - match phase { - // If phase == 0, compute the resulting write value and ask for more - 0 => { - // Unaligned memory op information thrown into the bus - let value = input.value; - let step = input.step; + // Unaligned memory op information thrown into the bus + let step = input.step; + let value = input.value; - // Get the aligned value - let value_first_read = input.mem_values[0]; + // Compute the shift + let rem_bytes = (offset + width) % CHUNK_NUM; - // Compute the write value - let value_first_write = { - // Normalize the width - let width_norm = CHUNK_NUM - offset; + // Get the aligned address + let addr_first_read_write = addr >> OFFSET_BITS; + let addr_second_read_write = addr_first_read_write + 1; - let width_bytes: u64 = (1 << (width_norm * CHUNK_BITS)) - 1; + // Get the first aligned value + let value_first_read = input.mem_values[0]; - let mask: u64 = width_bytes << (offset * CHUNK_BITS); + // Recompute the first write value + let value_first_write = { + // Normalize the width + let width_norm = CHUNK_NUM - offset; - // Get the first width bytes of the unaligned value - let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); + let width_bytes: u64 = (1 << (width_norm * CHUNK_BITS)) - 1; - // Write zeroes to value_read from offset to offset + width - // and add the value to write to the value read - (value_first_read & !mask) | value_to_write - }; + let mask: u64 = width_bytes << (offset * CHUNK_BITS); - MemAlignResponse { more_addr: true, step, value: Some(value_first_write) } - } - // Otherwise, do the RWVRW - 1 => { - // Unaligned memory op information thrown into the bus - let step = input.step; - let value = input.value; + // Get the first width bytes of the unaligned value + let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); - // Compute the shift - let rem_bytes = (offset + width) % CHUNK_NUM; + // Write zeroes to value_read from offset to offset + width + // and add the value to write to the value read + (value_first_read & !mask) | value_to_write + }; - // Get the aligned address - let addr_first_read_write = addr >> OFFSET_BITS; - let addr_second_read_write = addr_first_read_write + 1; + // Get the second aligned value + let value_second_read = input.mem_values[1]; - // Get the first aligned value - let value_first_read = input.mem_values[0]; + // Compute the second write value + let value_second_write = { + // Normalize the width + let width_norm = CHUNK_NUM - offset; - // Recompute the first write value - let value_first_write = { - // Normalize the width - let width_norm = CHUNK_NUM - offset; + let mask: u64 = (1 << (rem_bytes * CHUNK_BITS)) - 1; - let width_bytes: u64 = (1 << (width_norm * CHUNK_BITS)) - 1; + // Get the first width bytes of the unaligned value + let value_to_write = (value >> (width_norm * CHUNK_BITS)) & mask; - let mask: u64 = width_bytes << (offset * CHUNK_BITS); + // Write zeroes to value_read from 0 to offset + width + // and add the value to write to the value read + (value_second_read & !mask) | value_to_write + }; - // Get the first width bytes of the unaligned value - let value_to_write = (value & width_bytes) << (offset * CHUNK_BITS); + // Get the next pc + let next_pc = + self.mem_align_rom_sm.calculate_next_pc(MemOp::TwoWrites, offset, width); - // Write zeroes to value_read from offset to offset + width - // and add the value to write to the value read - (value_first_read & !mask) | value_to_write - }; + // RWVWR + let mut first_read_row = MemAlignTraceRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_first_read_write), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + // wr: F::from_bool(false), + // pc: F::from_canonical_u64(0), + reset: F::from_bool(true), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; - // Get the second aligned value - let value_second_read = input.mem_values[1]; + let mut first_write_row = MemAlignTraceRow:: { + step: F::from_canonical_u64(step + 1), + addr: F::from_canonical_u32(addr_first_read_write), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + wr: F::from_bool(true), + pc: F::from_canonical_u64(next_pc), + // reset: F::from_bool(false), + sel_up_to_down: F::from_bool(true), + ..Default::default() + }; - // Compute the second write value - let value_second_write = { - // Normalize the width - let width_norm = CHUNK_NUM - offset; + let mut value_row = MemAlignTraceRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_first_read_write), + // delta_addr: F::zero(), + offset: F::from_canonical_usize(offset), + width: F::from_canonical_usize(width), + wr: F::from_bool(true), + pc: F::from_canonical_u64(next_pc + 1), + // reset: F::from_bool(false), + sel_prove: F::from_bool(true), + ..Default::default() + }; - let mask: u64 = (1 << (rem_bytes * CHUNK_BITS)) - 1; + let mut second_write_row = MemAlignTraceRow:: { + step: F::from_canonical_u64(step + 1), + addr: F::from_canonical_u32(addr_second_read_write), + delta_addr: F::one(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + wr: F::from_bool(true), + pc: F::from_canonical_u64(next_pc + 2), + // reset: F::from_bool(false), + sel_down_to_up: F::from_bool(true), + ..Default::default() + }; - // Get the first width bytes of the unaligned value - let value_to_write = (value >> (width_norm * CHUNK_BITS)) & mask; + let mut second_read_row = MemAlignTraceRow:: { + step: F::from_canonical_u64(step), + addr: F::from_canonical_u32(addr_second_read_write), + // delta_addr: F::zero(), + offset: F::from_canonical_u64(DEFAULT_OFFSET), + width: F::from_canonical_u64(DEFAULT_WIDTH), + // wr: F::from_bool(false), + pc: F::from_canonical_u64(next_pc + 3), + reset: F::from_bool(false), + sel_down_to_up: F::from_bool(true), + ..Default::default() + }; - // Write zeroes to value_read from 0 to offset + width - // and add the value to write to the value read - (value_second_read & !mask) | value_to_write - }; + for i in 0..CHUNK_NUM { + first_read_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_first_read, i, 0)); + if i < offset { + first_read_row.sel[i] = F::from_bool(true); + } - // Get the next pc - let next_pc = self.mem_align_rom_sm.calculate_next_pc( - MemOp::TwoWrites, - offset, - width, - ); + first_write_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_first_write, i, 0)); + if i >= offset { + first_write_row.sel[i] = F::from_bool(true); + } - // RWVWR - let mut first_read_row = MemAlignTraceRow:: { - step: F::from_canonical_u64(step), - addr: F::from_canonical_u32(addr_first_read_write), - // delta_addr: F::zero(), - offset: F::from_canonical_u64(DEFAULT_OFFSET), - width: F::from_canonical_u64(DEFAULT_WIDTH), - // wr: F::from_bool(false), - // pc: F::from_canonical_u64(0), - reset: F::from_bool(true), - sel_up_to_down: F::from_bool(true), - ..Default::default() - }; - - let mut first_write_row = MemAlignTraceRow:: { - step: F::from_canonical_u64(step + 1), - addr: F::from_canonical_u32(addr_first_read_write), - // delta_addr: F::zero(), - offset: F::from_canonical_u64(DEFAULT_OFFSET), - width: F::from_canonical_u64(DEFAULT_WIDTH), - wr: F::from_bool(true), - pc: F::from_canonical_u64(next_pc), - // reset: F::from_bool(false), - sel_up_to_down: F::from_bool(true), - ..Default::default() - }; - - let mut value_row = MemAlignTraceRow:: { - step: F::from_canonical_u64(step), - addr: F::from_canonical_u32(addr_first_read_write), - // delta_addr: F::zero(), - offset: F::from_canonical_usize(offset), - width: F::from_canonical_usize(width), - wr: F::from_bool(true), - pc: F::from_canonical_u64(next_pc + 1), - // reset: F::from_bool(false), - sel_prove: F::from_bool(true), - ..Default::default() - }; - - let mut second_write_row = MemAlignTraceRow:: { - step: F::from_canonical_u64(step + 1), - addr: F::from_canonical_u32(addr_second_read_write), - delta_addr: F::one(), - offset: F::from_canonical_u64(DEFAULT_OFFSET), - width: F::from_canonical_u64(DEFAULT_WIDTH), - wr: F::from_bool(true), - pc: F::from_canonical_u64(next_pc + 2), - // reset: F::from_bool(false), - sel_down_to_up: F::from_bool(true), - ..Default::default() - }; - - let mut second_read_row = MemAlignTraceRow:: { - step: F::from_canonical_u64(step), - addr: F::from_canonical_u32(addr_second_read_write), - // delta_addr: F::zero(), - offset: F::from_canonical_u64(DEFAULT_OFFSET), - width: F::from_canonical_u64(DEFAULT_WIDTH), - // wr: F::from_bool(false), - pc: F::from_canonical_u64(next_pc + 3), - reset: F::from_bool(false), - sel_down_to_up: F::from_bool(true), - ..Default::default() - }; - - for i in 0..CHUNK_NUM { - first_read_row.reg[i] = - F::from_canonical_u64(Self::get_byte(value_first_read, i, 0)); - if i < offset { - first_read_row.sel[i] = F::from_bool(true); - } - - first_write_row.reg[i] = - F::from_canonical_u64(Self::get_byte(value_first_write, i, 0)); - if i >= offset { - first_write_row.sel[i] = F::from_bool(true); - } - - value_row.reg[i] = { - if i < rem_bytes { - second_write_row.reg[i] - } else if i >= offset { - first_write_row.reg[i] - } else { - F::from_canonical_u64(Self::get_byte( - value, - i, - CHUNK_NUM - offset, - )) - } - }; - if i == offset { - value_row.sel[i] = F::from_bool(true); - } - - second_write_row.reg[i] = - F::from_canonical_u64(Self::get_byte(value_second_write, i, 0)); - if i < rem_bytes { - second_write_row.sel[i] = F::from_bool(true); - } - - second_read_row.reg[i] = - F::from_canonical_u64(Self::get_byte(value_second_read, i, 0)); - if i >= rem_bytes { - second_read_row.sel[i] = F::from_bool(true); - } + value_row.reg[i] = { + if i < rem_bytes { + second_write_row.reg[i] + } else if i >= offset { + first_write_row.reg[i] + } else { + F::from_canonical_u64(Self::get_byte(value, i, CHUNK_NUM - offset)) } + }; + if i == offset { + value_row.sel[i] = F::from_bool(true); + } - let mut _value_first_read = value_first_read; - let mut _value_first_write = value_first_write; - let mut _value = value; - let mut _value_second_write = value_second_write; - let mut _value_second_read = value_second_read; - for i in 0..RC { - first_read_row.value[i] = - F::from_canonical_u64(_value_first_read & RC_MASK); - first_write_row.value[i] = - F::from_canonical_u64(_value_first_write & RC_MASK); - value_row.value[i] = F::from_canonical_u64(_value & RC_MASK); - second_write_row.value[i] = - F::from_canonical_u64(_value_second_write & RC_MASK); - second_read_row.value[i] = - F::from_canonical_u64(_value_second_read & RC_MASK); - _value_first_read >>= RC_BITS; - _value_first_write >>= RC_BITS; - _value >>= RC_BITS; - _value_second_write >>= RC_BITS; - _value_second_read >>= RC_BITS; - } + second_write_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_second_write, i, 0)); + if i < rem_bytes { + second_write_row.sel[i] = F::from_bool(true); + } + + second_read_row.reg[i] = + F::from_canonical_u64(Self::get_byte(value_second_read, i, 0)); + if i >= rem_bytes { + second_read_row.sel[i] = F::from_bool(true); + } + } - #[rustfmt::skip] + let mut _value_first_read = value_first_read; + let mut _value_first_write = value_first_write; + let mut _value = value; + let mut _value_second_write = value_second_write; + let mut _value_second_read = value_second_read; + for i in 0..RC { + first_read_row.value[i] = F::from_canonical_u64(_value_first_read & RC_MASK); + first_write_row.value[i] = F::from_canonical_u64(_value_first_write & RC_MASK); + value_row.value[i] = F::from_canonical_u64(_value & RC_MASK); + second_write_row.value[i] = + F::from_canonical_u64(_value_second_write & RC_MASK); + second_read_row.value[i] = F::from_canonical_u64(_value_second_read & RC_MASK); + _value_first_read >>= RC_BITS; + _value_first_write >>= RC_BITS; + _value >>= RC_BITS; + _value_second_write >>= RC_BITS; + _value_second_read >>= RC_BITS; + } + + #[rustfmt::skip] debug_info!( "\nTwo Words Write\n\ Num Rows: {:?}\n\ @@ -829,22 +773,16 @@ impl MemAlignSM { ] ); - #[cfg(feature = "debug_mem_align")] - drop(num_rows); - - // Prove the generated rows - self.prove(&[ - first_read_row, - first_write_row, - value_row, - second_write_row, - second_read_row, - ]); + #[cfg(feature = "debug_mem_align")] + drop(num_rows); - MemAlignResponse { more_addr: false, step, value: Some(value_second_write) } - } - _ => panic!("Invalid phase={}", phase), - } + // Prove the generated rows + trace[index] = first_read_row; + trace[index + 1] = first_write_row; + trace[index + 2] = value_row; + trace[index + 3] = second_write_row; + trace[index + 4] = second_read_row; + 5 } } } @@ -854,67 +792,54 @@ impl MemAlignSM { (value >> (chunk * CHUNK_BITS)) & CHUNK_BITS_MASK } - pub fn prove(&self, computed_rows: &[MemAlignTraceRow]) { - if let Ok(mut rows) = self.rows.lock() { - rows.extend_from_slice(computed_rows); - - #[cfg(feature = "debug_mem_align")] - { - let mut num_rows = self.num_computed_rows.lock().unwrap(); - *num_rows += computed_rows.len(); - drop(num_rows); - } - - let num_rows = MemAlignTrace::::NUM_ROWS; + pub fn prove_instance(&self, mem_ops: &[MemAlignInput], used_rows: u32) -> AirInstance { + let mut trace = MemAlignTrace::::new(); + let mut reg_range_check = [0u64; 1 << CHUNK_BITS]; - while rows.len() >= num_rows { - let num_drained = std::cmp::min(num_rows, rows.len()); - let drained_rows = rows.drain(..num_drained).collect::>(); - - self.fill_new_air_instance(&drained_rows); - } - } - } + let num_rows = trace.num_rows(); + info!( + "{}: ยทยทยท Creating Mem Align instance [{} / {} rows filled {:.2}%]", + Self::MY_NAME, + used_rows, + num_rows, + used_rows as f64 / num_rows as f64 * 100.0 + ); - fn fill_new_air_instance(&self, rows: &[MemAlignTraceRow]) { - // Get the Mem Align AIR - let air_mem_align_rows = MemAlignTrace::::NUM_ROWS; - let rows_len = rows.len(); - - // You cannot feed to the AIR more rows than it has - debug_assert!(rows_len <= air_mem_align_rows); - - let mut trace_buffer: MemAlignTrace = MemAlignTrace::new(); - - let mut reg_range_check: Vec = vec![0; 1 << CHUNK_BITS]; - // Add the input rows to the trace - for (i, &row) in rows.iter().enumerate() { - // Store the entire row - trace_buffer[i] = row; - // Store the value of all reg columns so that they can be range checked - for j in 0..CHUNK_NUM { - let element = - row.reg[j].as_canonical_biguint().to_usize().expect("Cannot convert to usize"); - reg_range_check[element] += 1; + let mut index = 0; + for input in mem_ops.iter() { + let count = self.prove_mem_align_op(&input, &mut trace, index); + for i in 0..count { + for j in 0..CHUNK_NUM { + let element = trace[index + i].reg[j] + .as_canonical_biguint() + .to_usize() + .expect("Cannot convert to usize"); + reg_range_check[element] += 1; + } } + index += count; } - - // Pad the remaining rows with trivially satisfying rows + let padding_size = num_rows - index; let padding_row = MemAlignTraceRow:: { reset: F::from_bool(true), ..Default::default() }; - let padding_size = air_mem_align_rows - rows_len; // Store the padding rows - for i in rows_len..air_mem_align_rows { - trace_buffer[i] = padding_row; + for i in index..num_rows { + trace[i] = padding_row; } - // Store the value of all padding reg columns so that they can be range checked - for _ in 0..CHUNK_NUM { - reg_range_check[0] += padding_size as u64; - } + // Compute the program multiplicity + let mem_align_rom_sm = self.mem_align_rom_sm.clone(); + mem_align_rom_sm.update_padding_row(padding_size as u64); + + reg_range_check[0] += CHUNK_NUM as u64 * padding_size as u64; + self.update_std_range_check(®_range_check); + AirInstance::new_from_trace(FromTrace::new(&mut trace)) + } + + fn update_std_range_check(&self, reg_range_check: &[u64]) { // Perform the range checks - let std = self.std.clone(); + let std = self._std.clone(); let range_id = std.get_range(BigInt::from(0), BigInt::from(CHUNK_BITS_MASK), None); for (value, &multiplicity) in reg_range_check.iter().enumerate() { std.range_check( @@ -923,27 +848,5 @@ impl MemAlignSM { range_id, ); } - - // Compute the program multiplicity - let mem_align_rom_sm = self.mem_align_rom_sm.clone(); - mem_align_rom_sm.update_padding_row(padding_size as u64); - - info!( - "{}: ยทยทยท Creating Mem Align instance [{} / {} rows filled {:.2}%]", - Self::MY_NAME, - rows_len, - air_mem_align_rows, - rows_len as f64 / air_mem_align_rows as f64 * 100.0 - ); - - // Add a new Mem Align instance - // let air_instance = AirInstance::new( - // sctx, - // ZISK_AIRGROUP_ID, - // MEM_ALIGN_AIR_IDS[0], - // None, - // trace_buffer.buffer.unwrap(), - // ); - // pctx.air_instance_repo.add_air_instance(air_instance, None); } } diff --git a/state-machines/mem/src/mem_constants.rs b/state-machines/mem/src/mem_constants.rs index 91e73a34..7b9a5cf7 100644 --- a/state-machines/mem/src/mem_constants.rs +++ b/state-machines/mem/src/mem_constants.rs @@ -1,6 +1,7 @@ -pub const MEM_ADDR_MASK: u32 = 0xFFFF_FFF8; pub const MEM_BYTES_BITS: u32 = 3; pub const MEM_BYTES: u32 = 1 << MEM_BYTES_BITS; +pub const MEM_ADDR_ALIGN_MASK: u32 = MEM_BYTES - 1; +pub const MEM_ADDR_MASK: u32 = 0xFFFF_FFF8; pub const MEM_STEP_BASE: u64 = 1; pub const MAX_MEM_STEP_OFFSET: u64 = 2; @@ -16,9 +17,9 @@ pub const MEM_REGS_ADDR: u32 = 0xA000_0000; pub const MEM_BUS_ID: u16 = 1000; pub const MAX_MAIN_STEP: u64 = 0x1FFF_FFFF_FFFF_FFFF; -pub const MAX_MEM_STEP: u64 = MEM_STEP_BASE - + MAX_MEM_OPS_BY_MAIN_STEP * MAX_MAIN_STEP - + MAX_MEM_OPS_BY_STEP_OFFSET * MAX_MEM_STEP_OFFSET; +pub const MAX_MEM_STEP: u64 = MEM_STEP_BASE + + MAX_MEM_OPS_BY_MAIN_STEP * MAX_MAIN_STEP + + MAX_MEM_OPS_BY_STEP_OFFSET * MAX_MEM_STEP_OFFSET; pub const MAX_MEM_ADDR: u64 = 0xFFFF_FFFF; diff --git a/state-machines/mem/src/mem_counters.rs b/state-machines/mem/src/mem_counters.rs index 828b8aee..8b4376d4 100644 --- a/state-machines/mem/src/mem_counters.rs +++ b/state-machines/mem/src/mem_counters.rs @@ -3,9 +3,7 @@ use std::collections::HashMap; use sm_common::Metrics; use zisk_common::{BusDevice, BusId}; -use crate::{ - MEMORY_MAX_DIFF, MEMORY_STORE_OP, MEM_BUS_ID, MEM_BYTES_BITS, MEM_REGS_ADDR, MEM_REGS_MASK, -}; +use crate::{MemHelpers, MEM_BUS_ID, MEM_BYTES_BITS, MEM_REGS_ADDR, MEM_REGS_MASK}; use log::info; @@ -37,14 +35,6 @@ impl MemCounters { mem_align_rows: 0, } } - pub fn count_extra_internal_reads(previous_step: u64, step: u64) -> u64 { - let diff = step - previous_step; - if diff > MEMORY_MAX_DIFF { - (diff - 1) / MEMORY_MAX_DIFF - } else { - 0 - } - } } impl Metrics for MemCounters { @@ -65,8 +55,11 @@ impl Metrics for MemCounters { last_value: data[4], }; } else { - self.registers[reg_index].count += - 1 + Self::count_extra_internal_reads(self.registers[reg_index].last_step, step); + // TODO: this only applies to non-imputable memories (mem) + self.registers[reg_index].count += 1 + MemHelpers::get_extra_internal_reads( + self.registers[reg_index].last_step, + step, + ); self.registers[reg_index].last_step = step; self.registers[reg_index].last_value = data[4]; } @@ -95,9 +88,8 @@ impl Metrics for MemCounters { // TODO: use mem_align helpers // TODO: last value must be calculated as last value operation let last_value = 0; - let addr_count = - if ((addr + bytes as u32) >> MEM_BYTES_BITS) != addr_w { 2 } else { 1 }; - let ops_by_addr = if op == MEMORY_STORE_OP { 2 } else { 1 }; + let addr_count = if MemHelpers::is_double(addr, bytes) { 2 } else { 1 }; + let ops_by_addr = if MemHelpers::is_write(op) { 2 } else { 1 }; let last_step = step + ops_by_addr - 1; for index in 0..addr_count { @@ -105,7 +97,11 @@ impl Metrics for MemCounters { .entry(addr_w + index) .and_modify(|value| { value.count += ops_by_addr + - Self::count_extra_internal_reads(value.last_step, step); + MemHelpers::get_extra_internal_reads_by_addr( + addr_w + index, + value.last_step, + step, + ); value.last_step = last_step; value.last_value = last_value }) diff --git a/state-machines/mem/src/mem_helpers.rs b/state-machines/mem/src/mem_helpers.rs index b6cb8eb6..148f3a0d 100644 --- a/state-machines/mem/src/mem_helpers.rs +++ b/state-machines/mem/src/mem_helpers.rs @@ -1,8 +1,9 @@ use crate::{ - MemAlignResponse, MAX_MEM_OPS_BY_MAIN_STEP, MAX_MEM_OPS_BY_STEP_OFFSET, MEM_STEP_BASE, + MemAlignResponse, MAX_MEM_OPS_BY_MAIN_STEP, MAX_MEM_OPS_BY_STEP_OFFSET, MEMORY_MAX_DIFF, + MEMORY_STORE_OP, MEM_ADDR_ALIGN_MASK, MEM_BYTES_BITS, MEM_STEP_BASE, }; use std::fmt; -use zisk_core::ZiskRequiredMemory; +use zisk_core::{ZiskRequiredMemory, RAM_ADDR}; #[allow(dead_code)] fn format_u64_hex(value: u64) -> String { @@ -27,11 +28,10 @@ pub struct MemAlignInput { #[derive(Debug, Clone)] pub struct MemInput { - pub addr: u32, // address in word native format means byte_address / MEM_BYTES - pub is_write: bool, // it's a write operation - pub is_internal: bool, // internal operation, don't send this operation to bus - pub step: u64, // mem_step = f(main_step, main_step_offset) - pub value: u64, // value to read or write + pub addr: u32, // address in word native format means byte_address / MEM_BYTES + pub is_write: bool, // it's a write operation + pub step: u64, // mem_step = f(main_step, main_step_offset) + pub value: u64, // value to read or write } impl MemAlignInput { @@ -68,12 +68,90 @@ pub struct MemHelpers {} impl MemHelpers { pub fn main_step_to_address_step(step: u64, step_offset: u8) -> u64 { - MEM_STEP_BASE - + MAX_MEM_OPS_BY_MAIN_STEP * step - + MAX_MEM_OPS_BY_STEP_OFFSET * step_offset as u64 + MEM_STEP_BASE + + MAX_MEM_OPS_BY_MAIN_STEP * step + + MAX_MEM_OPS_BY_STEP_OFFSET * step_offset as u64 + } + pub fn is_aligned(addr: u32, width: u8) -> bool { + addr & MEM_ADDR_ALIGN_MASK == 0 && width == 8 + } + pub fn get_addr_w(addr: u32) -> u32 { + addr >> MEM_BYTES_BITS + } + #[inline(always)] + pub fn get_read_step(step: u64) -> u64 { + step + } + #[inline(always)] + pub fn get_write_step(step: u64) -> u64 { + step + 1 + } + #[inline(always)] + pub fn is_double(addr: u32, bytes: u8) -> bool { + addr & MEM_ADDR_ALIGN_MASK + bytes as u32 > 8 + } + #[inline(always)] + pub fn is_write(op: u8) -> bool { + op == MEMORY_STORE_OP + } + #[inline(always)] + pub fn get_byte_offset(addr: u32) -> u8 { + (addr & MEM_ADDR_ALIGN_MASK) as u8 + } + #[inline(always)] + pub fn step_extra_reads_enabled(addr_w: u32) -> bool { + addr_w as u64 >= RAM_ADDR + } + #[inline(always)] + pub fn get_extra_internal_reads(previous_step: u64, step: u64) -> u64 { + let diff = step - previous_step; + if diff > MEMORY_MAX_DIFF { + (diff - 1) / MEMORY_MAX_DIFF + } else { + 0 + } + } + #[inline(always)] + pub fn get_extra_internal_reads_by_addr(addr_w: u32, previous_step: u64, step: u64) -> u64 { + if Self::step_extra_reads_enabled(addr_w) { + Self::get_extra_internal_reads(previous_step, step) + } else { + 0 + } } -} + #[cfg(target_endian = "big")] + compile_error!("This code requires a little-endian machine."); + pub fn get_write_values(addr: u32, bytes: u8, value: u64, read_values: [u64; 2]) -> [u64; 2] { + let is_double = Self::is_double(addr, bytes); + let offset = Self::get_byte_offset(addr); + let value = match bytes { + 1 => value & 0xFF, + 2 => value & 0xFFFF, + 4 => value & 0xFFFF_FFFF, + 8 => value, + _ => panic!("Invalid bytes value"), + }; + let byte_mask = match bytes { + 1 => 0xFFu64, + 2 => 0xFFFFu64, + 4 => 0xFFFF_FFFFu64, + 8 => 0xFFFF_FFFF_FFFF_FFFFu64, + _ => panic!("Invalid bytes value"), + }; + + let lo_mask = !(byte_mask << offset); + let lo_write = (lo_mask & read_values[0]) | (value << offset); + if !is_double { + return [lo_write, read_values[1]] + } + + let hi_mask = !(byte_mask >> (8 - offset)); + let hi_write = (hi_mask & read_values[1]) | (value >> (8 - offset)); + + return [lo_write, hi_write] + } +} impl fmt::Debug for MemAlignResponse { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( diff --git a/state-machines/mem/src/mem_module.rs b/state-machines/mem/src/mem_module.rs index 59308fd3..cfe3a42b 100644 --- a/state-machines/mem/src/mem_module.rs +++ b/state-machines/mem/src/mem_module.rs @@ -1,9 +1,10 @@ -use crate::{MemHelpers, MemInput, MEM_BYTES}; +use crate::{MemHelpers, MemInput, MemPreviousSegment, MEM_BYTES}; +use proofman_common::AirInstance; use zisk_core::ZiskRequiredMemory; impl MemInput { - pub fn new(addr: u32, is_write: bool, step: u64, value: u64, is_internal: bool) -> Self { - MemInput { addr, is_write, step, value, is_internal } + pub fn new(addr: u32, is_write: bool, step: u64, value: u64) -> Self { + MemInput { addr, is_write, step, value } } pub fn from(mem_op: &ZiskRequiredMemory) -> Self { match mem_op { @@ -12,7 +13,6 @@ impl MemInput { MemInput { addr: address >> 3, is_write: *is_write, - is_internal: false, step: MemHelpers::main_step_to_address_step(*step, *step_offset), value: *value, } @@ -25,7 +25,12 @@ impl MemInput { } pub trait MemModule: Send + Sync { - fn send_inputs(&self, mem_op: &[MemInput]); + fn prove_instance( + &self, + mem_ops: &[MemInput], + segment_id: usize, + is_last_segment: bool, + previous_segment: &MemPreviousSegment, + ) -> AirInstance; fn get_addr_ranges(&self) -> Vec<(u32, u32)>; - fn get_flush_input_size(&self) -> u32; } diff --git a/state-machines/mem/src/mem_module_instance.rs b/state-machines/mem/src/mem_module_instance.rs new file mode 100644 index 00000000..885ba2f4 --- /dev/null +++ b/state-machines/mem/src/mem_module_instance.rs @@ -0,0 +1,174 @@ +use crate::{MemHelpers, MemInput, MemInstanceCheckPoint, MemModule, MemPreviousSegment}; +use p3_field::PrimeField; +use proofman_common::{AirInstance, ProofCtx}; +use proofman_util::{timer_start_debug, timer_stop_and_log_debug}; +use sm_common::{CheckPoint, Instance, InstanceCtx, InstanceType}; +use std::sync::Arc; +use zisk_common::{BusDevice, BusId, MemBusData}; + +pub struct MemModuleInstance { + /// Binary Basic state machine + mem_check_point: MemInstanceCheckPoint, + /// Instance context + ictx: InstanceCtx, + + /// Collected inputs + inputs: Vec, + module: Arc>, +} + +impl MemModuleInstance { + pub fn new(module: Arc>, ictx: InstanceCtx) -> Self { + let mem_check_point = ictx + .plan + .meta + .as_ref() + .unwrap() + .downcast_ref::() + .unwrap() + .clone(); + Self { ictx, inputs: Vec::new(), mem_check_point, module } + } + + fn process_unaligned_data(&mut self, data: &[u64]) { + let addr = MemBusData::get_addr(data); + let addr_w = MemHelpers::get_addr_w(addr); + let bytes = MemBusData::get_bytes(data); + let is_write = MemHelpers::is_write(MemBusData::get_op(data)); + if MemHelpers::is_double(addr, bytes) { + if is_write { + self.process_unaligned_double_write(addr_w, bytes, data); + } else { + self.process_unaligned_double_read(addr_w, data); + } + } else { + if is_write { + self.process_unaligned_single_write(addr_w, bytes, data); + } else { + self.process_unaligned_single_read(addr_w, data); + } + } + } + + fn process_unaligned_single_read(&mut self, addr_w: u32, data: &[u64]) { + let value = MemBusData::get_mem_values(data)[0]; + let step = MemBusData::get_step(data); + self.filtered_inputs_push(addr_w, step, false, value); + } + + fn process_unaligned_single_write(&mut self, addr_w: u32, bytes: u8, data: &[u64]) { + let read_values = MemBusData::get_mem_values(data); + let write_values = MemHelpers::get_write_values( + MemBusData::get_addr(data), + bytes, + MemBusData::get_value(data), + read_values, + ); + let step = MemBusData::get_step(data); + self.filtered_inputs_push(addr_w, MemHelpers::get_read_step(step), false, read_values[0]); + self.filtered_inputs_push(addr_w, MemHelpers::get_write_step(step), true, write_values[0]); + } + + fn process_unaligned_double_read(&mut self, addr_w: u32, data: &[u64]) { + let read_values = MemBusData::get_mem_values(data); + let step = MemBusData::get_step(data); + self.filtered_inputs_push(addr_w, step, false, read_values[0]); + self.filtered_inputs_push(addr_w + 1, step, true, read_values[1]); + } + + fn process_unaligned_double_write(&mut self, addr_w: u32, bytes: u8, data: &[u64]) { + let read_values = MemBusData::get_mem_values(data); + let write_values = MemHelpers::get_write_values( + MemBusData::get_addr(data), + bytes, + MemBusData::get_value(data), + read_values, + ); + let step = MemBusData::get_step(data); + let read_step = MemHelpers::get_read_step(step); + let write_step = MemHelpers::get_write_step(step); + + // IMPORTANT: inputs must be ordered by step + self.filtered_inputs_push(addr_w, read_step, false, read_values[0]); + self.filtered_inputs_push(addr_w + 1, read_step, false, read_values[1]); + + self.filtered_inputs_push(addr_w, write_step, true, write_values[0]); + self.filtered_inputs_push(addr_w + 1, write_step, true, write_values[1]); + } + + fn discart_addr_step(&self, addr: u32, step: u64) -> bool { + if addr < self.mem_check_point.prev_addr || addr > self.mem_check_point.last_addr { + return true; + } + + if addr == self.mem_check_point.prev_addr && step < self.mem_check_point.prev_step { + return true; + } + + if addr == self.mem_check_point.last_addr && step > self.mem_check_point.last_step { + return true; + } + + true + } + fn filtered_inputs_push(&mut self, addr_w: u32, step: u64, is_write: bool, value: u64) { + if !self.discart_addr_step(addr_w, step) { + self.inputs.push(MemInput::new(addr_w, is_write, step, value)); + } + } + fn prepare_inputs(&mut self) { + // sort all instance inputs + timer_start_debug!(MEM_SORT); + self.inputs.sort_by_key(|input| input.addr); + timer_stop_and_log_debug!(MEM_SORT); + } +} + +impl Instance for MemModuleInstance { + fn compute_witness(&mut self, _pctx: &ProofCtx) -> Option> { + let prev_segment = MemPreviousSegment { + addr: self.mem_check_point.prev_addr, + step: self.mem_check_point.prev_step, + value: self.mem_check_point.prev_value, + }; + + self.prepare_inputs(); + + Some(self.module.prove_instance( + &self.inputs, + 0, + self.mem_check_point.is_last_segment, + &prev_segment, + )) + } + + fn check_point(&self) -> CheckPoint { + self.ictx.plan.check_point.clone() + } + + fn instance_type(&self) -> InstanceType { + InstanceType::Instance + } +} + +impl BusDevice for MemModuleInstance { + fn process_data(&mut self, _bus_id: &BusId, data: &[u64]) -> (bool, Vec<(BusId, Vec)>) { + let addr = MemBusData::get_addr(data); + let bytes = MemBusData::get_bytes(data); + if !MemHelpers::is_aligned(addr, bytes) { + self.process_unaligned_data(data); + return (false, vec![]) + } + + let addr_w = MemHelpers::get_addr_w(addr); + let step = MemBusData::get_step(data); + let is_write = MemHelpers::is_write(MemBusData::get_op(data)); + if is_write { + self.filtered_inputs_push(addr_w, step, true, MemBusData::get_value(data)); + } else { + self.filtered_inputs_push(addr_w, step, false, MemBusData::get_mem_values(data)[0]); + } + + (false, vec![]) + } +} diff --git a/state-machines/mem/src/mem_module_planner.rs b/state-machines/mem/src/mem_module_planner.rs index 1a90f203..838e9b3d 100644 --- a/state-machines/mem/src/mem_module_planner.rs +++ b/state-machines/mem/src/mem_module_planner.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use crate::{MemCounters, MemPlanCalculator, UsesCounter, MEMORY_MAX_DIFF}; +use crate::{MemCounters, MemHelpers, MemPlanCalculator, UsesCounter, MEMORY_MAX_DIFF}; use sm_common::{CheckPoint, ChunkId, InstanceType, Plan}; pub struct MemModulePlanner<'a> { @@ -22,12 +22,15 @@ pub struct MemModulePlanner<'a> { counters: Arc>, } -#[derive(Debug, Default)] -struct MemInstanceCheckPoint { - prev_addr: u32, - skip_internal: u32, - prev_step: u64, - prev_value: u64, +#[derive(Debug, Default, Clone)] +pub struct MemInstanceCheckPoint { + pub prev_addr: u32, + pub last_addr: u32, + pub skip_internal: u32, + pub is_last_segment: bool, + pub prev_step: u64, + pub last_step: u64, + pub prev_value: u64, } impl<'a> MemModulePlanner<'a> { @@ -56,9 +59,12 @@ impl<'a> MemModulePlanner<'a> { current_checkpoint_chunks: Vec::new(), current_checkpoint: MemInstanceCheckPoint { prev_addr: from_addr, + last_addr: 0, + last_step: 0, skip_internal: 0, prev_step: 0, prev_value: 0, + is_last_segment: false, }, } } @@ -162,7 +168,10 @@ impl<'a> MemModulePlanner<'a> { // instance.add_chunk_id(chunk_id.clone()); // } - let checkpoint = std::mem::take(&mut self.current_checkpoint); + let mut checkpoint = std::mem::take(&mut self.current_checkpoint); + checkpoint.last_addr = self.last_addr; + checkpoint.last_step = self.last_step; + let chunks = std::mem::take(&mut self.current_checkpoint_chunks); let instance = Plan::new( self.airgroup_id, @@ -191,6 +200,11 @@ impl<'a> MemModulePlanner<'a> { return; } + // TODO: dynamic by addr mapping + if !MemHelpers::step_extra_reads_enabled(addr) { + return; + } + let step_diff = addr_uses.first_step - self.last_step; if step_diff <= MEMORY_MAX_DIFF { return; diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index c276148e..d9879a8d 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -1,18 +1,18 @@ use std::sync::Arc; use crate::{ - InputDataSM, MemAlignRomSM, MemAlignSM, MemCounters, MemPlanner, MemProxyEngine, MemSM, + InputDataSM, MemAlignRomSM, MemAlignSM, MemCounters, MemModuleInstance, MemPlanner, MemSM, RomDataSM, }; use p3_field::PrimeField; use pil_std_lib::Std; use sm_common::{BusDeviceInstance, BusDeviceMetrics, ComponentBuilder, InstanceCtx, Planner}; -use zisk_core::ZiskRequiredMemory; +use zisk_pil::{InputDataTrace, MemTrace, RomDataTrace}; pub struct MemProxy { // Secondary State machines mem_sm: Arc>, - mem_align_sm: Arc>, + _mem_align_sm: Arc>, _mem_align_rom_sm: Arc, input_data_sm: Arc>, rom_data_sm: Arc>, @@ -27,24 +27,13 @@ impl MemProxy { let rom_data_sm = RomDataSM::new(std.clone()); Arc::new(Self { - mem_align_sm, + _mem_align_sm: mem_align_sm, _mem_align_rom_sm: mem_align_rom_sm, mem_sm, input_data_sm, rom_data_sm, }) } - - pub fn prove( - &self, - mem_operations: &mut Vec, - ) -> Result<(), Box> { - let mut engine = MemProxyEngine::::new(self.mem_align_sm.clone()); - engine.add_module("mem", self.mem_sm.clone()); - engine.add_module("input_data", self.input_data_sm.clone()); - engine.add_module("row_data", self.rom_data_sm.clone()); - engine.prove(mem_operations) - } } impl ComponentBuilder for MemProxy { @@ -57,10 +46,29 @@ impl ComponentBuilder for MemProxy { Box::new(MemPlanner::new()) } - fn build_inputs_collector(&self, _iectx: InstanceCtx) -> Box> { - unimplemented!("get_instance for MemProxy"); + fn build_inputs_collector(&self, ictx: InstanceCtx) -> Box> { + match ictx.plan.air_id { + id if id == MemTrace::::AIR_ID => { + Box::new(MemModuleInstance::new(self.mem_sm.clone(), ictx)) + } + id if id == RomDataTrace::::AIR_ID => { + Box::new(MemModuleInstance::new(self.rom_data_sm.clone(), ictx)) + } + id if id == InputDataTrace::::AIR_ID => { + Box::new(MemModuleInstance::new(self.input_data_sm.clone(), ictx)) + } + /* id if id == ArithTableTrace::::AIR_ID => { + table_instance!(ArithTableInstance, ArithTableSM, ArithTableTrace); + Box::new(ArithTableInstance::new(self.arith_table_sm.clone(), ictx)) + } + id if id == ArithRangeTableTrace::::AIR_ID => { + table_instance!(ArithRangeTableInstance, ArithRangeTableSM, ArithRangeTableTrace); + Box::new(ArithRangeTableInstance::new(self.arith_range_table_sm.clone(), ictx)) + }*/ + _ => panic!("Memory::get_instance() Unsupported air_id: {:?}", ictx.plan.air_id), + } } fn build_inputs_generator(&self) -> Option>> { - unimplemented!("get_instance for MemProxy"); + None } } diff --git a/state-machines/mem/src/mem_proxy_engine.rs b/state-machines/mem/src/mem_proxy_engine.rs index 3ca6bf5d..66b8c4ce 100644 --- a/state-machines/mem/src/mem_proxy_engine.rs +++ b/state-machines/mem/src/mem_proxy_engine.rs @@ -163,40 +163,6 @@ impl MemProxyEngine { } } - pub fn add_module(&mut self, name: &str, module: Arc>) { - if self.modules.is_empty() { - self.current_module = String::from(name); - } - let module_id = self.modules.len() as u8; - self.modules.push(module.clone()); - - let ranges = module.get_addr_ranges(); - let flush_input_size = module.get_flush_input_size(); - - for range in ranges.iter() { - debug_info!("adding range 0x{:X} 0x{:X} to {}", range.0, range.1, name); - self.insert_address_range(range.0, range.1, module_id); - } - self.modules_data.push(MemModuleData { - name: String::from(name), - inputs: Vec::new(), - flush_input_size: if flush_input_size == 0 { - 0xFFFF_FFFF_FFFF_FFFF - } else { - flush_input_size as usize - }, - }); - } - /* insert in sort way the address map and verify that */ - fn insert_address_range(&mut self, from_addr: u32, to_addr: u32, module_id: u8) { - let region = AddressRegion { from_addr, to_addr, module_id }; - if let Some(index) = self.addr_map.iter().position(|x| x.from_addr >= from_addr) { - self.addr_map.insert(index, region); - } else { - self.addr_map.push(region); - } - } - pub fn prove( &mut self, mem_operations: &mut Vec, diff --git a/state-machines/mem/src/mem_sm.rs b/state-machines/mem/src/mem_sm.rs index ee5f12b2..031ff3ca 100644 --- a/state-machines/mem/src/mem_sm.rs +++ b/state-machines/mem/src/mem_sm.rs @@ -7,7 +7,7 @@ use pil_std_lib::Std; use proofman_common::{AirInstance, FromTrace}; use zisk_core::{RAM_ADDR, RAM_SIZE}; -use zisk_pil::{MemAirValues, MemTrace, MEM_AIR_IDS, ZISK_AIRGROUP_ID}; +use zisk_pil::{MemAirValues, MemTrace}; pub const RAM_W_ADDR_INIT: u32 = RAM_ADDR as u32 >> MEM_BYTES_BITS; pub const RAM_W_ADDR_END: u32 = (RAM_ADDR + RAM_SIZE - 1) as u32 >> MEM_BYTES_BITS; @@ -37,7 +37,7 @@ pub struct MemoryAirValues { pub segment_last_step: u64, pub segment_last_value: [u32; 2], } -#[derive(Debug)] +#[derive(Debug, Default)] pub struct MemPreviousSegment { pub addr: u32, pub step: u64, @@ -49,60 +49,15 @@ impl MemSM { pub fn new(std: Arc>) -> Arc { Arc::new(Self { std: std.clone() }) } - - pub fn prove(&self, inputs: &[MemInput]) { - // PRE: proxy calculate if exists jmp on step out-of-range, adding internal inputs - // memory only need to process these special inputs, but inputs no change. At end of - // inputs proxy add an extra internal input to jump to last address - - let air_rows = MemTrace::::NUM_ROWS; - - // at least one row to go - let count = inputs.len(); - let count_rem = count % air_rows; - let num_segments = (count / air_rows) + if count_rem > 0 { 1 } else { 0 }; - - let mut global_idxs = vec![0; num_segments]; - - #[allow(clippy::needless_range_loop)] - for i in 0..num_segments { - // TODO: Review - if let (true, global_idx) = self.std.pctx.dctx.write().unwrap().add_instance( - ZISK_AIRGROUP_ID, - MEM_AIR_IDS[0], - 1, - ) { - global_idxs[i] = global_idx; - } - } - - #[allow(clippy::needless_range_loop)] - for segment_id in 0..num_segments { - let is_last_segment = segment_id == num_segments - 1; - let input_offset = segment_id * air_rows; - let previous_segment = if (segment_id == 0) { - MemPreviousSegment { addr: RAM_W_ADDR_INIT, step: 0, value: 0 } - } else { - MemPreviousSegment { - addr: inputs[input_offset - 1].addr, - step: inputs[input_offset - 1].step, - value: inputs[input_offset - 1].value, - } - }; - let input_end = - if (input_offset + air_rows) > count { count } else { input_offset + air_rows }; - let mem_ops = &inputs[input_offset..input_end]; - - let air_instance = - self.prove_instance(mem_ops, segment_id, is_last_segment, &previous_segment); - - self.std - .pctx - .air_instance_repo - .add_air_instance(air_instance, Some(global_idxs[segment_id])); - } + pub fn get_from_addr() -> u32 { + RAM_ADDR as u32 } + pub fn get_to_addr() -> u32 { + (RAM_ADDR + RAM_SIZE - 1) as u32 + } +} +impl MemModule for MemSM { /// Finalizes the witness accumulation process and triggers the proof generation. /// /// This method is invoked by the executor when no further witness data remains to be added. @@ -110,7 +65,7 @@ impl MemSM { /// # Parameters /// /// - `mem_inputs`: A slice of all `MemoryInput` inputs - pub fn prove_instance( + fn prove_instance( &self, mem_ops: &[MemInput], segment_id: usize, @@ -157,30 +112,91 @@ impl MemSM { // index it's value - 1, for this reason no add +1 range_check_data[(previous_segment.addr - RAM_W_ADDR_INIT) as usize] += 1; // TODO - // Fill the remaining rows let mut last_addr: u32 = previous_segment.addr; let mut last_step: u64 = previous_segment.step; let mut last_value: u64 = previous_segment.value; - for (i, mem_op) in mem_ops.iter().enumerate() { - trace[i].addr = F::from_canonical_u32(mem_op.addr); - trace[i].step = F::from_canonical_u64(mem_op.step); - trace[i].sel = F::from_bool(!mem_op.is_internal); - trace[i].wr = F::from_bool(mem_op.is_write); - - let (low_val, high_val) = (mem_op.value as u32, (mem_op.value >> 32) as u32); - trace[i].value = [F::from_canonical_u32(low_val), F::from_canonical_u32(high_val)]; + let mut i = 0; + let mut increment; + let f_max_increment = F::from_canonical_u64(MEMORY_MAX_DIFF); + for mem_op in mem_ops.iter() { + let mut step = mem_op.step; + // set the common values of trace between internal reads and regular memory operation + trace[i].addr = F::from_canonical_u32(mem_op.addr); let addr_changes = last_addr != mem_op.addr; trace[i].addr_changes = if addr_changes { F::one() } else { F::zero() }; - let increment = if addr_changes { - // (mem_op.addr - last_addr + if i == 0 && segment_id == 0 { 1 } else { 0 }) as u64 - (mem_op.addr - last_addr) as u64 + if addr_changes { + increment = (mem_op.addr - last_addr) as u64; } else { - mem_op.step - last_step - }; + increment = step - last_step; + if increment > MEMORY_MAX_DIFF { + // calculate the number of internal reads + let mut internal_reads = (increment - 1) / MEMORY_MAX_DIFF; + + // check if has enough rows to complete the internal reads + regular memory + let incomplete = (i + internal_reads as usize) >= trace.num_rows; + if incomplete { + internal_reads = (trace.num_rows - i) as u64; + } + + // without address changes, the internal reads before write must use the last + // value, in the case of reads value and the last value are the same + let (low_val, high_val) = (last_value as u32, (last_value >> 32) as u32); + trace[i].value = + [F::from_canonical_u32(low_val), F::from_canonical_u32(high_val)]; + + // it's intenal + trace[i].sel = F::zero(); + + // in internal reads the increment is always the max increment + trace[i].increment = f_max_increment; + + // internal reads always must be read + trace[i].wr = F::zero(); + + // setting step + trace[i].step = F::from_canonical_u64(step); + last_step = step; + step += MEMORY_MAX_DIFF; + + i += 1; + + // the trace values of the rest of internal reads are equal to previous, only + // change the value of step + for _j in 1..internal_reads { + trace[i] = trace[i - 1]; + trace[i].step = F::from_canonical_u64(step); + last_step = step; + step += MEMORY_MAX_DIFF; + i += 1; + } + + // increase the multiplicity of internal reads + range_check_data[(MEMORY_MAX_DIFF - 1) as usize] += internal_reads as u64; + + // control the edge case when there aren't enough rows to complete the internal + // reads or regular memory operation + if incomplete { + last_addr = mem_op.addr; + break; + } + // copy last trace for the regular memory operation (addr, addr_changes) + trace[i] = trace[i - 1]; + increment -= internal_reads * MEMORY_MAX_DIFF; + } + } + + // set specific values of trace for regular memory operation + let (low_val, high_val) = (mem_op.value as u32, (mem_op.value >> 32) as u32); + trace[i].value = [F::from_canonical_u32(low_val), F::from_canonical_u32(high_val)]; + + trace[i].step = F::from_canonical_u64(step); + trace[i].sel = F::one(); trace[i].increment = F::from_canonical_u64(increment); + trace[i].wr = F::from_bool(mem_op.is_write); + i += 1; // Store the value of incremenet so it can be range checked if increment <= MEMORY_MAX_DIFF || increment == 0 { @@ -191,19 +207,21 @@ impl MemSM { } last_addr = mem_op.addr; - last_step = mem_op.step; + last_step = step; last_value = mem_op.value; + i += 1; } + let count = i; // STEP3. Add dummy rows to the output vector to fill the remaining rows // PADDING: At end of memory fill with same addr, incrementing step, same value, sel = 0, rd // = 1, wr = 0 - let last_row_idx = mem_ops.len() - 1; + let last_row_idx = count - 1; let addr = trace[last_row_idx].addr; let value = trace[last_row_idx].value; - let padding_size = trace.num_rows - mem_ops.len(); - for i in mem_ops.len()..trace.num_rows { + let padding_size = trace.num_rows - count; + for i in count..trace.num_rows { last_step += 1; trace[i].addr = addr; trace[i].step = F::from_canonical_u64(last_step); @@ -232,7 +250,7 @@ impl MemSM { // TODO: Perform the range checks let range_id = self.std.get_range(BigInt::from(1), BigInt::from(MEMORY_MAX_DIFF), None); for (value, &multiplicity) in range_check_data.iter().enumerate() { - if (multiplicity == 0) { + if multiplicity == 0 { continue; } self.std.range_check( @@ -262,22 +280,7 @@ impl MemSM { AirInstance::new_from_trace(FromTrace::new(&mut trace).with_air_values(&mut air_values)) } - pub fn get_from_addr() -> u32 { - RAM_ADDR as u32 - } - pub fn get_to_addr() -> u32 { - (RAM_ADDR + RAM_SIZE - 1) as u32 - } -} - -impl MemModule for MemSM { - fn send_inputs(&self, mem_op: &[MemInput]) { - self.prove(mem_op); - } fn get_addr_ranges(&self) -> Vec<(u32, u32)> { vec![(RAM_ADDR as u32, (RAM_ADDR + RAM_SIZE - 1) as u32)] } - fn get_flush_input_size(&self) -> u32 { - 0 - } } diff --git a/state-machines/mem/src/rom_data.rs b/state-machines/mem/src/rom_data.rs index b64d0c9e..5d2f6ec7 100644 --- a/state-machines/mem/src/rom_data.rs +++ b/state-machines/mem/src/rom_data.rs @@ -32,59 +32,18 @@ impl RomDataSM { Arc::new(Self { std: std.clone() }) } - pub fn prove(&self, inputs: &[MemInput]) { - // PRE: proxy calculate if exists jmp on step out-of-range, adding internal inputs - // memory only need to process these special inputs, but inputs no change. At end of - // inputs proxy add an extra internal input to jump to last address - - let airgroup_id = RomDataTrace::::AIRGROUP_ID; - let air_id = RomDataTrace::::AIR_ID; - let air_rows = RomDataTrace::::NUM_ROWS; - - // at least one row to go - let count = inputs.len(); - let count_rem = count % air_rows; - let num_segments = (count / air_rows) + if count_rem > 0 { 1 } else { 0 }; - - let mut global_idxs = vec![0; num_segments]; - - #[allow(clippy::needless_range_loop)] - for i in 0..num_segments { - // TODO: Review - if let (true, global_idx) = - self.std.pctx.dctx.write().unwrap().add_instance(airgroup_id, air_id, 1) - { - global_idxs[i] = global_idx; - } - } - - #[allow(clippy::needless_range_loop)] - for segment_id in 0..num_segments { - let is_last_segment = segment_id == num_segments - 1; - let input_offset = segment_id * air_rows; - let previous_segment = if (segment_id == 0) { - MemPreviousSegment { addr: ROM_DATA_W_ADDR_INIT, step: 0, value: 0 } - } else { - MemPreviousSegment { - addr: inputs[input_offset - 1].addr, - step: inputs[input_offset - 1].step, - value: inputs[input_offset - 1].value, - } - }; - let input_end = - if (input_offset + air_rows) > count { count } else { input_offset + air_rows }; - let mem_ops = &inputs[input_offset..input_end]; - - let air_instance = - self.prove_instance(mem_ops, segment_id, is_last_segment, &previous_segment); - - self.std - .pctx - .air_instance_repo - .add_air_instance(air_instance, Some(global_idxs[segment_id])); - } + fn get_u32_values(&self, value: u64) -> (u32, u32) { + (value as u32, (value >> 32) as u32) + } + pub fn get_from_addr() -> u32 { + ROM_DATA_W_ADDR_INIT } + pub fn get_to_addr() -> u32 { + ROM_DATA_W_ADDR_END + } +} +impl MemModule for RomDataSM { /// Finalizes the witness accumulation process and triggers the proof generation. /// /// This method is invoked by the executor when no further witness data remains to be added. @@ -92,7 +51,7 @@ impl RomDataSM { /// # Parameters /// /// - `mem_inputs`: A slice of all `MemoryInput` inputs - pub fn prove_instance( + fn prove_instance( &self, mem_ops: &[MemInput], segment_id: usize, @@ -108,19 +67,6 @@ impl RomDataSM { trace.num_rows() ); - // In a Mem AIR instance the first row is a dummy row used for the continuations between AIR - // segments In a Memory AIR instance, the first row is reserved as a dummy row. - // This dummy row is used to facilitate the continuation state between different AIR - // segments. It ensures seamless transitions when multiple AIR segments are - // processed consecutively. This design avoids discontinuities in memory access - // patterns and ensures that the memory trace is continuous, For this reason we use - // AIR num_rows - 1 as the number of rows in each memory AIR instance - - // Create a vector of Mem0Row instances, one for each memory operation - // Recall that first row is a dummy row used for the continuations between AIR segments - // The length of the vector is the number of input memory operations plus one because - // in the prove_witnesses method we drain the memory operations in chunks of n - 1 rows - let mut air_values_mem = MemoryAirValues { segment_id: segment_id as u32, is_first_segment: segment_id == 0, @@ -147,10 +93,38 @@ impl RomDataSM { let mut last_step: u64 = previous_segment.step; let mut last_value: u64 = previous_segment.value; - for (i, mem_op) in mem_ops.iter().enumerate() { + let mut i = 0; + for mem_op in mem_ops.iter() { + let mut internal_reads = (mem_op.addr - last_addr) - 1; + if internal_reads > 1 { + // check if has enough rows to complete the internal reads + regular memory + let incomplete = (i + internal_reads as usize) >= trace.num_rows; + if incomplete { + internal_reads = (trace.num_rows - i) as u32; + } + + trace[i].addr_changes = F::one(); + last_addr += 1; + trace[i].addr = F::from_canonical_u32(last_addr); + trace[i].value = [F::zero(), F::zero()]; + trace[i].sel = F::zero(); + i += 1; + + for _j in 1..internal_reads { + trace[i] = trace[i - 1]; + last_addr += 1; + trace[i].addr = F::from_canonical_u32(last_addr); + i += 1; + } + // the step, value of internal reads isn't relevant + trace[i].sel = F::zero(); + if incomplete { + break; + } + } trace[i].addr = F::from_canonical_u32(mem_op.addr); trace[i].step = F::from_canonical_u64(mem_op.step); - trace[i].sel = F::from_bool(!mem_op.is_internal); + trace[i].sel = F::one(); let (low_val, high_val) = self.get_u32_values(mem_op.value); trace[i].value = [F::from_canonical_u32(low_val), F::from_canonical_u32(high_val)]; @@ -162,17 +136,17 @@ impl RomDataSM { last_addr = mem_op.addr; last_step = mem_op.step; last_value = mem_op.value; + i += 1; } - + let count = i; // STEP3. Add dummy rows to the output vector to fill the remaining rows // PADDING: At end of memory fill with same addr, incrementing step, same value, sel = 0, rd // = 1, wr = 0 - let last_row_idx = mem_ops.len() - 1; + let last_row_idx = count - 1; let addr = trace[last_row_idx].addr; let value = trace[last_row_idx].value; - let padding_size = trace.num_rows() - mem_ops.len(); - for i in mem_ops.len()..trace.num_rows() { + for i in count..trace.num_rows() { last_step += 1; trace[i].addr = addr; trace[i].step = F::from_canonical_u64(last_step); @@ -215,26 +189,7 @@ impl RomDataSM { AirInstance::new_from_trace(FromTrace::new(&mut trace).with_air_values(&mut air_values)) } - fn get_u32_values(&self, value: u64) -> (u32, u32) { - (value as u32, (value >> 32) as u32) - } - pub fn get_from_addr() -> u32 { - ROM_DATA_W_ADDR_INIT - } - pub fn get_to_addr() -> u32 { - ROM_DATA_W_ADDR_END - } -} - -impl MemModule for RomDataSM { - fn send_inputs(&self, mem_op: &[MemInput]) { - self.prove(mem_op); - } fn get_addr_ranges(&self) -> Vec<(u32, u32)> { vec![(ROM_ADDR as u32, ROM_ADDR_MAX as u32)] } - fn get_flush_input_size(&self) -> u32 { - // self.num_rows as u32 - 0 - } } From 7b76461d9b2c68381f048721fdb6a9f92e8a50cf Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Thu, 9 Jan 2025 21:43:08 +0000 Subject: [PATCH 08/10] cargo fmt --- cli/src/commands/install_toolchain.rs | 8 +- core/src/elf2rom.rs | 6 +- core/src/zisk_inst.rs | 26 ++--- emulator/src/emu_options.rs | 12 +- emulator/src/stats.rs | 28 ++--- state-machines/arith/src/arith_full.rs | 16 +-- state-machines/arith/src/arith_operation.rs | 110 +++++++++--------- .../arith/src/arith_operation_test.rs | 42 +++---- .../arith/src/arith_range_table_helpers.rs | 12 +- .../binary/src/binary_basic_table.rs | 110 +++++++++--------- state-machines/binary/src/binary_extension.rs | 24 ++-- state-machines/mem/src/mem_align_rom_sm.rs | 6 +- 12 files changed, 200 insertions(+), 200 deletions(-) diff --git a/cli/src/commands/install_toolchain.rs b/cli/src/commands/install_toolchain.rs index 9ad8799a..2b8d6e7e 100644 --- a/cli/src/commands/install_toolchain.rs +++ b/cli/src/commands/install_toolchain.rs @@ -36,10 +36,10 @@ impl InstallToolchainCmd { if let Ok(entry) = entry { let entry_path = entry.path(); let entry_name = entry_path.file_name().unwrap(); - if entry_path.is_dir() - && entry_name != "bin" - && entry_name != "circuits" - && entry_name != "toolchains" + if entry_path.is_dir() && + entry_name != "bin" && + entry_name != "circuits" && + entry_name != "toolchains" { if let Err(err) = fs::remove_dir_all(&entry_path) { println!("Failed to remove directory {:?}: {}", entry_path, err); diff --git a/core/src/elf2rom.rs b/core/src/elf2rom.rs index 217b7753..0209aa27 100644 --- a/core/src/elf2rom.rs +++ b/core/src/elf2rom.rs @@ -53,9 +53,9 @@ pub fn elf2rom(elf_file: String) -> Result> { // Add init data as a read/write memory section, initialized by code // If the data is a writable memory section, add it to the ROM memory using Zisk // copy instructions - if (section_header.sh_flags & SHF_WRITE as u64) != 0 - && addr >= RAM_ADDR - && addr + data.len() as u64 <= RAM_ADDR + RAM_SIZE + if (section_header.sh_flags & SHF_WRITE as u64) != 0 && + addr >= RAM_ADDR && + addr + data.len() as u64 <= RAM_ADDR + RAM_SIZE { //println! {"elf2rom() new RW from={:x} length={:x}={}", addr, data.len(), //data.len()}; diff --git a/core/src/zisk_inst.rs b/core/src/zisk_inst.rs index 989d6d50..be6e575c 100644 --- a/core/src/zisk_inst.rs +++ b/core/src/zisk_inst.rs @@ -228,19 +228,19 @@ impl ZiskInst { /// Constructs a `flags`` bitmap made of combinations of fields of the Zisk instruction. This /// field is used by the PIL to proof some of the operations. pub fn get_flags(&self) -> u64 { - let flags: u64 = 1 - | (((self.a_src == SRC_IMM) as u64) << 1) - | (((self.a_src == SRC_MEM) as u64) << 2) - | (((self.a_src == SRC_STEP) as u64) << 3) - | (((self.b_src == SRC_IMM) as u64) << 4) - | (((self.b_src == SRC_MEM) as u64) << 5) - | ((self.is_external_op as u64) << 6) - | ((self.store_ra as u64) << 7) - | (((self.store == STORE_MEM) as u64) << 8) - | (((self.store == STORE_IND) as u64) << 9) - | ((self.set_pc as u64) << 10) - | ((self.m32 as u64) << 11) - | (((self.b_src == SRC_IND) as u64) << 12); + let flags: u64 = 1 | + (((self.a_src == SRC_IMM) as u64) << 1) | + (((self.a_src == SRC_MEM) as u64) << 2) | + (((self.a_src == SRC_STEP) as u64) << 3) | + (((self.b_src == SRC_IMM) as u64) << 4) | + (((self.b_src == SRC_MEM) as u64) << 5) | + ((self.is_external_op as u64) << 6) | + ((self.store_ra as u64) << 7) | + (((self.store == STORE_MEM) as u64) << 8) | + (((self.store == STORE_IND) as u64) << 9) | + ((self.set_pc as u64) << 10) | + ((self.m32 as u64) << 11) | + (((self.b_src == SRC_IND) as u64) << 12); flags } diff --git a/emulator/src/emu_options.rs b/emulator/src/emu_options.rs index 0703ab80..acd7bae1 100644 --- a/emulator/src/emu_options.rs +++ b/emulator/src/emu_options.rs @@ -98,11 +98,11 @@ impl fmt::Display for EmuOptions { impl EmuOptions { /// Returns true if the configuration allows to emulate in fast mode, maximizing the performance pub fn is_fast(&self) -> bool { - self.trace_steps.is_none() - && (self.print_step.is_none() || (self.print_step.unwrap() == 0)) - && self.trace.is_none() - && !self.log_step - && !self.verbose - && !self.tracerv + self.trace_steps.is_none() && + (self.print_step.is_none() || (self.print_step.unwrap() == 0)) && + self.trace.is_none() && + !self.log_step && + !self.verbose && + !self.tracerv } } diff --git a/emulator/src/stats.rs b/emulator/src/stats.rs index 80b0ffec..c6a16fee 100644 --- a/emulator/src/stats.rs +++ b/emulator/src/stats.rs @@ -185,22 +185,22 @@ impl Stats { output += &format!(" COST_STEP: {:02} sec\n", COST_STEP); // Calculate some aggregated counters to be used in the logs - let total_mem_ops = self.mops.mread_na1 - + self.mops.mread_na2 - + self.mops.mread_a - + self.mops.mwrite_na1 - + self.mops.mwrite_na2 - + self.mops.mwrite_a; - let total_mem_align_steps = self.mops.mread_na1 - + self.mops.mread_na2 * 2 - + self.mops.mwrite_na1 * 2 - + self.mops.mwrite_na2 * 4; + let total_mem_ops = self.mops.mread_na1 + + self.mops.mread_na2 + + self.mops.mread_a + + self.mops.mwrite_na1 + + self.mops.mwrite_na2 + + self.mops.mwrite_a; + let total_mem_align_steps = self.mops.mread_na1 + + self.mops.mread_na2 * 2 + + self.mops.mwrite_na1 * 2 + + self.mops.mwrite_na2 * 4; let cost_mem = total_mem_ops as f64 * COST_MEM; - let cost_mem_align = self.mops.mread_na1 as f64 * COST_MEMA_R1 - + self.mops.mread_na2 as f64 * COST_MEMA_R2 - + self.mops.mwrite_na1 as f64 * COST_MEMA_W1 - + self.mops.mwrite_na2 as f64 * COST_MEMA_W2; + let cost_mem_align = self.mops.mread_na1 as f64 * COST_MEMA_R1 + + self.mops.mread_na2 as f64 * COST_MEMA_R2 + + self.mops.mwrite_na1 as f64 * COST_MEMA_W1 + + self.mops.mwrite_na2 as f64 * COST_MEMA_W2; // Declare some total counters for the opcodes let mut total_opcodes: u64 = 0; diff --git a/state-machines/arith/src/arith_full.rs b/state-machines/arith/src/arith_full.rs index be7b71a2..8e3a2b72 100644 --- a/state-machines/arith/src/arith_full.rs +++ b/state-machines/arith/src/arith_full.rs @@ -235,14 +235,14 @@ impl ArithFullSM { step, opcode, ZiskOperationType::Binary as u64, - aop.d[0] - + CHUNK_SIZE * aop.d[1] - + CHUNK_SIZE.pow(2) * (aop.d[2] + extension.0) - + CHUNK_SIZE.pow(3) * aop.d[3], - aop.b[0] - + CHUNK_SIZE * aop.b[1] - + CHUNK_SIZE.pow(2) * (aop.b[2] + extension.1) - + CHUNK_SIZE.pow(3) * aop.b[3], + aop.d[0] + + CHUNK_SIZE * aop.d[1] + + CHUNK_SIZE.pow(2) * (aop.d[2] + extension.0) + + CHUNK_SIZE.pow(3) * aop.d[3], + aop.b[0] + + CHUNK_SIZE * aop.b[1] + + CHUNK_SIZE.pow(2) * (aop.b[2] + extension.1) + + CHUNK_SIZE.pow(3) * aop.b[3], ) .to_vec()] } else { diff --git a/state-machines/arith/src/arith_operation.rs b/state-machines/arith/src/arith_operation.rs index 8ce0274c..c6b8b48a 100644 --- a/state-machines/arith/src/arith_operation.rs +++ b/state-machines/arith/src/arith_operation.rs @@ -133,22 +133,22 @@ impl ArithOperation { self.op = op; self.input_a = input_a; self.input_b = input_b; - self.div_by_zero = input_b == 0 - && (op == ZiskOp::Div.code() - || op == ZiskOp::Rem.code() - || op == ZiskOp::DivW.code() - || op == ZiskOp::RemW.code() - || op == ZiskOp::Divu.code() - || op == ZiskOp::Remu.code() - || op == ZiskOp::DivuW.code() - || op == ZiskOp::RemuW.code()); - - self.div_overflow = ((op == ZiskOp::Div.code() || op == ZiskOp::Rem.code()) - && input_a == 0x8000_0000_0000_0000 - && input_b == 0xFFFF_FFFF_FFFF_FFFF) - || ((op == ZiskOp::DivW.code() || op == ZiskOp::RemW.code()) - && input_a == 0x8000_0000 - && input_b == 0xFFFF_FFFF); + self.div_by_zero = input_b == 0 && + (op == ZiskOp::Div.code() || + op == ZiskOp::Rem.code() || + op == ZiskOp::DivW.code() || + op == ZiskOp::RemW.code() || + op == ZiskOp::Divu.code() || + op == ZiskOp::Remu.code() || + op == ZiskOp::DivuW.code() || + op == ZiskOp::RemuW.code()); + + self.div_overflow = ((op == ZiskOp::Div.code() || op == ZiskOp::Rem.code()) && + input_a == 0x8000_0000_0000_0000 && + input_b == 0xFFFF_FFFF_FFFF_FFFF) || + ((op == ZiskOp::DivW.code() || op == ZiskOp::RemW.code()) && + input_a == 0x8000_0000 && + input_b == 0xFFFF_FFFF); let [a, b, c, d] = Self::calculate_abcd_from_ab(op, input_a, input_b); self.a = Self::u64_to_chunks(a); @@ -578,15 +578,15 @@ impl ArithOperation { assert!(range_c1 == 0 || range_c3 == 0, "range_c1:{} range_c3:{}", range_c1, range_c3); assert!(range_d1 == 0 || range_d3 == 0, "range_d1:{} range_d3:{}", range_d1, range_d3); - self.range_ab = (range_a3 + range_a1) * 3 - + range_b3 - + range_b1 - + if (range_a1 + range_b1) > 0 { 8 } else { 0 }; + self.range_ab = (range_a3 + range_a1) * 3 + + range_b3 + + range_b1 + + if (range_a1 + range_b1) > 0 { 8 } else { 0 }; - self.range_cd = (range_c3 + range_c1) * 3 - + range_d3 - + range_d1 - + if (range_c1 + range_d1) > 0 { 8 } else { 0 }; + self.range_cd = (range_c3 + range_c1) * 3 + + range_d3 + + range_d1 + + if (range_c1 + range_d1) > 0 { 8 } else { 0 }; } pub fn calculate_chunks(&self) -> [i64; 8] { @@ -614,40 +614,40 @@ impl ArithOperation { let nb_fa = nb * (1 - 2 * na); chunks[0] = fab * a[0] * b[0] // chunk0 - - c[0] - + 2 * np * c[0] - + div * d[0] - - 2 * nr * d[0]; + - c[0] + + 2 * np * c[0] + + div * d[0] - + 2 * nr * d[0]; chunks[1] = fab * a[1] * b[0] // chunk1 - + fab * a[0] * b[1] - - c[1] - + 2 * np * c[1] - + div * d[1] - - 2 * nr * d[1]; + + fab * a[0] * b[1] - + c[1] + + 2 * np * c[1] + + div * d[1] - + 2 * nr * d[1]; chunks[2] = fab * a[2] * b[0] // chunk2 + fab * a[1] * b[1] + fab * a[0] * b[2] + a[0] * nb_fa * m32 - + b[0] * na_fb * m32 - - c[2] - + 2 * np * c[2] - + div * d[2] - - 2 * nr * d[2] - - np * div * m32 - + nr * m32; // div == 0 ==> nr = 0 + + b[0] * na_fb * m32 - + c[2] + + 2 * np * c[2] + + div * d[2] - + 2 * nr * d[2] - + np * div * m32 + + nr * m32; // div == 0 ==> nr = 0 chunks[3] = fab * a[3] * b[0] // chunk3 + fab * a[2] * b[1] + fab * a[1] * b[2] + fab * a[0] * b[3] + a[1] * nb_fa * m32 - + b[1] * na_fb * m32 - - c[3] - + 2 * np * c[3] - + div * d[3] - - 2 * nr * d[3]; + + b[1] * na_fb * m32 - + c[3] + + 2 * np * c[3] + + div * d[3] - + 2 * nr * d[3]; chunks[4] = fab * a[3] * b[1] // chunk4 + fab * a[2] * b[2] @@ -671,23 +671,23 @@ impl ArithOperation { chunks[5] = fab * a[3] * b[2] // chunk5 + fab * a[2] * b[3] + a[1] * nb_fa * (1 - m32) - + b[1] * na_fb * (1 - m32) - - d[1] * (1 - div) - + d[1] * 2 * np * (1 - div); + + b[1] * na_fb * (1 - m32) - + d[1] * (1 - div) + + d[1] * 2 * np * (1 - div); chunks[6] = fab * a[3] * b[3] // chunk6 + a[2] * nb_fa * (1 - m32) - + b[2] * na_fb * (1 - m32) - - d[2] * (1 - div) - + d[2] * 2 * np * (1 - div); + + b[2] * na_fb * (1 - m32) - + d[2] * (1 - div) + + d[2] * 2 * np * (1 - div); // 0x4000_0000_0000_0000__8000_0000_0000_0000 chunks[7] = 0x10000 * na * nb * (1 - m32) // chunk7 + a[3] * nb_fa * (1 - m32) - + b[3] * na_fb * (1 - m32) - - 0x10000 * np * (1 - div) * (1 - m32) - - d[3] * (1 - div) - + d[3] * 2 * np * (1 - div); + + b[3] * na_fb * (1 - m32) - + 0x10000 * np * (1 - div) * (1 - m32) - + d[3] * (1 - div) + + d[3] * 2 * np * (1 - div); chunks } diff --git a/state-machines/arith/src/arith_operation_test.rs b/state-machines/arith/src/arith_operation_test.rs index 67d37120..37e0a4c7 100644 --- a/state-machines/arith/src/arith_operation_test.rs +++ b/state-machines/arith/src/arith_operation_test.rs @@ -101,15 +101,15 @@ impl ArithOperationTest { fn is_m32_op(op: u8) -> bool { let zisk_op = ZiskOp::try_from_code(op).unwrap(); match zisk_op { - ZiskOp::Mul - | ZiskOp::Mulh - | ZiskOp::Mulsuh - | ZiskOp::Mulu - | ZiskOp::Muluh - | ZiskOp::Divu - | ZiskOp::Remu - | ZiskOp::Div - | ZiskOp::Rem => false, + ZiskOp::Mul | + ZiskOp::Mulh | + ZiskOp::Mulsuh | + ZiskOp::Mulu | + ZiskOp::Muluh | + ZiskOp::Divu | + ZiskOp::Remu | + ZiskOp::Div | + ZiskOp::Rem => false, ZiskOp::MulW | ZiskOp::DivuW | ZiskOp::RemuW | ZiskOp::DivW | ZiskOp::RemW => true, _ => panic!("ArithOperationTest::is_m32_op() Invalid opcode={}", op), } @@ -162,26 +162,26 @@ impl ArithOperationTest { println!("{:#?}", aop); const CHUNK_SIZE: u64 = 0x10000; - let bus_a_low: u64 = aop.div as u64 * (aop.c[0] + aop.c[1] * CHUNK_SIZE) - + (1 - aop.div as u64) * (aop.a[0] + aop.a[1] * CHUNK_SIZE); - let bus_a_high: u64 = aop.div as u64 * (aop.c[2] + aop.c[3] * CHUNK_SIZE) - + (1 - aop.div as u64) * (aop.a[2] + aop.a[3] * CHUNK_SIZE); + let bus_a_low: u64 = aop.div as u64 * (aop.c[0] + aop.c[1] * CHUNK_SIZE) + + (1 - aop.div as u64) * (aop.a[0] + aop.a[1] * CHUNK_SIZE); + let bus_a_high: u64 = aop.div as u64 * (aop.c[2] + aop.c[3] * CHUNK_SIZE) + + (1 - aop.div as u64) * (aop.a[2] + aop.a[3] * CHUNK_SIZE); let bus_b_low: u64 = aop.b[0] + CHUNK_SIZE * aop.b[1]; let bus_b_high: u64 = aop.b[2] + CHUNK_SIZE * aop.b[3]; let secondary_res: u64 = if aop.main_mul || aop.main_div { 0 } else { 1 }; - let bus_res_low = secondary_res * (aop.d[0] + aop.d[1] * CHUNK_SIZE) - + aop.main_mul as u64 * (aop.c[0] + aop.c[1] * CHUNK_SIZE) - + aop.main_div as u64 * (aop.a[0] + aop.a[1] * CHUNK_SIZE); + let bus_res_low = secondary_res * (aop.d[0] + aop.d[1] * CHUNK_SIZE) + + aop.main_mul as u64 * (aop.c[0] + aop.c[1] * CHUNK_SIZE) + + aop.main_div as u64 * (aop.a[0] + aop.a[1] * CHUNK_SIZE); - let bus_res_high_64 = secondary_res * (aop.d[2] + aop.d[3] * CHUNK_SIZE) - + aop.main_mul as u64 * (aop.c[2] + aop.c[3] * CHUNK_SIZE) - + aop.main_div as u64 * (aop.a[2] + aop.a[3] * CHUNK_SIZE); + let bus_res_high_64 = secondary_res * (aop.d[2] + aop.d[3] * CHUNK_SIZE) + + aop.main_mul as u64 * (aop.c[2] + aop.c[3] * CHUNK_SIZE) + + aop.main_div as u64 * (aop.a[2] + aop.a[3] * CHUNK_SIZE); - let bus_res_high = if aop.sext && !aop.div_overflow { 0xFFFF_FFFF } else { 0 } - + (1 - aop.m32 as u64) * bus_res_high_64; + let bus_res_high = if aop.sext && !aop.div_overflow { 0xFFFF_FFFF } else { 0 } + + (1 - aop.m32 as u64) * bus_res_high_64; let expected_a_low = a & 0xFFFF_FFFF; let expected_a_high = (a >> 32) & 0xFFFF_FFFF; diff --git a/state-machines/arith/src/arith_range_table_helpers.rs b/state-machines/arith/src/arith_range_table_helpers.rs index 0685b09e..6fdd3cab 100644 --- a/state-machines/arith/src/arith_range_table_helpers.rs +++ b/state-machines/arith/src/arith_range_table_helpers.rs @@ -45,16 +45,16 @@ impl ArithRangeTableHelpers { assert!(range_index < 43); assert!(value >= if range_type == NEG { 0x8000 } else { 0 }); assert!( - value - <= match range_type { + value <= + match range_type { FULL => 0xFFFF, POS => 0x7FFF, NEG => 0xFFFF, _ => panic!("Invalid range type"), } ); - OFFSETS[range_index as usize] * 0x8000 - + if range_type == NEG { value - 0x8000 } else { value } as usize + OFFSETS[range_index as usize] * 0x8000 + + if range_type == NEG { value - 0x8000 } else { value } as usize } pub fn get_row_carry_range_check(value: i64) -> usize { assert!(value >= -0xEFFFF); @@ -158,8 +158,8 @@ impl Iterator for ArithRangeTableInputsIterator<'_> { fn next(&mut self) -> Option { if !self.iter_hash { - while self.iter_row < ROWS as u32 - && self.inputs.multiplicity[self.iter_row as usize] == 0 + while self.iter_row < ROWS as u32 && + self.inputs.multiplicity[self.iter_row as usize] == 0 { self.iter_row += 1; } diff --git a/state-machines/binary/src/binary_basic_table.rs b/state-machines/binary/src/binary_basic_table.rs index 89eea8d0..e5047c5d 100644 --- a/state-machines/binary/src/binary_basic_table.rs +++ b/state-machines/binary/src/binary_basic_table.rs @@ -100,72 +100,72 @@ impl BinaryBasicTableSM { fn opcode_has_last(opcode: BinaryBasicTableOp) -> bool { match opcode { - BinaryBasicTableOp::Minu - | BinaryBasicTableOp::Min - | BinaryBasicTableOp::Maxu - | BinaryBasicTableOp::Max - | BinaryBasicTableOp::LtAbsNP - | BinaryBasicTableOp::LtAbsPN - | BinaryBasicTableOp::Ltu - | BinaryBasicTableOp::Lt - | BinaryBasicTableOp::Gt - | BinaryBasicTableOp::Eq - | BinaryBasicTableOp::Add - | BinaryBasicTableOp::Sub - | BinaryBasicTableOp::Leu - | BinaryBasicTableOp::Le - | BinaryBasicTableOp::And - | BinaryBasicTableOp::Or - | BinaryBasicTableOp::Xor => true, + BinaryBasicTableOp::Minu | + BinaryBasicTableOp::Min | + BinaryBasicTableOp::Maxu | + BinaryBasicTableOp::Max | + BinaryBasicTableOp::LtAbsNP | + BinaryBasicTableOp::LtAbsPN | + BinaryBasicTableOp::Ltu | + BinaryBasicTableOp::Lt | + BinaryBasicTableOp::Gt | + BinaryBasicTableOp::Eq | + BinaryBasicTableOp::Add | + BinaryBasicTableOp::Sub | + BinaryBasicTableOp::Leu | + BinaryBasicTableOp::Le | + BinaryBasicTableOp::And | + BinaryBasicTableOp::Or | + BinaryBasicTableOp::Xor => true, BinaryBasicTableOp::Ext32 => false, } } fn opcode_has_cin(opcode: BinaryBasicTableOp) -> bool { match opcode { - BinaryBasicTableOp::Minu - | BinaryBasicTableOp::Min - | BinaryBasicTableOp::Maxu - | BinaryBasicTableOp::Max - | BinaryBasicTableOp::LtAbsNP - | BinaryBasicTableOp::LtAbsPN - | BinaryBasicTableOp::Ltu - | BinaryBasicTableOp::Lt - | BinaryBasicTableOp::Gt - | BinaryBasicTableOp::Eq - | BinaryBasicTableOp::Add - | BinaryBasicTableOp::Sub => true, - - BinaryBasicTableOp::Leu - | BinaryBasicTableOp::Le - | BinaryBasicTableOp::And - | BinaryBasicTableOp::Or - | BinaryBasicTableOp::Xor - | BinaryBasicTableOp::Ext32 => false, + BinaryBasicTableOp::Minu | + BinaryBasicTableOp::Min | + BinaryBasicTableOp::Maxu | + BinaryBasicTableOp::Max | + BinaryBasicTableOp::LtAbsNP | + BinaryBasicTableOp::LtAbsPN | + BinaryBasicTableOp::Ltu | + BinaryBasicTableOp::Lt | + BinaryBasicTableOp::Gt | + BinaryBasicTableOp::Eq | + BinaryBasicTableOp::Add | + BinaryBasicTableOp::Sub => true, + + BinaryBasicTableOp::Leu | + BinaryBasicTableOp::Le | + BinaryBasicTableOp::And | + BinaryBasicTableOp::Or | + BinaryBasicTableOp::Xor | + BinaryBasicTableOp::Ext32 => false, } } fn opcode_result_is_a(opcode: BinaryBasicTableOp) -> bool { match opcode { - BinaryBasicTableOp::Minu - | BinaryBasicTableOp::Min - | BinaryBasicTableOp::Maxu - | BinaryBasicTableOp::Max => true, - - BinaryBasicTableOp::LtAbsNP - | BinaryBasicTableOp::LtAbsPN - | BinaryBasicTableOp::Ltu - | BinaryBasicTableOp::Lt - | BinaryBasicTableOp::Gt - | BinaryBasicTableOp::Eq - | BinaryBasicTableOp::Add - | BinaryBasicTableOp::Sub - | BinaryBasicTableOp::Leu - | BinaryBasicTableOp::Le - | BinaryBasicTableOp::And - | BinaryBasicTableOp::Or - | BinaryBasicTableOp::Xor - | BinaryBasicTableOp::Ext32 => false, + BinaryBasicTableOp::Minu | + BinaryBasicTableOp::Min | + BinaryBasicTableOp::Maxu | + BinaryBasicTableOp::Max => true, + + BinaryBasicTableOp::LtAbsNP | + BinaryBasicTableOp::LtAbsPN | + BinaryBasicTableOp::Ltu | + BinaryBasicTableOp::Lt | + BinaryBasicTableOp::Gt | + BinaryBasicTableOp::Eq | + BinaryBasicTableOp::Add | + BinaryBasicTableOp::Sub | + BinaryBasicTableOp::Leu | + BinaryBasicTableOp::Le | + BinaryBasicTableOp::And | + BinaryBasicTableOp::Or | + BinaryBasicTableOp::Xor | + BinaryBasicTableOp::Ext32 => false, } } diff --git a/state-machines/binary/src/binary_extension.rs b/state-machines/binary/src/binary_extension.rs index c4007a4c..0b623a3f 100644 --- a/state-machines/binary/src/binary_extension.rs +++ b/state-machines/binary/src/binary_extension.rs @@ -46,12 +46,12 @@ impl BinaryExtensionSM { fn opcode_is_shift(opcode: ZiskOp) -> bool { match opcode { - ZiskOp::Sll - | ZiskOp::Srl - | ZiskOp::Sra - | ZiskOp::SllW - | ZiskOp::SrlW - | ZiskOp::SraW => true, + ZiskOp::Sll | + ZiskOp::Srl | + ZiskOp::Sra | + ZiskOp::SllW | + ZiskOp::SrlW | + ZiskOp::SraW => true, ZiskOp::SignExtendB | ZiskOp::SignExtendH | ZiskOp::SignExtendW => false, @@ -63,12 +63,12 @@ impl BinaryExtensionSM { match opcode { ZiskOp::SllW | ZiskOp::SrlW | ZiskOp::SraW => true, - ZiskOp::Sll - | ZiskOp::Srl - | ZiskOp::Sra - | ZiskOp::SignExtendB - | ZiskOp::SignExtendH - | ZiskOp::SignExtendW => false, + ZiskOp::Sll | + ZiskOp::Srl | + ZiskOp::Sra | + ZiskOp::SignExtendB | + ZiskOp::SignExtendH | + ZiskOp::SignExtendW => false, _ => panic!("BinaryExtensionSM::opcode_is_shift() got invalid opcode={:?}", opcode), } diff --git a/state-machines/mem/src/mem_align_rom_sm.rs b/state-machines/mem/src/mem_align_rom_sm.rs index b6cd71d0..cd92743e 100644 --- a/state-machines/mem/src/mem_align_rom_sm.rs +++ b/state-machines/mem/src/mem_align_rom_sm.rs @@ -48,9 +48,9 @@ impl MemAlignRomSM { ), MemOp::TwoWrites => ( - 1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] - + ONE_WORD_COMBINATIONS * OP_SIZES[1] - + TWO_WORD_COMBINATIONS * OP_SIZES[2], + 1 + ONE_WORD_COMBINATIONS * OP_SIZES[0] + + ONE_WORD_COMBINATIONS * OP_SIZES[1] + + TWO_WORD_COMBINATIONS * OP_SIZES[2], false, ), }; From 40b8a8891e96ead53542ed887f61b649ad185302 Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Thu, 9 Jan 2025 21:46:05 +0000 Subject: [PATCH 09/10] cargo clippy --- state-machines/mem/src/mem_align_sm.rs | 2 +- state-machines/mem/src/mem_helpers.rs | 4 ++-- state-machines/mem/src/mem_module_instance.rs | 8 +++----- state-machines/mem/src/mem_sm.rs | 2 +- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/state-machines/mem/src/mem_align_sm.rs b/state-machines/mem/src/mem_align_sm.rs index d6b62403..40eb1832 100644 --- a/state-machines/mem/src/mem_align_sm.rs +++ b/state-machines/mem/src/mem_align_sm.rs @@ -807,7 +807,7 @@ impl MemAlignSM { let mut index = 0; for input in mem_ops.iter() { - let count = self.prove_mem_align_op(&input, &mut trace, index); + let count = self.prove_mem_align_op(input, &mut trace, index); for i in 0..count { for j in 0..CHUNK_NUM { let element = trace[index + i].reg[j] diff --git a/state-machines/mem/src/mem_helpers.rs b/state-machines/mem/src/mem_helpers.rs index 148f3a0d..af66f7b9 100644 --- a/state-machines/mem/src/mem_helpers.rs +++ b/state-machines/mem/src/mem_helpers.rs @@ -88,7 +88,7 @@ impl MemHelpers { } #[inline(always)] pub fn is_double(addr: u32, bytes: u8) -> bool { - addr & MEM_ADDR_ALIGN_MASK + bytes as u32 > 8 + (addr & MEM_ADDR_ALIGN_MASK) + bytes as u32 > 8 } #[inline(always)] pub fn is_write(op: u8) -> bool { @@ -149,7 +149,7 @@ impl MemHelpers { let hi_mask = !(byte_mask >> (8 - offset)); let hi_write = (hi_mask & read_values[1]) | (value >> (8 - offset)); - return [lo_write, hi_write] + [lo_write, hi_write] } } impl fmt::Debug for MemAlignResponse { diff --git a/state-machines/mem/src/mem_module_instance.rs b/state-machines/mem/src/mem_module_instance.rs index 885ba2f4..8cce3a1f 100644 --- a/state-machines/mem/src/mem_module_instance.rs +++ b/state-machines/mem/src/mem_module_instance.rs @@ -41,12 +41,10 @@ impl MemModuleInstance { } else { self.process_unaligned_double_read(addr_w, data); } + } else if is_write { + self.process_unaligned_single_write(addr_w, bytes, data); } else { - if is_write { - self.process_unaligned_single_write(addr_w, bytes, data); - } else { - self.process_unaligned_single_read(addr_w, data); - } + self.process_unaligned_single_read(addr_w, data); } } diff --git a/state-machines/mem/src/mem_sm.rs b/state-machines/mem/src/mem_sm.rs index 031ff3ca..d26b73bc 100644 --- a/state-machines/mem/src/mem_sm.rs +++ b/state-machines/mem/src/mem_sm.rs @@ -174,7 +174,7 @@ impl MemModule for MemSM { } // increase the multiplicity of internal reads - range_check_data[(MEMORY_MAX_DIFF - 1) as usize] += internal_reads as u64; + range_check_data[(MEMORY_MAX_DIFF - 1) as usize] += internal_reads; // control the edge case when there aren't enough rows to complete the internal // reads or regular memory operation From 723988336f38cb86657b058d2c938144ee46d88e Mon Sep 17 00:00:00 2001 From: zkronos73 Date: Thu, 9 Jan 2025 22:55:01 +0000 Subject: [PATCH 10/10] mem_align_rom multiplicity, remove some obsolete comments --- state-machines/mem/src/input_data_sm.rs | 18 +-------- state-machines/mem/src/mem_align_rom_sm.rs | 47 +++++----------------- state-machines/mem/src/mem_proxy.rs | 35 +++++++--------- state-machines/mem/src/mem_sm.rs | 45 ++++++++++++--------- 4 files changed, 52 insertions(+), 93 deletions(-) diff --git a/state-machines/mem/src/input_data_sm.rs b/state-machines/mem/src/input_data_sm.rs index 1c7b8fdc..c1e6c4a1 100644 --- a/state-machines/mem/src/input_data_sm.rs +++ b/state-machines/mem/src/input_data_sm.rs @@ -75,23 +75,7 @@ impl MemModule for InputDataSM { trace.num_rows() ); - // In a Mem AIR instance the first row is a dummy row used for the continuations between AIR - // segments In a Memory AIR instance, the first row is reserved as a dummy row. - // This dummy row is used to facilitate the continuation state between different AIR - // segments. It ensures seamless transitions when multiple AIR segments are - // processed consecutively. This design avoids discontinuities in memory access - // patterns and ensures that the memory trace is continuous, For this reason we use - // AIR num_rows - 1 as the number of rows in each memory AIR instance - - // Create a vector of Mem0Row instances, one for each memory operation - // Recall that first row is a dummy row used for the continuations between AIR segments - // The length of the vector is the number of input memory operations plus one because - // in the prove_witnesses method we drain the memory operations in chunks of n - 1 rows - - //println! {"InputDataSM::prove_instance() mem_ops.len={} prover_buffer.len={} - // air.num_rows={}", mem_ops.len(), prover_buffer.len(), air.num_rows()}; - - let mut range_check_data: Vec = vec![0; 1 << 16]; + let mut range_check_data = Box::new([0u64; 1 << 16]); let mut air_values_mem = MemoryAirValues { segment_id: segment_id as u32, diff --git a/state-machines/mem/src/mem_align_rom_sm.rs b/state-machines/mem/src/mem_align_rom_sm.rs index cd92743e..5445ed58 100644 --- a/state-machines/mem/src/mem_align_rom_sm.rs +++ b/state-machines/mem/src/mem_align_rom_sm.rs @@ -121,6 +121,16 @@ impl MemAlignRomSM { } } + pub fn detach_multiplicity(&self) -> Vec { + let multiplicity = self.multiplicity.lock().unwrap(); + let mut multiplicity_vec = vec![0; MemAlignRomTrace::::NUM_ROWS]; + for (row_idx, multiplicity) in multiplicity.iter() { + assert!(*row_idx < MemAlignRomTrace::::NUM_ROWS as u64); + multiplicity_vec[*row_idx as usize] = *multiplicity; + } + multiplicity_vec + } + pub fn update_padding_row(&self, padding_len: u64) { // Update entry at the padding row (pos = 0) with the given padding length self.update_multiplicity_by_row_idx(0, padding_len); @@ -130,41 +140,4 @@ impl MemAlignRomSM { let mut multiplicity = self.multiplicity.lock().unwrap(); *multiplicity.entry(row_idx).or_insert(0) += mul; } - - pub fn create_air_instance(&self) { - // Get the contexts - // let wcm = self.wcm.clone(); - // let pctx = wcm.get_pctx(); - // let sctx = wcm.get_sctx(); - - // Get the Mem Align ROM AIR - // let air_mem_align_rom = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_ROM_AIR_IDS[0]); - // let air_mem_align_rom_rows = air_mem_align_rom.num_rows(); - - // let mut trace_buffer: MemAlignRomTrace<'_, _> = MemAlignRomTrace::new(); - - // // Initialize the trace buffer to zero - // for i in 0..air_mem_align_rom_rows { - // trace_buffer[i] = MemAlignRomTraceRow { multiplicity: F::zero() }; - // } - - // // Fill the trace buffer with the multiplicity values - // if let Ok(multiplicity) = self.multiplicity.lock() { - // for (row_idx, multiplicity) in multiplicity.iter() { - // trace_buffer[*row_idx as usize] = - // MemAlignRomTraceRow { multiplicity: F::from_canonical_u64(*multiplicity) }; - // } - // } - - // info!("{}: ยทยทยท Creating Mem Align Rom instance", Self::MY_NAME,); - - // let air_instance = AirInstance::new( - // sctx, - // ZISK_AIRGROUP_ID, - // MEM_ALIGN_ROM_AIR_IDS[0], - // None, - // trace_buffer.buffer.unwrap(), - // ); - // pctx.air_instance_repo.add_air_instance(air_instance, None); - } } diff --git a/state-machines/mem/src/mem_proxy.rs b/state-machines/mem/src/mem_proxy.rs index d9879a8d..87526790 100644 --- a/state-machines/mem/src/mem_proxy.rs +++ b/state-machines/mem/src/mem_proxy.rs @@ -1,19 +1,21 @@ use std::sync::Arc; use crate::{ - InputDataSM, MemAlignRomSM, MemAlignSM, MemCounters, MemModuleInstance, MemPlanner, MemSM, - RomDataSM, + InputDataSM, MemAlignInstance, MemAlignRomSM, MemAlignSM, MemCounters, MemModuleInstance, + MemPlanner, MemSM, RomDataSM, }; use p3_field::PrimeField; use pil_std_lib::Std; -use sm_common::{BusDeviceInstance, BusDeviceMetrics, ComponentBuilder, InstanceCtx, Planner}; -use zisk_pil::{InputDataTrace, MemTrace, RomDataTrace}; +use sm_common::{ + table_instance, BusDeviceInstance, BusDeviceMetrics, ComponentBuilder, InstanceCtx, Planner, +}; +use zisk_pil::{InputDataTrace, MemAlignRomTrace, MemAlignTrace, MemTrace, RomDataTrace}; pub struct MemProxy { // Secondary State machines mem_sm: Arc>, - _mem_align_sm: Arc>, - _mem_align_rom_sm: Arc, + mem_align_sm: Arc>, + mem_align_rom_sm: Arc, input_data_sm: Arc>, rom_data_sm: Arc>, } @@ -26,13 +28,7 @@ impl MemProxy { let input_data_sm = InputDataSM::new(std.clone()); let rom_data_sm = RomDataSM::new(std.clone()); - Arc::new(Self { - _mem_align_sm: mem_align_sm, - _mem_align_rom_sm: mem_align_rom_sm, - mem_sm, - input_data_sm, - rom_data_sm, - }) + Arc::new(Self { mem_align_sm, mem_align_rom_sm, mem_sm, input_data_sm, rom_data_sm }) } } @@ -57,14 +53,13 @@ impl ComponentBuilder for MemProxy { id if id == InputDataTrace::::AIR_ID => { Box::new(MemModuleInstance::new(self.input_data_sm.clone(), ictx)) } - /* id if id == ArithTableTrace::::AIR_ID => { - table_instance!(ArithTableInstance, ArithTableSM, ArithTableTrace); - Box::new(ArithTableInstance::new(self.arith_table_sm.clone(), ictx)) + id if id == MemAlignTrace::::AIR_ID => { + Box::new(MemAlignInstance::new(self.mem_align_sm.clone(), ictx)) + } + id if id == MemAlignRomTrace::::AIR_ID => { + table_instance!(MemAlignRomInstance, MemAlignRomSM, MemAlignRomTrace); + Box::new(MemAlignRomInstance::new(self.mem_align_rom_sm.clone(), ictx)) } - id if id == ArithRangeTableTrace::::AIR_ID => { - table_instance!(ArithRangeTableInstance, ArithRangeTableSM, ArithRangeTableTrace); - Box::new(ArithRangeTableInstance::new(self.arith_range_table_sm.clone(), ictx)) - }*/ _ => panic!("Memory::get_instance() Unsupported air_id: {:?}", ictx.plan.air_id), } } diff --git a/state-machines/mem/src/mem_sm.rs b/state-machines/mem/src/mem_sm.rs index d26b73bc..0e0c1ab1 100644 --- a/state-machines/mem/src/mem_sm.rs +++ b/state-machines/mem/src/mem_sm.rs @@ -81,20 +81,13 @@ impl MemModule for MemSM { trace.num_rows, ); - // In a Mem AIR instance the first row is a dummy row used for the continuations between AIR - // segments In a Memory AIR instance, the first row is reserved as a dummy row. - // This dummy row is used to facilitate the continuation state between different AIR - // segments. It ensures seamless transitions when multiple AIR segments are - // processed consecutively. This design avoids discontinuities in memory access - // patterns and ensures that the memory trace is continuous, For this reason we use - // AIR num_rows - 1 as the number of rows in each memory AIR instance + let std = self.std.clone(); + let range_id = std.get_range(BigInt::from(1), BigInt::from(MEMORY_MAX_DIFF), None); + let mut range_check_data = Box::new([0u16; MEMORY_MAX_DIFF as usize]); + let f_range_check_max_value = F::from_canonical_u64(0xFFFF + 1); - // Create a vector of Mem0Row instances, one for each memory operation - // Recall that first row is a dummy row used for the continuations between AIR segments - // The length of the vector is the number of input memory operations plus one because - // in the prove_witnesses method we drain the memory operations in chunks of n - 1 rows - - let mut range_check_data: Vec = vec![0; MEMORY_MAX_DIFF as usize]; + // use special counter for internal reads + let mut range_check_data_max = 0u64; let mut air_values_mem = MemoryAirValues { segment_id: segment_id as u32, @@ -173,8 +166,7 @@ impl MemModule for MemSM { i += 1; } - // increase the multiplicity of internal reads - range_check_data[(MEMORY_MAX_DIFF - 1) as usize] += internal_reads; + range_check_data_max += internal_reads; // control the edge case when there aren't enough rows to complete the internal // reads or regular memory operation @@ -199,8 +191,18 @@ impl MemModule for MemSM { i += 1; // Store the value of incremenet so it can be range checked - if increment <= MEMORY_MAX_DIFF || increment == 0 { - range_check_data[(increment - 1) as usize] += 1; + let range_index = increment as usize - 1; + if range_index < MEMORY_MAX_DIFF as usize { + if range_check_data[range_index] == 0xFFFF { + range_check_data[range_index] = 0; + std.range_check( + F::from_canonical_u64(increment), + f_range_check_max_value, + range_id, + ); + } else { + range_check_data[range_index] += 1; + } } else { panic!("MemSM: increment's out of range: {} i:{} addr_changes:{} mem_op.addr:0x{:X} last_addr:0x{:X} mem_op.step:{} last_step:{}", increment, i, addr_changes as u8, mem_op.addr, last_addr, mem_op.step, last_step); @@ -241,7 +243,7 @@ impl MemModule for MemSM { // Store the value of trivial increment so that they can be range checked // value = 1 => index = 0 - range_check_data[0] += padding_size as u64; + self.std.range_check(F::zero(), F::from_canonical_usize(padding_size), range_id); // no add extra +1 because index = value - 1 // RAM_W_ADDR_END - last_addr + 1 - 1 = RAM_W_ADDR_END - last_addr @@ -255,10 +257,15 @@ impl MemModule for MemSM { } self.std.range_check( F::from_canonical_usize(value + 1), - F::from_canonical_u64(multiplicity), + F::from_canonical_u16(multiplicity), range_id, ); } + self.std.range_check( + f_range_check_max_value, + F::from_canonical_u64(range_check_data_max), + range_id, + ); let mut air_values = MemAirValues::::new(); air_values.segment_id = F::from_canonical_u32(air_values_mem.segment_id);