From 03d37e8638545c1573045bc80db7567ec786d685 Mon Sep 17 00:00:00 2001 From: Franco Nieddu Date: Mon, 25 Nov 2024 15:08:43 +0100 Subject: [PATCH 1/4] refactor: moved constraint parsing of groth16 zkey to rayon --- co-circom/circom-types/src/groth16/zkey.rs | 112 ++++++++++++--------- 1 file changed, 64 insertions(+), 48 deletions(-) diff --git a/co-circom/circom-types/src/groth16/zkey.rs b/co-circom/circom-types/src/groth16/zkey.rs index bfbbb11d7..d3978a5f7 100644 --- a/co-circom/circom-types/src/groth16/zkey.rs +++ b/co-circom/circom-types/src/groth16/zkey.rs @@ -33,6 +33,8 @@ use ark_serialize::CanonicalDeserialize; use std::io::Read; +use rayon::prelude::*; + use crate::{ binfile::{BinFile, ZKeyParserError, ZKeyParserResult}, traits::{CheckElement, CircomArkworksPairingBridge, CircomArkworksPrimeFieldBridge}, @@ -123,11 +125,12 @@ where let header = HeaderGroth::

::read(&mut binfile.take_section(2), check)?; let n_vars = header.n_vars; let n_public = header.n_public; - let domain_size = header.domain_size; + let domain_size = usize::try_from(header.domain_size).expect("fits into usize"); // parse proving key let ic_section = binfile.take_section(3); + let matrices_section = binfile.take_section(4); let a_section = binfile.take_section(5); let b_g1_section = binfile.take_section(6); let b_g2_section = binfile.take_section(7); @@ -140,6 +143,7 @@ where let mut b_g2_query = None; let mut l_query = None; let mut h_query = None; + let mut matrices = None; tracing::debug!("parsing zkey sections with rayon..."); rayon::scope(|s| { @@ -149,55 +153,17 @@ where s.spawn(|_| b_g2_query = Some(Self::b_g2_query(n_vars, b_g2_section, check))); s.spawn(|_| l_query = Some(Self::l_query(n_vars - n_public - 1, l_section, check))); s.spawn(|_| h_query = Some(Self::h_query(domain_size as usize, h_section, check))); + s.spawn(|_| { + matrices = Some(Self::constraint_matrices( + domain_size, + n_public, + n_vars, + matrices_section, + )) + }); }); tracing::debug!("we are done with parsing sections!"); - // parse matrices - - tracing::debug!("reading matrices..."); - let mut matrices_section = binfile.take_section(4); - - // this function (an all following uses) assumes that values are encoded in little-endian - let num_coeffs = u32::deserialize_uncompressed(&mut matrices_section)?; - - // instantiate AB - let mut matrices = vec![vec![vec![]; domain_size as usize]; 2]; - let mut max_constraint_index = 0; - for _ in 0..num_coeffs { - let matrix = u32::deserialize_uncompressed(&mut matrices_section)?; - let constraint = u32::deserialize_uncompressed(&mut matrices_section)?; - let signal = u32::deserialize_uncompressed(&mut matrices_section)?; - - let value = P::ScalarField::from_reader_for_groth16_zkey(&mut matrices_section)?; - max_constraint_index = std::cmp::max(max_constraint_index, constraint); - matrices[matrix as usize][constraint as usize].push((value, signal as usize)); - } - - let num_constraints = max_constraint_index as usize - n_public; - // Remove the public input constraints, Arkworks adds them later - matrices.iter_mut().for_each(|m| { - m.truncate(num_constraints); - }); - - // This is taken from Arkworks' to_matrices() function - let a = matrices[0].clone(); - let b = matrices[1].clone(); - let a_num_non_zero: usize = a.iter().map(|lc| lc.len()).sum(); - let b_num_non_zero: usize = b.iter().map(|lc| lc.len()).sum(); - - let matrices = ConstraintMatrices { - num_instance_variables: n_public + 1, - num_witness_variables: n_vars - n_public, - num_constraints, - - a_num_non_zero, - b_num_non_zero, - c_num_non_zero: 0, - - a, - b, - c: vec![], - }; // this thread automatically joins on the rayon scope, therefore we can // only be here if the scope finished. let vk = VerifyingKey { @@ -220,7 +186,7 @@ where b_g2_query: b_g2_query.unwrap()?, h_query: h_query.unwrap()?, l_query: l_query.unwrap()?, - matrices, + matrices: matrices.unwrap()?, vk, }) } @@ -273,6 +239,56 @@ where ) -> ZKeyParserResult> { Ok(P::g1_vec_from_reader(reader, n_vars, check)?) } + + fn constraint_matrices( + domain_size: usize, + n_public: usize, + n_vars: usize, + mut matrices_section: R, + ) -> ZKeyParserResult> { + // this function (an all following uses) assumes that values are encoded in little-endian + let num_coeffs = u32::deserialize_uncompressed(&mut matrices_section)?; + + // instantiate AB + let mut matrices = vec![vec![vec![]; domain_size]; 2]; + let mut max_constraint_index = 0; + for _ in 0..num_coeffs { + let matrix = u32::deserialize_uncompressed(&mut matrices_section)?; + let constraint = u32::deserialize_uncompressed(&mut matrices_section)?; + let signal = u32::deserialize_uncompressed(&mut matrices_section)?; + + let value = P::ScalarField::from_reader_for_groth16_zkey(&mut matrices_section)?; + max_constraint_index = std::cmp::max(max_constraint_index, constraint); + matrices[matrix as usize][constraint as usize].push((value, signal as usize)); + } + + let num_constraints = max_constraint_index as usize - n_public; + // Remove the public input constraints, Arkworks adds them later + matrices.iter_mut().for_each(|m| { + m.truncate(num_constraints); + }); + + // This is taken from Arkworks' to_matrices() function + let a = matrices[0].clone(); + let b = matrices[1].clone(); + let a_num_non_zero: usize = a.par_iter().map(|lc| lc.len()).sum(); + let b_num_non_zero: usize = b.par_iter().map(|lc| lc.len()).sum(); + + let matrices = ConstraintMatrices { + num_instance_variables: n_public + 1, + num_witness_variables: n_vars - n_public, + num_constraints, + + a_num_non_zero, + b_num_non_zero, + c_num_non_zero: 0, + + a, + b, + c: vec![], + }; + Ok(matrices) + } } impl HeaderGroth

From 7ad975f0f78d079e5722b139698903234a281f66 Mon Sep 17 00:00:00 2001 From: Franco Nieddu Date: Mon, 25 Nov 2024 15:16:03 +0100 Subject: [PATCH 2/4] refactor!: Removed unnecessary parts of the zkey --- co-circom/circom-types/src/groth16/zkey.rs | 77 +++++++--------------- co-circom/co-groth16/src/groth16.rs | 6 +- 2 files changed, 27 insertions(+), 56 deletions(-) diff --git a/co-circom/circom-types/src/groth16/zkey.rs b/co-circom/circom-types/src/groth16/zkey.rs index d3978a5f7..8454cee53 100644 --- a/co-circom/circom-types/src/groth16/zkey.rs +++ b/co-circom/circom-types/src/groth16/zkey.rs @@ -52,8 +52,6 @@ pub struct ZKey { pub n_public: usize, /// domain size pub pow: usize, - /// The underlying verification key. - pub vk: VerifyingKey

, /// beta pub beta_g1: P::G1Affine, /// delta @@ -68,24 +66,14 @@ pub struct ZKey { pub h_query: Vec, /// l_query pub l_query: Vec, - /// The constraint matrices A, B, and C - pub matrices: ConstraintMatrices, -} - -/// The verifying key encapsulated in the zkey. This is NOT the key used for verifying (although it has the same values). -/// You most likely are looking for the [`JsonVerificationKey`](crate::groth16::verification_key::JsonVerificationKey). -#[derive(Default, Clone, Debug)] -pub struct VerifyingKey { - /// alpha + /// alpha_g1 pub alpha_g1: P::G1Affine, - /// beta + /// beta_g1 pub beta_g2: P::G2Affine, - /// gamma - pub gamma_g2: P::G2Affine, - /// delta + /// delta_g1 pub delta_g2: P::G2Affine, - /// delta - pub gamma_abc_g1: Vec, + /// The constraint matrices A, B, and C + pub matrices: ConstraintMatrices, } #[derive(Clone, Debug)] @@ -97,7 +85,6 @@ struct HeaderGroth { alpha_g1: P::G1Affine, beta_g1: P::G1Affine, beta_g2: P::G2Affine, - gamma_g2: P::G2Affine, delta_g1: P::G1Affine, delta_g2: P::G2Affine, } @@ -129,7 +116,6 @@ where // parse proving key - let ic_section = binfile.take_section(3); let matrices_section = binfile.take_section(4); let a_section = binfile.take_section(5); let b_g1_section = binfile.take_section(6); @@ -137,7 +123,6 @@ where let l_section = binfile.take_section(8); let h_section = binfile.take_section(9); - let mut ic = None; let mut a_query = None; let mut b_g1_query = None; let mut b_g2_query = None; @@ -147,7 +132,6 @@ where tracing::debug!("parsing zkey sections with rayon..."); rayon::scope(|s| { - s.spawn(|_| ic = Some(Self::ic(n_public, ic_section, check))); s.spawn(|_| a_query = Some(Self::a_query(n_vars, a_section, check))); s.spawn(|_| b_g1_query = Some(Self::b_g1_query(n_vars, b_g1_section, check))); s.spawn(|_| b_g2_query = Some(Self::b_g2_query(n_vars, b_g2_section, check))); @@ -166,14 +150,14 @@ where // this thread automatically joins on the rayon scope, therefore we can // only be here if the scope finished. - let vk = VerifyingKey { - alpha_g1: header.alpha_g1, - beta_g2: header.beta_g2, - gamma_g2: header.gamma_g2, - delta_g2: header.delta_g2, - // unwrap is fine, because we are guaranteed to have a Some value (rayon scope) - gamma_abc_g1: ic.unwrap()?, - }; + //let vk = VerifyingKey { + // alpha_g1: header.alpha_g1, + // beta_g2: header.beta_g2, + // gamma_g2: header.gamma_g2, + // delta_g2: header.delta_g2, + // // unwrap is fine, because we are guaranteed to have a Some value (rayon scope) + // gamma_abc_g1: ic.unwrap()?, + //}; tracing::debug!("groth16 zkey parsing done!"); Ok(ZKey { n_public: header.n_public, @@ -186,20 +170,13 @@ where b_g2_query: b_g2_query.unwrap()?, h_query: h_query.unwrap()?, l_query: l_query.unwrap()?, + alpha_g1: header.alpha_g1, + beta_g2: header.beta_g2, + delta_g2: header.delta_g2, matrices: matrices.unwrap()?, - vk, }) } - fn ic( - n_public: usize, - reader: R, - check: CheckElement, - ) -> ZKeyParserResult> { - // the range is non-inclusive so we do +1 to get all inputs - Ok(P::g1_vec_from_reader(reader, n_public + 1, check)?) - } - fn a_query( n_vars: usize, reader: R, @@ -332,7 +309,8 @@ where let alpha_g1 = P::g1_from_reader(&mut reader, check)?; let beta_g1 = P::g1_from_reader(&mut reader, check)?; let beta_g2 = P::g2_from_reader(&mut reader, check)?; - let gamma_g2 = P::g2_from_reader(&mut reader, check)?; + // we don't need this element but we need to read it anyways + let _ = P::g2_from_reader(&mut reader, check)?; let delta_g1 = P::g1_from_reader(&mut reader, check)?; let delta_g2 = P::g2_from_reader(&mut reader, check)?; tracing::debug!("read header done!"); @@ -344,7 +322,6 @@ where alpha_g1, beta_g1, beta_g2, - gamma_g2, delta_g1, delta_g2, }) @@ -457,7 +434,6 @@ mod tests { assert_eq!(b_g2_query, *pk.b_g2_query); assert_eq!(h_query, pk.h_query); assert_eq!(l_query, pk.l_query); - let vk = pk.vk; let alpha_g1 = test_utils::to_g1_bls12_381!( "573513743870798705896078935465463988747193691665514373553428213826028808426481266659437596949247877550493216010640", "3195692015363680281472407569911592878057544540747596023043039898101401350267601241530895953964131482377769738361054" @@ -485,11 +461,9 @@ mod tests { "1374573688907712469603830822734104311026384172354584262904362700919219617284680686401889337872942140366529825919103" ), ]; - assert_eq!(alpha_g1, vk.alpha_g1); - assert_eq!(beta_g2, vk.beta_g2); - assert_eq!(gamma_g2, vk.gamma_g2); - assert_eq!(delta_g2, vk.delta_g2); - assert_eq!(gamma_abc_g1, vk.gamma_abc_g1); + assert_eq!(alpha_g1, pk.alpha_g1); + assert_eq!(beta_g2, pk.beta_g2); + assert_eq!(delta_g2, pk.delta_g2); } } @@ -577,7 +551,6 @@ mod tests { assert_eq!(b_g2_query, *pk.b_g2_query); assert_eq!(h_query, pk.h_query); assert_eq!(l_query, pk.l_query); - let vk = pk.vk; let alpha_g1 = test_utils::to_g1_bn254!( "16899422092493380665487369855810985762968608626455123789954325961085508316984", @@ -606,11 +579,9 @@ mod tests { "10737415594461993507153866894812637432840367562913937920244709428556226500845" ), ]; - assert_eq!(alpha_g1, vk.alpha_g1); - assert_eq!(beta_g2, vk.beta_g2); - assert_eq!(gamma_g2, vk.gamma_g2); - assert_eq!(delta_g2, vk.delta_g2); - assert_eq!(gamma_abc_g1, vk.gamma_abc_g1); + assert_eq!(alpha_g1, pk.alpha_g1); + assert_eq!(beta_g2, pk.beta_g2); + assert_eq!(delta_g2, pk.delta_g2); let a = vec![vec![( ark_bn254::Fr::from_str( diff --git a/co-circom/co-groth16/src/groth16.rs b/co-circom/co-groth16/src/groth16.rs index 33927c3f6..1ee82b2f5 100644 --- a/co-circom/co-groth16/src/groth16.rs +++ b/co-circom/co-groth16/src/groth16.rs @@ -351,10 +351,10 @@ where let aux_assignment2 = Arc::clone(&aux_assignment); let aux_assignment3 = Arc::clone(&aux_assignment); let aux_assignment4 = Arc::clone(&aux_assignment); - let alpha_g1 = zkey.vk.alpha_g1; + let alpha_g1 = zkey.alpha_g1; let beta_g1 = zkey.beta_g1; - let beta_g2 = zkey.vk.beta_g2; - let delta_g2 = zkey.vk.delta_g2.into_group(); + let beta_g2 = zkey.beta_g2; + let delta_g2 = zkey.delta_g2.into_group(); rayon::spawn(move || { let compute_a = From 667f0037aeb15a3f779a41365caa74bdb429b290 Mon Sep 17 00:00:00 2001 From: Franco Nieddu Date: Tue, 26 Nov 2024 12:41:12 +0100 Subject: [PATCH 3/4] refactor!: Removed ark_relations deps. Also changed verify impls to not return bool but a common error BREAKING CHANGE: Now the verify impls from groth16/plonk circom return an error indicating whether it was a success or not --- Cargo.toml | 1 - co-circom/circom-types/Cargo.toml | 1 - co-circom/circom-types/src/groth16/mod.rs | 1 + co-circom/circom-types/src/groth16/zkey.rs | 101 +++++---------------- co-circom/co-circom-snarks/src/lib.rs | 39 ++++++++ co-circom/co-circom/src/bin/co-circom.rs | 26 +++--- co-circom/co-groth16/Cargo.toml | 1 - co-circom/co-groth16/src/groth16.rs | 31 ++----- co-circom/co-groth16/src/lib.rs | 36 ++------ co-circom/co-groth16/src/verifier.rs | 11 ++- co-circom/co-plonk/src/lib.rs | 6 +- co-circom/co-plonk/src/plonk.rs | 28 +++--- tests/tests/circom/e2e_tests/rep3.rs | 4 - tests/tests/circom/e2e_tests/shamir.rs | 4 - 14 files changed, 123 insertions(+), 167 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 00ee177c3..ee6b1766b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,7 +37,6 @@ ark-bn254 = "0.4.0" ark-ec = { version = "0.4.2", default-features = false } ark-ff = "0.4.2" ark-poly = "0.4.2" -ark-relations = { version = "0.4.0", default-features = false } ark-serialize = { version = "0.4", features = ["derive", "std"] } ark-std = { version = "0.4.0", features = ["std"] } bincode = "1.3.3" diff --git a/co-circom/circom-types/Cargo.toml b/co-circom/circom-types/Cargo.toml index 02d50bb13..e10baf43d 100644 --- a/co-circom/circom-types/Cargo.toml +++ b/co-circom/circom-types/Cargo.toml @@ -17,7 +17,6 @@ ark-bn254 = { workspace = true } ark-ec = { workspace = true } ark-ff = { workspace = true } ark-poly = { workspace = true } -ark-relations = { workspace = true } ark-serialize = { workspace = true } ark-std = { workspace = true } byteorder = { workspace = true } diff --git a/co-circom/circom-types/src/groth16/mod.rs b/co-circom/circom-types/src/groth16/mod.rs index 98aad2eee..de8074f9c 100644 --- a/co-circom/circom-types/src/groth16/mod.rs +++ b/co-circom/circom-types/src/groth16/mod.rs @@ -7,6 +7,7 @@ mod zkey; pub use proof::Groth16Proof; pub use public_input::JsonPublicInput; pub use verification_key::JsonVerificationKey; +pub use zkey::ConstraintMatrix; pub use zkey::ZKey; #[cfg(test)] diff --git a/co-circom/circom-types/src/groth16/zkey.rs b/co-circom/circom-types/src/groth16/zkey.rs index 8454cee53..c6dc9d814 100644 --- a/co-circom/circom-types/src/groth16/zkey.rs +++ b/co-circom/circom-types/src/groth16/zkey.rs @@ -28,13 +28,10 @@ //! Inspired by use ark_ec::pairing::Pairing; use ark_ff::PrimeField; -use ark_relations::r1cs::ConstraintMatrices; use ark_serialize::CanonicalDeserialize; use std::io::Read; -use rayon::prelude::*; - use crate::{ binfile::{BinFile, ZKeyParserError, ZKeyParserResult}, traits::{CheckElement, CircomArkworksPairingBridge, CircomArkworksPrimeFieldBridge}, @@ -52,6 +49,8 @@ pub struct ZKey { pub n_public: usize, /// domain size pub pow: usize, + /// the amount of constraints + pub num_constraints: usize, /// beta pub beta_g1: P::G1Affine, /// delta @@ -72,10 +71,16 @@ pub struct ZKey { pub beta_g2: P::G2Affine, /// delta_g1 pub delta_g2: P::G2Affine, - /// The constraint matrices A, B, and C - pub matrices: ConstraintMatrices, + /// The constraint matrices A + pub a_matrix: ConstraintMatrix, + /// The constraint matrices B + pub b_matrix: ConstraintMatrix, } +/// A constraint matrix used in Groth16. +pub type ConstraintMatrix = Vec>; +type ConstraintMatrixAB = (usize, ConstraintMatrix, ConstraintMatrix); + #[derive(Clone, Debug)] struct HeaderGroth { n_vars: usize, @@ -136,32 +141,24 @@ where s.spawn(|_| b_g1_query = Some(Self::b_g1_query(n_vars, b_g1_section, check))); s.spawn(|_| b_g2_query = Some(Self::b_g2_query(n_vars, b_g2_section, check))); s.spawn(|_| l_query = Some(Self::l_query(n_vars - n_public - 1, l_section, check))); - s.spawn(|_| h_query = Some(Self::h_query(domain_size as usize, h_section, check))); + s.spawn(|_| h_query = Some(Self::h_query(domain_size, h_section, check))); s.spawn(|_| { matrices = Some(Self::constraint_matrices( domain_size, n_public, - n_vars, matrices_section, )) }); }); - tracing::debug!("we are done with parsing sections!"); + let (num_constraints, a_matrix, b_matrix) = matrices.unwrap()?; // this thread automatically joins on the rayon scope, therefore we can // only be here if the scope finished. - //let vk = VerifyingKey { - // alpha_g1: header.alpha_g1, - // beta_g2: header.beta_g2, - // gamma_g2: header.gamma_g2, - // delta_g2: header.delta_g2, - // // unwrap is fine, because we are guaranteed to have a Some value (rayon scope) - // gamma_abc_g1: ic.unwrap()?, - //}; tracing::debug!("groth16 zkey parsing done!"); Ok(ZKey { n_public: header.n_public, pow: u32_to_usize!(header.pow), + num_constraints, beta_g1: header.beta_g1, delta_g1: header.delta_g1, // unwrap is fine, because we are guaranteed to have a Some value (rayon scope) @@ -173,7 +170,8 @@ where alpha_g1: header.alpha_g1, beta_g2: header.beta_g2, delta_g2: header.delta_g2, - matrices: matrices.unwrap()?, + a_matrix, + b_matrix, }) } @@ -220,14 +218,15 @@ where fn constraint_matrices( domain_size: usize, n_public: usize, - n_vars: usize, mut matrices_section: R, - ) -> ZKeyParserResult> { + ) -> ZKeyParserResult> { // this function (an all following uses) assumes that values are encoded in little-endian let num_coeffs = u32::deserialize_uncompressed(&mut matrices_section)?; // instantiate AB - let mut matrices = vec![vec![vec![]; domain_size]; 2]; + let a = vec![vec![]; domain_size]; + let b = vec![vec![]; domain_size]; + let mut matrices = [a, b]; let mut max_constraint_index = 0; for _ in 0..num_coeffs { let matrix = u32::deserialize_uncompressed(&mut matrices_section)?; @@ -245,26 +244,8 @@ where m.truncate(num_constraints); }); - // This is taken from Arkworks' to_matrices() function - let a = matrices[0].clone(); - let b = matrices[1].clone(); - let a_num_non_zero: usize = a.par_iter().map(|lc| lc.len()).sum(); - let b_num_non_zero: usize = b.par_iter().map(|lc| lc.len()).sum(); - - let matrices = ConstraintMatrices { - num_instance_variables: n_public + 1, - num_witness_variables: n_vars - n_public, - num_constraints, - - a_num_non_zero, - b_num_non_zero, - c_num_non_zero: 0, - - a, - b, - c: vec![], - }; - Ok(matrices) + let [a, b] = matrices; + Ok((num_constraints, a, b)) } } @@ -443,24 +424,10 @@ mod tests { { "1213509159032791114787919253810063723698125343911375817823407964507894154588429618034348468252648939670896208579873", "1573371412929811557753878280884507253544333246060733954030366147593600651713802914366664802456680232238300886611563"}, { "227372997676533734391726211114649274508389438640619116602997243907961458158899171192162581346407208971296972028627", "3173649281634920042594077931157174670855523098488107297282865037955359011267273317056899941445467620214571651786849"} ); - let gamma_g2 = test_utils::to_g2_bls12_381!( - { "352701069587466618187139116011060144890029952792775240219908644239793785735715026873347600343865175952761926303160", "3059144344244213709971259814753781636986470325476647558659373206291635324768958432433509563104347017837885763365758"}, - { "1985150602287291935568054521177171638300868978215655730859378665066344726373823718423869104263333984641494340347905", "927553665492332455747201965776037880757740193453592970025027978793976877002675564980949289727957565575433344219582"} - ); let delta_g2 = test_utils::to_g2_bls12_381!( { "1225439548733361287866553883695456824469134186836570397762131498241583159823035296217074111710636342557133382852358", "2605368487020759648403319793196297851010839805929073625099854787778388904778675959353258883417612421791844637077008"}, { "1154742119857928659368603772369477002539216605293799365584478673152507602473688973931247635774944414206241097299617", "3083613843092389681361977317882198510817133309742782178582263450336527557948727917944434768179612190551923309894740"} ); - let gamma_abc_g1 = vec![ - test_utils::to_g1_bls12_381!( - "1496325678302426440401133733502043551289869837205655668080008848699551523921245028359850882036392240986058622892606", - "1817947725837285375871533104780166089829860102882637736910105269739240593327578312097322455849119517519139026844600" - ), - test_utils::to_g1_bls12_381!( - "1718008724910268123339696488143341961797261917931626884153637247409759465219924679458496161324559634841879674394994", - "1374573688907712469603830822734104311026384172354584262904362700919219617284680686401889337872942140366529825919103" - ), - ]; assert_eq!(alpha_g1, pk.alpha_g1); assert_eq!(beta_g2, pk.beta_g2); assert_eq!(delta_g2, pk.delta_g2); @@ -561,24 +528,10 @@ mod tests { { "10507543441632391771444308193378912964353702039245296649929512844719350719061", "18201322790656668038537601329094316169506292175603805191741014817443184049262"}, { "5970405197328671009015216309153477729292937823545171027250144292199028398006", "207690659672174295265842461226025308763643182574816306177651013602294932409"} ); - let gamma_g2 = test_utils::to_g2_bn254!( - { "10857046999023057135944570762232829481370756359578518086990519993285655852781", "11559732032986387107991004021392285783925812861821192530917403151452391805634"}, - { "8495653923123431417604973247489272438418190587263600148770280649306958101930", "4082367875863433681332203403145435568316851327593401208105741076214120093531"} - ); let delta_g2 = test_utils::to_g2_bn254!( { "16155635570759079539128338844496116072647798864000233687303657902717776158999", "146722472349298011683444548694315820674090918095096001856936731325601586110"}, { "7220557679759413200896918190625936046017159618724594116959480938714251928850", "3740741795440491235944811815904112252316619638122978144672498770442910025884"} ); - let gamma_abc_g1 = vec![ - test_utils::to_g1_bn254!( - "17064056514210178269621297150176790945669784643731237949186503569701111845663", - "5160771857172547017310246971961987180872028348077571247747329170768684330052" - ), - test_utils::to_g1_bn254!( - "19547536507588365344778723326587455846790642159887261127893730469532513538882", - "10737415594461993507153866894812637432840367562913937920244709428556226500845" - ), - ]; assert_eq!(alpha_g1, pk.alpha_g1); assert_eq!(beta_g2, pk.beta_g2); assert_eq!(delta_g2, pk.delta_g2); @@ -591,15 +544,9 @@ mod tests { 2, )]]; let b = vec![vec![(ark_bn254::Fr::from_str("1").unwrap(), 3)]]; - assert_eq!(2, pk.matrices.num_instance_variables); - assert_eq!(3, pk.matrices.num_witness_variables); - assert_eq!(1, pk.matrices.num_constraints); - assert_eq!(1, pk.matrices.a_num_non_zero); - assert_eq!(1, pk.matrices.b_num_non_zero); - assert_eq!(0, pk.matrices.c_num_non_zero); - assert_eq!(a, pk.matrices.a); - assert_eq!(b, pk.matrices.b); - assert!(pk.matrices.c.is_empty()); + assert_eq!(1, pk.num_constraints); + assert_eq!(a, pk.a_matrix); + assert_eq!(b, pk.b_matrix); } } fn fq_from_str(s: &str) -> Fq { diff --git a/co-circom/co-circom-snarks/src/lib.rs b/co-circom/co-circom-snarks/src/lib.rs index e8b22f3f4..69400abbc 100644 --- a/co-circom/co-circom-snarks/src/lib.rs +++ b/co-circom/co-circom-snarks/src/lib.rs @@ -12,6 +12,7 @@ use mpc_core::protocols::{ use rand::{distributions::Standard, prelude::Distribution, CryptoRng, Rng, SeedableRng}; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; +use std::error::Error; /// This type represents the serialized version of a Rep3 witness. Its share can be either additive or replicated, and in both cases also compressed. #[derive(Debug, Serialize, Deserialize)] @@ -452,6 +453,44 @@ impl SharedWitness> { } } +/// The error type for the verification of a Circom proof. +/// +/// If the verification failed because the proof is Invalid, the method +/// will return the [VerificationError::InvalidProof] variant. If the +/// underlying implementation encounters an error, the method +/// will wrap that error in the [VerificationError::Malformed] variant. +#[derive(Debug)] +pub enum VerificationError { + /// Indicates that the proof verification failed + InvalidProof, + /// Wraps an underlying error (e.g., malformed verification key) + Malformed(eyre::Report), +} + +impl From for VerificationError { + fn from(error: eyre::Report) -> Self { + VerificationError::Malformed(error) + } +} + +impl std::error::Error for VerificationError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + VerificationError::Malformed(source) => Some(source.as_ref()), + VerificationError::InvalidProof => None, + } + } +} + +impl std::fmt::Display for VerificationError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + VerificationError::InvalidProof => writeln!(f, "proof is invalid"), + VerificationError::Malformed(error) => writeln!(f, "cannot verify proof: {error}"), + } + } +} + /// Gathers utility methods for proving coSNARKs. pub mod utils { use ark_ff::{FftField, LegendreSymbol, PrimeField}; diff --git a/co-circom/co-circom/src/bin/co-circom.rs b/co-circom/co-circom/src/bin/co-circom.rs index bfb07fa05..e309c927d 100644 --- a/co-circom/co-circom/src/bin/co-circom.rs +++ b/co-circom/co-circom/src/bin/co-circom.rs @@ -31,13 +31,13 @@ use co_circom::VerifyCli; use co_circom::VerifyConfig; use co_circom::{file_utils, MPCCurve, MPCProtocol, ProofSystem, SeedRng}; use co_circom_snarks::{ - SerializeableSharedRep3Input, SerializeableSharedRep3Witness, SharedWitness, + SerializeableSharedRep3Input, SerializeableSharedRep3Witness, SharedWitness, VerificationError, }; use co_groth16::Groth16; use co_groth16::{Rep3CoGroth16, ShamirCoGroth16}; use co_plonk::Rep3CoPlonk; use co_plonk::{Plonk, ShamirCoPlonk}; -use color_eyre::eyre::{eyre, Context, ContextCompat}; +use color_eyre::eyre::{self, eyre, Context, ContextCompat}; use mpc_core::protocols::{ bridges::network::RepToShamirNetwork, rep3::network::Rep3MpcNet, @@ -659,8 +659,7 @@ where // The actual verifier let start = Instant::now(); - let res = Groth16::

::verify(&vk, &proof, &public_inputs) - .context("while verifying proof")?; + let res = Groth16::

::verify(&vk, &proof, &public_inputs); let duration_ms = start.elapsed().as_micros() as f64 / 1000.; tracing::info!("Proof verification took {} ms", duration_ms); res @@ -674,20 +673,23 @@ where // The actual verifier let start = Instant::now(); - let res = - Plonk::

::verify(&vk, &proof, &public_inputs).context("while verifying proof")?; + let res = Plonk::

::verify(&vk, &proof, &public_inputs); let duration_ms = start.elapsed().as_micros() as f64 / 1000.; tracing::info!("Proof verification took {} ms", duration_ms); res } }; - if res { - tracing::info!("Proof verified successfully"); - Ok(ExitCode::SUCCESS) - } else { - tracing::error!("Proof verification failed"); - Ok(ExitCode::FAILURE) + match res { + Ok(_) => { + tracing::info!("Proof verified successfully"); + Ok(ExitCode::SUCCESS) + } + Err(VerificationError::InvalidProof) => { + tracing::error!("Proof verification failed"); + Ok(ExitCode::FAILURE) + } + Err(VerificationError::Malformed(err)) => eyre::bail!(err), } } diff --git a/co-circom/co-groth16/Cargo.toml b/co-circom/co-groth16/Cargo.toml index bbcb4e734..d111be942 100644 --- a/co-circom/co-groth16/Cargo.toml +++ b/co-circom/co-groth16/Cargo.toml @@ -24,7 +24,6 @@ ark-groth16 = { version = "=0.4.0", default-features = false, features = [ "parallel", ], optional = true } ark-poly = { workspace = true } -ark-relations = { workspace = true } ark-serialize = { workspace = true } circom-types = { version = "0.6.0", path = "../circom-types" } co-circom-snarks = { version = "0.2.0", path = "../co-circom-snarks" } diff --git a/co-circom/co-groth16/src/groth16.rs b/co-circom/co-groth16/src/groth16.rs index 1ee82b2f5..7c9c24c5c 100644 --- a/co-circom/co-groth16/src/groth16.rs +++ b/co-circom/co-groth16/src/groth16.rs @@ -4,8 +4,7 @@ use ark_ec::scalar_mul::variable_base::VariableBaseMSM; use ark_ec::{AffineRepr, CurveGroup}; use ark_ff::{FftField, PrimeField}; use ark_poly::{EvaluationDomain, GeneralEvaluationDomain}; -use ark_relations::r1cs::{ConstraintMatrices, Matrix, SynthesisError}; -use circom_types::groth16::{Groth16Proof, ZKey}; +use circom_types::groth16::{ConstraintMatrix, Groth16Proof, ZKey}; use circom_types::traits::{CircomArkworksPairingBridge, CircomArkworksPrimeFieldBridge}; use co_circom_snarks::SharedWitness; use eyre::Result; @@ -116,19 +115,9 @@ where let id = self.driver.get_party_id(); tracing::info!("Party {}: starting proof generation..", id); let start = Instant::now(); - let matrices = &zkey.matrices; - let num_inputs = matrices.num_instance_variables; - let num_constraints = matrices.num_constraints; let public_inputs = Arc::new(private_witness.public_inputs); let private_witness = Arc::new(private_witness.witness); - let h = self.witness_map_from_matrices( - zkey.pow, - matrices, - num_constraints, - num_inputs, - &public_inputs, - &private_witness, - )?; + let h = self.witness_map_from_matrices(&zkey, &public_inputs, &private_witness)?; let (r, s) = (self.driver.rand()?, self.driver.rand()?); let proof = self.create_proof_with_assignment( @@ -148,7 +137,7 @@ where fn evaluate_constraint( party_id: T::PartyID, domain_size: usize, - matrix: &Matrix, + matrix: &ConstraintMatrix, public_inputs: &[P::ScalarField], private_witness: &[T::ArithmeticShare], ) -> Vec { @@ -164,16 +153,16 @@ where #[instrument(level = "debug", name = "witness map from matrices", skip_all)] fn witness_map_from_matrices( &mut self, - power: usize, - matrices: &ConstraintMatrices, - num_constraints: usize, - num_inputs: usize, + zkey: &ZKey

, public_inputs: &[P::ScalarField], private_witness: &[T::ArithmeticShare], ) -> Result> { + let num_constraints = zkey.num_constraints; + let num_inputs = zkey.n_public + 1; + let power = zkey.pow; let mut domain = GeneralEvaluationDomain::::new(num_constraints + num_inputs) - .ok_or(SynthesisError::PolynomialDegreeTooLarge)?; + .ok_or(eyre::eyre!("Polynomial Degree too large"))?; let domain_size = domain.size(); let party_id = self.driver.get_party_id(); let eval_constraint_span = @@ -198,7 +187,7 @@ where let mut result = Self::evaluate_constraint( party_id, domain_size, - &matrices.a, + &zkey.a_matrix, public_inputs, private_witness, ); @@ -214,7 +203,7 @@ where let result = Self::evaluate_constraint( party_id, domain_size, - &matrices.b, + &zkey.b_matrix, public_inputs, private_witness, ); diff --git a/co-circom/co-groth16/src/lib.rs b/co-circom/co-groth16/src/lib.rs index 5a8e65221..e60dfdfc6 100644 --- a/co-circom/co-groth16/src/lib.rs +++ b/co-circom/co-groth16/src/lib.rs @@ -52,9 +52,7 @@ mod tests { Groth16::::plain_prove(zkey, witness).expect("proof generation works"); let ser_proof = serde_json::to_string(&proof).unwrap(); let der_proof = serde_json::from_str::>(&ser_proof).unwrap(); - let verified = - Groth16::verify(&vk, &der_proof, &public_input[1..]).expect("can verify"); - assert!(verified); + Groth16::verify(&vk, &der_proof, &public_input[1..]).expect("can verify"); } } @@ -73,9 +71,8 @@ mod tests { let public_input = serde_json::from_str::>(public_string).unwrap(); let proof = serde_json::from_str::>(&proof_string).unwrap(); - let verified = - Groth16::::verify(&vk, &proof, &public_input.values).expect("can verify"); - assert!(verified) + + Groth16::::verify(&vk, &proof, &public_input.values).expect("can verify"); } #[test] @@ -100,9 +97,7 @@ mod tests { Groth16::::plain_prove(zkey, witness).expect("proof generation works"); let ser_proof = serde_json::to_string(&proof).unwrap(); let der_proof = serde_json::from_str::>(&ser_proof).unwrap(); - let verified = - Groth16::verify(&vk, &der_proof, &public_input[1..]).expect("can verify"); - assert!(verified); + Groth16::verify(&vk, &der_proof, &public_input[1..]).expect("can verify"); } } @@ -121,8 +116,7 @@ mod tests { let public_input = serde_json::from_str::>(public_string).unwrap(); let proof = serde_json::from_str::>(&proof_string).unwrap(); - let verified = Groth16::verify(&vk, &proof, &public_input.values).expect("can verify"); - assert!(verified) + Groth16::verify(&vk, &proof, &public_input.values).expect("can verify"); } #[test] @@ -140,9 +134,7 @@ mod tests { let public_input = serde_json::from_str::>(public_string).unwrap(); let proof = serde_json::from_str::>(&proof_string).unwrap(); - let verified = - Groth16::::verify(&vk, &proof, &public_input.values).expect("can verify"); - assert!(verified) + Groth16::::verify(&vk, &proof, &public_input.values).expect("can verify"); } #[test] @@ -169,14 +161,10 @@ mod tests { let proof = Groth16::::plain_prove(zkey, witness).expect("proof generation works"); - let verified = - Groth16::::verify(&vk, &proof, &public_input[1..]).expect("can verify"); - assert!(verified); + Groth16::::verify(&vk, &proof, &public_input[1..]).expect("can verify"); let ser_proof = serde_json::to_string(&proof).unwrap(); let der_proof = serde_json::from_str::>(&ser_proof).unwrap(); - let verified = Groth16::::verify(&vk, &der_proof, &public_input[1..]) - .expect("can verify"); - assert!(verified) + Groth16::::verify(&vk, &der_proof, &public_input[1..]).expect("can verify"); } } @@ -201,14 +189,10 @@ mod tests { let proof = Groth16::::plain_prove(zkey, witness).expect("proof generation works"); - let verified = - Groth16::::verify(&vk, &proof, &public_input[1..]).expect("can verify"); - assert!(verified); + Groth16::::verify(&vk, &proof, &public_input[1..]).expect("can verify"); let ser_proof = serde_json::to_string(&proof).unwrap(); let der_proof = serde_json::from_str::>(&ser_proof).unwrap(); - let verified = - Groth16::::verify(&vk, &der_proof, &public_input[1..]).expect("can verify"); - assert!(verified) + Groth16::::verify(&vk, &der_proof, &public_input[1..]).expect("can verify"); } } } diff --git a/co-circom/co-groth16/src/verifier.rs b/co-circom/co-groth16/src/verifier.rs index ed61a7ec3..6893e05fb 100644 --- a/co-circom/co-groth16/src/verifier.rs +++ b/co-circom/co-groth16/src/verifier.rs @@ -11,6 +11,7 @@ use circom_types::groth16::{Groth16Proof, JsonVerificationKey}; use circom_types::traits::{CircomArkworksPairingBridge, CircomArkworksPrimeFieldBridge}; use ark_groth16::Groth16 as ArkworksGroth16; +use co_circom_snarks::VerificationError; impl Groth16

where @@ -24,7 +25,7 @@ where vk: &JsonVerificationKey

, proof: &Groth16Proof

, public_inputs: &[P::ScalarField], - ) -> Result { + ) -> Result<(), VerificationError> { let vk = VerifyingKey::

{ alpha_g1: vk.alpha_1, beta_g2: vk.beta_2, @@ -39,6 +40,12 @@ where }; let vk = ark_groth16::prepare_verifying_key(&vk); - ArkworksGroth16::

::verify_proof(&vk, &proof, public_inputs) + let proof_valid = ArkworksGroth16::

::verify_proof(&vk, &proof, public_inputs) + .map_err(eyre::Report::from)?; + if proof_valid { + Ok(()) + } else { + Err(VerificationError::InvalidProof) + } } } diff --git a/co-circom/co-plonk/src/lib.rs b/co-circom/co-plonk/src/lib.rs index 78d3a6bee..76ef8422d 100644 --- a/co-circom/co-plonk/src/lib.rs +++ b/co-circom/co-plonk/src/lib.rs @@ -299,8 +299,7 @@ pub mod tests { .unwrap(); let proof = Plonk::::plain_prove(zkey, witness).unwrap(); - let result = Plonk::::verify(&vk, &proof, &public_input.values).unwrap(); - assert!(result); + Plonk::::verify(&vk, &proof, &public_input.values).unwrap(); } Ok(()) } @@ -337,8 +336,7 @@ pub mod tests { let mut proof_bytes = vec![]; serde_json::to_writer(&mut proof_bytes, &proof).unwrap(); let proof = serde_json::from_reader(proof_bytes.as_slice()).unwrap(); - let result = Plonk::::verify(&vk, &proof, &public_inputs.values).unwrap(); - assert!(result) + Plonk::::verify(&vk, &proof, &public_inputs.values).unwrap(); } } } diff --git a/co-circom/co-plonk/src/plonk.rs b/co-circom/co-plonk/src/plonk.rs index 5bcfa360b..f8ae82f3c 100644 --- a/co-circom/co-plonk/src/plonk.rs +++ b/co-circom/co-plonk/src/plonk.rs @@ -13,7 +13,7 @@ use circom_types::{ plonk::{JsonVerificationKey, PlonkProof, ZKey}, traits::{CircomArkworksPairingBridge, CircomArkworksPrimeFieldBridge}, }; -use co_circom_snarks::SharedWitness; +use co_circom_snarks::{SharedWitness, VerificationError}; use num_traits::One; use num_traits::Zero; @@ -136,7 +136,7 @@ where vk: &JsonVerificationKey

, proof: &PlonkProof

, public_inputs: &[P::ScalarField], - ) -> Result + ) -> Result<(), VerificationError> where P: Pairing, P: CircomArkworksPairingBridge, @@ -144,11 +144,13 @@ where P::ScalarField: CircomArkworksPrimeFieldBridge, { if vk.n_public != public_inputs.len() { - return Err(eyre::eyre!("Invalid number of public inputs")); + return Err(VerificationError::Malformed(eyre::eyre!( + "Invalid number of public inputs" + ))); } let challenges = VerifierChallenges::

::new(vk, proof, public_inputs); - let domains = Domains::::new(1 << vk.power)?; + let domains = Domains::::new(1 << vk.power).map_err(eyre::Report::from)?; let (l, xin) = plonk_utils::calculate_lagrange_evaluations::

( vk.power, @@ -161,15 +163,13 @@ where let e = Plonk::

::calculate_e(proof, &challenges, r0); let f = Plonk::

::calculate_f(vk, proof, &challenges, d); + let valid = Plonk::

::valid_pairing(vk, proof, &challenges, e, f, &domains); - Ok(Plonk::

::valid_pairing( - vk, - proof, - &challenges, - e, - f, - &domains, - )) + if valid { + Ok(()) + } else { + Err(VerificationError::InvalidProof) + } } pub(crate) fn calculate_r0_d( @@ -386,7 +386,7 @@ pub mod tests { File::open("../../test_vectors/Plonk/bn254/multiplier2/public.json").unwrap(), ) .unwrap(); - assert!(Plonk::verify(&vk, &proof, &public_inputs.values).unwrap()); + Plonk::verify(&vk, &proof, &public_inputs.values).unwrap(); } #[test] @@ -403,6 +403,6 @@ pub mod tests { File::open("../../test_vectors/Plonk/bn254/poseidon/public.json").unwrap(), ) .unwrap(); - assert!(Plonk::verify(&vk, &proof, &public_inputs.values).unwrap()); + Plonk::verify(&vk, &proof, &public_inputs.values).unwrap(); } } diff --git a/tests/tests/circom/e2e_tests/rep3.rs b/tests/tests/circom/e2e_tests/rep3.rs index 2235f0545..9ce59623b 100644 --- a/tests/tests/circom/e2e_tests/rep3.rs +++ b/tests/tests/circom/e2e_tests/rep3.rs @@ -82,9 +82,7 @@ macro_rules! add_test_impl { ) .unwrap(); assert_eq!(der_proof, result2); - let verified = $proof_system::<$curve>::verify(&vk, &der_proof, &public_input).expect("can verify"); - assert!(verified); } #[test] @@ -101,9 +99,7 @@ macro_rules! add_test_impl { .unwrap(); let public_input: JsonPublicInput::<[< ark_ $curve:lower >]::Fr> = serde_json::from_reader(public_input_file).unwrap(); let snarkjs_proof: [< $proof_system Proof >]<$curve> = serde_json::from_reader(&snarkjs_proof_file).unwrap(); - let verified = $proof_system::<$curve>::verify(&vk, &snarkjs_proof, &public_input.values).expect("can verify"); - assert!(verified); } } }; diff --git a/tests/tests/circom/e2e_tests/shamir.rs b/tests/tests/circom/e2e_tests/shamir.rs index a008b0088..f150402c8 100644 --- a/tests/tests/circom/e2e_tests/shamir.rs +++ b/tests/tests/circom/e2e_tests/shamir.rs @@ -93,9 +93,7 @@ macro_rules! add_test_impl { ) .unwrap(); assert_eq!(der_proof, result2); - let verified = $proof_system::<$curve>::verify(&vk, &der_proof, &public_input).expect("can verify"); - assert!(verified); } #[test] @@ -112,9 +110,7 @@ macro_rules! add_test_impl { .unwrap(); let public_input: JsonPublicInput::<[< ark_ $curve:lower >]::Fr> = serde_json::from_reader(public_input_file).unwrap(); let snarkjs_proof: [< $proof_system Proof >]<$curve> = serde_json::from_reader(&snarkjs_proof_file).unwrap(); - let verified = $proof_system::<$curve>::verify(&vk, &snarkjs_proof, &public_input.values).expect("can verify"); - assert!(verified); } } }; From dcca49a0c102778826d0a2399c4d5c0410dd567f Mon Sep 17 00:00:00 2001 From: Franco Nieddu Date: Tue, 26 Nov 2024 14:03:47 +0100 Subject: [PATCH 4/4] fix: added a check during groth16 prover for public inputs --- co-circom/co-groth16/src/groth16.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/co-circom/co-groth16/src/groth16.rs b/co-circom/co-groth16/src/groth16.rs index 7c9c24c5c..a21abfdbc 100644 --- a/co-circom/co-groth16/src/groth16.rs +++ b/co-circom/co-groth16/src/groth16.rs @@ -116,6 +116,14 @@ where tracing::info!("Party {}: starting proof generation..", id); let start = Instant::now(); let public_inputs = Arc::new(private_witness.public_inputs); + if public_inputs.len() != zkey.n_public + 1 { + eyre::bail!( + "amount of public inputs do not match with provided zkey! Expected {}, but got {}", + zkey.n_public + 1, + public_inputs.len() + ) + } + let private_witness = Arc::new(private_witness.witness); let h = self.witness_map_from_matrices(&zkey, &public_inputs, &private_witness)?; let (r, s) = (self.driver.rand()?, self.driver.rand()?);