Skip to content

Commit

Permalink
refactor: flag trace elements and constraints (#680)
Browse files Browse the repository at this point in the history
* Remove range check and builtin functionality

* Remove useless tests

* Apply clippy suggestions

* Remove useless commented code

* Remove legacy code

* Remove legacy commented code

* Start work over change of constraints

* Fix flag constraints

* add inline directive to into_bit_flag function

* Remove commented code

* remove more commented code

* Refactor flags representation functions

* Add documentation to function

* Add some comments

* Fix documentation

* Remove commented code

* Update hardcoded proof

* Fix clippy
  • Loading branch information
entropidelic authored Nov 17, 2023
1 parent a821fae commit 8677b06
Show file tree
Hide file tree
Showing 5 changed files with 1,347 additions and 1,260 deletions.
70 changes: 40 additions & 30 deletions provers/cairo/src/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ fn generate_memory_permutation_argument_column(
})
.collect::<Vec<Felt252>>()
}

fn generate_range_check_permutation_argument_column(
offset_column_original: &[Felt252],
offset_column_sorted: &[Felt252],
Expand Down Expand Up @@ -881,30 +882,31 @@ fn compute_instr_constraints(constraints: &mut [Felt252], frame: &Frame<Stark252
// These constraints are only applied over elements of the same row.
let curr = frame.get_evaluation_step(0);

// Bit-prefixes constraints.
// See section 9.4 of Cairo whitepaper https://eprint.iacr.org/2021/1063.pdf.
let flags: Vec<&Felt252> = (0..16)
.map(|col_idx| curr.get_evaluation_element(0, col_idx))
.collect();

// Bit constraints
for (i, flag) in flags.clone().into_iter().enumerate() {
constraints[i] = match i {
0..=14 => flag * (flag - Felt252::one()),
15 => *flag,
let two = Felt252::from(2);
(0..15).for_each(|idx| {
constraints[idx] = match idx {
0..=14 => {
(flags[idx] - two * flags[idx + 1])
* (flags[idx] - two * flags[idx + 1] - Felt252::one())
}
15 => *flags[idx],
_ => panic!("Unknown flag offset"),
};
}
}
});

// Instruction unpacking
let two = Felt252::from(2);
let b16 = two.pow(16u32);
let b32 = two.pow(32u32);
let b48 = two.pow(48u32);

// Named like this to match the Cairo whitepaper's notation.
let f0_squiggle = flags
.into_iter()
.rev()
.fold(Felt252::zero(), |acc, flag| flag + two * acc);
let f0_squiggle = flags[0];

let off_dst = curr.get_evaluation_element(0, OFF_DST);
let off_op0 = curr.get_evaluation_element(0, OFF_OP0);
Expand All @@ -922,17 +924,17 @@ fn compute_operand_constraints(constraints: &mut [Felt252], frame: &Frame<Stark2
let fp = curr.get_evaluation_element(0, FRAME_FP);
let pc = curr.get_evaluation_element(0, FRAME_PC);

let dst_fp = curr.get_evaluation_element(0, F_DST_FP);
let dst_fp = into_bit_flag(curr, F_DST_FP);
let off_dst = curr.get_evaluation_element(0, OFF_DST);
let dst_addr = curr.get_evaluation_element(0, FRAME_DST_ADDR);

let op0_fp = curr.get_evaluation_element(0, F_OP_0_FP);
let op0_fp = into_bit_flag(curr, F_OP_0_FP);
let off_op0 = curr.get_evaluation_element(0, OFF_OP0);
let op0_addr = curr.get_evaluation_element(0, FRAME_OP0_ADDR);

let op1_val = curr.get_evaluation_element(0, F_OP_1_VAL);
let op1_ap = curr.get_evaluation_element(0, F_OP_1_AP);
let op1_fp = curr.get_evaluation_element(0, F_OP_1_FP);
let op1_val = into_bit_flag(curr, F_OP_1_VAL);
let op1_ap = into_bit_flag(curr, F_OP_1_AP);
let op1_fp = into_bit_flag(curr, F_OP_1_FP);
let op0 = curr.get_evaluation_element(0, FRAME_OP0);
let off_op1 = curr.get_evaluation_element(0, OFF_OP1);
let op1_addr = curr.get_evaluation_element(0, FRAME_OP1_ADDR);
Expand All @@ -952,6 +954,14 @@ fn compute_operand_constraints(constraints: &mut [Felt252], frame: &Frame<Stark2
- op1_addr;
}

/// Given a step and the index of the bit-prefix format flag, gives the bit representation
/// of that flag, needed for the evaluation of some constraints.
#[inline(always)]
fn into_bit_flag(step: &StepView<Stark252PrimeField>, element_idx: usize) -> Felt252 {
step.get_evaluation_element(0, element_idx)
- Felt252::from(2) * step.get_evaluation_element(0, element_idx + 1)
}

fn compute_register_constraints(constraints: &mut [Felt252], frame: &Frame<Stark252PrimeField>) {
let curr = frame.get_evaluation_step(0);
let next = frame.get_evaluation_step(1);
Expand All @@ -961,25 +971,25 @@ fn compute_register_constraints(constraints: &mut [Felt252], frame: &Frame<Stark

let ap = curr.get_evaluation_element(0, FRAME_AP);
let next_ap = next.get_evaluation_element(0, FRAME_AP);
let ap_add = curr.get_evaluation_element(0, F_AP_ADD);
let ap_add = into_bit_flag(curr, F_AP_ADD);
let res = curr.get_evaluation_element(0, FRAME_RES);
let ap_one = curr.get_evaluation_element(0, F_AP_ONE);
let opc_call = curr.get_evaluation_element(0, F_OPC_CALL);
let ap_one = into_bit_flag(curr, F_AP_ONE);

let opc_ret = curr.get_evaluation_element(0, F_OPC_RET);
let opc_ret = into_bit_flag(curr, F_OPC_RET);
let opc_call = into_bit_flag(curr, F_OPC_CALL);
let dst = curr.get_evaluation_element(0, FRAME_DST);
let fp = curr.get_evaluation_element(0, FRAME_FP);
let next_fp = next.get_evaluation_element(0, FRAME_FP);

let t1 = curr.get_evaluation_element(0, FRAME_T1);
let pc_jnz = curr.get_evaluation_element(0, F_PC_JNZ);
let pc_jnz = into_bit_flag(curr, F_PC_JNZ);
let pc = curr.get_evaluation_element(0, FRAME_PC);
let next_pc = next.get_evaluation_element(0, FRAME_PC);

let t0 = curr.get_evaluation_element(0, FRAME_T0);
let op1 = curr.get_evaluation_element(0, FRAME_OP1);
let pc_abs = curr.get_evaluation_element(0, F_PC_ABS);
let pc_rel = curr.get_evaluation_element(0, F_PC_REL);
let pc_abs = into_bit_flag(curr, F_PC_ABS);
let pc_rel = into_bit_flag(curr, F_PC_REL);

// ap and fp constraints
constraints[NEXT_AP] = ap + ap_add * res + ap_one + opc_call * two - next_ap;
Expand Down Expand Up @@ -1007,17 +1017,17 @@ fn compute_opcode_constraints(constraints: &mut [Felt252], frame: &Frame<Stark25
let op0 = curr.get_evaluation_element(0, FRAME_OP0);
let op1 = curr.get_evaluation_element(0, FRAME_OP1);

let res_add = curr.get_evaluation_element(0, F_RES_ADD);
let res_mul = curr.get_evaluation_element(0, F_RES_MUL);
let pc_jnz = curr.get_evaluation_element(0, F_PC_JNZ);
let res_add = into_bit_flag(curr, F_RES_ADD);
let res_mul = into_bit_flag(curr, F_RES_MUL);
let pc_jnz = into_bit_flag(curr, F_PC_JNZ);
let res = curr.get_evaluation_element(0, FRAME_RES);

let opc_call = curr.get_evaluation_element(0, F_OPC_CALL);
let opc_call = into_bit_flag(curr, F_OPC_CALL);
let dst = curr.get_evaluation_element(0, FRAME_DST);
let fp = curr.get_evaluation_element(0, FRAME_FP);
let pc = curr.get_evaluation_element(0, FRAME_PC);

let opc_aeq = curr.get_evaluation_element(0, F_OPC_AEQ);
let opc_aeq = into_bit_flag(curr, F_OPC_AEQ);

constraints[MUL_1] = mul - op0 * op1;

Expand Down Expand Up @@ -1181,7 +1191,7 @@ fn permutation_argument_range_check(
}

fn frame_inst_size(step: &StepView<Stark252PrimeField>) -> Felt252 {
let op1_val = step.get_evaluation_element(0, F_OP_1_VAL);
let op1_val = into_bit_flag(step, F_OP_1_VAL);
op1_val + Felt252::one()
}

Expand Down
95 changes: 78 additions & 17 deletions provers/cairo/src/decode/instruction_flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@ impl CairoInstructionFlags {
/// represented by field elements: Felt252::zero() for bit 0 and
/// Felt252::one() for bit 1.
#[rustfmt::skip]
pub fn to_trace_representation(&self) -> [Felt252; 16] {
let b0 = self.dst_reg.to_trace_representation();
let b1 = self.op0_reg.to_trace_representation();
let [b2, b3, b4] = self.op1_src.to_trace_representation();
let [b5, b6] = self.res_logic.to_trace_representation();
let [b7, b8, b9] = self.pc_update.to_trace_representation();
let [b10, b11] = self.ap_update.to_trace_representation();
let [b12, b13, b14] = self.opcode.to_trace_representation();
pub fn to_bit_representation(&self) -> [Felt252; 16] {
let b0 = self.dst_reg.to_bit_representation();
let b1 = self.op0_reg.to_bit_representation();
let [b2, b3, b4] = self.op1_src.to_bit_representation();
let [b5, b6] = self.res_logic.to_bit_representation();
let [b7, b8, b9] = self.pc_update.to_bit_representation();
let [b10, b11] = self.ap_update.to_bit_representation();
let [b12, b13, b14] = self.opcode.to_bit_representation();

// In the paper, a little-endian format for the bit flags is
// mentioned. That is why they are arranged in this way (section 4.4
Expand All @@ -76,6 +76,11 @@ impl CairoInstructionFlags {
Felt252::zero(),
]
}

pub fn to_trace_representation(&self) -> [Felt252; 16] {
let bit_flags = self.to_bit_representation();
to_bit_prefixes(bit_flags)
}
}

#[derive(Debug, PartialEq, Eq, Clone)]
Expand All @@ -85,7 +90,7 @@ pub enum Op0Reg {
}

impl Op0Reg {
pub fn to_trace_representation(&self) -> Felt252 {
pub fn to_bit_representation(&self) -> Felt252 {
match self {
Op0Reg::AP => Felt252::zero(),
Op0Reg::FP => Felt252::one(),
Expand Down Expand Up @@ -117,7 +122,7 @@ pub enum DstReg {
FP = 1,
}
impl DstReg {
pub fn to_trace_representation(&self) -> Felt252 {
pub fn to_bit_representation(&self) -> Felt252 {
match self {
DstReg::AP => Felt252::zero(),
DstReg::FP => Felt252::one(),
Expand Down Expand Up @@ -151,7 +156,7 @@ pub enum Op1Src {
}

impl Op1Src {
pub fn to_trace_representation(&self) -> [Felt252; 3] {
pub fn to_bit_representation(&self) -> [Felt252; 3] {
match self {
Op1Src::Op0 => [Felt252::zero(), Felt252::zero(), Felt252::zero()],
Op1Src::Imm => [Felt252::zero(), Felt252::zero(), Felt252::one()],
Expand Down Expand Up @@ -188,7 +193,7 @@ pub enum ResLogic {
}

impl ResLogic {
pub fn to_trace_representation(&self) -> [Felt252; 2] {
pub fn to_bit_representation(&self) -> [Felt252; 2] {
match self {
ResLogic::Op1 => [Felt252::zero(), Felt252::zero()],
ResLogic::Add => [Felt252::zero(), Felt252::one()],
Expand Down Expand Up @@ -225,7 +230,7 @@ pub enum PcUpdate {
}

impl PcUpdate {
pub fn to_trace_representation(&self) -> [Felt252; 3] {
pub fn to_bit_representation(&self) -> [Felt252; 3] {
match self {
PcUpdate::Regular => [Felt252::zero(), Felt252::zero(), Felt252::zero()],
PcUpdate::Jump => [Felt252::zero(), Felt252::zero(), Felt252::one()],
Expand Down Expand Up @@ -262,7 +267,7 @@ pub enum ApUpdate {
}

impl ApUpdate {
pub fn to_trace_representation(&self) -> [Felt252; 2] {
pub fn to_bit_representation(&self) -> [Felt252; 2] {
match self {
ApUpdate::Regular => [Felt252::zero(), Felt252::zero()],
ApUpdate::Add => [Felt252::zero(), Felt252::one()],
Expand Down Expand Up @@ -314,7 +319,7 @@ pub enum CairoOpcode {
}

impl CairoOpcode {
pub fn to_trace_representation(&self) -> [Felt252; 3] {
pub fn to_bit_representation(&self) -> [Felt252; 3] {
match self {
CairoOpcode::NOp => [Felt252::zero(), Felt252::zero(), Felt252::zero()],
CairoOpcode::Call => [Felt252::zero(), Felt252::zero(), Felt252::one()],
Expand All @@ -341,6 +346,28 @@ impl TryFrom<&Felt252> for CairoOpcode {
}
}

fn to_bit_prefixes(bit_array: [Felt252; 16]) -> [Felt252; 16] {
let two = Felt252::from(2);
(0..bit_array.len())
.map(|i| {
bit_array
.iter()
.enumerate()
.fold(Felt252::zero(), |acc, (j, flag_j)| {
let sum_term = if j < i {
Felt252::zero()
} else {
let exponent = j - i;
two.pow(exponent) * flag_j
};
acc + sum_term
})
})
.collect::<Vec<_>>()
.try_into()
.unwrap()
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -671,7 +698,7 @@ mod tests {
}

#[test]
fn flags_trace_representation() {
fn flags_bit_representation() {
// Bit-trace representation for each flag:
// DstReg::FP = 1
// Op0Reg::FP = 1
Expand Down Expand Up @@ -703,8 +730,42 @@ mod tests {
Felt252::zero(),
];

let representation = flags.to_trace_representation();
let representation = flags.to_bit_representation();

assert_eq!(representation, expected_representation);
}

#[test]
fn to_bit_prefixes_all_zeros_works() {
let bit_array = [Felt252::zero(); 16];

let result = to_bit_prefixes(bit_array);
let expected = [Felt252::zero(); 16];

assert_eq!(result, expected);
}

#[test]
fn to_bit_prefixes_flag14_on_works() {
let mut bit_array = [Felt252::zero(); 16];
// We turn on only the 14th bit, which is the most significant.

bit_array[14] = Felt252::one();

let result = to_bit_prefixes(bit_array);
// The bit prefix value of f0 should be 16384, and the rest
// should be the bit prefixes of that number.
let expected: [Felt252; 16] = (0..16u32)
.map(|idx| {
if idx == 15 {
return Felt252::zero();
}
Felt252::from(16384) / Felt252::from(2).pow(idx)
})
.collect::<Vec<_>>()
.try_into()
.unwrap();

assert_eq!(result, expected);
}
}
Loading

0 comments on commit 8677b06

Please sign in to comment.