Skip to content

Commit

Permalink
Optimizing the binary component (#167)
Browse files Browse the repository at this point in the history
* Optimizing the binary

* Updating the executor
  • Loading branch information
hecmas authored and RogerTaule committed Nov 20, 2024
1 parent 0c00e2f commit 85ff00c
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 37 deletions.
2 changes: 1 addition & 1 deletion pil/src/pil_helpers/traces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ trace!(ArithRangeTableRow, ArithRangeTableTrace<F> {
});

trace!(BinaryRow, BinaryTrace<F> {
m_op: F, mode32: F, free_in_a: [F; 8], free_in_b: [F; 8], free_in_c: [F; 8], carry: [F; 8], use_last_carry: F, op_is_min_max: F, multiplicity: F, main_step: F,
m_op: F, mode32: F, free_in_a: [F; 8], free_in_b: [F; 8], free_in_c: [F; 8], carry: [F; 8], use_last_carry: F, op_is_min_max: F, cout: F, result_is_a: F, use_last_carry_mode32: F, use_last_carry_mode64: F, m_op_or_ext: F, free_in_a_or_c: [F; 4], free_in_b_or_zero: [F; 4], multiplicity: F, main_step: F,
});

trace!(BinaryTableRow, BinaryTableTrace<F> {
Expand Down
86 changes: 53 additions & 33 deletions state-machines/binary/pil/binary.pil
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ airtemplate Binary(const int N = 2**21, const int operation_bus_id = BINARY_ID)
// Default values
const int bits = 64;
const int bytes = bits / 8;
const int half_bytes = bytes / 2;

// Main values
const int input_chunks = 2;
Expand All @@ -80,51 +81,70 @@ airtemplate Binary(const int N = 2**21, const int operation_bus_id = BINARY_ID)

// Secondary columns
col witness use_last_carry; // 1 if the operation uses the last carry as its result
col witness op_is_min_max; // 1 if op ∈ {MINU,MIN,MAXU,MAX}
col witness op_is_min_max; // 1 if the operation is any of the MIN/MAX operations

const expr cout32 = carry[bytes/2-1];
const expr mode64 = 1 - mode32;
const expr cout32 = carry[half_bytes-1];
const expr cout64 = carry[bytes-1];
expr cout = (1-mode32) * (cout64 - cout32) + cout32;

use_last_carry * (1 - use_last_carry) === 0;
op_is_min_max * (1 - op_is_min_max) === 0;
cout32*(1 - cout32) === 0;
cout64*(1 - cout64) === 0;

// Constraints to check the correctness of each binary operation
// Auxiliary columns (primarily used to optimize lookups, but can be substituted with expressions)
col witness cout;
col witness result_is_a;
col witness use_last_carry_mode32;
col witness use_last_carry_mode64;
cout === mode64 * (cout64 - cout32) + cout32;
result_is_a === op_is_min_max * cout;
use_last_carry_mode32 === mode32 * use_last_carry;
use_last_carry_mode64 === mode64 * use_last_carry;

/*
opid last a b c cin cout
───────────────────────────────────────────────────────────────
m_op 0 a0 b0 c0 0 carry0
m_op 0 a1 b1 c1 carry0 carry1
m_op 0 a2 b2 c2 carry1 carry2
m_op 0 a3 b3 c3 carry2 carry3 + 2*use_last_carry
m_op|EXT_32 0 a4|c3 b4|0 c4 carry3 carry4
m_op|EXT_32 0 a5|c3 b5|0 c5 carry4 carry5
m_op|EXT_32 0 a6|c3 b6|0 c6 carry5 carry6
m_op|EXT_32 1 a7|c3 b7|0 c7 carry6 carry7 + 2*use_last_carry
Constraints to check the correctness of each binary operation
opid last a b c cin cout + flags
───────────────────────────────────────────────────────────────-------------------------------------------------
m_op 0 a0 b0 c0 0 carry0 + 2*op_is_min_max + 4*result_is_a
m_op 0 a1 b1 c1 carry0 carry1 + 2*op_is_min_max + 4*result_is_a
m_op 0 a2 b2 c2 carry1 carry2 + 2*op_is_min_max + 4*result_is_a
m_op 0|1 a3 b3 c3 carry2 carry3 + 2*op_is_min_max + 4*result_is_a + 8*use_last_carry_mode32
m_op|EXT_32 0 a4|c3 b4|0 c4 carry3 carry4 + 2*op_is_min_max + 4*result_is_a
m_op|EXT_32 0 a5|c3 b5|0 c5 carry4 carry5 + 2*op_is_min_max + 4*result_is_a
m_op|EXT_32 0 a6|c3 b6|0 c6 carry5 carry6 + 2*op_is_min_max + 4*result_is_a
m_op|EXT_32 0|1 a7|c3 b7|0 c7 carry6 carry7 + 2*op_is_min_max + 4*result_is_a + 8*use_last_carry_mode64
───────────────────────────────────────────────────────────────-------------------------------------------------
Perform, at the byte level, lookups against the binary table on inputs:
[last, m_op, a, b, cin, c, cout + flags]
where last indicates whether the byte is the last one in the operation
*/

// Perform, at the byte level, lookups against the binary table on inputs:
// [last, m_op, a, b, cin, c, cout + flags]
// where last indicates whether the byte is the last one in the operation

lookup_assumes(BINARY_TABLE_ID, [0, m_op, free_in_a[0], free_in_b[0], 0, free_in_c[0], carry[0] + 2*op_is_min_max + 4*op_is_min_max*cout]);
lookup_assumes(BINARY_TABLE_ID, [0, m_op, free_in_a[0], free_in_b[0], 0, free_in_c[0], carry[0] + 2*op_is_min_max + 4*result_is_a]);

expr _m_op = (1-mode32) * (m_op - EXT_32_OP) + EXT_32_OP;
// More auxiliary columns
col witness m_op_or_ext;
col witness free_in_a_or_c[half_bytes];
col witness free_in_b_or_zero[half_bytes];
m_op_or_ext === mode64 * (m_op - EXT_32_OP) + EXT_32_OP;
int j = 0;
for (int i = 1; i < bytes; i++) {
expr _free_in_a = (1-mode32) * (free_in_a[i] - free_in_c[bytes/2-1]) + free_in_c[bytes/2-1];
expr _free_in_b = (1-mode32) * free_in_b[i];

if (i < bytes/2 - 1) {
lookup_assumes(BINARY_TABLE_ID, [0, m_op, free_in_a[i], free_in_b[i], carry[i-1], free_in_c[i], carry[i] + 2*op_is_min_max + 4*op_is_min_max*cout]);
} else if (i == bytes/2 - 1) {
lookup_assumes(BINARY_TABLE_ID, [mode32, m_op, free_in_a[i], free_in_b[i], carry[i-1], free_in_c[i], cout32 + 2*op_is_min_max + 4*op_is_min_max*cout + 8*use_last_carry*mode32]);
} else if (i < bytes - 1) {
lookup_assumes(BINARY_TABLE_ID, [0, _m_op, _free_in_a, _free_in_b, carry[i-1], free_in_c[i], carry[i] + 2*op_is_min_max + 4*op_is_min_max*cout]);
} else {
lookup_assumes(BINARY_TABLE_ID, [1-mode32, _m_op, _free_in_a, _free_in_b, carry[i-1], free_in_c[i], cout64 + 2*op_is_min_max + 4*op_is_min_max*cout + 8*use_last_carry*(1-mode32)]);
}
if (i >= half_bytes) {
free_in_a_or_c[j] === mode64 * (free_in_a[i] - free_in_c[half_bytes-1]) + free_in_c[half_bytes-1];
free_in_b_or_zero[j] === mode64 * free_in_b[i];
}

if (i < half_bytes - 1) {
lookup_assumes(BINARY_TABLE_ID, [0, m_op, free_in_a[i], free_in_b[i], carry[i-1], free_in_c[i], carry[i] + 2*op_is_min_max + 4*result_is_a]);
} else if (i == half_bytes - 1) {
lookup_assumes(BINARY_TABLE_ID, [mode32, m_op, free_in_a[i], free_in_b[i], carry[i-1], free_in_c[i], cout32 + 2*op_is_min_max + 4*result_is_a + 8*use_last_carry_mode32]);
} else if (i < bytes - 1) {
lookup_assumes(BINARY_TABLE_ID, [0, m_op_or_ext, free_in_a_or_c[j], free_in_b_or_zero[j], carry[i-1], free_in_c[i], carry[i] + 2*op_is_min_max + 4*result_is_a]);
j++;
} else {
lookup_assumes(BINARY_TABLE_ID, [mode64, m_op_or_ext, free_in_a_or_c[j], free_in_b_or_zero[j], carry[i-1], free_in_c[i], cout64 + 2*op_is_min_max + 4*result_is_a + 8*use_last_carry_mode64]);
j++;
}
}

// Constraints to make sure that this component is called from the main component
Expand Down Expand Up @@ -164,5 +184,5 @@ airtemplate Binary(const int N = 2**21, const int operation_bus_id = BINARY_ID)

col witness multiplicity;
col witness main_step;
lookup_proves(OPERATION_BUS_ID, [main_step, op, ...a, ...b, ...c, (1-op_is_min_max)*cout], multiplicity);
lookup_proves(OPERATION_BUS_ID, [main_step, op, ...a, ...b, ...c, cout - result_is_a], multiplicity);
}
35 changes: 32 additions & 3 deletions state-machines/binary/src/binary_basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ use zisk_pil::*;

use crate::{BinaryBasicTableOp, BinaryBasicTableSM};

const BYTES: usize = 8;
const HALF_BYTES: usize = BYTES / 2;

pub struct BinaryBasicSM<F> {
wcm: Arc<WitnessManager<F>>,

Expand Down Expand Up @@ -158,6 +161,7 @@ impl<F: Field> BinaryBasicSM<F> {
let opcode = ZiskOp::try_from_code(operation.opcode).expect("Invalid ZiskOp opcode");
let mode32 = Self::opcode_is_32_bits(opcode);
row.mode32 = F::from_bool(mode32);
let mode64 = F::from_bool(!mode32);

// Set c_filtered
let c_filtered = if mode32 { c & 0xFFFFFFFF } else { c };
Expand Down Expand Up @@ -667,6 +671,33 @@ impl<F: Field> BinaryBasicSM<F> {
_ => panic!("BinaryBasicSM::process_slice() found invalid opcode={}", operation.opcode),
}

// Set cout
let cout32 = row.carry[HALF_BYTES - 1];
let cout64 = row.carry[BYTES - 1];
row.cout = mode64 * (cout64 - cout32) + cout32;

// Set result_is_a
row.result_is_a = row.op_is_min_max * row.cout;

// Set use_last_carry_mode32 and use_last_carry_mode64
row.use_last_carry_mode32 = F::from_bool(mode32) * row.use_last_carry;
row.use_last_carry_mode64 = mode64 * row.use_last_carry;

// Set micro opcode
row.m_op = F::from_canonical_u8(binary_basic_table_op as u8);

// Set m_op_or_ext
let ext_32_op = F::from_canonical_u8(BinaryBasicTableOp::Ext32 as u8);
row.m_op_or_ext = mode64 * (row.m_op - ext_32_op) + ext_32_op;

// Set free_in_a_or_c and free_in_b_or_zero
for i in 0..HALF_BYTES {
row.free_in_a_or_c[i] = mode64 *
(row.free_in_a[i + HALF_BYTES] - row.free_in_c[HALF_BYTES - 1]) +
row.free_in_c[HALF_BYTES - 1];
row.free_in_b_or_zero[i] = mode64 * row.free_in_b[i + HALF_BYTES];
}

if row.use_last_carry == F::one() {
// Set first and last elements
row.free_in_c[7] = row.free_in_c[0];
Expand All @@ -676,9 +707,6 @@ impl<F: Field> BinaryBasicSM<F> {
// TODO: Find duplicates of this trace and reuse them by increasing their multiplicity.
row.multiplicity = F::one();

// Set micro opcode
row.m_op = F::from_canonical_u8(binary_basic_table_op as u8);

// Return
row
}
Expand Down Expand Up @@ -732,6 +760,7 @@ impl<F: Field> BinaryBasicSM<F> {
timer_start_trace!(BINARY_PADDING);
let padding_row = BinaryRow::<F> {
m_op: F::from_canonical_u8(0x20),
m_op_or_ext: F::from_canonical_u8(0x20),
multiplicity: F::zero(),
main_step: F::zero(), /* TODO: remove, since main_step is just for
* debugging */
Expand Down

0 comments on commit 85ff00c

Please sign in to comment.