diff --git a/backend/src/plonky3/mod.rs b/backend/src/plonky3/mod.rs index 32a22e7383..a3fff30243 100644 --- a/backend/src/plonky3/mod.rs +++ b/backend/src/plonky3/mod.rs @@ -1,17 +1,59 @@ -use std::{io, path::PathBuf, sync::Arc}; +use std::{ + any::{Any, TypeId}, + io, + path::PathBuf, + sync::Arc, +}; use powdr_ast::analyzed::Analyzed; use powdr_executor::{ constant_evaluator::{get_uniquely_sized_cloned, VariablySizedColumn}, witgen::WitgenCallback, }; -use powdr_number::{FieldElement, GoldilocksField, LargeInt}; -use powdr_plonky3::Plonky3Prover; +use powdr_number::{BabyBearField, FieldElement, GoldilocksField}; +use powdr_plonky3::{Commitment, FieldElementMap, Plonky3Prover, ProverData}; use crate::{Backend, BackendFactory, BackendOptions, Error, Proof}; pub(crate) struct Factory; +fn try_create( + pil: &Arc>, + fixed: &Arc)>>, + verification_key: &mut Option<&mut dyn io::Read>, +) -> Option>> +where + ProverData: Send, + Commitment: Send, +{ + // We ensure that FInner and FOuter are the same type, so we can even safely + // transmute between them. + if TypeId::of::() != TypeId::of::() { + return None; + } + + let pil = (pil as &dyn Any) + .downcast_ref::>>() + .unwrap(); + let fixed = (fixed as &dyn Any) + .downcast_ref::)>>>() + .unwrap(); + + let mut p3 = Box::new(Plonky3Prover::new(pil.clone(), fixed.clone())); + + if let Some(verification_key) = verification_key { + p3.set_verifying_key(*verification_key); + } else { + p3.setup(); + } + + let p3: Box> = p3; + let p3 = Box::into_raw(p3); + + // This is safe because we know that FInner == FOuter. + Some(unsafe { Box::from_raw(p3 as *mut dyn Backend) }) +} + impl BackendFactory for Factory { fn create( &self, @@ -19,14 +61,10 @@ impl BackendFactory for Factory { fixed: Arc)>>, _output_dir: Option, setup: Option<&mut dyn io::Read>, - verification_key: Option<&mut dyn io::Read>, + mut verification_key: Option<&mut dyn io::Read>, verification_app_key: Option<&mut dyn io::Read>, _: BackendOptions, ) -> Result>, Error> { - if T::modulus().to_arbitrary_integer() != GoldilocksField::modulus().to_arbitrary_integer() - { - unimplemented!("plonky3 is only implemented for the Goldilocks field"); - } if setup.is_some() { return Err(Error::NoSetupAvailable); } @@ -41,19 +79,26 @@ impl BackendFactory for Factory { get_uniquely_sized_cloned(&fixed).map_err(|_| Error::NoVariableDegreeAvailable)?, ); - let mut p3 = Box::new(Plonky3Prover::new(pil, fixed)); - - if let Some(verification_key) = verification_key { - p3.set_verifying_key(verification_key); - } else { - p3.setup(); - } - - Ok(p3) + Ok( + if let Some(p3) = try_create::(&pil, &fixed, &mut verification_key) + { + p3 + } else if let Some(p3) = + try_create::(&pil, &fixed, &mut verification_key) + { + p3 + } else { + unimplemented!("unsupported field type: {:?}", TypeId::of::()) + }, + ) } } -impl Backend for Plonky3Prover { +impl Backend for Plonky3Prover +where + ProverData: Send, + Commitment: Send, +{ fn verify(&self, proof: &[u8], instances: &[Vec]) -> Result<(), Error> { Ok(self.verify(proof, instances)?) } diff --git a/number/Cargo.toml b/number/Cargo.toml index ce848f5fd2..7f74cf36b8 100644 --- a/number/Cargo.toml +++ b/number/Cargo.toml @@ -13,8 +13,8 @@ ark-bn254 = { version = "0.4.0", default-features = false, features = [ ] } ark-ff = "0.4.2" ark-serialize = "0.4.2" -p3-baby-bear = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "uni-stark-with-fixed" } -p3-field = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "uni-stark-with-fixed" } +p3-baby-bear = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } +p3-field = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } num-bigint = { version = "0.4.3", features = ["serde"] } num-traits = "0.2.15" csv = "1.3" diff --git a/number/src/baby_bear.rs b/number/src/baby_bear.rs index d1c4852869..cc79e1f219 100644 --- a/number/src/baby_bear.rs +++ b/number/src/baby_bear.rs @@ -48,6 +48,10 @@ impl BabyBearField { fn to_canonical_u32(self) -> u32 { self.0.as_canonical_u32() } + + pub fn into_inner(self) -> BabyBear { + self.0 + } } impl FieldElement for BabyBearField { diff --git a/pipeline/src/test_util.rs b/pipeline/src/test_util.rs index 90aa523ef7..78f8091b77 100644 --- a/pipeline/src/test_util.rs +++ b/pipeline/src/test_util.rs @@ -77,12 +77,6 @@ pub fn regular_test(file_name: &str, inputs: &[i32]) { pipeline_bb.compute_witness().unwrap(); } -pub fn regular_test_only_babybear(file_name: &str, inputs: &[i32]) { - let inputs_bb = inputs.iter().map(|x| BabyBearField::from(*x)).collect(); - let mut pipeline_bb = make_prepared_pipeline(file_name, inputs_bb, vec![]); - pipeline_bb.compute_witness().unwrap(); -} - pub fn regular_test_without_babybear(file_name: &str, inputs: &[i32]) { let inputs_gl = inputs.iter().map(|x| GoldilocksField::from(*x)).collect(); let pipeline_gl = make_prepared_pipeline(file_name, inputs_gl, vec![]); @@ -296,9 +290,9 @@ pub fn gen_halo2_proof(pipeline: Pipeline, backend: BackendVariant) pub fn gen_halo2_proof(_pipeline: Pipeline, _backend: BackendVariant) {} #[cfg(feature = "plonky3")] -pub fn test_plonky3_with_backend_variant( +pub fn test_plonky3_with_backend_variant( file_name: &str, - inputs: Vec, + inputs: Vec, backend: BackendVariant, ) { let backend = match backend { @@ -314,7 +308,7 @@ pub fn test_plonky3_with_backend_variant( // Generate a proof let proof = pipeline.compute_proof().cloned().unwrap(); - let publics: Vec = pipeline + let publics: Vec = pipeline .publics() .clone() .unwrap() @@ -341,10 +335,10 @@ pub fn test_plonky3_with_backend_variant( } #[cfg(not(feature = "plonky3"))] -pub fn test_plonky3_with_backend_variant(_: &str, _: Vec, _: BackendVariant) {} +pub fn test_plonky3_with_backend_variant(_: &str, _: Vec, _: BackendVariant) {} #[cfg(not(feature = "plonky3"))] -pub fn gen_plonky3_proof(_: &str, _: Vec) {} +pub fn gen_plonky3_proof(_: &str, _: Vec) {} /// Returns the analyzed PIL containing only the std library. pub fn std_analyzed() -> Analyzed { diff --git a/pipeline/tests/asm.rs b/pipeline/tests/asm.rs index 6ff2c8068c..d8efe23181 100644 --- a/pipeline/tests/asm.rs +++ b/pipeline/tests/asm.rs @@ -39,7 +39,11 @@ fn simple_sum_asm() { let f = "asm/simple_sum.asm"; let i = [16, 4, 1, 2, 8, 5]; regular_test(f, &i); - test_plonky3_with_backend_variant(f, slice_to_vec(&i), BackendVariant::Composite); + test_plonky3_with_backend_variant::( + f, + slice_to_vec(&i), + BackendVariant::Composite, + ); } #[test] diff --git a/pipeline/tests/pil.rs b/pipeline/tests/pil.rs index c21c93df51..3c50a0c0d6 100644 --- a/pipeline/tests/pil.rs +++ b/pipeline/tests/pil.rs @@ -87,7 +87,11 @@ fn permutation_with_selector() { fn fibonacci() { let f = "pil/fibonacci.pil"; regular_test(f, Default::default()); - test_plonky3_with_backend_variant(f, Default::default(), BackendVariant::Monolithic); + test_plonky3_with_backend_variant::( + f, + Default::default(), + BackendVariant::Monolithic, + ); } #[test] @@ -241,7 +245,11 @@ fn halo_without_lookup() { #[test] fn add() { let f = "pil/add.pil"; - test_plonky3_with_backend_variant(f, Default::default(), BackendVariant::Monolithic); + test_plonky3_with_backend_variant::( + f, + Default::default(), + BackendVariant::Monolithic, + ); } #[test] diff --git a/pipeline/tests/powdr_std.rs b/pipeline/tests/powdr_std.rs index c367899ccd..e77771cc59 100644 --- a/pipeline/tests/powdr_std.rs +++ b/pipeline/tests/powdr_std.rs @@ -6,8 +6,8 @@ use powdr_pil_analyzer::evaluator::Value; use powdr_pipeline::{ test_util::{ evaluate_function, evaluate_integer_function, execute_test_file, gen_estark_proof, - gen_halo2_proof, make_simple_prepared_pipeline, regular_test, regular_test_only_babybear, - std_analyzed, test_halo2, test_pilcom, BackendVariant, + gen_halo2_proof, make_simple_prepared_pipeline, regular_test, std_analyzed, test_halo2, + test_pilcom, test_plonky3_with_backend_variant, BackendVariant, }, Pipeline, }; @@ -186,14 +186,14 @@ fn binary_test() { #[ignore = "Too slow"] fn binary_bb_8_test() { let f = "std/binary_bb_test_8.asm"; - regular_test_only_babybear(f, &[]); + test_plonky3_with_backend_variant::(f, vec![], BackendVariant::Composite); } #[test] #[ignore = "Too slow"] fn binary_bb_16_test() { let f = "std/binary_bb_test_16.asm"; - regular_test_only_babybear(f, &[]); + test_plonky3_with_backend_variant::(f, vec![], BackendVariant::Composite); } #[test] diff --git a/plonky3/Cargo.toml b/plonky3/Cargo.toml index 9734ba77f9..0f49b9b9f5 100644 --- a/plonky3/Cargo.toml +++ b/plonky3/Cargo.toml @@ -18,7 +18,7 @@ p3-matrix = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" p3-field = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } p3-uni-stark = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } p3-commit = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main", features = [ - "test-utils", + "test-utils", ] } p3-poseidon2 = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } p3-poseidon = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } @@ -26,10 +26,11 @@ p3-fri = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } # We don't use p3-maybe-rayon directly, but it is a dependency of p3-uni-stark. # Activating the "parallel" feature gives us parallelism in the prover. p3-maybe-rayon = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main", features = [ - "parallel", + "parallel", ] } p3-mds = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } p3-merkle-tree = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } +p3-baby-bear = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } p3-goldilocks = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } p3-symmetric = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } p3-dft = { git = "https://github.com/powdr-labs/Plonky3.git", branch = "main" } diff --git a/plonky3/src/baby_bear.rs b/plonky3/src/baby_bear.rs new file mode 100644 index 0000000000..bf6b0533eb --- /dev/null +++ b/plonky3/src/baby_bear.rs @@ -0,0 +1,109 @@ +//! The concrete parameters used in the prover +//! Inspired from [this example](https://github.com/Plonky3/Plonky3/blob/6a1b0710fdf85136d0fdd645b92933615867740a/keccak-air/examples/prove_baby_bear_poseidon2.rs) + +use lazy_static::lazy_static; + +use crate::params::{Challenger, FieldElementMap, Plonky3Field}; +use p3_baby_bear::{BabyBear, DiffusionMatrixBabyBear}; +use p3_challenger::DuplexChallenger; +use p3_commit::ExtensionMmcs; +use p3_dft::Radix2DitParallel; +use p3_field::{extension::BinomialExtensionField, Field}; +use p3_fri::{FriConfig, TwoAdicFriPcs}; +use p3_merkle_tree::FieldMerkleTreeMmcs; +use p3_poseidon2::{Poseidon2, Poseidon2ExternalMatrixGeneral}; +use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; +use p3_uni_stark::StarkConfig; + +use rand::{distributions::Standard, Rng, SeedableRng}; + +use powdr_number::BabyBearField; + +const D: u64 = 7; +// params directly taken from plonky3's poseidon2_round_numbers_128 function +// to guarentee 128-bit security. +const ROUNDS_F: usize = 8; +const ROUNDS_P: usize = 13; +const WIDTH: usize = 16; +type Perm = Poseidon2; + +const DEGREE: usize = 4; +type FriChallenge = BinomialExtensionField; + +const RATE: usize = 8; +const OUT: usize = 8; +type FriChallenger = DuplexChallenger; +type Hash = PaddingFreeSponge; + +const N: usize = 2; +const CHUNK: usize = 8; +type Compress = TruncatedPermutation; +const DIGEST_ELEMS: usize = 8; +type ValMmcs = FieldMerkleTreeMmcs< + ::Packing, + ::Packing, + Hash, + Compress, + DIGEST_ELEMS, +>; + +type ChallengeMmcs = ExtensionMmcs; +type Dft = Radix2DitParallel; +type MyPcs = TwoAdicFriPcs; + +const FRI_LOG_BLOWUP: usize = 1; +const FRI_NUM_QUERIES: usize = 100; +const FRI_PROOF_OF_WORK_BITS: usize = 16; + +const RNG_SEED: u64 = 42; + +lazy_static! { + static ref PERM_BB: Perm = Perm::new( + ROUNDS_F, + rand_chacha::ChaCha8Rng::seed_from_u64(RNG_SEED) + .sample_iter(Standard) + .take(ROUNDS_F) + .collect::>(), + Poseidon2ExternalMatrixGeneral, + ROUNDS_P, + rand_chacha::ChaCha8Rng::seed_from_u64(RNG_SEED) + .sample_iter(Standard) + .take(ROUNDS_P) + .collect(), + DiffusionMatrixBabyBear::default() + ); +} + +impl FieldElementMap for BabyBearField { + type Config = StarkConfig; + fn into_p3_field(self) -> Plonky3Field { + self.into_inner() + } + + fn get_challenger() -> Challenger { + FriChallenger::new(PERM_BB.clone()) + } + + fn get_config() -> Self::Config { + let hash = Hash::new(PERM_BB.clone()); + + let compress = Compress::new(PERM_BB.clone()); + + let val_mmcs = ValMmcs::new(hash, compress); + + let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); + + let dft = Dft {}; + + let fri_config = FriConfig { + log_blowup: FRI_LOG_BLOWUP, + num_queries: FRI_NUM_QUERIES, + proof_of_work_bits: FRI_PROOF_OF_WORK_BITS, + mmcs: challenge_mmcs, + }; + + let pcs = MyPcs::new(dft, val_mmcs, fri_config); + + Self::Config::new(pcs) + } +} diff --git a/plonky3/src/circuit_builder.rs b/plonky3/src/circuit_builder.rs index 9c47a905d0..5e9e7480f7 100644 --- a/plonky3/src/circuit_builder.rs +++ b/plonky3/src/circuit_builder.rs @@ -6,14 +6,10 @@ //! everywhere save for at row j is constructed to constrain s * (pub - x) on //! every row. -use std::{ - any::TypeId, - collections::{BTreeMap, HashSet}, -}; +use std::collections::{BTreeMap, HashSet}; +use crate::params::{Commitment, FieldElementMap, Plonky3Field, ProverData}; use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir, PairBuilder}; -use p3_field::AbstractField; -use p3_goldilocks::Goldilocks; use p3_matrix::{dense::RowMajorMatrix, Matrix}; use powdr_ast::analyzed::{ AlgebraicBinaryOperation, AlgebraicBinaryOperator, AlgebraicExpression, @@ -21,9 +17,7 @@ use powdr_ast::analyzed::{ PolynomialType, SelectedExpressions, }; use powdr_executor::witgen::WitgenCallback; -use powdr_number::{DegreeType, FieldElement, GoldilocksField, LargeInt}; - -pub type Val = p3_goldilocks::Goldilocks; +use powdr_number::{DegreeType, FieldElement}; /// A description of the constraint system. /// All of the data is derived from the analyzed PIL, but is materialized @@ -53,8 +47,11 @@ impl From<&Analyzed> for ConstraintSystem { } } } - -pub(crate) struct PowdrCircuit<'a, T> { +pub(crate) struct PowdrCircuit<'a, T: FieldElementMap> +where + ProverData: Send, + Commitment: Send, +{ /// The constraint system description constraint_system: ConstraintSystem, /// The value of the witness columns, if set @@ -63,11 +60,15 @@ pub(crate) struct PowdrCircuit<'a, T> { _witgen_callback: Option>, /// The matrix of preprocessed values, used in debug mode to check the constraints before proving #[cfg(debug_assertions)] - preprocessed: Option>, + preprocessed: Option>>, } -impl<'a, T: FieldElement> PowdrCircuit<'a, T> { - pub fn generate_trace_rows(&self) -> RowMajorMatrix { +impl<'a, T: FieldElementMap> PowdrCircuit<'a, T> +where + ProverData: Send, + Commitment: Send, +{ + pub fn generate_trace_rows(&self) -> RowMajorMatrix> { // an iterator over all columns, committed then fixed let witness = self.witness().iter(); let degrees = &self.constraint_system.degrees; @@ -81,7 +82,7 @@ impl<'a, T: FieldElement> PowdrCircuit<'a, T> { // witness values witness.clone().map(move |(_, v)| v[i as usize]) }) - .map(cast_to_goldilocks) + .map(|f| f.into_p3_field()) .collect() } 0 => { @@ -95,12 +96,11 @@ impl<'a, T: FieldElement> PowdrCircuit<'a, T> { } } -pub fn cast_to_goldilocks(v: T) -> Val { - assert_eq!(TypeId::of::(), TypeId::of::()); - Val::from_canonical_u64(v.to_integer().try_into_u64().unwrap()) -} - -impl<'a, T: FieldElement> PowdrCircuit<'a, T> { +impl<'a, T: FieldElementMap> PowdrCircuit<'a, T> +where + ProverData: Send, + Commitment: Send, +{ pub(crate) fn new(analyzed: &'a Analyzed) -> Self { if analyzed .definitions @@ -124,7 +124,7 @@ impl<'a, T: FieldElement> PowdrCircuit<'a, T> { } /// Calculates public values from generated witness values. - pub(crate) fn get_public_values(&self) -> Vec { + pub(crate) fn get_public_values(&self) -> Vec> { let witness = self .witness .as_ref() @@ -138,7 +138,7 @@ impl<'a, T: FieldElement> PowdrCircuit<'a, T> { .iter() .map(|(col_name, _, idx)| { let vals = *witness.get(&col_name).unwrap(); - cast_to_goldilocks(vals[*idx]) + vals[*idx].into_p3_field() }) .collect() } @@ -161,14 +161,14 @@ impl<'a, T: FieldElement> PowdrCircuit<'a, T> { #[cfg(debug_assertions)] pub(crate) fn with_preprocessed( mut self, - preprocessed_matrix: RowMajorMatrix, + preprocessed_matrix: RowMajorMatrix>, ) -> Self { self.preprocessed = Some(preprocessed_matrix); self } /// Conversion to plonky3 expression - fn to_plonky3_expr + AirBuilderWithPublicValues>( + fn to_plonky3_expr> + AirBuilderWithPublicValues>( &self, e: &AlgebraicExpression, main: &AB::M, @@ -205,7 +205,7 @@ impl<'a, T: FieldElement> PowdrCircuit<'a, T> { .get(id) .expect("Referenced public value does not exist")) .into(), - AlgebraicExpression::Number(n) => AB::Expr::from(cast_to_goldilocks(*n)), + AlgebraicExpression::Number(n) => AB::Expr::from(n.into_p3_field()), AlgebraicExpression::BinaryOperation(AlgebraicBinaryOperation { left, op, right }) => { let left = self.to_plonky3_expr::(left, main, fixed, publics); let right = self.to_plonky3_expr::(right, main, fixed, publics); @@ -237,7 +237,11 @@ impl<'a, T: FieldElement> PowdrCircuit<'a, T> { /// An extension of [Air] allowing access to the number of fixed columns -impl<'a, T: FieldElement> BaseAir for PowdrCircuit<'a, T> { +impl<'a, T: FieldElementMap> BaseAir> for PowdrCircuit<'a, T> +where + ProverData: Send, + Commitment: Send, +{ fn width(&self) -> usize { self.constraint_system.commitment_count } @@ -246,7 +250,7 @@ impl<'a, T: FieldElement> BaseAir for PowdrCircuit<'a, T> { self.constraint_system.constant_count + self.constraint_system.publics.len() } - fn preprocessed_trace(&self) -> Option> { + fn preprocessed_trace(&self) -> Option>> { #[cfg(debug_assertions)] { self.preprocessed.clone() @@ -256,8 +260,11 @@ impl<'a, T: FieldElement> BaseAir for PowdrCircuit<'a, T> { } } -impl<'a, T: FieldElement, AB: AirBuilderWithPublicValues + PairBuilder> Air - for PowdrCircuit<'a, T> +impl<'a, T: FieldElementMap, AB: AirBuilderWithPublicValues> + PairBuilder> + Air for PowdrCircuit<'a, T> +where + ProverData: Send, + Commitment: Send, { fn eval(&self, builder: &mut AB) { let main = builder.main(); diff --git a/plonky3/src/goldilocks.rs b/plonky3/src/goldilocks.rs new file mode 100644 index 0000000000..70bc98ad01 --- /dev/null +++ b/plonky3/src/goldilocks.rs @@ -0,0 +1,105 @@ +//! The concrete parameters used in the prover +//! Inspired from [this example](https://github.com/Plonky3/Plonky3/blob/6a1b0710fdf85136d0fdd645b92933615867740a/keccak-air/examples/prove_goldilocks_keccak.rs#L57) + +use lazy_static::lazy_static; + +use crate::params::{Challenger, FieldElementMap, Plonky3Field}; +use p3_challenger::DuplexChallenger; +use p3_commit::ExtensionMmcs; +use p3_dft::Radix2DitParallel; +use p3_field::{extension::BinomialExtensionField, AbstractField, Field}; +use p3_fri::{FriConfig, TwoAdicFriPcs}; +use p3_goldilocks::{Goldilocks, MdsMatrixGoldilocks}; +use p3_merkle_tree::FieldMerkleTreeMmcs; +use p3_poseidon::Poseidon; +use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; +use p3_uni_stark::StarkConfig; +use powdr_number::{FieldElement, GoldilocksField, LargeInt}; +use rand::{distributions::Standard, Rng, SeedableRng}; + +const DEGREE: usize = 2; +type FriChallenge = BinomialExtensionField; +const WIDTH: usize = 8; +const ALPHA: u64 = 7; +type Perm = Poseidon; + +const RATE: usize = 4; +const OUT: usize = 4; +type Hash = PaddingFreeSponge; + +const N: usize = 2; +const CHUNK: usize = 4; +type Compress = TruncatedPermutation; + +const DIGEST_ELEMS: usize = 4; +type ValMmcs = FieldMerkleTreeMmcs< + ::Packing, + ::Packing, + Hash, + Compress, + DIGEST_ELEMS, +>; + +pub type FriChallenger = DuplexChallenger; +type ChallengeMmcs = ExtensionMmcs; +type Dft = Radix2DitParallel; +type MyPcs = TwoAdicFriPcs; + +const HALF_NUM_FULL_ROUNDS: usize = 4; +const NUM_PARTIAL_ROUNDS: usize = 22; + +const FRI_LOG_BLOWUP: usize = 1; +const FRI_NUM_QUERIES: usize = 100; +const FRI_PROOF_OF_WORK_BITS: usize = 16; + +const NUM_ROUNDS: usize = 2 * HALF_NUM_FULL_ROUNDS + NUM_PARTIAL_ROUNDS; +const NUM_CONSTANTS: usize = WIDTH * NUM_ROUNDS; + +const RNG_SEED: u64 = 42; + +lazy_static! { + static ref PERM_GL: Perm = Perm::new( + HALF_NUM_FULL_ROUNDS, + NUM_PARTIAL_ROUNDS, + rand_chacha::ChaCha8Rng::seed_from_u64(RNG_SEED) + .sample_iter(Standard) + .take(NUM_CONSTANTS) + .collect(), + MdsMatrixGoldilocks, + ); +} + +impl FieldElementMap for GoldilocksField { + type Config = StarkConfig; + + fn into_p3_field(self) -> Plonky3Field { + Goldilocks::from_canonical_u64(self.to_integer().try_into_u64().unwrap()) + } + + fn get_challenger() -> Challenger { + FriChallenger::new(PERM_GL.clone()) + } + + fn get_config() -> Self::Config { + let hash = Hash::new(PERM_GL.clone()); + + let compress = Compress::new(PERM_GL.clone()); + + let val_mmcs = ValMmcs::new(hash, compress); + + let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); + + let dft = Dft {}; + + let fri_config = FriConfig { + log_blowup: FRI_LOG_BLOWUP, + num_queries: FRI_NUM_QUERIES, + proof_of_work_bits: FRI_PROOF_OF_WORK_BITS, + mmcs: challenge_mmcs, + }; + + let pcs = MyPcs::new(dft, val_mmcs, fri_config); + + Self::Config::new(pcs) + } +} diff --git a/plonky3/src/lib.rs b/plonky3/src/lib.rs index 4bcb981bcf..364e741b91 100644 --- a/plonky3/src/lib.rs +++ b/plonky3/src/lib.rs @@ -1,5 +1,7 @@ +mod baby_bear; mod circuit_builder; +mod goldilocks; mod params; mod stark; - +pub use params::{Commitment, FieldElementMap, ProverData}; pub use stark::Plonky3Prover; diff --git a/plonky3/src/params.rs b/plonky3/src/params.rs index 8df2bd4053..ee34b80961 100644 --- a/plonky3/src/params.rs +++ b/plonky3/src/params.rs @@ -1,98 +1,29 @@ //! The concrete parameters used in the prover //! Inspired from [this example](https://github.com/Plonky3/Plonky3/blob/6a1b0710fdf85136d0fdd645b92933615867740a/keccak-air/examples/prove_goldilocks_poseidon.rs) -use lazy_static::lazy_static; +use p3_commit::PolynomialSpace; +use p3_uni_stark::StarkGenericConfig; +use powdr_number::FieldElement; -use p3_challenger::DuplexChallenger; -use p3_commit::ExtensionMmcs; -use p3_dft::Radix2DitParallel; -use p3_field::{extension::BinomialExtensionField, Field}; -use p3_fri::{FriConfig, TwoAdicFriPcs}; -use p3_goldilocks::MdsMatrixGoldilocks; -use p3_merkle_tree::FieldMerkleTreeMmcs; -use p3_poseidon::Poseidon; -use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; -use p3_uni_stark::StarkConfig; +pub type Plonky3Field = + < as p3_commit::Pcs, Challenger>>::Domain as PolynomialSpace>::Val; +pub type Pcs = <::Config as StarkGenericConfig>::Pcs; +pub type Challenge = <::Config as StarkGenericConfig>::Challenge; +pub type Challenger = <::Config as StarkGenericConfig>::Challenger; -use rand::{distributions::Standard, Rng, SeedableRng}; +pub type ProverData = as p3_commit::Pcs, Challenger>>::ProverData; +pub type Commitment = as p3_commit::Pcs, Challenger>>::Commitment; -use crate::circuit_builder::Val; +pub trait FieldElementMap: FieldElement +where + ProverData: Send, + Commitment: Send, +{ + type Config: StarkGenericConfig; -const D: usize = 2; -type Challenge = BinomialExtensionField; -const WIDTH: usize = 8; -const ALPHA: u64 = 7; -type Perm = Poseidon; + fn into_p3_field(self) -> Plonky3Field; -const RATE: usize = 4; -const OUT: usize = 4; -type Hash = PaddingFreeSponge; + fn get_challenger() -> Challenger; -const N: usize = 2; -const CHUNK: usize = 4; -type Compress = TruncatedPermutation; - -const DIGEST_ELEMS: usize = 4; -type ValMmcs = FieldMerkleTreeMmcs< - ::Packing, - ::Packing, - Hash, - Compress, - DIGEST_ELEMS, ->; -pub type Challenger = DuplexChallenger; -type ChallengeMmcs = ExtensionMmcs; -type Dft = Radix2DitParallel; -type MyPcs = TwoAdicFriPcs; -pub type Config = StarkConfig; - -const HALF_NUM_FULL_ROUNDS: usize = 4; -const NUM_PARTIAL_ROUNDS: usize = 22; - -const FRI_LOG_BLOWUP: usize = 1; -const FRI_NUM_QUERIES: usize = 100; -const FRI_PROOF_OF_WORK_BITS: usize = 16; - -const NUM_ROUNDS: usize = 2 * HALF_NUM_FULL_ROUNDS + NUM_PARTIAL_ROUNDS; -const NUM_CONSTANTS: usize = WIDTH * NUM_ROUNDS; - -const RNG_SEED: u64 = 42; - -lazy_static! { - static ref PERM: Perm = Perm::new( - HALF_NUM_FULL_ROUNDS, - NUM_PARTIAL_ROUNDS, - rand_chacha::ChaCha8Rng::seed_from_u64(RNG_SEED) - .sample_iter(Standard) - .take(NUM_CONSTANTS) - .collect(), - MdsMatrixGoldilocks, - ); -} - -pub fn get_challenger() -> Challenger { - Challenger::new(PERM.clone()) -} - -pub fn get_config() -> StarkConfig { - let hash = Hash::new(PERM.clone()); - - let compress = Compress::new(PERM.clone()); - - let val_mmcs = ValMmcs::new(hash, compress); - - let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); - - let dft = Dft {}; - - let fri_config = FriConfig { - log_blowup: FRI_LOG_BLOWUP, - num_queries: FRI_NUM_QUERIES, - proof_of_work_bits: FRI_PROOF_OF_WORK_BITS, - mmcs: challenge_mmcs, - }; - - let pcs = MyPcs::new(dft, val_mmcs, fri_config); - - Config::new(pcs) + fn get_config() -> Self::Config; } diff --git a/plonky3/src/stark.rs b/plonky3/src/stark.rs index d5ff6de73c..aa760aaf16 100644 --- a/plonky3/src/stark.rs +++ b/plonky3/src/stark.rs @@ -1,12 +1,10 @@ //! A plonky3 prover using FRI and Poseidon -use p3_goldilocks::Goldilocks; use p3_matrix::dense::RowMajorMatrix; use core::fmt; use std::sync::Arc; -use crate::params::Challenger; use powdr_ast::analyzed::Analyzed; use powdr_executor::witgen::WitgenCallback; @@ -14,21 +12,25 @@ use powdr_executor::witgen::WitgenCallback; use p3_uni_stark::{ prove_with_key, verify_with_key, Proof, StarkGenericConfig, StarkProvingKey, StarkVerifyingKey, }; -use powdr_number::{FieldElement, KnownField}; -use crate::circuit_builder::{cast_to_goldilocks, PowdrCircuit}; - -use crate::params::{get_challenger, get_config, Config}; +use crate::{ + circuit_builder::PowdrCircuit, + params::{Challenger, Commitment, FieldElementMap, Plonky3Field, ProverData}, +}; -pub struct Plonky3Prover { +pub struct Plonky3Prover +where + ProverData: Send, + Commitment: Send, +{ /// The analyzed PIL analyzed: Arc>, /// The value of the fixed columns fixed: Arc)>>, /// Proving key - proving_key: Option>, + proving_key: Option>, /// Verifying key - verifying_key: Option>, + verifying_key: Option>, } pub enum VerificationKeyExportError { @@ -43,7 +45,11 @@ impl fmt::Display for VerificationKeyExportError { } } -impl Plonky3Prover { +impl Plonky3Prover +where + ProverData: Send, + Commitment: Send, +{ pub fn new(analyzed: Arc>, fixed: Arc)>>) -> Self { Self { analyzed, @@ -68,7 +74,7 @@ impl Plonky3Prover { /// Returns preprocessed matrix based on the fixed inputs [`Plonky3Prover`]. /// This is used when running the setup phase - pub fn get_preprocessed_matrix(&self) -> RowMajorMatrix { + pub fn get_preprocessed_matrix(&self) -> RowMajorMatrix> { let publics = self .analyzed .get_publics() @@ -82,18 +88,18 @@ impl Plonky3Prover { .collect::)>>(); match self.fixed.len() + publics.len() { - 0 => RowMajorMatrix::new(Vec::::new(), 0), + 0 => RowMajorMatrix::new(Vec::>::new(), 0), _ => RowMajorMatrix::new( // write fixed row by row (0..self.analyzed.degree()) .flat_map(|i| { self.fixed .iter() - .map(move |(_, values)| cast_to_goldilocks(values[i as usize])) + .map(move |(_, values)| values[i as usize].into_p3_field()) .chain( publics .iter() - .map(move |(_, values)| cast_to_goldilocks(values[i as usize])), + .map(move |(_, values)| values[i as usize].into_p3_field()), ) }) .collect(), @@ -103,7 +109,11 @@ impl Plonky3Prover { } } -impl Plonky3Prover { +impl Plonky3Prover +where + ProverData: Send, + Commitment: Send, +{ pub fn setup(&mut self) { // get fixed columns let fixed = &self.fixed; @@ -126,11 +136,11 @@ impl Plonky3Prover { } // get the config - let config = get_config(); + let config = T::get_config(); // commit to the fixed columns let pcs = config.pcs(); - let domain = <_ as p3_commit::Pcs<_, Challenger>>::natural_domain_for_degree( + let domain = <_ as p3_commit::Pcs<_, Challenger>>::natural_domain_for_degree( pcs, self.analyzed.degree() as usize, ); @@ -141,7 +151,7 @@ impl Plonky3Prover { fixed .iter() .chain(publics.iter()) - .map(move |(_, values)| cast_to_goldilocks(values[i as usize])) + .map(move |(_, values)| values[i as usize].into_p3_field()) }) .collect(), self.fixed.len() + publics.len(), @@ -151,10 +161,10 @@ impl Plonky3Prover { // commit to the evaluations let (fixed_commit, fixed_data) = - <_ as p3_commit::Pcs<_, Challenger>>::commit(pcs, evaluations); + <_ as p3_commit::Pcs<_, Challenger>>::commit(pcs, evaluations); let proving_key = StarkProvingKey { - preprocessed_commit: fixed_commit, + preprocessed_commit: fixed_commit.clone(), preprocessed_data: fixed_data, }; let verifying_key = StarkVerifyingKey { @@ -170,9 +180,7 @@ impl Plonky3Prover { witness: &[(String, Vec)], witgen_callback: WitgenCallback, ) -> Result, String> { - assert_eq!(T::known_field(), Some(KnownField::GoldilocksField)); - - let circuit = PowdrCircuit::new(&self.analyzed) + let circuit: PowdrCircuit = PowdrCircuit::new(&self.analyzed) .with_witgen_callback(witgen_callback) .with_witness(witness); @@ -183,9 +191,9 @@ impl Plonky3Prover { let trace = circuit.generate_trace_rows(); - let config = get_config(); + let config = T::get_config(); - let mut challenger = get_challenger(); + let mut challenger = T::get_challenger(); let proving_key = self.proving_key.as_ref(); @@ -198,7 +206,7 @@ impl Plonky3Prover { &publics, ); - let mut challenger = get_challenger(); + let mut challenger = T::get_challenger(); let verifying_key = self.verifying_key.as_ref(); @@ -220,12 +228,12 @@ impl Plonky3Prover { let publics = instances .iter() .flatten() - .map(|v| cast_to_goldilocks(*v)) + .map(|v| v.into_p3_field()) .collect(); - let config = get_config(); + let config = T::get_config(); - let mut challenger = get_challenger(); + let mut challenger = T::get_challenger(); let verifying_key = self.verifying_key.as_ref(); @@ -246,7 +254,7 @@ mod tests { use std::sync::Arc; use powdr_executor::constant_evaluator::get_uniquely_sized_cloned; - use powdr_number::GoldilocksField; + use powdr_number::{BabyBearField, GoldilocksField}; use powdr_pipeline::Pipeline; use test_log::test; @@ -277,10 +285,47 @@ mod tests { } } + fn run_test_baby_bear(pil: &str) { + run_test_baby_bear_publics(pil, None) + } + + fn run_test_baby_bear_publics(pil: &str, malicious_publics: Option>) { + let mut pipeline = Pipeline::::default().from_pil_string(pil.to_string()); + + let pil = pipeline.compute_optimized_pil().unwrap(); + let witness_callback = pipeline.witgen_callback().unwrap(); + let witness = pipeline.compute_witness().unwrap(); + let fixed = pipeline.compute_fixed_cols().unwrap(); + let fixed = Arc::new(get_uniquely_sized_cloned(&fixed).unwrap()); + + let mut prover = Plonky3Prover::new(pil, fixed); + prover.setup(); + let proof = prover.prove(&witness, witness_callback); + + assert!(proof.is_ok()); + + if let Some(publics) = malicious_publics { + prover.verify(&proof.unwrap(), &[publics]).unwrap() + } + } + + #[test] + fn add_baby_bear() { + let content = r#" + namespace Add(8); + col witness x; + col witness y; + col witness z; + x + y = z; + "#; + run_test_baby_bear(content); + } + #[test] fn public_values() { let content = "namespace Global(8); pol witness x; x * (x - 1) = 0; public out = x(7);"; run_test_goldilocks(content); + run_test_baby_bear(content); } #[test] @@ -295,6 +340,7 @@ mod tests { y = 1 + :oldstate; "#; run_test_goldilocks(content); + run_test_baby_bear(content); } #[test] @@ -311,8 +357,11 @@ mod tests { public outz = z(7); "#; - let malicious_publics = Some(vec![GoldilocksField::from(0)]); - run_test_goldilocks_publics(content, malicious_publics); + let gl_malicious_publics = Some(vec![GoldilocksField::from(0)]); + run_test_goldilocks_publics(content, gl_malicious_publics); + + let bb_malicious_publics = Some(vec![BabyBearField::from(0)]); + run_test_baby_bear_publics(content, bb_malicious_publics); } #[test] @@ -320,6 +369,7 @@ mod tests { fn empty() { let content = "namespace Global(8);"; run_test_goldilocks(content); + run_test_baby_bear(content); } #[test] @@ -332,6 +382,7 @@ mod tests { x + y = z; "#; run_test_goldilocks(content); + run_test_baby_bear(content); } #[test] @@ -343,6 +394,7 @@ mod tests { x * y = y; "#; run_test_goldilocks(content); + run_test_baby_bear(content); } #[test] @@ -358,12 +410,14 @@ mod tests { x = y + beta; "#; run_test_goldilocks(content); + run_test_baby_bear(content); } #[test] fn polynomial_identity() { let content = "namespace Global(8); pol fixed z = [1, 2]*; pol witness a; a = z + 1;"; run_test_goldilocks(content); + run_test_baby_bear(content); } #[test] @@ -371,5 +425,6 @@ mod tests { fn lookup() { let content = "namespace Global(8); pol fixed z = [0, 1]*; pol witness a; [a] in [z];"; run_test_goldilocks(content); + run_test_baby_bear(content); } }