From e443c9378dbaccc7d85c7f8dfac846386306232f Mon Sep 17 00:00:00 2001 From: Joaquin Carletti <56092489+ColoCarletti@users.noreply.github.com> Date: Thu, 5 Dec 2024 15:35:23 -0300 Subject: [PATCH 1/2] Stark continuous read-only memory example (#940) * create file * continuity and single value constraint * imp air * permutation constraint * evaluate function for SingleValueConstraint * add last element constraint * add public inputs * add sort function for the trace * add integration test * fix clippy * fix constraints * add documentation * handle possible panic * rename variables * fix doc --------- Co-authored-by: Nicole Co-authored-by: jotabulacios Co-authored-by: Nicole Graus --- provers/stark/src/examples/mod.rs | 1 + .../stark/src/examples/read_only_memory.rs | 433 ++++++++++++++++++ provers/stark/src/tests/integration_tests.rs | 47 ++ 3 files changed, 481 insertions(+) create mode 100644 provers/stark/src/examples/read_only_memory.rs diff --git a/provers/stark/src/examples/mod.rs b/provers/stark/src/examples/mod.rs index 6a8949f7a..ba4f6586e 100644 --- a/provers/stark/src/examples/mod.rs +++ b/provers/stark/src/examples/mod.rs @@ -4,5 +4,6 @@ pub mod fibonacci_2_cols_shifted; pub mod fibonacci_2_columns; pub mod fibonacci_rap; pub mod quadratic_air; +pub mod read_only_memory; pub mod simple_fibonacci; pub mod simple_periodic_cols; diff --git a/provers/stark/src/examples/read_only_memory.rs b/provers/stark/src/examples/read_only_memory.rs new file mode 100644 index 000000000..8b5b01b07 --- /dev/null +++ b/provers/stark/src/examples/read_only_memory.rs @@ -0,0 +1,433 @@ +use std::marker::PhantomData; + +use crate::{ + constraints::{ + boundary::{BoundaryConstraint, BoundaryConstraints}, + transition::TransitionConstraint, + }, + context::AirContext, + frame::Frame, + proof::options::ProofOptions, + trace::TraceTable, + traits::AIR, +}; +use lambdaworks_crypto::fiat_shamir::is_transcript::IsTranscript; +use lambdaworks_math::field::traits::IsPrimeField; +use lambdaworks_math::{ + field::{element::FieldElement, traits::IsFFTField}, + traits::ByteConversion, +}; + +/// This condition ensures the continuity in a read-only memory structure, preserving strict ordering. +/// Equation based on Cairo Whitepaper section 9.7.2 +#[derive(Clone)] +struct ContinuityConstraint { + phantom: PhantomData, +} + +impl ContinuityConstraint { + pub fn new() -> Self { + Self { + phantom: PhantomData, + } + } +} + +impl TransitionConstraint for ContinuityConstraint +where + F: IsFFTField + Send + Sync, +{ + fn degree(&self) -> usize { + 2 + } + + fn constraint_idx(&self) -> usize { + 0 + } + + fn end_exemptions(&self) -> usize { + // NOTE: We are assuming that the trace has as length a power of 2. + 1 + } + + fn evaluate( + &self, + frame: &Frame, + transition_evaluations: &mut [FieldElement], + _periodic_values: &[FieldElement], + _rap_challenges: &[FieldElement], + ) { + let first_step = frame.get_evaluation_step(0); + let second_step = frame.get_evaluation_step(1); + + let a_sorted_0 = first_step.get_main_evaluation_element(0, 2); + let a_sorted_1 = second_step.get_main_evaluation_element(0, 2); + // (a'_{i+1} - a'_i)(a'_{i+1} - a'_i - 1) = 0 where a' is the sorted address + let res = (a_sorted_1 - a_sorted_0) * (a_sorted_1 - a_sorted_0 - FieldElement::::one()); + + // The eval always exists, except if the constraint idx were incorrectly defined. + if let Some(eval) = transition_evaluations.get_mut(self.constraint_idx()) { + *eval = res; + } + } +} +/// Transition constraint that ensures that same addresses have same values, making the memory read-only. +/// Equation based on Cairo Whitepaper section 9.7.2 +#[derive(Clone)] +struct SingleValueConstraint { + phantom: PhantomData, +} + +impl SingleValueConstraint { + pub fn new() -> Self { + Self { + phantom: PhantomData, + } + } +} + +impl TransitionConstraint for SingleValueConstraint +where + F: IsFFTField + Send + Sync, +{ + fn degree(&self) -> usize { + 2 + } + + fn constraint_idx(&self) -> usize { + 1 + } + + fn end_exemptions(&self) -> usize { + // NOTE: We are assuming that the trace has as length a power of 2. + 1 + } + + fn evaluate( + &self, + frame: &Frame, + transition_evaluations: &mut [FieldElement], + _periodic_values: &[FieldElement], + _rap_challenges: &[FieldElement], + ) { + let first_step = frame.get_evaluation_step(0); + let second_step = frame.get_evaluation_step(1); + + let a_sorted0 = first_step.get_main_evaluation_element(0, 2); + let a_sorted1 = second_step.get_main_evaluation_element(0, 2); + let v_sorted0 = first_step.get_main_evaluation_element(0, 3); + let v_sorted1 = second_step.get_main_evaluation_element(0, 3); + // (v'_{i+1} - v'_i) * (a'_{i+1} - a'_i - 1) = 0 + let res = (v_sorted1 - v_sorted0) * (a_sorted1 - a_sorted0 - FieldElement::::one()); + + // The eval always exists, except if the constraint idx were incorrectly defined. + if let Some(eval) = transition_evaluations.get_mut(self.constraint_idx()) { + *eval = res; + } + } +} +/// Permutation constraint ensures that the values are permuted in the memory. +/// Equation based on Cairo Whitepaper section 9.7.2 +#[derive(Clone)] +struct PermutationConstraint { + phantom: PhantomData, +} + +impl PermutationConstraint { + pub fn new() -> Self { + Self { + phantom: PhantomData, + } + } +} + +impl TransitionConstraint for PermutationConstraint +where + F: IsFFTField + Send + Sync, +{ + fn degree(&self) -> usize { + 2 + } + + fn constraint_idx(&self) -> usize { + 2 + } + + fn end_exemptions(&self) -> usize { + 1 + } + + fn evaluate( + &self, + frame: &Frame, + transition_evaluations: &mut [FieldElement], + _periodic_values: &[FieldElement], + rap_challenges: &[FieldElement], + ) { + let first_step = frame.get_evaluation_step(0); + let second_step = frame.get_evaluation_step(1); + + // Auxiliary constraints + let p0 = first_step.get_aux_evaluation_element(0, 0); + let p1 = second_step.get_aux_evaluation_element(0, 0); + let z = &rap_challenges[0]; + let alpha = &rap_challenges[1]; + let a1 = second_step.get_main_evaluation_element(0, 0); + let v1 = second_step.get_main_evaluation_element(0, 1); + let a_sorted_1 = second_step.get_main_evaluation_element(0, 2); + let v_sorted_1 = second_step.get_main_evaluation_element(0, 3); + // (z - (a'_{i+1} + α * v'_{i+1})) * p_{i+1} = (z - (a_{i+1} + α * v_{i+1})) * p_i + let res = (z - (a_sorted_1 + alpha * v_sorted_1)) * p1 - (z - (a1 + alpha * v1)) * p0; + + // The eval always exists, except if the constraint idx were incorrectly defined. + if let Some(eval) = transition_evaluations.get_mut(self.constraint_idx()) { + *eval = res; + } + } +} + +pub struct ReadOnlyRAP +where + F: IsFFTField, +{ + context: AirContext, + trace_length: usize, + pub_inputs: ReadOnlyPublicInputs, + transition_constraints: Vec>>, +} + +#[derive(Clone, Debug)] +pub struct ReadOnlyPublicInputs +where + F: IsFFTField, +{ + pub a0: FieldElement, + pub v0: FieldElement, + pub a_sorted0: FieldElement, + pub v_sorted0: FieldElement, +} + +impl AIR for ReadOnlyRAP +where + F: IsFFTField + Send + Sync + 'static, + FieldElement: ByteConversion, +{ + type Field = F; + type FieldExtension = F; + type PublicInputs = ReadOnlyPublicInputs; + + const STEP_SIZE: usize = 1; + + fn new( + trace_length: usize, + pub_inputs: &Self::PublicInputs, + proof_options: &ProofOptions, + ) -> Self { + let transition_constraints: Vec< + Box>, + > = vec![ + Box::new(ContinuityConstraint::new()), + Box::new(SingleValueConstraint::new()), + Box::new(PermutationConstraint::new()), + ]; + + let context = AirContext { + proof_options: proof_options.clone(), + trace_columns: 5, + transition_offsets: vec![0, 1], + num_transition_constraints: transition_constraints.len(), + }; + + Self { + context, + trace_length, + pub_inputs: pub_inputs.clone(), + transition_constraints, + } + } + + fn build_auxiliary_trace( + &self, + trace: &mut TraceTable, + challenges: &[FieldElement], + ) { + let main_segment_cols = trace.columns_main(); + let a = &main_segment_cols[0]; + let v = &main_segment_cols[1]; + let a_sorted = &main_segment_cols[2]; + let v_sorted = &main_segment_cols[3]; + let z = &challenges[0]; + let alpha = &challenges[1]; + + let trace_len = trace.num_rows(); + + let mut aux_col = Vec::new(); + let num = z - (&a[0] + alpha * &v[0]); + let den = z - (&a_sorted[0] + alpha * &v_sorted[0]); + aux_col.push(num / den); + // Apply the same equation given in the permutation case to the rest of the trace + for i in 0..trace_len - 1 { + let num = (z - (&a[i + 1] + alpha * &v[i + 1])) * &aux_col[i]; + let den = z - (&a_sorted[i + 1] + alpha * &v_sorted[i + 1]); + aux_col.push(num / den); + } + + for (i, aux_elem) in aux_col.iter().enumerate().take(trace.num_rows()) { + trace.set_aux(i, 0, aux_elem.clone()) + } + } + + fn build_rap_challenges( + &self, + transcript: &mut impl IsTranscript, + ) -> Vec> { + vec![ + transcript.sample_field_element(), + transcript.sample_field_element(), + ] + } + + fn trace_layout(&self) -> (usize, usize) { + (4, 1) + } + + fn boundary_constraints( + &self, + rap_challenges: &[FieldElement], + ) -> BoundaryConstraints { + let a0 = &self.pub_inputs.a0; + let v0 = &self.pub_inputs.v0; + let a_sorted0 = &self.pub_inputs.a_sorted0; + let v_sorted0 = &self.pub_inputs.v_sorted0; + let z = &rap_challenges[0]; + let alpha = &rap_challenges[1]; + + // Main boundary constraints + let c1 = BoundaryConstraint::new_main(0, 0, a0.clone()); + let c2 = BoundaryConstraint::new_main(1, 0, v0.clone()); + let c3 = BoundaryConstraint::new_main(2, 0, a_sorted0.clone()); + let c4 = BoundaryConstraint::new_main(3, 0, v_sorted0.clone()); + + // Auxiliary boundary constraints + let num = z - (a0 + alpha * v0); + let den = z - (a_sorted0 + alpha * v_sorted0); + let p0_value = num / den; + + let c_aux1 = BoundaryConstraint::new_aux(0, 0, p0_value); + let c_aux2 = BoundaryConstraint::new_aux( + 0, + self.trace_length - 1, + FieldElement::::one(), + ); + + BoundaryConstraints::from_constraints(vec![c1, c2, c3, c4, c_aux1, c_aux2]) + } + + fn transition_constraints( + &self, + ) -> &Vec>> { + &self.transition_constraints + } + + fn context(&self) -> &AirContext { + &self.context + } + + fn composition_poly_degree_bound(&self) -> usize { + self.trace_length() + } + + fn trace_length(&self) -> usize { + self.trace_length + } + + fn pub_inputs(&self) -> &Self::PublicInputs { + &self.pub_inputs + } + + fn compute_transition_verifier( + &self, + frame: &Frame, + periodic_values: &[FieldElement], + rap_challenges: &[FieldElement], + ) -> Vec> { + self.compute_transition_prover(frame, periodic_values, rap_challenges) + } +} + +/// Given the adress and value columns, it returns the trace table with 5 columns, which are: +/// Addres, Value, Adress Sorted, Value Sorted and a Column of Zeroes (where we'll insert the auxiliary colunn). +pub fn sort_rap_trace( + address: Vec>, + value: Vec>, +) -> TraceTable { + let mut address_value_pairs: Vec<_> = address.iter().zip(value.iter()).collect(); + + address_value_pairs.sort_by_key(|(addr, _)| addr.representative()); + + let (sorted_address, sorted_value): (Vec>, Vec>) = + address_value_pairs + .into_iter() + .map(|(addr, val)| (addr.clone(), val.clone())) + .unzip(); + let main_columns = vec![address.clone(), value.clone(), sorted_address, sorted_value]; + // create a vector with zeros of the same length as the main columns + let zero_vec = vec![FieldElement::::zero(); main_columns[0].len()]; + TraceTable::from_columns(main_columns, vec![zero_vec], 1) +} + +#[cfg(test)] +mod test { + use super::*; + use lambdaworks_math::field::fields::u64_prime_field::FE17; + + #[test] + fn test_sort_rap_trace() { + let address_col = vec![ + FE17::from(5), + FE17::from(2), + FE17::from(3), + FE17::from(4), + FE17::from(1), + FE17::from(6), + FE17::from(7), + FE17::from(8), + ]; + let value_col = vec![ + FE17::from(50), + FE17::from(20), + FE17::from(30), + FE17::from(40), + FE17::from(10), + FE17::from(60), + FE17::from(70), + FE17::from(80), + ]; + + let sorted_trace = sort_rap_trace(address_col.clone(), value_col.clone()); + + let expected_sorted_addresses = vec![ + FE17::from(1), + FE17::from(2), + FE17::from(3), + FE17::from(4), + FE17::from(5), + FE17::from(6), + FE17::from(7), + FE17::from(8), + ]; + let expected_sorted_values = vec![ + FE17::from(10), + FE17::from(20), + FE17::from(30), + FE17::from(40), + FE17::from(50), + FE17::from(60), + FE17::from(70), + FE17::from(80), + ]; + + assert_eq!(sorted_trace.columns_main()[2], expected_sorted_addresses); + assert_eq!(sorted_trace.columns_main()[3], expected_sorted_values); + } +} diff --git a/provers/stark/src/tests/integration_tests.rs b/provers/stark/src/tests/integration_tests.rs index c7f2f6a4c..7513caad0 100644 --- a/provers/stark/src/tests/integration_tests.rs +++ b/provers/stark/src/tests/integration_tests.rs @@ -10,6 +10,7 @@ use crate::{ fibonacci_2_columns::{self, Fibonacci2ColsAIR}, fibonacci_rap::{fibonacci_rap_trace, FibonacciRAP, FibonacciRAPPublicInputs}, quadratic_air::{self, QuadraticAIR, QuadraticPublicInputs}, + read_only_memory::{sort_rap_trace, ReadOnlyPublicInputs, ReadOnlyRAP}, simple_fibonacci::{self, FibonacciAIR, FibonacciPublicInputs}, simple_periodic_cols::{self, SimplePeriodicAIR, SimplePeriodicPublicInputs}, // simple_periodic_cols::{self, SimplePeriodicAIR, SimplePeriodicPublicInputs}, }, @@ -247,3 +248,49 @@ fn test_prove_bit_flags() { StoneProverTranscript::new(&[]), )); } + +#[test_log::test] +fn test_prove_read_only_memory() { + let address_col = vec![ + FieldElement::::from(3), // a0 + FieldElement::::from(2), // a1 + FieldElement::::from(2), // a2 + FieldElement::::from(3), // a3 + FieldElement::::from(4), // a4 + FieldElement::::from(5), // a5 + FieldElement::::from(1), // a6 + FieldElement::::from(3), // a7 + ]; + let value_col = vec![ + FieldElement::::from(10), // v0 + FieldElement::::from(5), // v1 + FieldElement::::from(5), // v2 + FieldElement::::from(10), // v3 + FieldElement::::from(25), // v4 + FieldElement::::from(25), // v5 + FieldElement::::from(7), // v6 + FieldElement::::from(10), // v7 + ]; + + let pub_inputs = ReadOnlyPublicInputs { + a0: FieldElement::::from(3), + v0: FieldElement::::from(10), + a_sorted0: FieldElement::::from(1), // a6 + v_sorted0: FieldElement::::from(7), // v6 + }; + let mut trace = sort_rap_trace(address_col, value_col); + let proof_options = ProofOptions::default_test_options(); + let proof = Prover::>::prove( + &mut trace, + &pub_inputs, + &proof_options, + StoneProverTranscript::new(&[]), + ) + .unwrap(); + assert!(Verifier::>::verify( + &proof, + &pub_inputs, + &proof_options, + StoneProverTranscript::new(&[]) + )); +} From 7b5a638d4ce81f380ea5f43a22be41ef9b2d7ff2 Mon Sep 17 00:00:00 2001 From: Nicole Graus Date: Wed, 11 Dec 2024 18:13:02 -0300 Subject: [PATCH 2/2] Baby bear extension (#942) * wip * add byte conversion for quartic * fft tests for baby bear quartic extension working * add test / add comments * fix typo * fix clippy * add test inv of zero error * fix fmt * fix clippy and doc * resolve PR comments * remove commented code --------- Co-authored-by: Joaquin Carletti Co-authored-by: Nicole Co-authored-by: jotabulacios Co-authored-by: jotabulacios <45471455+jotabulacios@users.noreply.github.com> Co-authored-by: Diego K <43053772+diegokingston@users.noreply.github.com> --- math/src/field/fields/fft_friendly/mod.rs | 2 + .../fields/fft_friendly/quartic_babybear.rs | 569 ++++++++++++++++++ 2 files changed, 571 insertions(+) create mode 100644 math/src/field/fields/fft_friendly/quartic_babybear.rs diff --git a/math/src/field/fields/fft_friendly/mod.rs b/math/src/field/fields/fft_friendly/mod.rs index 535b94ecc..7ba6a0943 100644 --- a/math/src/field/fields/fft_friendly/mod.rs +++ b/math/src/field/fields/fft_friendly/mod.rs @@ -2,6 +2,8 @@ pub mod babybear; /// Implemenation of the quadratic extension of the babybear field pub mod quadratic_babybear; +/// Implemenation of the quadric extension of the babybear field +pub mod quartic_babybear; /// Implementation of the prime field used in [Stark101](https://starkware.co/stark-101/) tutorial, p = 3 * 2^30 + 1 pub mod stark_101_prime_field; /// Implementation of two-adic prime field over 256 bit unsigned integers. diff --git a/math/src/field/fields/fft_friendly/quartic_babybear.rs b/math/src/field/fields/fft_friendly/quartic_babybear.rs new file mode 100644 index 000000000..361de0e0b --- /dev/null +++ b/math/src/field/fields/fft_friendly/quartic_babybear.rs @@ -0,0 +1,569 @@ +use crate::field::{ + element::FieldElement, + errors::FieldError, + fields::fft_friendly::babybear::Babybear31PrimeField, + traits::{IsFFTField, IsField, IsSubFieldOf}, +}; + +#[cfg(feature = "lambdaworks-serde-binary")] +use crate::traits::ByteConversion; + +/// We are implementig the extension of Baby Bear of degree 4 using the irreducible polynomial x^4 + 11. +/// BETA = 11 and -BETA = -11 is the non-residue. +pub const BETA: FieldElement = + FieldElement::::from_hex_unchecked("b"); + +#[derive(Clone, Debug)] +pub struct Degree4BabyBearExtensionField; + +impl IsField for Degree4BabyBearExtensionField { + type BaseType = [FieldElement; 4]; + + fn add(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + [&a[0] + &b[0], &a[1] + &b[1], &a[2] + &b[2], &a[3] + &b[3]] + } + + /// Result of multiplying two polynomials a = a0 + a1 * x + a2 * x^2 + a3 * x^3 and + /// b = b0 + b1 * x + b2 * x^2 + b3 * x^3 by applying distribution and taking + /// the remainder of the division by x^4 + 11. + fn mul(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + [ + &a[0] * &b[0] - BETA * (&a[1] * &b[3] + &a[3] * &b[1] + &a[2] * &b[2]), + &a[0] * &b[1] + &a[1] * &b[0] - BETA * (&a[2] * &b[3] + &a[3] * &b[2]), + &a[0] * &b[2] + &a[2] * &b[0] + &a[1] * &b[1] - BETA * (&a[3] * &b[3]), + &a[0] * &b[3] + &a[3] * &b[0] + &a[1] * &b[2] + &a[2] * &b[1], + ] + } + + fn square(a: &Self::BaseType) -> Self::BaseType { + [ + &a[0].square() - BETA * ((&a[1] * &a[3]).double() + &a[2].square()), + (&a[0] * &a[1] - BETA * (&a[2] * &a[3])).double(), + (&a[0] * &a[2]).double() + &a[1].square() - BETA * (&a[3].square()), + (&a[0] * &a[3] + &a[1] * &a[2]).double(), + ] + } + + fn sub(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + [&a[0] - &b[0], &a[1] - &b[1], &a[2] - &b[2], &a[3] - &b[3]] + } + + fn neg(a: &Self::BaseType) -> Self::BaseType { + [-&a[0], -&a[1], -&a[2], -&a[3]] + } + + /// Return te inverse of a fp4 element if exist. + /// This algorithm is inspired by Risc0 implementation: + /// + fn inv(a: &Self::BaseType) -> Result { + let mut b0 = &a[0] * &a[0] + BETA * (&a[1] * (&a[3] + &a[3]) - &a[2] * &a[2]); + let mut b2 = &a[0] * (&a[2] + &a[2]) - &a[1] * &a[1] + BETA * (&a[3] * &a[3]); + let c = &b0.square() + BETA * b2.square(); + let c_inv = c.inv()?; + b0 *= &c_inv; + b2 *= &c_inv; + Ok([ + &a[0] * &b0 + BETA * &a[2] * &b2, + -&a[1] * &b0 - BETA * &a[3] * &b2, + -&a[0] * &b2 + &a[2] * &b0, + &a[1] * &b2 - &a[3] * &b0, + ]) + } + + fn div(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + ::mul(a, &Self::inv(b).unwrap()) + } + + fn eq(a: &Self::BaseType, b: &Self::BaseType) -> bool { + a[0] == b[0] && a[1] == b[1] && a[2] == b[2] && a[3] == b[3] + } + + fn zero() -> Self::BaseType { + Self::BaseType::default() + } + + fn one() -> Self::BaseType { + [ + FieldElement::one(), + FieldElement::zero(), + FieldElement::zero(), + FieldElement::zero(), + ] + } + + fn from_u64(x: u64) -> Self::BaseType { + [ + FieldElement::from(x), + FieldElement::zero(), + FieldElement::zero(), + FieldElement::zero(), + ] + } + + /// Takes as input an element of BaseType and returns the internal representation + /// of that element in the field. + /// Note: for this case this is simply the identity, because the components + /// already have correct representations. + fn from_base_type(x: Self::BaseType) -> Self::BaseType { + x + } + + fn double(a: &Self::BaseType) -> Self::BaseType { + ::add(a, a) + } + + fn pow(a: &Self::BaseType, mut exponent: T) -> Self::BaseType + where + T: crate::unsigned_integer::traits::IsUnsignedInteger, + { + let zero = T::from(0); + let one = T::from(1); + + if exponent == zero { + return Self::one(); + } + if exponent == one { + return a.clone(); + } + + let mut result = a.clone(); + + // Fast path for powers of 2 + while exponent & one == zero { + result = Self::square(&result); + exponent >>= 1; + if exponent == zero { + return result; + } + } + + let mut base = result.clone(); + exponent >>= 1; + + while exponent != zero { + base = Self::square(&base); + if exponent & one == one { + result = ::mul(&result, &base); + } + exponent >>= 1; + } + + result + } +} + +impl IsSubFieldOf for Babybear31PrimeField { + fn mul( + a: &Self::BaseType, + b: &::BaseType, + ) -> ::BaseType { + let c0 = FieldElement::from_raw(::mul(a, b[0].value())); + let c1 = FieldElement::from_raw(::mul(a, b[1].value())); + let c2 = FieldElement::from_raw(::mul(a, b[2].value())); + let c3 = FieldElement::from_raw(::mul(a, b[3].value())); + + [c0, c1, c2, c3] + } + + fn add( + a: &Self::BaseType, + b: &::BaseType, + ) -> ::BaseType { + let c0 = FieldElement::from_raw(::add(a, b[0].value())); + let c1 = FieldElement::from_raw(*b[1].value()); + let c2 = FieldElement::from_raw(*b[2].value()); + let c3 = FieldElement::from_raw(*b[3].value()); + + [c0, c1, c2, c3] + } + + fn div( + a: &Self::BaseType, + b: &::BaseType, + ) -> ::BaseType { + let b_inv = Degree4BabyBearExtensionField::inv(b).unwrap(); + >::mul(a, &b_inv) + } + + fn sub( + a: &Self::BaseType, + b: &::BaseType, + ) -> ::BaseType { + let c0 = FieldElement::from_raw(::sub(a, b[0].value())); + let c1 = FieldElement::from_raw(::neg(b[1].value())); + let c2 = FieldElement::from_raw(::neg(b[2].value())); + let c3 = FieldElement::from_raw(::neg(b[3].value())); + [c0, c1, c2, c3] + } + + fn embed(a: Self::BaseType) -> ::BaseType { + [ + FieldElement::from_raw(a), + FieldElement::zero(), + FieldElement::zero(), + FieldElement::zero(), + ] + } + + #[cfg(feature = "alloc")] + fn to_subfield_vec( + b: ::BaseType, + ) -> alloc::vec::Vec { + b.into_iter().map(|x| x.to_raw()).collect() + } +} + +#[cfg(feature = "lambdaworks-serde-binary")] +impl ByteConversion for [FieldElement; 4] { + #[cfg(feature = "alloc")] + fn to_bytes_be(&self) -> alloc::vec::Vec { + let mut byte_slice = ByteConversion::to_bytes_be(&self[0]); + byte_slice.extend(ByteConversion::to_bytes_be(&self[1])); + byte_slice.extend(ByteConversion::to_bytes_be(&self[2])); + byte_slice.extend(ByteConversion::to_bytes_be(&self[3])); + byte_slice + } + + #[cfg(feature = "alloc")] + fn to_bytes_le(&self) -> alloc::vec::Vec { + let mut byte_slice = ByteConversion::to_bytes_le(&self[0]); + byte_slice.extend(ByteConversion::to_bytes_le(&self[1])); + byte_slice.extend(ByteConversion::to_bytes_le(&self[2])); + byte_slice.extend(ByteConversion::to_bytes_le(&self[3])); + byte_slice + } + + fn from_bytes_be(bytes: &[u8]) -> Result + where + Self: Sized, + { + const BYTES_PER_FIELD: usize = 64; + + let x0 = FieldElement::from_bytes_be(&bytes[0..BYTES_PER_FIELD])?; + let x1 = FieldElement::from_bytes_be(&bytes[BYTES_PER_FIELD..BYTES_PER_FIELD * 2])?; + let x2 = FieldElement::from_bytes_be(&bytes[BYTES_PER_FIELD * 2..BYTES_PER_FIELD * 3])?; + let x3 = FieldElement::from_bytes_be(&bytes[BYTES_PER_FIELD * 3..BYTES_PER_FIELD * 4])?; + + Ok([x0, x1, x2, x3]) + } + + fn from_bytes_le(bytes: &[u8]) -> Result + where + Self: Sized, + { + const BYTES_PER_FIELD: usize = 64; + + let x0 = FieldElement::from_bytes_le(&bytes[0..BYTES_PER_FIELD])?; + let x1 = FieldElement::from_bytes_le(&bytes[BYTES_PER_FIELD..BYTES_PER_FIELD * 2])?; + let x2 = FieldElement::from_bytes_le(&bytes[BYTES_PER_FIELD * 2..BYTES_PER_FIELD * 3])?; + let x3 = FieldElement::from_bytes_le(&bytes[BYTES_PER_FIELD * 3..BYTES_PER_FIELD * 4])?; + + Ok([x0, x1, x2, x3]) + } +} + +impl IsFFTField for Degree4BabyBearExtensionField { + const TWO_ADICITY: u64 = 29; + const TWO_ADIC_PRIMITVE_ROOT_OF_UNITY: Self::BaseType = [ + FieldElement::from_hex_unchecked("0"), + FieldElement::from_hex_unchecked("0"), + FieldElement::from_hex_unchecked("0"), + FieldElement::from_hex_unchecked("771F1C8"), + ]; +} + +#[cfg(test)] +mod tests { + use super::*; + + type FpE = FieldElement; + type Fp4E = FieldElement; + + #[test] + fn test_add() { + let a = Fp4E::new([FpE::from(0), FpE::from(1), FpE::from(2), FpE::from(3)]); + let b = Fp4E::new([-FpE::from(2), FpE::from(4), FpE::from(6), -FpE::from(8)]); + let expected_result = Fp4E::new([ + FpE::from(0) - FpE::from(2), + FpE::from(1) + FpE::from(4), + FpE::from(2) + FpE::from(6), + FpE::from(3) - FpE::from(8), + ]); + assert_eq!(a + b, expected_result); + } + + #[test] + fn test_sub() { + let a = Fp4E::new([FpE::from(0), FpE::from(1), FpE::from(2), FpE::from(3)]); + let b = Fp4E::new([-FpE::from(2), FpE::from(4), FpE::from(6), -FpE::from(8)]); + let expected_result = Fp4E::new([ + FpE::from(0) + FpE::from(2), + FpE::from(1) - FpE::from(4), + FpE::from(2) - FpE::from(6), + FpE::from(3) + FpE::from(8), + ]); + assert_eq!(a - b, expected_result); + } + + #[test] + fn test_mul_by_0() { + let a = Fp4E::new([FpE::from(4), FpE::from(1), FpE::from(2), FpE::from(3)]); + let b = Fp4E::new([FpE::zero(), FpE::zero(), FpE::zero(), FpE::zero()]); + assert_eq!(&a * &b, b); + } + + #[test] + fn test_mul_by_1() { + let a = Fp4E::new([FpE::from(4), FpE::from(1), FpE::from(2), FpE::from(3)]); + let b = Fp4E::new([FpE::one(), FpE::zero(), FpE::zero(), FpE::zero()]); + assert_eq!(&a * b, a); + } + + #[test] + fn test_mul() { + let a = Fp4E::new([FpE::from(0), FpE::from(1), FpE::from(2), FpE::from(3)]); + let b = Fp4E::new([FpE::from(2), FpE::from(4), FpE::from(6), FpE::from(8)]); + let expected_result = Fp4E::new([ + -FpE::from(352), + -FpE::from(372), + -FpE::from(256), + FpE::from(20), + ]); + assert_eq!(a * b, expected_result); + } + + #[test] + fn test_pow() { + let a = Fp4E::new([FpE::from(0), FpE::from(1), FpE::from(2), FpE::from(3)]); + let expected_result = &a * &a * &a; + assert_eq!(a.pow(3u64), expected_result); + } + + #[test] + fn test_inv_of_one_is_one() { + let a = Fp4E::one(); + assert_eq!(a.inv().unwrap(), a); + } + + #[test] + fn test_inv_of_zero_error() { + let result = Fp4E::zero().inv(); + assert!(result.is_err()); + } + + #[test] + fn test_mul_by_inv_is_identity() { + let a = Fp4E::from(123456); + assert_eq!(&a * a.inv().unwrap(), Fp4E::one()); + } + + #[test] + fn test_mul_as_subfield() { + let a = FpE::from(2); + let b = Fp4E::new([FpE::from(2), FpE::from(4), FpE::from(6), FpE::from(8)]); + let expected_result = Fp4E::new([ + FpE::from(2) * FpE::from(2), + FpE::from(4) * FpE::from(2), + FpE::from(6) * FpE::from(2), + FpE::from(8) * FpE::from(2), + ]); + assert_eq!(a * b, expected_result); + } + + #[test] + fn test_double_equals_sum_two_times() { + let a = Fp4E::new([FpE::from(2), FpE::from(4), FpE::from(6), FpE::from(8)]); + + assert_eq!(a.double(), &a + &a); + } + + #[test] + fn test_mul_group_generator_pow_order_is_one() { + let generator = Fp4E::new([FpE::from(8), FpE::from(1), FpE::zero(), FpE::zero()]); + let extension_order: u128 = 2013265921_u128.pow(4); + assert_eq!(generator.pow(extension_order), generator); + } + + #[test] + fn test_two_adic_primitve_root_of_unity() { + let generator = Fp4E::new(Degree4BabyBearExtensionField::TWO_ADIC_PRIMITVE_ROOT_OF_UNITY); + assert_eq!( + generator.pow(2u64.pow(Degree4BabyBearExtensionField::TWO_ADICITY as u32)), + Fp4E::one() + ); + } + + #[cfg(all(feature = "std", not(feature = "instruments")))] + mod test_babybear_31_fft { + use super::*; + #[cfg(not(any(feature = "metal", feature = "cuda")))] + use crate::fft::cpu::roots_of_unity::{ + get_powers_of_primitive_root, get_powers_of_primitive_root_coset, + }; + #[cfg(not(any(feature = "metal", feature = "cuda")))] + use crate::field::element::FieldElement; + #[cfg(not(any(feature = "metal", feature = "cuda")))] + use crate::field::traits::{IsFFTField, RootsConfig}; + use crate::polynomial::Polynomial; + use proptest::{collection, prelude::*, std_facade::Vec}; + + #[cfg(not(any(feature = "metal", feature = "cuda")))] + fn gen_fft_and_naive_evaluation( + poly: Polynomial>, + ) -> (Vec>, Vec>) { + let len = poly.coeff_len().next_power_of_two(); + let order = len.trailing_zeros(); + let twiddles = + get_powers_of_primitive_root(order.into(), len, RootsConfig::Natural).unwrap(); + + let fft_eval = Polynomial::evaluate_fft::(&poly, 1, None).unwrap(); + let naive_eval = poly.evaluate_slice(&twiddles); + + (fft_eval, naive_eval) + } + + #[cfg(not(any(feature = "metal", feature = "cuda")))] + fn gen_fft_coset_and_naive_evaluation( + poly: Polynomial>, + offset: FieldElement, + blowup_factor: usize, + ) -> (Vec>, Vec>) { + let len = poly.coeff_len().next_power_of_two(); + let order = (len * blowup_factor).trailing_zeros(); + let twiddles = + get_powers_of_primitive_root_coset(order.into(), len * blowup_factor, &offset) + .unwrap(); + + let fft_eval = + Polynomial::evaluate_offset_fft::(&poly, blowup_factor, None, &offset).unwrap(); + let naive_eval = poly.evaluate_slice(&twiddles); + + (fft_eval, naive_eval) + } + + #[cfg(not(any(feature = "metal", feature = "cuda")))] + fn gen_fft_and_naive_interpolate( + fft_evals: &[FieldElement], + ) -> (Polynomial>, Polynomial>) { + let order = fft_evals.len().trailing_zeros() as u64; + let twiddles = + get_powers_of_primitive_root(order, 1 << order, RootsConfig::Natural).unwrap(); + + let naive_poly = Polynomial::interpolate(&twiddles, fft_evals).unwrap(); + let fft_poly = Polynomial::interpolate_fft::(fft_evals).unwrap(); + + (fft_poly, naive_poly) + } + + #[cfg(not(any(feature = "metal", feature = "cuda")))] + fn gen_fft_and_naive_coset_interpolate( + fft_evals: &[FieldElement], + offset: &FieldElement, + ) -> (Polynomial>, Polynomial>) { + let order = fft_evals.len().trailing_zeros() as u64; + let twiddles = get_powers_of_primitive_root_coset(order, 1 << order, offset).unwrap(); + + let naive_poly = Polynomial::interpolate(&twiddles, fft_evals).unwrap(); + let fft_poly = Polynomial::interpolate_offset_fft(fft_evals, offset).unwrap(); + + (fft_poly, naive_poly) + } + + #[cfg(not(any(feature = "metal", feature = "cuda")))] + fn gen_fft_interpolate_and_evaluate( + poly: Polynomial>, + ) -> (Polynomial>, Polynomial>) { + let eval = Polynomial::evaluate_fft::(&poly, 1, None).unwrap(); + let new_poly = Polynomial::interpolate_fft::(&eval).unwrap(); + + (poly, new_poly) + } + + prop_compose! { + fn powers_of_two(max_exp: u8)(exp in 1..max_exp) -> usize { 1 << exp } + // max_exp cannot be multiple of the bits that represent a usize, generally 64 or 32. + // also it can't exceed the test field's two-adicity. + } + prop_compose! { + fn field_element()(coeffs in [any::(); 4]) -> Fp4E { + Fp4E::new([ + FpE::from(coeffs[0]), + FpE::from(coeffs[1]), + FpE::from(coeffs[2]), + FpE::from(coeffs[3])] + ) + } + } + prop_compose! { + fn offset()(num in field_element(), factor in any::()) -> Fp4E { num.pow(factor) } + } + + prop_compose! { + fn field_vec(max_exp: u8)(vec in collection::vec(field_element(), 0..1 << max_exp)) -> Vec { + vec + } + } + prop_compose! { + fn non_power_of_two_sized_field_vec(max_exp: u8)(vec in collection::vec(field_element(), 2..1< Vec { + vec + } + } + prop_compose! { + fn poly(max_exp: u8)(coeffs in field_vec(max_exp)) -> Polynomial { + Polynomial::new(&coeffs) + } + } + prop_compose! { + fn poly_with_non_power_of_two_coeffs(max_exp: u8)(coeffs in non_power_of_two_sized_field_vec(max_exp)) -> Polynomial { + Polynomial::new(&coeffs) + } + } + + proptest! { + // Property-based test that ensures FFT eval. gives same result as a naive polynomial evaluation. + #[test] + #[cfg(not(any(feature = "metal",feature = "cuda")))] + fn test_fft_matches_naive_evaluation(poly in poly(8)) { + let (fft_eval, naive_eval) = gen_fft_and_naive_evaluation(poly); + prop_assert_eq!(fft_eval, naive_eval); + } + + // Property-based test that ensures FFT eval. with coset gives same result as a naive polynomial evaluation. + #[test] + #[cfg(not(any(feature = "metal",feature = "cuda")))] + fn test_fft_coset_matches_naive_evaluation(poly in poly(4), offset in offset(), blowup_factor in powers_of_two(4)) { + let (fft_eval, naive_eval) = gen_fft_coset_and_naive_evaluation(poly, offset, blowup_factor); + prop_assert_eq!(fft_eval, naive_eval); + } + + // Property-based test that ensures FFT interpolation is the same as naive.. + #[test] + #[cfg(not(any(feature = "metal",feature = "cuda")))] + fn test_fft_interpolate_matches_naive(fft_evals in field_vec(4) + .prop_filter("Avoid polynomials of size not power of two", + |evals| evals.len().is_power_of_two())) { + let (fft_poly, naive_poly) = gen_fft_and_naive_interpolate(&fft_evals); + prop_assert_eq!(fft_poly, naive_poly); + } + + // Property-based test that ensures FFT interpolation with an offset is the same as naive. + #[test] + #[cfg(not(any(feature = "metal",feature = "cuda")))] + fn test_fft_interpolate_coset_matches_naive(offset in offset(), fft_evals in field_vec(4) + .prop_filter("Avoid polynomials of size not power of two", + |evals| evals.len().is_power_of_two())) { + let (fft_poly, naive_poly) = gen_fft_and_naive_coset_interpolate(&fft_evals, &offset); + prop_assert_eq!(fft_poly, naive_poly); + } + + // Property-based test that ensures interpolation is the inverse operation of evaluation. + #[test] + #[cfg(not(any(feature = "metal",feature = "cuda")))] + fn test_fft_interpolate_is_inverse_of_evaluate( + poly in poly(4).prop_filter("Avoid non pows of two", |poly| poly.coeff_len().is_power_of_two())) { + let (poly, new_poly) = gen_fft_interpolate_and_evaluate(poly); + prop_assert_eq!(poly, new_poly); + } + } + } +}