Skip to content

Commit

Permalink
Merge pull request #195 from 0xPolygonHermez/fix/mem
Browse files Browse the repository at this point in the history
Bugs in memory and binary extension
  • Loading branch information
hecmas authored Dec 19, 2024
1 parent 3fc9061 commit ee6d994
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 37 deletions.
1 change: 0 additions & 1 deletion state-machines/arith/src/arith_full.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,6 @@ impl<F: PrimeField> ArithFullSM<F> {

if !binary_inputs.is_empty() {
timer_start_trace!(ARITH_BINARY);
info!("{}: ··· calling binary_sm", Self::MY_NAME);
self.binary_sm.prove(binary_inputs.as_slice(), false);
timer_stop_and_log_trace!(ARITH_BINARY);
}
Expand Down
2 changes: 1 addition & 1 deletion state-machines/binary/src/binary_basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ impl<F: Field> BinaryBasicSM<F> {
assert!(operations.len() <= air.num_rows());

info!(
"{}: ··· Creating Binary basic instance [{} / {} rows filled {:.2}%]",
"{}: ··· Creating Binary instance [{} / {} rows filled {:.2}%]",
Self::MY_NAME,
operations.len(),
air.num_rows(),
Expand Down
28 changes: 20 additions & 8 deletions state-machines/binary/src/binary_basic_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::sync::{
Arc, Mutex,
};

use log::info;
use p3_field::Field;
use proofman::{WitnessComponent, WitnessManager};
use proofman_common::AirInstance;
Expand Down Expand Up @@ -261,19 +260,32 @@ impl<F: Field> BinaryBasicTableSM<F> {

if is_myne {
// Create the prover buffer
let trace: BinaryTableTrace<'_, _> = BinaryTableTrace::new(self.num_rows);
let num_rows = self.num_rows;
let trace: BinaryTableTrace<'_, _> = BinaryTableTrace::new(num_rows);
let mut prover_buffer = trace.buffer.unwrap();

prover_buffer[0..self.num_rows]
let non_zero_multiplicities = prover_buffer[0..num_rows]
.par_iter_mut()
.enumerate()
.for_each(|(i, input)| *input = F::from_canonical_u64(multiplicity_[i]));

info!(
"{}: ··· Creating Binary basic table instance [{} rows filled 100%]",
.map(|(i, input)| {
*input = F::from_canonical_u64(multiplicity_[i]);
if multiplicity_[i] != 0 {
Some(1)
} else {
None
}
})
.filter_map(|x| x)
.sum::<usize>();

log::info!(
"{}: ··· Creating Binary Table instance [{} / {} rows used {:.2}%]",
Self::MY_NAME,
self.num_rows,
non_zero_multiplicities,
num_rows,
non_zero_multiplicities as f64 / num_rows as f64 * 100.0
);

let air_instance = AirInstance::new(
self.wcm.get_sctx(),
ZISK_AIRGROUP_ID,
Expand Down
4 changes: 2 additions & 2 deletions state-machines/binary/src/binary_extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ impl<F: PrimeField> BinaryExtensionSM<F> {
if ((a_bytes[j] as u64) & SIGN_BYTE) != 0 {
out = (a_bytes[j] as u64) << 8 | SE_MASK_16;
} else {
out = a_bytes[j] as u64;
out = (a_bytes[j] as u64) << 8;
}
} else {
out = 0;
Expand Down Expand Up @@ -391,7 +391,7 @@ impl<F: PrimeField> BinaryExtensionSM<F> {
assert!(operations.len() <= air.num_rows());

info!(
"{}: ··· Creating Binary extension instance [{} / {} rows filled {:.2}%]",
"{}: ··· Creating Binary Extension instance [{} / {} rows filled {:.2}%]",
Self::MY_NAME,
operations.len(),
air.num_rows(),
Expand Down
28 changes: 19 additions & 9 deletions state-machines/binary/src/binary_extension_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::sync::{
Arc, Mutex,
};

use log::info;
use p3_field::Field;
use proofman::{WitnessComponent, WitnessManager};
use proofman_common::AirInstance;
Expand Down Expand Up @@ -136,19 +135,30 @@ impl<F: Field> BinaryExtensionTableSM<F> {
dctx.distribute_multiplicity(&mut multiplicity_, owner);

if is_myne {
let trace: BinaryExtensionTableTrace<'_, _> =
BinaryExtensionTableTrace::new(self.num_rows);
let num_rows = self.num_rows;
let trace: BinaryExtensionTableTrace<'_, _> = BinaryExtensionTableTrace::new(num_rows);
let mut prover_buffer = trace.buffer.unwrap();

prover_buffer[0..self.num_rows]
let non_zero_multiplicities = prover_buffer[0..num_rows]
.par_iter_mut()
.enumerate()
.for_each(|(i, input)| *input = F::from_canonical_u64(multiplicity_[i]));

info!(
"{}: ··· Creating Binary extension table instance [{} rows filled 100%]",
.map(|(i, input)| {
*input = F::from_canonical_u64(multiplicity_[i]);
if multiplicity_[i] != 0 {
Some(1)
} else {
None
}
})
.filter_map(|x| x)
.sum::<usize>();

log::info!(
"{}: ··· Creating Binary Extension Table instance [{} / {} rows used {:.2}%]",
Self::MY_NAME,
self.num_rows,
non_zero_multiplicities,
num_rows,
non_zero_multiplicities as f64 / num_rows as f64 * 100.0
);

let air_instance = AirInstance::new(
Expand Down
13 changes: 9 additions & 4 deletions state-machines/mem/src/mem_align_rom_sm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use std::{
},
};

use log::info;
use p3_field::PrimeField;
use proofman::{WitnessComponent, WitnessManager};
use proofman_common::AirInstance;
Expand Down Expand Up @@ -43,7 +42,7 @@ pub enum ExtensionTableSMErr {
}

impl<F: PrimeField> MemAlignRomSM<F> {
const MY_NAME: &'static str = "MemAlignRom";
const MY_NAME: &'static str = "MemAlROM";

pub fn new(wcm: Arc<WitnessManager<F>>) -> Arc<Self> {
let pctx = wcm.get_pctx();
Expand Down Expand Up @@ -196,9 +195,15 @@ impl<F: PrimeField> MemAlignRomSM<F> {
trace_buffer[*row_idx as usize] =
MemAlignRomRow { multiplicity: F::from_canonical_u64(*multiplicity) };
}
}

info!("{}: ··· Creating Mem Align Rom instance", Self::MY_NAME,);
log::info!(
"{}: ··· Creating Mem Align ROM instance [{} / {} rows executed {:.2}%]",
Self::MY_NAME,
multiplicity.len(),
air_mem_align_rom_rows,
multiplicity.len() as f64 / air_mem_align_rom_rows as f64 * 100.0
);
}

let air_instance = AirInstance::new(
sctx,
Expand Down
14 changes: 12 additions & 2 deletions state-machines/mem/src/mem_align_sm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,7 @@ impl<F: PrimeField> MemAlignSM<F> {

pub fn prove(&self, computed_rows: &[MemAlignRow<F>]) {
if let Ok(mut rows) = self.rows.lock() {
let previous_num_rows = rows.len();
rows.extend_from_slice(computed_rows);

#[cfg(feature = "debug_mem_align")]
Expand All @@ -924,8 +925,17 @@ impl<F: PrimeField> MemAlignSM<F> {
let air_mem_align = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_ALIGN_AIR_IDS[0]);

while rows.len() >= air_mem_align.num_rows() {
let num_drained = std::cmp::min(air_mem_align.num_rows(), rows.len());
let drained_rows = rows.drain(..num_drained).collect::<Vec<_>>();
// Find the correct cutting point
let cutting_point =
if previous_num_rows + computed_rows.len() == air_mem_align.num_rows() {
air_mem_align.num_rows()
} else {
// This is the case where previous_num_rows + computed_rows.len() >
// air_mem_align.num_rows() In this case, we prove
// computed_rows in the next air instance
previous_num_rows
};
let drained_rows = rows.drain(..cutting_point).collect::<Vec<_>>();

self.fill_new_air_instance(&drained_rows);
}
Expand Down
13 changes: 8 additions & 5 deletions state-machines/mem/src/mem_proxy_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ use crate::{
MAX_MAIN_STEP, MAX_MEM_ADDR, MAX_MEM_OPS_BY_MAIN_STEP, MAX_MEM_STEP, MAX_MEM_STEP_OFFSET,
MEMORY_MAX_DIFF, MEM_ADDR_MASK, MEM_BYTES, MEM_BYTES_BITS,
};
use log::info;

use p3_field::PrimeField;
use proofman_util::{timer_start_debug, timer_stop_and_log_debug};
Expand All @@ -104,7 +103,7 @@ macro_rules! debug_info {
($prefix:expr, $($arg:tt)*) => {
#[cfg(feature = "debug_mem_proxy_engine")]
{
info!(concat!("MemProxy: ",$prefix), $($arg)*);
log::info!(concat!("MemProxy: ",$prefix), $($arg)*);
}
};
}
Expand Down Expand Up @@ -133,6 +132,7 @@ pub struct MemProxyEngine<F: PrimeField> {
mem_align_sm: Arc<MemAlignSM<F>>,
next_open_addr: u32,
next_open_step: u64,
last_value: u64,
last_addr: u32,
last_step: u64,
intermediate_cases: u32,
Expand All @@ -156,6 +156,7 @@ impl<F: PrimeField> MemProxyEngine<F> {
mem_align_sm,
next_open_addr: NO_OPEN_ADDR,
next_open_step: NO_OPEN_STEP,
last_value: 0,
last_addr: 0xFFFF_FFFF,
last_step: 0,
intermediate_cases: 0,
Expand Down Expand Up @@ -370,11 +371,12 @@ impl<F: PrimeField> MemProxyEngine<F> {

// check if step difference is too large
if self.last_addr == w_addr && (step - self.last_step) > MEMORY_MAX_DIFF {
self.push_intermediate_internal_reads(w_addr, value, self.last_step, step);
self.push_intermediate_internal_reads(w_addr, self.last_value, self.last_step, step);
}

self.last_step = step;
self.last_addr = w_addr;
self.last_value = value;

let mem_op = MemInput { step, is_write, is_internal: false, addr: w_addr, value };
debug_info!(
Expand Down Expand Up @@ -542,9 +544,10 @@ impl<F: PrimeField> MemProxyEngine<F> {
);
module.send_inputs(&self.modules_data[module_id].inputs);
}
info!(
debug_info!(
"MemProxy: ··· Intermediate reads [cases:{} steps:{}]",
self.intermediate_cases, self.intermediate_steps
self.intermediate_cases,
self.intermediate_steps
);
}
/// Fetches the address map, defining and calculating all necessary structures to manage the
Expand Down
2 changes: 0 additions & 2 deletions state-machines/mem/src/mem_sm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ pub struct MemPreviousSegment {
#[allow(unused, unused_variables)]
impl<F: PrimeField> MemSM<F> {
pub fn new(wcm: Arc<WitnessManager<F>>, std: Arc<Std<F>>) -> Arc<Self> {
let pctx = wcm.get_pctx();
let air = pctx.pilout.get_air(ZISK_AIRGROUP_ID, MEM_AIR_IDS[0]);
let mem_sm =
Self { wcm: wcm.clone(), std: std.clone(), registered_predecessors: AtomicU32::new(0) };
let mem_sm = Arc::new(mem_sm);
Expand Down
16 changes: 13 additions & 3 deletions state-machines/rom/src/rom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ pub struct RomSM<F> {
}

impl<F: Field> RomSM<F> {
const MY_NAME: &'static str = "ROM ";

pub fn new(wcm: Arc<WitnessManager<F>>) -> Arc<Self> {
let rom_sm = Self { wcm: wcm.clone() };
let rom_sm = Arc::new(rom_sm);
Expand All @@ -42,9 +44,9 @@ impl<F: Field> RomSM<F> {

// Create an empty ROM trace
let pilout = Pilout::pilout();
let num_rows = pilout.get_air(ZISK_AIRGROUP_ID, ROM_AIR_IDS[0]).num_rows();
let rom_trace_len = pilout.get_air(ZISK_AIRGROUP_ID, ROM_AIR_IDS[0]).num_rows();

let mut rom_trace = RomTrace::new(num_rows);
let mut rom_trace = RomTrace::new(rom_trace_len);

// For every instruction in the rom, fill its corresponding ROM trace
let main_trace_len = pilout.get_air(ZISK_AIRGROUP_ID, MAIN_AIR_IDS[0]).num_rows() as u64;
Expand Down Expand Up @@ -73,10 +75,18 @@ impl<F: Field> RomSM<F> {
}

// Padd with zeroes
for i in rom.insts.len()..num_rows {
for i in rom.insts.len()..rom_trace_len {
rom_trace[i] = RomRow::default();
}

log::info!(
"{}: ··· Creating ROM instance [{} / {} rows executed {:.2}%]",
Self::MY_NAME,
pc_histogram.map.len(),
rom_trace_len,
pc_histogram.map.len() as f64 / rom_trace_len as f64 * 100.0
);

let mut air_instance = AirInstance::new(
sctx.clone(),
ZISK_AIRGROUP_ID,
Expand Down

0 comments on commit ee6d994

Please sign in to comment.