From 85ff00c95c724a2c4846f433f30e856df0b9643e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A9ctor=20Masip=20Ardevol?= Date: Wed, 20 Nov 2024 16:52:54 +0100 Subject: [PATCH] Optimizing the binary component (#167) * Optimizing the binary * Updating the executor --- pil/src/pil_helpers/traces.rs | 2 +- state-machines/binary/pil/binary.pil | 86 ++++++++++++++--------- state-machines/binary/src/binary_basic.rs | 35 ++++++++- 3 files changed, 86 insertions(+), 37 deletions(-) diff --git a/pil/src/pil_helpers/traces.rs b/pil/src/pil_helpers/traces.rs index f3cdf5e7..da3c9ff1 100644 --- a/pil/src/pil_helpers/traces.rs +++ b/pil/src/pil_helpers/traces.rs @@ -24,7 +24,7 @@ trace!(ArithRangeTableRow, ArithRangeTableTrace { }); trace!(BinaryRow, BinaryTrace { - m_op: F, mode32: F, free_in_a: [F; 8], free_in_b: [F; 8], free_in_c: [F; 8], carry: [F; 8], use_last_carry: F, op_is_min_max: F, multiplicity: F, main_step: F, + m_op: F, mode32: F, free_in_a: [F; 8], free_in_b: [F; 8], free_in_c: [F; 8], carry: [F; 8], use_last_carry: F, op_is_min_max: F, cout: F, result_is_a: F, use_last_carry_mode32: F, use_last_carry_mode64: F, m_op_or_ext: F, free_in_a_or_c: [F; 4], free_in_b_or_zero: [F; 4], multiplicity: F, main_step: F, }); trace!(BinaryTableRow, BinaryTableTrace { diff --git a/state-machines/binary/pil/binary.pil b/state-machines/binary/pil/binary.pil index b6c48f73..d23b7d06 100644 --- a/state-machines/binary/pil/binary.pil +++ b/state-machines/binary/pil/binary.pil @@ -65,6 +65,7 @@ airtemplate Binary(const int N = 2**21, const int operation_bus_id = BINARY_ID) // Default values const int bits = 64; const int bytes = bits / 8; + const int half_bytes = bytes / 2; // Main values const int input_chunks = 2; @@ -80,51 +81,70 @@ airtemplate Binary(const int N = 2**21, const int operation_bus_id = BINARY_ID) // Secondary columns col witness use_last_carry; // 1 if the operation uses the last carry as its result - col witness op_is_min_max; // 1 if op ∈ {MINU,MIN,MAXU,MAX} + col witness op_is_min_max; // 1 if the operation is any of the MIN/MAX operations - const expr cout32 = carry[bytes/2-1]; + const expr mode64 = 1 - mode32; + const expr cout32 = carry[half_bytes-1]; const expr cout64 = carry[bytes-1]; - expr cout = (1-mode32) * (cout64 - cout32) + cout32; use_last_carry * (1 - use_last_carry) === 0; op_is_min_max * (1 - op_is_min_max) === 0; cout32*(1 - cout32) === 0; cout64*(1 - cout64) === 0; - // Constraints to check the correctness of each binary operation + // Auxiliary columns (primarily used to optimize lookups, but can be substituted with expressions) + col witness cout; + col witness result_is_a; + col witness use_last_carry_mode32; + col witness use_last_carry_mode64; + cout === mode64 * (cout64 - cout32) + cout32; + result_is_a === op_is_min_max * cout; + use_last_carry_mode32 === mode32 * use_last_carry; + use_last_carry_mode64 === mode64 * use_last_carry; + /* - opid last a b c cin cout - ─────────────────────────────────────────────────────────────── - m_op 0 a0 b0 c0 0 carry0 - m_op 0 a1 b1 c1 carry0 carry1 - m_op 0 a2 b2 c2 carry1 carry2 - m_op 0 a3 b3 c3 carry2 carry3 + 2*use_last_carry - m_op|EXT_32 0 a4|c3 b4|0 c4 carry3 carry4 - m_op|EXT_32 0 a5|c3 b5|0 c5 carry4 carry5 - m_op|EXT_32 0 a6|c3 b6|0 c6 carry5 carry6 - m_op|EXT_32 1 a7|c3 b7|0 c7 carry6 carry7 + 2*use_last_carry + Constraints to check the correctness of each binary operation + opid last a b c cin cout + flags + ───────────────────────────────────────────────────────────────------------------------------------------------- + m_op 0 a0 b0 c0 0 carry0 + 2*op_is_min_max + 4*result_is_a + m_op 0 a1 b1 c1 carry0 carry1 + 2*op_is_min_max + 4*result_is_a + m_op 0 a2 b2 c2 carry1 carry2 + 2*op_is_min_max + 4*result_is_a + m_op 0|1 a3 b3 c3 carry2 carry3 + 2*op_is_min_max + 4*result_is_a + 8*use_last_carry_mode32 + m_op|EXT_32 0 a4|c3 b4|0 c4 carry3 carry4 + 2*op_is_min_max + 4*result_is_a + m_op|EXT_32 0 a5|c3 b5|0 c5 carry4 carry5 + 2*op_is_min_max + 4*result_is_a + m_op|EXT_32 0 a6|c3 b6|0 c6 carry5 carry6 + 2*op_is_min_max + 4*result_is_a + m_op|EXT_32 0|1 a7|c3 b7|0 c7 carry6 carry7 + 2*op_is_min_max + 4*result_is_a + 8*use_last_carry_mode64 + ───────────────────────────────────────────────────────────────------------------------------------------------- + Perform, at the byte level, lookups against the binary table on inputs: + [last, m_op, a, b, cin, c, cout + flags] + where last indicates whether the byte is the last one in the operation */ - // Perform, at the byte level, lookups against the binary table on inputs: - // [last, m_op, a, b, cin, c, cout + flags] - // where last indicates whether the byte is the last one in the operation - - lookup_assumes(BINARY_TABLE_ID, [0, m_op, free_in_a[0], free_in_b[0], 0, free_in_c[0], carry[0] + 2*op_is_min_max + 4*op_is_min_max*cout]); + lookup_assumes(BINARY_TABLE_ID, [0, m_op, free_in_a[0], free_in_b[0], 0, free_in_c[0], carry[0] + 2*op_is_min_max + 4*result_is_a]); - expr _m_op = (1-mode32) * (m_op - EXT_32_OP) + EXT_32_OP; + // More auxiliary columns + col witness m_op_or_ext; + col witness free_in_a_or_c[half_bytes]; + col witness free_in_b_or_zero[half_bytes]; + m_op_or_ext === mode64 * (m_op - EXT_32_OP) + EXT_32_OP; + int j = 0; for (int i = 1; i < bytes; i++) { - expr _free_in_a = (1-mode32) * (free_in_a[i] - free_in_c[bytes/2-1]) + free_in_c[bytes/2-1]; - expr _free_in_b = (1-mode32) * free_in_b[i]; - - if (i < bytes/2 - 1) { - lookup_assumes(BINARY_TABLE_ID, [0, m_op, free_in_a[i], free_in_b[i], carry[i-1], free_in_c[i], carry[i] + 2*op_is_min_max + 4*op_is_min_max*cout]); - } else if (i == bytes/2 - 1) { - lookup_assumes(BINARY_TABLE_ID, [mode32, m_op, free_in_a[i], free_in_b[i], carry[i-1], free_in_c[i], cout32 + 2*op_is_min_max + 4*op_is_min_max*cout + 8*use_last_carry*mode32]); - } else if (i < bytes - 1) { - lookup_assumes(BINARY_TABLE_ID, [0, _m_op, _free_in_a, _free_in_b, carry[i-1], free_in_c[i], carry[i] + 2*op_is_min_max + 4*op_is_min_max*cout]); - } else { - lookup_assumes(BINARY_TABLE_ID, [1-mode32, _m_op, _free_in_a, _free_in_b, carry[i-1], free_in_c[i], cout64 + 2*op_is_min_max + 4*op_is_min_max*cout + 8*use_last_carry*(1-mode32)]); - } + if (i >= half_bytes) { + free_in_a_or_c[j] === mode64 * (free_in_a[i] - free_in_c[half_bytes-1]) + free_in_c[half_bytes-1]; + free_in_b_or_zero[j] === mode64 * free_in_b[i]; + } + + if (i < half_bytes - 1) { + lookup_assumes(BINARY_TABLE_ID, [0, m_op, free_in_a[i], free_in_b[i], carry[i-1], free_in_c[i], carry[i] + 2*op_is_min_max + 4*result_is_a]); + } else if (i == half_bytes - 1) { + lookup_assumes(BINARY_TABLE_ID, [mode32, m_op, free_in_a[i], free_in_b[i], carry[i-1], free_in_c[i], cout32 + 2*op_is_min_max + 4*result_is_a + 8*use_last_carry_mode32]); + } else if (i < bytes - 1) { + lookup_assumes(BINARY_TABLE_ID, [0, m_op_or_ext, free_in_a_or_c[j], free_in_b_or_zero[j], carry[i-1], free_in_c[i], carry[i] + 2*op_is_min_max + 4*result_is_a]); + j++; + } else { + lookup_assumes(BINARY_TABLE_ID, [mode64, m_op_or_ext, free_in_a_or_c[j], free_in_b_or_zero[j], carry[i-1], free_in_c[i], cout64 + 2*op_is_min_max + 4*result_is_a + 8*use_last_carry_mode64]); + j++; + } } // Constraints to make sure that this component is called from the main component @@ -164,5 +184,5 @@ airtemplate Binary(const int N = 2**21, const int operation_bus_id = BINARY_ID) col witness multiplicity; col witness main_step; - lookup_proves(OPERATION_BUS_ID, [main_step, op, ...a, ...b, ...c, (1-op_is_min_max)*cout], multiplicity); + lookup_proves(OPERATION_BUS_ID, [main_step, op, ...a, ...b, ...c, cout - result_is_a], multiplicity); } \ No newline at end of file diff --git a/state-machines/binary/src/binary_basic.rs b/state-machines/binary/src/binary_basic.rs index 765ed11f..40cb3698 100644 --- a/state-machines/binary/src/binary_basic.rs +++ b/state-machines/binary/src/binary_basic.rs @@ -16,6 +16,9 @@ use zisk_pil::*; use crate::{BinaryBasicTableOp, BinaryBasicTableSM}; +const BYTES: usize = 8; +const HALF_BYTES: usize = BYTES / 2; + pub struct BinaryBasicSM { wcm: Arc>, @@ -158,6 +161,7 @@ impl BinaryBasicSM { let opcode = ZiskOp::try_from_code(operation.opcode).expect("Invalid ZiskOp opcode"); let mode32 = Self::opcode_is_32_bits(opcode); row.mode32 = F::from_bool(mode32); + let mode64 = F::from_bool(!mode32); // Set c_filtered let c_filtered = if mode32 { c & 0xFFFFFFFF } else { c }; @@ -667,6 +671,33 @@ impl BinaryBasicSM { _ => panic!("BinaryBasicSM::process_slice() found invalid opcode={}", operation.opcode), } + // Set cout + let cout32 = row.carry[HALF_BYTES - 1]; + let cout64 = row.carry[BYTES - 1]; + row.cout = mode64 * (cout64 - cout32) + cout32; + + // Set result_is_a + row.result_is_a = row.op_is_min_max * row.cout; + + // Set use_last_carry_mode32 and use_last_carry_mode64 + row.use_last_carry_mode32 = F::from_bool(mode32) * row.use_last_carry; + row.use_last_carry_mode64 = mode64 * row.use_last_carry; + + // Set micro opcode + row.m_op = F::from_canonical_u8(binary_basic_table_op as u8); + + // Set m_op_or_ext + let ext_32_op = F::from_canonical_u8(BinaryBasicTableOp::Ext32 as u8); + row.m_op_or_ext = mode64 * (row.m_op - ext_32_op) + ext_32_op; + + // Set free_in_a_or_c and free_in_b_or_zero + for i in 0..HALF_BYTES { + row.free_in_a_or_c[i] = mode64 * + (row.free_in_a[i + HALF_BYTES] - row.free_in_c[HALF_BYTES - 1]) + + row.free_in_c[HALF_BYTES - 1]; + row.free_in_b_or_zero[i] = mode64 * row.free_in_b[i + HALF_BYTES]; + } + if row.use_last_carry == F::one() { // Set first and last elements row.free_in_c[7] = row.free_in_c[0]; @@ -676,9 +707,6 @@ impl BinaryBasicSM { // TODO: Find duplicates of this trace and reuse them by increasing their multiplicity. row.multiplicity = F::one(); - // Set micro opcode - row.m_op = F::from_canonical_u8(binary_basic_table_op as u8); - // Return row } @@ -732,6 +760,7 @@ impl BinaryBasicSM { timer_start_trace!(BINARY_PADDING); let padding_row = BinaryRow:: { m_op: F::from_canonical_u8(0x20), + m_op_or_ext: F::from_canonical_u8(0x20), multiplicity: F::zero(), main_step: F::zero(), /* TODO: remove, since main_step is just for * debugging */