From c5b09354672b7d7b414ed9fae6321dea3d688b5f Mon Sep 17 00:00:00 2001 From: "Mariano A. Nicolini" Date: Fri, 10 Nov 2023 12:49:34 -0300 Subject: [PATCH] refactor: add step abstraction to trace (#662) * Start with multi row steps on trace * Add table and step view functionality * Save work in progress * Remove unused feature from cairo-vm * Add comment in Cargo.toml * Fix some AIRs with the refactor * Change Frame to Table struct in stark proof for OODS * Further refactor in Table APIs * Apply clippy suggestion * Fix trace oods evaluations name in proof struct * Save work in progress * Fix unhandled merge conflict * Fix some Cairo constraints * Fix permutation argument Cairo constraints * Fix all Cairo constraints * Fix all compilation * Fix all tests * Apply clippy suggestion * Remove argument in empty() function * revert a mistaken change * Remove unused code * Use step size constant in validate_trace function * Remove unused method and remove hardcoded 1 in prover * Remove hardcoded value in trace construction * Add documentation to code * Add more documentation * Change hardcoded step size * Remove useless test --- provers/cairo/src/air.rs | 473 +++++++++--------- provers/cairo/src/execution_trace.rs | 11 +- provers/cairo/src/tests/integration_tests.rs | 4 +- provers/stark/src/debug.rs | 6 +- provers/stark/src/examples/dummy_air.rs | 19 +- .../src/examples/fibonacci_2_cols_shifted.rs | 18 +- .../stark/src/examples/fibonacci_2_columns.rs | 17 +- provers/stark/src/examples/fibonacci_rap.rs | 27 +- provers/stark/src/examples/quadratic_air.rs | 13 +- .../stark/src/examples/simple_fibonacci.rs | 16 +- provers/stark/src/frame.rs | 55 +- provers/stark/src/prover.rs | 14 +- provers/stark/src/table.rs | 84 +++- provers/stark/src/trace.rs | 103 ++-- provers/stark/src/traits.rs | 2 + provers/stark/src/verifier.rs | 2 +- 16 files changed, 504 insertions(+), 360 deletions(-) diff --git a/provers/cairo/src/air.rs b/provers/cairo/src/air.rs index 858d44f5c..77d2816fb 100644 --- a/provers/cairo/src/air.rs +++ b/provers/cairo/src/air.rs @@ -14,7 +14,7 @@ use stark_platinum_prover::{ frame::Frame, proof::{options::ProofOptions, stark::StarkProof}, prover::{IsStarkProver, Prover, ProvingError}, - trace::TraceTable, + trace::{StepView, TraceTable}, traits::AIR, transcript::{IsStarkTranscript, StoneProverTranscript}, verifier::{IsStarkVerifier, Verifier}, @@ -623,6 +623,8 @@ impl AIR for CairoAIR { type RAPChallenges = CairoRAPChallenges; type PublicInputs = PublicInputs; + const STEP_SIZE: usize = 1; + /// Creates a new CairoAIR from proof_options /// /// # Arguments @@ -786,7 +788,10 @@ impl AIR for CairoAIR { let aux_table = Table::new(aux_data, self.number_auxiliary_rap_columns()); - TraceTable { table: aux_table } + TraceTable { + table: aux_table, + step_size: Self::STEP_SIZE, + } } fn build_rap_challenges( @@ -937,10 +942,14 @@ impl AIR for CairoAIR { /// From the Cairo whitepaper, section 9.10 fn compute_instr_constraints(constraints: &mut [Felt252], frame: &Frame) { // These constraints are only applied over elements of the same row. - let curr = frame.get_row(0); + let curr = frame.get_evaluation_step(0); + + let flags: Vec<&Felt252> = (0..16) + .map(|col_idx| curr.get_evaluation_element(0, col_idx)) + .collect(); // Bit constraints - for (i, flag) in curr[0..16].iter().enumerate() { + for (i, flag) in flags.clone().into_iter().enumerate() { constraints[i] = match i { 0..=14 => flag * (flag - Felt252::one()), 15 => *flag, @@ -955,92 +964,135 @@ fn compute_instr_constraints(constraints: &mut [Felt252], frame: &Frame) { // These constraints are only applied over elements of the same row. - let curr = frame.get_row(0); + let curr = frame.get_evaluation_step(0); + + let ap = curr.get_evaluation_element(0, FRAME_AP); + let fp = curr.get_evaluation_element(0, FRAME_FP); + let pc = curr.get_evaluation_element(0, FRAME_PC); + + let dst_fp = curr.get_evaluation_element(0, F_DST_FP); + let off_dst = curr.get_evaluation_element(0, OFF_DST); + let dst_addr = curr.get_evaluation_element(0, FRAME_DST_ADDR); - let ap = &curr[FRAME_AP]; - let fp = &curr[FRAME_FP]; - let pc = &curr[FRAME_PC]; + let op0_fp = curr.get_evaluation_element(0, F_OP_0_FP); + let off_op0 = curr.get_evaluation_element(0, OFF_OP0); + let op0_addr = curr.get_evaluation_element(0, FRAME_OP0_ADDR); + + let op1_val = curr.get_evaluation_element(0, F_OP_1_VAL); + let op1_ap = curr.get_evaluation_element(0, F_OP_1_AP); + let op1_fp = curr.get_evaluation_element(0, F_OP_1_FP); + let op0 = curr.get_evaluation_element(0, FRAME_OP0); + let off_op1 = curr.get_evaluation_element(0, OFF_OP1); + let op1_addr = curr.get_evaluation_element(0, FRAME_OP1_ADDR); let one = Felt252::one(); let b15 = Felt252::from(2).pow(15u32); - constraints[DST_ADDR] = - curr[F_DST_FP] * fp + (one - curr[F_DST_FP]) * ap + (curr[OFF_DST] - b15) - - curr[FRAME_DST_ADDR]; + constraints[DST_ADDR] = dst_fp * fp + (one - dst_fp) * ap + (off_dst - b15) - dst_addr; - constraints[OP0_ADDR] = - curr[F_OP_0_FP] * fp + (one - curr[F_OP_0_FP]) * ap + (curr[OFF_OP0] - b15) - - curr[FRAME_OP0_ADDR]; + constraints[OP0_ADDR] = op0_fp * fp + (one - op0_fp) * ap + (off_op0 - b15) - op0_addr; - constraints[OP1_ADDR] = curr[F_OP_1_VAL] * pc - + curr[F_OP_1_AP] * ap - + curr[F_OP_1_FP] * fp - + (one - curr[F_OP_1_VAL] - curr[F_OP_1_AP] - curr[F_OP_1_FP]) * curr[FRAME_OP0] - + (curr[OFF_OP1] - b15) - - curr[FRAME_OP1_ADDR]; + constraints[OP1_ADDR] = op1_val * pc + + op1_ap * ap + + op1_fp * fp + + (one - op1_val - op1_ap - op1_fp) * op0 + + (off_op1 - b15) + - op1_addr; } fn compute_register_constraints(constraints: &mut [Felt252], frame: &Frame) { - let curr = frame.get_row(0); - let next = frame.get_row(1); + let curr = frame.get_evaluation_step(0); + let next = frame.get_evaluation_step(1); let one = Felt252::one(); let two = Felt252::from(2); + let ap = curr.get_evaluation_element(0, FRAME_AP); + let next_ap = next.get_evaluation_element(0, FRAME_AP); + let ap_add = curr.get_evaluation_element(0, F_AP_ADD); + let res = curr.get_evaluation_element(0, FRAME_RES); + let ap_one = curr.get_evaluation_element(0, F_AP_ONE); + let opc_call = curr.get_evaluation_element(0, F_OPC_CALL); + + let opc_ret = curr.get_evaluation_element(0, F_OPC_RET); + let dst = curr.get_evaluation_element(0, FRAME_DST); + let fp = curr.get_evaluation_element(0, FRAME_FP); + let next_fp = next.get_evaluation_element(0, FRAME_FP); + + let t1 = curr.get_evaluation_element(0, FRAME_T1); + let pc_jnz = curr.get_evaluation_element(0, F_PC_JNZ); + let pc = curr.get_evaluation_element(0, FRAME_PC); + let next_pc = next.get_evaluation_element(0, FRAME_PC); + + let t0 = curr.get_evaluation_element(0, FRAME_T0); + let op1 = curr.get_evaluation_element(0, FRAME_OP1); + let pc_abs = curr.get_evaluation_element(0, F_PC_ABS); + let pc_rel = curr.get_evaluation_element(0, F_PC_REL); + // ap and fp constraints - constraints[NEXT_AP] = - curr[FRAME_AP] + curr[F_AP_ADD] * curr[FRAME_RES] + curr[F_AP_ONE] + curr[F_OPC_CALL] * two - - next[FRAME_AP]; + constraints[NEXT_AP] = ap + ap_add * res + ap_one + opc_call * two - next_ap; - constraints[NEXT_FP] = curr[F_OPC_RET] * curr[FRAME_DST] - + curr[F_OPC_CALL] * (curr[FRAME_AP] + two) - + (one - curr[F_OPC_RET] - curr[F_OPC_CALL]) * curr[FRAME_FP] - - next[FRAME_FP]; + constraints[NEXT_FP] = + opc_ret * dst + opc_call * (ap + two) + (one - opc_ret - opc_call) * fp - next_fp; // pc constraints - constraints[NEXT_PC_1] = (curr[FRAME_T1] - curr[F_PC_JNZ]) - * (next[FRAME_PC] - (curr[FRAME_PC] + frame_inst_size(curr))); - - constraints[NEXT_PC_2] = curr[FRAME_T0] * (next[FRAME_PC] - (curr[FRAME_PC] + curr[FRAME_OP1])) - + (one - curr[F_PC_JNZ]) * next[FRAME_PC] - - ((one - curr[F_PC_ABS] - curr[F_PC_REL] - curr[F_PC_JNZ]) - * (curr[FRAME_PC] + frame_inst_size(curr)) - + curr[F_PC_ABS] * curr[FRAME_RES] - + curr[F_PC_REL] * (curr[FRAME_PC] + curr[FRAME_RES])); - - constraints[T0] = curr[F_PC_JNZ] * curr[FRAME_DST] - curr[FRAME_T0]; - constraints[T1] = curr[FRAME_T0] * curr[FRAME_RES] - curr[FRAME_T1]; + constraints[NEXT_PC_1] = (t1 - pc_jnz) * (next_pc - (pc + frame_inst_size(curr))); + + constraints[NEXT_PC_2] = t0 * (next_pc - (pc + op1)) + (one - pc_jnz) * next_pc + - ((one - pc_abs - pc_rel - pc_jnz) * (pc + frame_inst_size(curr)) + + pc_abs * res + + pc_rel * (pc + res)); + + constraints[T0] = pc_jnz * dst - t0; + constraints[T1] = t0 * res - t1; } fn compute_opcode_constraints(constraints: &mut [Felt252], frame: &Frame) { - let curr = frame.get_row(0); + let curr = frame.get_evaluation_step(0); let one = Felt252::one(); - constraints[MUL_1] = curr[FRAME_MUL] - curr[FRAME_OP0] * curr[FRAME_OP1]; + let mul = curr.get_evaluation_element(0, FRAME_MUL); + let op0 = curr.get_evaluation_element(0, FRAME_OP0); + let op1 = curr.get_evaluation_element(0, FRAME_OP1); - constraints[MUL_2] = curr[F_RES_ADD] * (curr[FRAME_OP0] + curr[FRAME_OP1]) - + curr[F_RES_MUL] * curr[FRAME_MUL] - + (one - curr[F_RES_ADD] - curr[F_RES_MUL] - curr[F_PC_JNZ]) * curr[FRAME_OP1] - - (one - curr[F_PC_JNZ]) * curr[FRAME_RES]; + let res_add = curr.get_evaluation_element(0, F_RES_ADD); + let res_mul = curr.get_evaluation_element(0, F_RES_MUL); + let pc_jnz = curr.get_evaluation_element(0, F_PC_JNZ); + let res = curr.get_evaluation_element(0, FRAME_RES); - constraints[CALL_1] = curr[F_OPC_CALL] * (curr[FRAME_DST] - curr[FRAME_FP]); + let opc_call = curr.get_evaluation_element(0, F_OPC_CALL); + let dst = curr.get_evaluation_element(0, FRAME_DST); + let fp = curr.get_evaluation_element(0, FRAME_FP); + let pc = curr.get_evaluation_element(0, FRAME_PC); - constraints[CALL_2] = - curr[F_OPC_CALL] * (curr[FRAME_OP0] - (curr[FRAME_PC] + frame_inst_size(curr))); + let opc_aeq = curr.get_evaluation_element(0, F_OPC_AEQ); - constraints[ASSERT_EQ] = curr[F_OPC_AEQ] * (curr[FRAME_DST] - curr[FRAME_RES]); + constraints[MUL_1] = mul - op0 * op1; + + constraints[MUL_2] = + res_add * (op0 + op1) + res_mul * mul + (one - res_add - res_mul - pc_jnz) * op1 + - (one - pc_jnz) * res; + + constraints[CALL_1] = opc_call * (dst - fp); + + constraints[CALL_2] = opc_call * (op0 - (pc + frame_inst_size(curr))); + + constraints[ASSERT_EQ] = opc_aeq * (dst - res); } fn memory_is_increasing( @@ -1048,69 +1100,55 @@ fn memory_is_increasing( frame: &Frame, builtin_offset: usize, ) { - let curr = frame.get_row(0); - let next = frame.get_row(1); + let curr = frame.get_evaluation_step(0); + let next = frame.get_evaluation_step(1); let one = FieldElement::one(); - constraints[MEMORY_INCREASING_0] = (curr[MEMORY_ADDR_SORTED_0 - builtin_offset] - - curr[MEMORY_ADDR_SORTED_1 - builtin_offset]) - * (curr[MEMORY_ADDR_SORTED_1 - builtin_offset] - - curr[MEMORY_ADDR_SORTED_0 - builtin_offset] - - one); - - constraints[MEMORY_INCREASING_1] = (curr[MEMORY_ADDR_SORTED_1 - builtin_offset] - - curr[MEMORY_ADDR_SORTED_2 - builtin_offset]) - * (curr[MEMORY_ADDR_SORTED_2 - builtin_offset] - - curr[MEMORY_ADDR_SORTED_1 - builtin_offset] - - one); - - constraints[MEMORY_INCREASING_2] = (curr[MEMORY_ADDR_SORTED_2 - builtin_offset] - - curr[MEMORY_ADDR_SORTED_3 - builtin_offset]) - * (curr[MEMORY_ADDR_SORTED_3 - builtin_offset] - - curr[MEMORY_ADDR_SORTED_2 - builtin_offset] - - one); - - constraints[MEMORY_INCREASING_3] = (curr[MEMORY_ADDR_SORTED_3 - builtin_offset] - - curr[MEMORY_ADDR_SORTED_4 - builtin_offset]) - * (curr[MEMORY_ADDR_SORTED_4 - builtin_offset] - - curr[MEMORY_ADDR_SORTED_3 - builtin_offset] - - one); - - constraints[MEMORY_INCREASING_4] = (curr[MEMORY_ADDR_SORTED_4 - builtin_offset] - - next[MEMORY_ADDR_SORTED_0 - builtin_offset]) - * (next[MEMORY_ADDR_SORTED_0 - builtin_offset] - - curr[MEMORY_ADDR_SORTED_4 - builtin_offset] - - one); - - constraints[MEMORY_CONSISTENCY_0] = (curr[MEMORY_VALUES_SORTED_0 - builtin_offset] - - curr[MEMORY_VALUES_SORTED_1 - builtin_offset]) - * (curr[MEMORY_ADDR_SORTED_1 - builtin_offset] - - curr[MEMORY_ADDR_SORTED_0 - builtin_offset] - - one); - - constraints[MEMORY_CONSISTENCY_1] = (curr[MEMORY_VALUES_SORTED_1 - builtin_offset] - - curr[MEMORY_VALUES_SORTED_2 - builtin_offset]) - * (curr[MEMORY_ADDR_SORTED_2 - builtin_offset] - - curr[MEMORY_ADDR_SORTED_1 - builtin_offset] - - one); - - constraints[MEMORY_CONSISTENCY_2] = (curr[MEMORY_VALUES_SORTED_2 - builtin_offset] - - curr[MEMORY_VALUES_SORTED_3 - builtin_offset]) - * (curr[MEMORY_ADDR_SORTED_3 - builtin_offset] - - curr[MEMORY_ADDR_SORTED_2 - builtin_offset] - - one); - - constraints[MEMORY_CONSISTENCY_3] = (curr[MEMORY_VALUES_SORTED_3 - builtin_offset] - - curr[MEMORY_VALUES_SORTED_4 - builtin_offset]) - * (curr[MEMORY_ADDR_SORTED_4 - builtin_offset] - - curr[MEMORY_ADDR_SORTED_3 - builtin_offset] - - one); - - constraints[MEMORY_CONSISTENCY_4] = (curr[MEMORY_VALUES_SORTED_4 - builtin_offset] - - next[MEMORY_VALUES_SORTED_0 - builtin_offset]) - * (next[MEMORY_ADDR_SORTED_0 - builtin_offset] - - curr[MEMORY_ADDR_SORTED_4 - builtin_offset] - - one); + let mem_addr_sorted_0 = curr.get_evaluation_element(0, MEMORY_ADDR_SORTED_0 - builtin_offset); + let mem_addr_sorted_1 = curr.get_evaluation_element(0, MEMORY_ADDR_SORTED_1 - builtin_offset); + let mem_addr_sorted_2 = curr.get_evaluation_element(0, MEMORY_ADDR_SORTED_2 - builtin_offset); + let mem_addr_sorted_3 = curr.get_evaluation_element(0, MEMORY_ADDR_SORTED_3 - builtin_offset); + let mem_addr_sorted_4 = curr.get_evaluation_element(0, MEMORY_ADDR_SORTED_4 - builtin_offset); + let next_mem_addr_sorted_0 = + next.get_evaluation_element(0, MEMORY_ADDR_SORTED_0 - builtin_offset); + + let mem_val_sorted_0 = curr.get_evaluation_element(0, MEMORY_VALUES_SORTED_0 - builtin_offset); + let mem_val_sorted_1 = curr.get_evaluation_element(0, MEMORY_VALUES_SORTED_1 - builtin_offset); + let mem_val_sorted_2 = curr.get_evaluation_element(0, MEMORY_VALUES_SORTED_2 - builtin_offset); + let mem_val_sorted_3 = curr.get_evaluation_element(0, MEMORY_VALUES_SORTED_3 - builtin_offset); + let mem_val_sorted_4 = curr.get_evaluation_element(0, MEMORY_VALUES_SORTED_4 - builtin_offset); + let next_mem_val_sorted_0 = + next.get_evaluation_element(0, MEMORY_VALUES_SORTED_0 - builtin_offset); + + constraints[MEMORY_INCREASING_0] = + (mem_addr_sorted_0 - mem_addr_sorted_1) * (mem_addr_sorted_1 - mem_addr_sorted_0 - one); + + constraints[MEMORY_INCREASING_1] = + (mem_addr_sorted_1 - mem_addr_sorted_2) * (mem_addr_sorted_2 - mem_addr_sorted_1 - one); + + constraints[MEMORY_INCREASING_2] = + (mem_addr_sorted_2 - mem_addr_sorted_3) * (mem_addr_sorted_3 - mem_addr_sorted_2 - one); + + constraints[MEMORY_INCREASING_3] = + (mem_addr_sorted_3 - mem_addr_sorted_4) * (mem_addr_sorted_4 - mem_addr_sorted_3 - one); + + constraints[MEMORY_INCREASING_4] = (mem_addr_sorted_4 - next_mem_addr_sorted_0) + * (next_mem_addr_sorted_0 - mem_addr_sorted_4 - one); + + constraints[MEMORY_CONSISTENCY_0] = + (mem_val_sorted_0 - mem_val_sorted_1) * (mem_addr_sorted_1 - mem_addr_sorted_0 - one); + + constraints[MEMORY_CONSISTENCY_1] = + (mem_val_sorted_1 - mem_val_sorted_2) * (mem_addr_sorted_2 - mem_addr_sorted_1 - one); + + constraints[MEMORY_CONSISTENCY_2] = + (mem_val_sorted_2 - mem_val_sorted_3) * (mem_addr_sorted_3 - mem_addr_sorted_2 - one); + + constraints[MEMORY_CONSISTENCY_3] = + (mem_val_sorted_3 - mem_val_sorted_4) * (mem_addr_sorted_4 - mem_addr_sorted_3 - one); + + constraints[MEMORY_CONSISTENCY_4] = (mem_val_sorted_4 - next_mem_val_sorted_0) + * (next_mem_addr_sorted_0 - mem_addr_sorted_4 - one); } fn permutation_argument( @@ -1119,41 +1157,42 @@ fn permutation_argument( rap_challenges: &CairoRAPChallenges, builtin_offset: usize, ) { - let curr = frame.get_row(0); - let next = frame.get_row(1); + let curr = frame.get_evaluation_step(0); + let next = frame.get_evaluation_step(1); + let z = &rap_challenges.z_memory; let alpha = &rap_challenges.alpha_memory; - let p0 = &curr[PERMUTATION_ARGUMENT_COL_0 - builtin_offset]; - let p0_next = &next[PERMUTATION_ARGUMENT_COL_0 - builtin_offset]; - let p1 = &curr[PERMUTATION_ARGUMENT_COL_1 - builtin_offset]; - let p2 = &curr[PERMUTATION_ARGUMENT_COL_2 - builtin_offset]; - let p3 = &curr[PERMUTATION_ARGUMENT_COL_3 - builtin_offset]; - let p4 = &curr[PERMUTATION_ARGUMENT_COL_4 - builtin_offset]; - - let ap0_next = &next[MEMORY_ADDR_SORTED_0 - builtin_offset]; - let ap1 = &curr[MEMORY_ADDR_SORTED_1 - builtin_offset]; - let ap2 = &curr[MEMORY_ADDR_SORTED_2 - builtin_offset]; - let ap3 = &curr[MEMORY_ADDR_SORTED_3 - builtin_offset]; - let ap4 = &curr[MEMORY_ADDR_SORTED_4 - builtin_offset]; - - let vp0_next = &next[MEMORY_VALUES_SORTED_0 - builtin_offset]; - let vp1 = &curr[MEMORY_VALUES_SORTED_1 - builtin_offset]; - let vp2 = &curr[MEMORY_VALUES_SORTED_2 - builtin_offset]; - let vp3 = &curr[MEMORY_VALUES_SORTED_3 - builtin_offset]; - let vp4 = &curr[MEMORY_VALUES_SORTED_4 - builtin_offset]; - - let a0_next = &next[FRAME_PC]; - let a1 = &curr[FRAME_DST_ADDR]; - let a2 = &curr[FRAME_OP0_ADDR]; - let a3 = &curr[FRAME_OP1_ADDR]; - let a4 = &curr[EXTRA_ADDR]; - - let v0_next = &next[FRAME_INST]; - let v1 = &curr[FRAME_DST]; - let v2 = &curr[FRAME_OP0]; - let v3 = &curr[FRAME_OP1]; - let v4 = &curr[EXTRA_VAL]; + let p0 = curr.get_evaluation_element(0, PERMUTATION_ARGUMENT_COL_0 - builtin_offset); + let next_p0 = next.get_evaluation_element(0, PERMUTATION_ARGUMENT_COL_0 - builtin_offset); + let p1 = curr.get_evaluation_element(0, PERMUTATION_ARGUMENT_COL_1 - builtin_offset); + let p2 = curr.get_evaluation_element(0, PERMUTATION_ARGUMENT_COL_2 - builtin_offset); + let p3 = curr.get_evaluation_element(0, PERMUTATION_ARGUMENT_COL_3 - builtin_offset); + let p4 = curr.get_evaluation_element(0, PERMUTATION_ARGUMENT_COL_4 - builtin_offset); + + let next_ap0 = next.get_evaluation_element(0, MEMORY_ADDR_SORTED_0 - builtin_offset); + let ap1 = curr.get_evaluation_element(0, MEMORY_ADDR_SORTED_1 - builtin_offset); + let ap2 = curr.get_evaluation_element(0, MEMORY_ADDR_SORTED_2 - builtin_offset); + let ap3 = curr.get_evaluation_element(0, MEMORY_ADDR_SORTED_3 - builtin_offset); + let ap4 = curr.get_evaluation_element(0, MEMORY_ADDR_SORTED_4 - builtin_offset); + + let next_vp0 = next.get_evaluation_element(0, MEMORY_VALUES_SORTED_0 - builtin_offset); + let vp1 = curr.get_evaluation_element(0, MEMORY_VALUES_SORTED_1 - builtin_offset); + let vp2 = curr.get_evaluation_element(0, MEMORY_VALUES_SORTED_2 - builtin_offset); + let vp3 = curr.get_evaluation_element(0, MEMORY_VALUES_SORTED_3 - builtin_offset); + let vp4 = curr.get_evaluation_element(0, MEMORY_VALUES_SORTED_4 - builtin_offset); + + let next_a0 = next.get_evaluation_element(0, FRAME_PC); + let a1 = curr.get_evaluation_element(0, FRAME_DST_ADDR); + let a2 = curr.get_evaluation_element(0, FRAME_OP0_ADDR); + let a3 = curr.get_evaluation_element(0, FRAME_OP1_ADDR); + let a4 = curr.get_evaluation_element(0, EXTRA_ADDR); + + let next_v0 = next.get_evaluation_element(0, FRAME_INST); + let v1 = curr.get_evaluation_element(0, FRAME_DST); + let v2 = curr.get_evaluation_element(0, FRAME_OP0); + let v3 = curr.get_evaluation_element(0, FRAME_OP1); + let v4 = curr.get_evaluation_element(0, EXTRA_VAL); constraints[PERMUTATION_ARGUMENT_0] = (z - (ap1 + alpha * vp1)) * p1 - (z - (a1 + alpha * v1)) * p0; @@ -1164,7 +1203,7 @@ fn permutation_argument( constraints[PERMUTATION_ARGUMENT_3] = (z - (ap4 + alpha * vp4)) * p4 - (z - (a4 + alpha * v4)) * p3; constraints[PERMUTATION_ARGUMENT_4] = - (z - (ap0_next + alpha * vp0_next)) * p0_next - (z - (a0_next + alpha * v0_next)) * p4; + (z - (next_ap0 + alpha * next_vp0)) * next_p0 - (z - (next_a0 + alpha * next_v0)) * p4; } fn permutation_argument_range_check( @@ -1173,77 +1212,83 @@ fn permutation_argument_range_check( rap_challenges: &CairoRAPChallenges, builtin_offset: usize, ) { - let curr = frame.get_row(0); - let next = frame.get_row(1); + let curr = frame.get_evaluation_step(0); + let next = frame.get_evaluation_step(1); let one = FieldElement::one(); let z = &rap_challenges.z_range_check; - constraints[RANGE_CHECK_INCREASING_0] = (curr[RANGE_CHECK_COL_1 - builtin_offset] - - curr[RANGE_CHECK_COL_2 - builtin_offset]) - * (curr[RANGE_CHECK_COL_2 - builtin_offset] - - curr[RANGE_CHECK_COL_1 - builtin_offset] - - one); - constraints[RANGE_CHECK_INCREASING_1] = (curr[RANGE_CHECK_COL_2 - builtin_offset] - - curr[RANGE_CHECK_COL_3 - builtin_offset]) - * (curr[RANGE_CHECK_COL_3 - builtin_offset] - - curr[RANGE_CHECK_COL_2 - builtin_offset] - - one); - constraints[RANGE_CHECK_INCREASING_2] = (curr[RANGE_CHECK_COL_3 - builtin_offset] - - curr[RANGE_CHECK_COL_4 - builtin_offset]) - * (curr[RANGE_CHECK_COL_4 - builtin_offset] - - curr[RANGE_CHECK_COL_3 - builtin_offset] - - one); - constraints[RANGE_CHECK_INCREASING_3] = (curr[RANGE_CHECK_COL_4 - builtin_offset] - - next[RANGE_CHECK_COL_1 - builtin_offset]) - * (next[RANGE_CHECK_COL_1 - builtin_offset] - - curr[RANGE_CHECK_COL_4 - builtin_offset] - - one); - - let p0 = curr[PERMUTATION_ARGUMENT_RANGE_CHECK_COL_1 - builtin_offset]; - let p0_next = next[PERMUTATION_ARGUMENT_RANGE_CHECK_COL_1 - builtin_offset]; - let p1 = curr[PERMUTATION_ARGUMENT_RANGE_CHECK_COL_2 - builtin_offset]; - let p2 = curr[PERMUTATION_ARGUMENT_RANGE_CHECK_COL_3 - builtin_offset]; - let p3 = curr[PERMUTATION_ARGUMENT_RANGE_CHECK_COL_4 - builtin_offset]; - - let ap0_next = next[RANGE_CHECK_COL_1 - builtin_offset]; - let ap1 = curr[RANGE_CHECK_COL_2 - builtin_offset]; - let ap2 = curr[RANGE_CHECK_COL_3 - builtin_offset]; - let ap3 = curr[RANGE_CHECK_COL_4 - builtin_offset]; - - let a0_next = next[OFF_DST]; - let a1 = curr[OFF_OP0]; - let a2 = curr[OFF_OP1]; - let a3 = curr[RC_HOLES]; + let rc_col_1 = curr.get_evaluation_element(0, RANGE_CHECK_COL_1 - builtin_offset); + let rc_col_2 = curr.get_evaluation_element(0, RANGE_CHECK_COL_2 - builtin_offset); + let rc_col_3 = curr.get_evaluation_element(0, RANGE_CHECK_COL_3 - builtin_offset); + let rc_col_4 = curr.get_evaluation_element(0, RANGE_CHECK_COL_4 - builtin_offset); + let next_rc_col_1 = next.get_evaluation_element(0, RANGE_CHECK_COL_1 - builtin_offset); + + constraints[RANGE_CHECK_INCREASING_0] = (rc_col_1 - rc_col_2) * (rc_col_2 - rc_col_1 - one); + constraints[RANGE_CHECK_INCREASING_1] = (rc_col_2 - rc_col_3) * (rc_col_3 - rc_col_2 - one); + constraints[RANGE_CHECK_INCREASING_2] = (rc_col_3 - rc_col_4) * (rc_col_4 - rc_col_3 - one); + constraints[RANGE_CHECK_INCREASING_3] = + (rc_col_4 - next_rc_col_1) * (next_rc_col_1 - rc_col_4 - one); + + let p0 = + curr.get_evaluation_element(0, PERMUTATION_ARGUMENT_RANGE_CHECK_COL_1 - builtin_offset); + let next_p0 = + next.get_evaluation_element(0, PERMUTATION_ARGUMENT_RANGE_CHECK_COL_1 - builtin_offset); + let p1 = + curr.get_evaluation_element(0, PERMUTATION_ARGUMENT_RANGE_CHECK_COL_2 - builtin_offset); + let p2 = + curr.get_evaluation_element(0, PERMUTATION_ARGUMENT_RANGE_CHECK_COL_3 - builtin_offset); + let p3 = + curr.get_evaluation_element(0, PERMUTATION_ARGUMENT_RANGE_CHECK_COL_4 - builtin_offset); + + let next_ap0 = next.get_evaluation_element(0, RANGE_CHECK_COL_1 - builtin_offset); + let ap1 = curr.get_evaluation_element(0, RANGE_CHECK_COL_2 - builtin_offset); + let ap2 = curr.get_evaluation_element(0, RANGE_CHECK_COL_3 - builtin_offset); + let ap3 = curr.get_evaluation_element(0, RANGE_CHECK_COL_4 - builtin_offset); + + let a0_next = next.get_evaluation_element(0, OFF_DST); + let a1 = curr.get_evaluation_element(0, OFF_OP0); + let a2 = curr.get_evaluation_element(0, OFF_OP1); + let a3 = curr.get_evaluation_element(0, RC_HOLES); constraints[RANGE_CHECK_0] = (z - ap1) * p1 - (z - a1) * p0; constraints[RANGE_CHECK_1] = (z - ap2) * p2 - (z - a2) * p1; constraints[RANGE_CHECK_2] = (z - ap3) * p3 - (z - a3) * p2; - constraints[RANGE_CHECK_3] = (z - ap0_next) * p0_next - (z - a0_next) * p3; + constraints[RANGE_CHECK_3] = (z - next_ap0) * next_p0 - (z - a0_next) * p3; } -fn frame_inst_size(frame_row: &[Felt252]) -> Felt252 { - frame_row[F_OP_1_VAL] + Felt252::one() +fn frame_inst_size(step: &StepView) -> Felt252 { + let op1_val = step.get_evaluation_element(0, F_OP_1_VAL); + op1_val + Felt252::one() } fn range_check_builtin( constraints: &mut [FieldElement], frame: &Frame, ) { - let curr = frame.get_row(0); + let curr = frame.get_evaluation_step(0); constraints[RANGE_CHECK_BUILTIN] = evaluate_range_check_builtin_constraint(curr) } -fn evaluate_range_check_builtin_constraint(curr: &[Felt252]) -> Felt252 { - curr[RC_0] - + curr[RC_1] * Felt252::from_hex("10000").unwrap() - + curr[RC_2] * Felt252::from_hex("100000000").unwrap() - + curr[RC_3] * Felt252::from_hex("1000000000000").unwrap() - + curr[RC_4] * Felt252::from_hex("10000000000000000").unwrap() - + curr[RC_5] * Felt252::from_hex("100000000000000000000").unwrap() - + curr[RC_6] * Felt252::from_hex("1000000000000000000000000").unwrap() - + curr[RC_7] * Felt252::from_hex("10000000000000000000000000000").unwrap() - - curr[RC_VALUE] +fn evaluate_range_check_builtin_constraint(step: &StepView) -> Felt252 { + let rc_0 = step.get_evaluation_element(0, RC_0); + let rc_1 = step.get_evaluation_element(0, RC_1); + let rc_2 = step.get_evaluation_element(0, RC_2); + let rc_3 = step.get_evaluation_element(0, RC_3); + let rc_4 = step.get_evaluation_element(0, RC_4); + let rc_5 = step.get_evaluation_element(0, RC_5); + let rc_6 = step.get_evaluation_element(0, RC_6); + let rc_7 = step.get_evaluation_element(0, RC_7); + let rc_value = step.get_evaluation_element(0, RC_VALUE); + + rc_0 + rc_1 * Felt252::from_hex_unchecked("10000") + + rc_2 * Felt252::from_hex_unchecked("100000000") + + rc_3 * Felt252::from_hex_unchecked("1000000000000") + + rc_4 * Felt252::from_hex_unchecked("10000000000000000") + + rc_5 * Felt252::from_hex_unchecked("100000000000000000000") + + rc_6 * Felt252::from_hex_unchecked("1000000000000000000000000") + + rc_7 * Felt252::from_hex_unchecked("10000000000000000000000000000") + - rc_value } /// Wrapper function for generating Cairo proofs without the need to specify @@ -1284,30 +1329,6 @@ mod test { use super::*; use lambdaworks_math::field::element::FieldElement; - #[test] - fn range_check_eval_works() { - let mut row: Vec = Vec::new(); - - for _ in 0..61 { - row.push(Felt252::zero()); - } - - row[super::RC_0] = Felt252::one(); - row[super::RC_1] = Felt252::one(); - row[super::RC_2] = Felt252::one(); - row[super::RC_3] = Felt252::one(); - row[super::RC_4] = Felt252::one(); - row[super::RC_5] = Felt252::one(); - row[super::RC_6] = Felt252::one(); - row[super::RC_7] = Felt252::one(); - - row[super::RC_VALUE] = Felt252::from_hex("00010001000100010001000100010001").unwrap(); - assert_eq!( - evaluate_range_check_builtin_constraint(&row), - Felt252::zero() - ); - } - #[test] fn test_build_auxiliary_trace_add_program_in_public_input_section_works() { let dummy_public_input = PublicInputs { diff --git a/provers/cairo/src/execution_trace.rs b/provers/cairo/src/execution_trace.rs index 66058d28d..b6fd22b9e 100644 --- a/provers/cairo/src/execution_trace.rs +++ b/provers/cairo/src/execution_trace.rs @@ -292,7 +292,7 @@ pub fn build_cairo_execution_trace( add_rc_builtin_columns(&mut trace_cols, range_check_builtin_range.clone(), memory); } - TraceTable::from_columns(trace_cols) + TraceTable::from_columns(trace_cols, 1) } // Build range-check builtin columns: rc_0, rc_1, ... , rc_7, rc_value @@ -620,7 +620,7 @@ mod test { FieldElement::from(7), FieldElement::from(7), ]; - let table = TraceTable::::from_columns(columns); + let table = TraceTable::::from_columns(columns, 1); let (col, rc_min, rc_max) = get_rc_holes(&table, &[0, 1, 2]); assert_eq!(col, expected_col); @@ -635,7 +635,10 @@ mod test { let data = row.repeat(8); let table = Table::new(data, 36); - let mut main_trace = TraceTable:: { table }; + let mut main_trace = TraceTable:: { + table, + step_size: 1, + }; let rc_holes = vec![ Felt252::from(1), @@ -736,7 +739,7 @@ mod test { trace_cols[FRAME_DST_ADDR][1] = Felt252::from(9); trace_cols[FRAME_OP0_ADDR][1] = Felt252::from(10); trace_cols[FRAME_OP1_ADDR][1] = Felt252::from(11); - let mut trace = TraceTable::from_columns(trace_cols); + let mut trace = TraceTable::from_columns(trace_cols, 1); let memory_holes = vec![Felt252::from(4), Felt252::from(7), Felt252::from(8)]; fill_memory_holes(&mut trace, &memory_holes); diff --git a/provers/cairo/src/tests/integration_tests.rs b/provers/cairo/src/tests/integration_tests.rs index 87d499a39..667f186a2 100644 --- a/provers/cairo/src/tests/integration_tests.rs +++ b/provers/cairo/src/tests/integration_tests.rs @@ -175,7 +175,7 @@ fn test_verifier_rejects_proof_with_changed_range_check_value() { last_column[0] = malicious_rc_value; malicious_trace_columns[n_cols - 1] = last_column; - let malicious_trace = TraceTable::from_columns(malicious_trace_columns); + let malicious_trace = TraceTable::from_columns(malicious_trace_columns, 1); let proof = generate_cairo_proof(&malicious_trace, &pub_inputs, &proof_options).unwrap(); assert!(!verify_cairo_proof(&proof, &pub_inputs, &proof_options)); } @@ -243,7 +243,7 @@ fn test_verifier_rejects_proof_with_changed_output() { output_value_column[output_row_idx] = malicious_output_value; malicious_trace_columns[output_col_idx + 4] = output_value_column; - let malicious_trace = TraceTable::from_columns(malicious_trace_columns); + let malicious_trace = TraceTable::from_columns(malicious_trace_columns, 1); let proof = generate_cairo_proof(&malicious_trace, &pub_inputs, &proof_options).unwrap(); assert!(!verify_cairo_proof(&proof, &pub_inputs, &proof_options)); } diff --git a/provers/stark/src/debug.rs b/provers/stark/src/debug.rs index d95ca8dfe..a3fd6ee92 100644 --- a/provers/stark/src/debug.rs +++ b/provers/stark/src/debug.rs @@ -28,7 +28,7 @@ pub fn validate_trace>( }) .collect(); - let trace = TraceTable::from_columns(trace_columns); + let trace = TraceTable::from_columns(trace_columns, A::STEP_SIZE); // --------- VALIDATE BOUNDARY CONSTRAINTS ------------ air.boundary_constraints(rap_challenges) @@ -40,7 +40,7 @@ pub fn validate_trace>( let boundary_value = constraint.value.clone(); let trace_value = trace.get(step, col); - if boundary_value != trace_value { + if &boundary_value != trace_value { ret = false; error!("Boundary constraint inconsistency - Expected value {} in step {} and column {}, found: {}", boundary_value.representative(), step, col, trace_value.representative()); } @@ -57,7 +57,7 @@ pub fn validate_trace>( .collect(); // Iterate over trace and compute transitions - for step in 0..trace.n_rows() { + for step in 0..trace.num_steps() { let frame = Frame::read_from_trace(&trace, step, 1, &air.context().transition_offsets); let evaluations = air.compute_transition(&frame, rap_challenges); diff --git a/provers/stark/src/examples/dummy_air.rs b/provers/stark/src/examples/dummy_air.rs index 0b08b7c90..3a20c6f9a 100644 --- a/provers/stark/src/examples/dummy_air.rs +++ b/provers/stark/src/examples/dummy_air.rs @@ -24,6 +24,8 @@ impl AIR for DummyAIR { type RAPChallenges = (); type PublicInputs = (); + const STEP_SIZE: usize = 1; + fn new( trace_length: usize, _pub_inputs: &Self::PublicInputs, @@ -63,13 +65,18 @@ impl AIR for DummyAIR { frame: &Frame, _rap_challenges: &Self::RAPChallenges, ) -> Vec> { - let first_row = frame.get_row(0); - let second_row = frame.get_row(1); - let third_row = frame.get_row(2); + let first_step = frame.get_evaluation_step(0); + let second_step = frame.get_evaluation_step(1); + let third_step = frame.get_evaluation_step(2); + + let flag = first_step.get_evaluation_element(0, 0); + let a0 = first_step.get_evaluation_element(0, 1); + let a1 = second_step.get_evaluation_element(0, 1); + let a2 = third_step.get_evaluation_element(0, 1); - let f_constraint = first_row[0] * (first_row[0] - FieldElement::one()); + let f_constraint = flag * (flag - FieldElement::one()); - let fib_constraint = third_row[1] - second_row[1] - first_row[1]; + let fib_constraint = a2 - a1 - a0; vec![f_constraint, fib_constraint] } @@ -118,5 +125,5 @@ pub fn dummy_trace(trace_length: usize) -> TraceTable { ret.push(ret[i - 1].clone() + ret[i - 2].clone()); } - TraceTable::from_columns(vec![vec![FieldElement::::one(); trace_length], ret]) + TraceTable::from_columns(vec![vec![FieldElement::::one(); trace_length], ret], 1) } diff --git a/provers/stark/src/examples/fibonacci_2_cols_shifted.rs b/provers/stark/src/examples/fibonacci_2_cols_shifted.rs index e5482bd19..c5401dbb4 100644 --- a/provers/stark/src/examples/fibonacci_2_cols_shifted.rs +++ b/provers/stark/src/examples/fibonacci_2_cols_shifted.rs @@ -56,6 +56,8 @@ where type RAPChallenges = (); type PublicInputs = PublicInputs; + const STEP_SIZE: usize = 1; + fn new( trace_length: usize, pub_inputs: &Self::PublicInputs, @@ -97,11 +99,17 @@ where frame: &Frame, _rap_challenges: &Self::RAPChallenges, ) -> Vec> { - let first_row = frame.get_row(0); - let second_row = frame.get_row(1); + let first_row = frame.get_evaluation_step(0); + let second_row = frame.get_evaluation_step(1); + + let a0_0 = first_row.get_evaluation_element(0, 0); + let a0_1 = first_row.get_evaluation_element(0, 1); + + let a1_0 = second_row.get_evaluation_element(0, 0); + let a1_1 = second_row.get_evaluation_element(0, 1); - let first_transition = &second_row[0] - &first_row[1]; - let second_transition = &second_row[1] - &first_row[0] - &first_row[1]; + let first_transition = a1_0 - a0_1; + let second_transition = a1_1 - a0_0 - a0_1; vec![first_transition, second_transition] } @@ -156,7 +164,7 @@ pub fn compute_trace( col1.push(y.clone()); } - TraceTable::from_columns(vec![col0, col1]) + TraceTable::from_columns(vec![col0, col1], 1) } #[cfg(test)] diff --git a/provers/stark/src/examples/fibonacci_2_columns.rs b/provers/stark/src/examples/fibonacci_2_columns.rs index 968810743..ca129fec3 100644 --- a/provers/stark/src/examples/fibonacci_2_columns.rs +++ b/provers/stark/src/examples/fibonacci_2_columns.rs @@ -32,6 +32,8 @@ where type RAPChallenges = (); type PublicInputs = FibonacciPublicInputs; + const STEP_SIZE: usize = 1; + fn new( trace_length: usize, pub_inputs: &Self::PublicInputs, @@ -73,14 +75,19 @@ where frame: &Frame, _rap_challenges: &Self::RAPChallenges, ) -> Vec> { - let first_row = frame.get_row(0); - let second_row = frame.get_row(1); + let first_step = frame.get_evaluation_step(0); + let second_step = frame.get_evaluation_step(1); // constraints of Fibonacci sequence (2 terms per step): // s_{0, i+1} = s_{0, i} + s_{1, i} // s_{1, i+1} = s_{1, i} + s_{0, i+1} - let first_transition = &second_row[0] - &first_row[0] - &first_row[1]; - let second_transition = &second_row[1] - &first_row[1] - &second_row[0]; + let s0_0 = first_step.get_evaluation_element(0, 0); + let s0_1 = first_step.get_evaluation_element(0, 1); + let s1_0 = second_step.get_evaluation_element(0, 0); + let s1_1 = second_step.get_evaluation_element(0, 1); + + let first_transition = s1_0 - s0_0 - s0_1; + let second_transition = s1_1 - s0_1 - s1_0; vec![first_transition, second_transition] } @@ -132,5 +139,5 @@ pub fn compute_trace( ret2.push(new_val + ret2[i - 1].clone()); } - TraceTable::from_columns(vec![ret1, ret2]) + TraceTable::from_columns(vec![ret1, ret2], 1) } diff --git a/provers/stark/src/examples/fibonacci_rap.rs b/provers/stark/src/examples/fibonacci_rap.rs index 880241759..d646a866a 100644 --- a/provers/stark/src/examples/fibonacci_rap.rs +++ b/provers/stark/src/examples/fibonacci_rap.rs @@ -45,6 +45,8 @@ where type RAPChallenges = FieldElement; type PublicInputs = FibonacciRAPPublicInputs; + const STEP_SIZE: usize = 1; + fn new( trace_length: usize, pub_inputs: &Self::PublicInputs, @@ -92,7 +94,7 @@ where aux_col.push(z_i * n_p_term.div(p_term)); } } - TraceTable::from_columns(vec![aux_col]) + TraceTable::from_columns(vec![aux_col], 1) } fn build_rap_challenges( @@ -112,19 +114,22 @@ where gamma: &Self::RAPChallenges, ) -> Vec> { // Main constraints - let first_row = frame.get_row(0); - let second_row = frame.get_row(1); - let third_row = frame.get_row(2); + let first_step = frame.get_evaluation_step(0); + let second_step = frame.get_evaluation_step(1); + let third_step = frame.get_evaluation_step(2); + + let a0 = first_step.get_evaluation_element(0, 0); + let a1 = second_step.get_evaluation_element(0, 0); + let a2 = third_step.get_evaluation_element(0, 0); - let mut constraints = - vec![third_row[0].clone() - second_row[0].clone() - first_row[0].clone()]; + let mut constraints = vec![a2 - a1 - a0]; // Auxiliary constraints - let z_i = &frame.get_row(0)[2]; - let z_i_plus_one = &frame.get_row(1)[2]; + let z_i = first_step.get_evaluation_element(0, 2); + let z_i_plus_one = second_step.get_evaluation_element(0, 2); - let a_i = &frame.get_row(0)[0]; - let b_i = &frame.get_row(0)[1]; + let a_i = first_step.get_evaluation_element(0, 0); + let b_i = first_step.get_evaluation_element(0, 1); let eval = z_i_plus_one * (b_i + gamma) - z_i * (a_i + gamma); @@ -186,7 +191,7 @@ pub fn fibonacci_rap_trace( let mut trace_cols = vec![fib_seq, fib_permuted]; resize_to_next_power_of_two(&mut trace_cols); - TraceTable::from_columns(trace_cols) + TraceTable::from_columns(trace_cols, 1) } #[cfg(test)] diff --git a/provers/stark/src/examples/quadratic_air.rs b/provers/stark/src/examples/quadratic_air.rs index 68448a631..cc5edb5f8 100644 --- a/provers/stark/src/examples/quadratic_air.rs +++ b/provers/stark/src/examples/quadratic_air.rs @@ -36,6 +36,8 @@ where type RAPChallenges = (); type PublicInputs = QuadraticPublicInputs; + const STEP_SIZE: usize = 1; + fn new( trace_length: usize, pub_inputs: &Self::PublicInputs, @@ -77,10 +79,13 @@ where frame: &Frame, _rap_challenges: &Self::RAPChallenges, ) -> Vec> { - let first_row = frame.get_row(0); - let second_row = frame.get_row(1); + let first_step = frame.get_evaluation_step(0); + let second_step = frame.get_evaluation_step(1); + + let x = first_step.get_evaluation_element(0, 0); + let x_squared = second_step.get_evaluation_element(0, 0); - vec![&second_row[0] - &first_row[0] * &first_row[0]] + vec![x_squared - x * x] } fn number_auxiliary_rap_columns(&self) -> usize { @@ -125,5 +130,5 @@ pub fn quadratic_trace( ret.push(ret[i - 1].clone() * ret[i - 1].clone()); } - TraceTable::from_columns(vec![ret]) + TraceTable::from_columns(vec![ret], 1) } diff --git a/provers/stark/src/examples/simple_fibonacci.rs b/provers/stark/src/examples/simple_fibonacci.rs index fe8cef15a..15a54a466 100644 --- a/provers/stark/src/examples/simple_fibonacci.rs +++ b/provers/stark/src/examples/simple_fibonacci.rs @@ -37,6 +37,8 @@ where type RAPChallenges = (); type PublicInputs = FibonacciPublicInputs; + const STEP_SIZE: usize = 1; + fn new( trace_length: usize, pub_inputs: &Self::PublicInputs, @@ -82,11 +84,15 @@ where frame: &Frame, _rap_challenges: &Self::RAPChallenges, ) -> Vec> { - let first_row = frame.get_row(0); - let second_row = frame.get_row(1); - let third_row = frame.get_row(2); + let first_step = frame.get_evaluation_step(0); + let second_step = frame.get_evaluation_step(1); + let third_step = frame.get_evaluation_step(2); + + let a0 = first_step.get_evaluation_element(0, 0); + let a1 = second_step.get_evaluation_element(0, 0); + let a2 = third_step.get_evaluation_element(0, 0); - vec![third_row[0].clone() - second_row[0].clone() - first_row[0].clone()] + vec![a2 - a1 - a0] } fn boundary_constraints( @@ -129,5 +135,5 @@ pub fn fibonacci_trace( ret.push(ret[i - 1].clone() + ret[i - 2].clone()); } - TraceTable::from_columns(vec![ret]) + TraceTable::from_columns(vec![ret], 1) } diff --git a/provers/stark/src/frame.rs b/provers/stark/src/frame.rs index ba5a06c84..6b3442c40 100644 --- a/provers/stark/src/frame.rs +++ b/provers/stark/src/frame.rs @@ -1,58 +1,41 @@ use super::trace::TraceTable; -use crate::table::Table; -use lambdaworks_math::field::{element::FieldElement, traits::IsFFTField}; +use crate::trace::StepView; +use lambdaworks_math::field::traits::IsFFTField; +/// A frame represents a collection of trace steps. +/// The collected steps are all the necessary steps for +/// all transition costraints over a trace to be evaluated. #[derive(Clone, Debug, PartialEq)] -pub struct Frame { - table: Table, +pub struct Frame<'t, F: IsFFTField> { + steps: Vec>, } -impl Frame { - pub fn new(data: Vec>, row_width: usize) -> Self { - let table = Table::new(data, row_width); - Self { table } +impl<'t, F: IsFFTField> Frame<'t, F> { + pub fn new(steps: Vec>) -> Self { + Self { steps } } - pub fn n_rows(&self) -> usize { - self.table.height - } - - pub fn n_cols(&self) -> usize { - self.table.width - } - - pub fn get_row(&self, row_idx: usize) -> &[FieldElement] { - self.table.get_row(row_idx) - } - - pub fn get_row_mut(&mut self, row_idx: usize) -> &mut [FieldElement] { - self.table.get_row_mut(row_idx) + pub fn get_evaluation_step(&self, step: usize) -> &StepView { + &self.steps[step] } pub fn read_from_trace( - trace: &TraceTable, + trace: &'t TraceTable, step: usize, blowup: u8, offsets: &[usize], ) -> Self { // Get trace length to apply module with it when getting elements of // the frame from the trace. - let trace_steps = trace.n_rows(); - let data = offsets + let trace_steps = trace.num_steps(); + + let steps = offsets .iter() - .flat_map(|frame_row_idx| { - trace - .get_row((step + (frame_row_idx * blowup as usize)) % trace_steps) - .to_vec() + .map(|eval_offset| { + trace.step_view((step + (eval_offset * blowup as usize)) % trace_steps) }) .collect(); - Self::new(data, trace.table.width) - } -} - -impl From<&Table> for Frame { - fn from(value: &Table) -> Self { - Self::new(value.data.clone(), value.width) + Self::new(steps) } } diff --git a/provers/stark/src/prover.rs b/provers/stark/src/prover.rs index 48cd72c72..4363d7e03 100644 --- a/provers/stark/src/prover.rs +++ b/provers/stark/src/prover.rs @@ -113,7 +113,7 @@ pub trait IsStarkProver { } #[allow(clippy::type_complexity)] - fn interpolate_and_commit( + fn interpolate_and_commit( trace: &TraceTable, domain: &Domain, transcript: &mut impl IsStarkTranscript, @@ -124,6 +124,7 @@ pub trait IsStarkProver { Commitment, ) where + A: AIR, FieldElement: Serializable + Send + Sync, { let trace_polys = trace.compute_trace_polys(); @@ -138,7 +139,7 @@ pub trait IsStarkProver { } // Compute commitments [t_j]. - let lde_trace = TraceTable::from_columns(lde_trace_permuted); + let lde_trace = TraceTable::from_columns(lde_trace_permuted, A::STEP_SIZE); let (lde_trace_merkle_tree, lde_trace_merkle_root) = Self::batch_commit(&lde_trace.rows()); // >>>> Send commitments: [tâ±¼] @@ -177,17 +178,18 @@ pub trait IsStarkProver { .unwrap() } - fn round_1_randomized_air_with_preprocessing>( + fn round_1_randomized_air_with_preprocessing( air: &A, main_trace: &TraceTable, domain: &Domain, transcript: &mut impl IsStarkTranscript, ) -> Result, ProvingError> where + A: AIR, FieldElement: Serializable + Send + Sync, { let (mut trace_polys, mut evaluations, main_merkle_tree, main_merkle_root) = - Self::interpolate_and_commit(main_trace, domain, transcript); + Self::interpolate_and_commit::(main_trace, domain, transcript); let rap_challenges = air.build_rap_challenges(transcript); @@ -198,14 +200,14 @@ pub trait IsStarkProver { if !aux_trace.is_empty() { // Check that this is valid for interpolation let (aux_trace_polys, aux_trace_polys_evaluations, aux_merkle_tree, aux_merkle_root) = - Self::interpolate_and_commit(&aux_trace, domain, transcript); + Self::interpolate_and_commit::(&aux_trace, domain, transcript); trace_polys.extend_from_slice(&aux_trace_polys); evaluations.extend_from_slice(&aux_trace_polys_evaluations); lde_trace_merkle_trees.push(aux_merkle_tree); lde_trace_merkle_roots.push(aux_merkle_root); } - let lde_trace = TraceTable::from_columns(evaluations); + let lde_trace = TraceTable::from_columns(evaluations, A::STEP_SIZE); Ok(Round1 { trace_polys, diff --git a/provers/stark/src/table.rs b/provers/stark/src/table.rs index 8047a1ce8..2d584abf1 100644 --- a/provers/stark/src/table.rs +++ b/provers/stark/src/table.rs @@ -1,5 +1,7 @@ use lambdaworks_math::field::{element::FieldElement, traits::IsFFTField}; +use crate::{frame::Frame, trace::StepView}; + /// A two-dimensional Table holding field elements, arranged in a row-major order. /// This is the basic underlying data structure used for any two-dimensional component in the /// the STARK protocol implementation, such as the `TraceTable` and the `EvaluationFrame`. @@ -12,7 +14,7 @@ pub struct Table { pub height: usize, } -impl Table { +impl<'t, F: IsFFTField> Table { /// Crates a new Table instance from a one-dimensional array in row major order /// and the intended width of the table. pub fn new(data: Vec>, width: usize) -> Self { @@ -76,6 +78,20 @@ impl Table { &mut self.data[row_offset..row_offset + n_cols] } + /// Given a row index and a number of rows, returns a view of a subset of contiguous rows + /// of the table, starting from that index. + pub fn table_view(&'t self, from_idx: usize, num_rows: usize) -> TableView<'t, F> { + let from_offset = from_idx * self.width; + let data = &self.data[from_offset..from_offset + self.width * num_rows]; + + TableView { + data, + table_row_idx: from_idx, + width: self.width, + height: num_rows, + } + } + /// Given a slice of field elements representing a row, appends it to /// the end of the table. pub fn append_row(&mut self, row: &[FieldElement]) { @@ -102,8 +118,70 @@ impl Table { } /// Given row and column indexes, returns the stored field element in that position of the table. - pub fn get(&self, row: usize, col: usize) -> FieldElement { + pub fn get(&self, row: usize, col: usize) -> &FieldElement { let idx = row * self.width + col; - self.data[idx].clone() + &self.data[idx] + } + + /// Given a step size, converts the given table into a `Frame`. + pub fn into_frame(&'t self, step_size: usize) -> Frame<'t, F> { + debug_assert!(self.height % step_size == 0); + let steps = (0..self.height) + .step_by(step_size) + .enumerate() + .map(|(step_idx, row_idx)| { + let table_view = self.table_view(row_idx, step_size); + StepView::new(table_view, step_idx) + }) + .collect(); + + Frame::new(steps) + } +} + +/// A view of a contiguos subset of rows of a table. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct TableView<'t, F: IsFFTField> { + pub data: &'t [FieldElement], + pub table_row_idx: usize, + pub width: usize, + pub height: usize, +} + +impl<'t, F: IsFFTField> TableView<'t, F> { + pub fn new( + data: &'t [FieldElement], + table_row_idx: usize, + width: usize, + height: usize, + ) -> Self { + Self { + data, + width, + table_row_idx, + height, + } + } + + pub fn get(&self, row: usize, col: usize) -> &FieldElement { + let idx = row * self.width + col; + &self.data[idx] + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::Felt252; + + #[test] + fn get_rows_slice_works() { + let data: Vec = (0..=11).map(Felt252::from).collect(); + let table = Table::new(data, 3); + + let slice = table.table_view(1, 2); + let expected_data: Vec = (3..=8).map(Felt252::from).collect(); + + assert_eq!(slice.data, expected_data); } } diff --git a/provers/stark/src/trace.rs b/provers/stark/src/trace.rs index 3dfe52129..044feb686 100644 --- a/provers/stark/src/trace.rs +++ b/provers/stark/src/trace.rs @@ -1,4 +1,4 @@ -use crate::table::Table; +use crate::table::{Table, TableView}; use lambdaworks_math::fft::errors::FFTError; use lambdaworks_math::fft::polynomial::FFTPoly; use lambdaworks_math::{ @@ -13,25 +13,27 @@ use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; /// /// For the moment it is mostly a wrapper around the `Table` struct. It is a /// layer above the raw two-dimensional table, with functionality relevant to the -/// STARK protocol. +/// STARK protocol, such as the step size (number of consecutive rows of the table) +/// of the computation being proven. #[derive(Clone, Default, Debug, PartialEq, Eq)] pub struct TraceTable { pub table: Table, + pub step_size: usize, } -impl TraceTable { - pub fn new(data: Vec>, n_columns: usize) -> Self { +impl<'t, F: IsFFTField> TraceTable { + pub fn new(data: Vec>, n_columns: usize, step_size: usize) -> Self { let table = Table::new(data, n_columns); - Self { table } + Self { table, step_size } } - pub fn from_columns(columns: Vec>>) -> Self { + pub fn from_columns(columns: Vec>>, step_size: usize) -> Self { let table = Table::from_columns(columns); - Self { table } + Self { table, step_size } } pub fn empty() -> Self { - Self::new(Vec::new(), 0) + Self::new(Vec::new(), 0, 0) } pub fn is_empty(&self) -> bool { @@ -42,6 +44,28 @@ impl TraceTable { self.table.height } + pub fn num_steps(&self) -> usize { + debug_assert!((self.table.height % self.step_size) == 0); + self.table.height / self.step_size + } + + /// Given a particular step of the computation represented on the trace, + /// returns the row of the underlying table. + pub fn step_to_row(&self, step: usize) -> usize { + self.step_size * step + } + + /// Given a step index, return the step view of the trace for that index + pub fn step_view(&'t self, step_idx: usize) -> StepView<'t, F> { + let row_idx = self.step_to_row(step_idx); + let table_view = self.table.table_view(row_idx, self.step_size); + + StepView { + table_view, + step_idx, + } + } + pub fn n_cols(&self) -> usize { self.table.width } @@ -83,7 +107,7 @@ impl TraceTable { } /// Given a row and a column index, gives stored value in that position - pub fn get(&self, row: usize, col: usize) -> FieldElement { + pub fn get(&self, row: usize, col: usize) -> &FieldElement { self.table.get(row, col) } @@ -102,19 +126,6 @@ impl TraceTable { .unwrap() } - pub fn concatenate(&self, new_cols: Vec>, n_cols: usize) -> Self { - let mut data = Vec::new(); - let mut i = 0; - for row_index in (0..self.table.data.len()).step_by(self.table.width) { - data.append(&mut self.table.data[row_index..row_index + self.table.width].to_vec()); - data.append(&mut new_cols[i..(i + n_cols)].to_vec()); - i += n_cols; - } - - let table = Table::new(data, self.n_cols() + n_cols); - Self { table } - } - /// Given the padding length, appends the last row of the trace table /// that many times. /// This is useful for example when the desired trace length should be power @@ -150,6 +161,32 @@ impl TraceTable { } } +/// A view into a step of the trace. In general, a step over the trace +/// can be thought as a fixed size subset of trace rows +/// +/// The main purpose of this data structure is to have a way to +/// access the steps in a trace, in order to grab elements to calculate +/// constraint evaluations. +#[derive(Debug, Clone, PartialEq)] +pub struct StepView<'t, F: IsFFTField> { + pub table_view: TableView<'t, F>, + pub step_idx: usize, +} + +impl<'t, F: IsFFTField> StepView<'t, F> { + pub fn new(table_view: TableView<'t, F>, step_idx: usize) -> Self { + StepView { + table_view, + step_idx, + } + } + + /// Gets the evaluation element specified by `row_idx` and `col_idx` of this step + pub fn get_evaluation_element(&self, row_idx: usize, col_idx: usize) -> &FieldElement { + self.table_view.get(row_idx, col_idx) + } +} + /// Given a slice of trace polynomials, an evaluation point `x`, the frame offsets /// corresponding to the computation of the transitions, and a primitive root, /// outputs the trace evaluations of each trace polynomial over the values used to @@ -185,29 +222,9 @@ mod test { let col_1 = vec![FE::from(1), FE::from(2), FE::from(5), FE::from(13)]; let col_2 = vec![FE::from(1), FE::from(3), FE::from(8), FE::from(21)]; - let trace_table = TraceTable::from_columns(vec![col_1.clone(), col_2.clone()]); + let trace_table = TraceTable::from_columns(vec![col_1.clone(), col_2.clone()], 1); let res_cols = trace_table.columns(); assert_eq!(res_cols, vec![col_1, col_2]); } - - #[test] - fn test_concatenate_works() { - let table1_columns = vec![vec![FE::new(7), FE::new(8), FE::new(9)]]; - let new_columns = vec![ - FE::new(1), - FE::new(2), - FE::new(3), - FE::new(4), - FE::new(5), - FE::new(6), - ]; - let expected_table = TraceTable::from_columns(vec![ - vec![FE::new(7), FE::new(8), FE::new(9)], - vec![FE::new(1), FE::new(3), FE::new(5)], - vec![FE::new(2), FE::new(4), FE::new(6)], - ]); - let table1 = TraceTable::from_columns(table1_columns); - assert_eq!(table1.concatenate(new_columns, 2), expected_table) - } } diff --git a/provers/stark/src/traits.rs b/provers/stark/src/traits.rs index 251b89ea3..21d3cc476 100644 --- a/provers/stark/src/traits.rs +++ b/provers/stark/src/traits.rs @@ -18,6 +18,8 @@ pub trait AIR: Clone { type RAPChallenges; type PublicInputs; + const STEP_SIZE: usize; + fn new( trace_length: usize, pub_inputs: &Self::PublicInputs, diff --git a/provers/stark/src/verifier.rs b/provers/stark/src/verifier.rs index 11a4feb7a..ec94a1f0b 100644 --- a/provers/stark/src/verifier.rs +++ b/provers/stark/src/verifier.rs @@ -250,7 +250,7 @@ pub trait IsStarkVerifier { .fold(FieldElement::::zero(), |acc, x| acc + x); let transition_ood_frame_evaluations = air.compute_transition( - &(&proof.trace_ood_evaluations).into(), + &(proof.trace_ood_evaluations).into_frame(A::STEP_SIZE), &challenges.rap_challenges, );