From cedb7e453ad8a3e9ec6699664ca6b6356dd38496 Mon Sep 17 00:00:00 2001 From: PatStiles Date: Fri, 8 Mar 2024 04:00:36 +0000 Subject: [PATCH] rewrite with generic msm fail on G's type --- crypto/src/commitments/kzg.rs | 17 +-- math/src/gpu/icicle.rs | 251 ------------------------------- math/src/gpu/icicle/bls12_381.rs | 6 +- math/src/gpu/icicle/mod.rs | 180 +++++++++++++++++++++- math/src/msm/naive.rs | 3 + math/src/msm/pippenger.rs | 31 +++- provers/groth16/src/prover.rs | 10 +- provers/groth16/src/verifier.rs | 5 +- 8 files changed, 215 insertions(+), 288 deletions(-) delete mode 100644 math/src/gpu/icicle.rs diff --git a/crypto/src/commitments/kzg.rs b/crypto/src/commitments/kzg.rs index 2ecbb483c..9d8587605 100644 --- a/crypto/src/commitments/kzg.rs +++ b/crypto/src/commitments/kzg.rs @@ -5,7 +5,7 @@ use lambdaworks_math::{ cyclic_group::IsGroup, elliptic_curve::traits::IsPairing, errors::DeserializationError, - field::{element::FieldElement, traits::IsPrimeField}, + field::{element::FieldElement, traits::{IsPrimeField, IsField}}, msm::pippenger::msm, polynomial::Polynomial, traits::{AsBytes, Deserializable}, @@ -136,12 +136,12 @@ where } #[derive(Clone)] -pub struct KateZaveruchaGoldberg { +pub struct KateZaveruchaGoldberg { srs: StructuredReferenceString, phantom: PhantomData, } -impl KateZaveruchaGoldberg { +impl KateZaveruchaGoldberg { pub fn new(srs: StructuredReferenceString) -> Self { Self { srs, @@ -150,20 +150,15 @@ impl KateZaveruchaGoldberg { } } -impl>, P: IsPairing> +impl> + IsPrimeField>, P: IsPairing> IsCommitmentScheme for KateZaveruchaGoldberg { type Commitment = P::G1Point; fn commit(&self, p: &Polynomial>) -> Self::Commitment { - let coefficients: Vec<_> = p - .coefficients - .iter() - .map(|coefficient| coefficient.representative()) - .collect(); msm( - &coefficients, - &self.srs.powers_main_group[..coefficients.len()], + &p.coefficients, + &self.srs.powers_main_group[..p.coefficients.len()], ) .expect("`points` is sliced by `cs`'s length") } diff --git a/math/src/gpu/icicle.rs b/math/src/gpu/icicle.rs deleted file mode 100644 index fa5072472..000000000 --- a/math/src/gpu/icicle.rs +++ /dev/null @@ -1,251 +0,0 @@ -use icicle_bls12_377::{CurveCfg, G1Projective, G2CurveCfg, G2Projective, ScalarCfg}; -use icicle_bls12_381::{CurveCfg, G1Projective, G2CurveCfg, G2Projective, ScalarCfg}; -use icicle_bn254::{CurveCfg, G1Projective, G2CurveCfg, G2Projective, ScalarCfg}; -use icicle_core::{ - field::Field, - msm, - traits::{FieldConfig, FieldImpl, GenerateRandom}, - Curve::{Affine, Curve, Projective}, - Field::{Field, FieldImpl, MontgomeryConvertibleField}, -}; -use icicle_cuda_runtime::{memory::HostOrDeviceSlice, stream::CudaStream}; - -use crate::{ - elliptic_curve::{ - short_weierstrass::{ - curves::{ - bls12_377::{ - curve::{BLS12377Curve, BLS12377FieldElement}, - field_extension::BLS12377PrimeField, - }, - bls12_381::{ - curve::{BLS12381Curve, BLS12381FieldElement, BLS12381TwistCurveFieldElement}, - twist::BLS12381TwistCurve, - }, - bn_254::{ - curve::{BN254Curve, BN254FieldElement, BN254TwistCurveFieldElement}, - twist::BN254TwistCurve, - }, - }, - point::ShortWeierstrassProjectivePoint, - }, - traits::IsEllipticCurve, - }, - errors::ByteConversionError, - field::{element::FieldElement, traits::IsField}, - traits::ByteConversion, -}; - -use core::fmt::Debug; - -/// Notes: -/// Lambdaworks supplies rust bindings generic over there internal Field and Coordinate types. -/// The best solution is to upstream a `LambdaConvertible` trait implementation that handles this conversion for us. -/// In the meantime conversions are for specific curves and field implemented as the Icicle's Field type is not abstracted -/// from the field configuration or number of underlying limbs used in its representation - -/// trait for Conversions of lambdaworks type -> Icicle type -/// NOTE: This may be removed with eliminating `LambdaConvertible` -pub trait ToIcicle: Clone + Debug { - type IcicleType; - - fn to_icicle(&self) -> Self::IcicleType; - fn from_icicle(icicle: Self::IcicleType) -> Result; -} - -impl ToIcicle for BLS12377FieldElement { - type IcicleType = icicle_bls12_377::curve::BaseField; - - fn to_icicle(&self) -> Self::IcicleType { - IcicleType::from_bytes_le(self.to_representative().to_bytes_le()) - } - - fn from_icicle(icicle: Self::IcicleType) -> Result { - Self::from_bytes_le(icicle.to_repr().to_bytes_le()) - } -} - -impl ToIcicle for BLS12381FieldElement { - type IcicleType = icicle_bls12_381::curve::BaseField; - - fn to_icicle(&self) -> Self::IcicleType { - IcicleType::from_bytes_le(self.to_representative().to_bytes_le()) - } - - fn from_icicle(icicle: Self::IcicleType) -> Result { - Self::from_bytes_le(icicle.to_repr().to_bytes_le()) - } -} - -impl ToIcicle for BLS12381TwistCurveFieldElement { - type IcicleType = icicle_bls12_381::curve::BaseField; - - fn to_icicle(&self) -> Self::IcicleType { - IcicleType::from_bytes_le(self.to_representative().to_bytes_le()) - } - - fn from_icicle(icicle: Self::IcicleType) -> Result { - Self::from_bytes_le(icicle.to_repr().to_bytes_le()) - } -} - -impl ToIcicle for BN254FieldElement { - type IcicleType = icicle_bn254::curve::BaseField; - - fn to_icicle(&self) -> Self::IcicleType { - IcicleType::from_bytes_le(self.to_representative().to_bytes_le()) - } - - fn from_icicle(icicle: Self::IcicleType) -> Result { - Self::from_bytes_le(icicle.to_repr().to_bytes_le()) - } -} - -impl ToIcicle for BN254TwistCurveFieldElement { - type IcicleType = icicle_bn254::curve::BaseField; - - fn to_icicle(&self) -> Self::IcicleType { - IcicleType::from_bytes_le(self.to_representative().to_bytes_le()) - } - - fn from_icicle(icicle: Self::IcicleType) -> Result { - Self::from_bytes_le(&icicle.to_bytes_le()) - } -} - -impl ToIcicle for ShortWeierstrassProjectivePoint { - type IcicleType = icicle_bls12_377::curve::G1Projective; - - fn to_icicle(&self) -> Self::IcicleType { - Self::IcicleType { - x: self.x().to_icicle(), - y: self.y().to_icicle(), - z: self.z().to_icicle(), - } - } - - fn from_icicle(icicle: Self::IcicleType) -> Result { - Ok(Self::new([ - FieldElement::::from_icicle(icicle.x).unwrap(), - FieldElement::::from_icicle(icicle.y).unwrap(), - FieldElement::::from_icicle(icicle.z).unwrap(), - ])) - } -} - -impl ToIcicle for ShortWeierstrassProjectivePoint { - type IcicleType = icicle_bls12_3811::curve::G1Projective; - - fn to_icicle(&self) -> Self::IcicleType { - Self::IcicleType { - x: self.x().to_icicle(), - y: self.y().to_icicle(), - z: self.z().to_icicle(), - } - } - - fn from_icicle(icicle: Self::IcicleType) -> Result { - Ok(Self::new([ - FieldElement::::from_icicle(icicle.x).unwrap(), - FieldElement::::from_icicle(icicle.y).unwrap(), - FieldElement::::from_icicle(icicle.z).unwrap(), - ])) - } -} - -impl ToIcicle for ShortWeierstrassProjectivePoint { - type IcicleType = icicle_bls12_381::curve::G2Projective; - - fn to_icicle(&self) -> Self::IcicleType { - Self::IcicleType { - x: self.x().to_icicle(), - y: self.y().to_icicle(), - z: self.z().to_icicle(), - } - } - - fn from_icicle(icicle: Self::IcicleType) -> Result { - Ok(Self::new([ - FieldElement::::from_icicle(icicle.x).unwrap(), - FieldElement::::from_icicle(icicle.y).unwrap(), - FieldElement::::from_icicle(icicle.z).unwrap(), - ])) - } -} - -impl ToIcicle for ShortWeierstrassProjectivePoint { - type IcicleType = icicle_bn254::curve::G1Projective; - - fn to_icicle(&self) -> Self::IcicleType { - Self::IcicleType { - x: self.x().to_icicle(), - y: self.y().to_icicle(), - z: self.z().to_icicle(), - } - } - - fn from_icicle(icicle: Self::IcicleType) -> Result { - Ok(Self::new([ - FieldElement::::from_icicle(icicle.x).unwrap(), - FieldElement::::from_icicle(icicle.y).unwrap(), - FieldElement::::from_icicle(icicle.z).unwrap(), - ])) - } -} - -impl ToIcicle for ShortWeierstrassProjectivePoint { - type IcicleType = icicle_bn254::curve::G2Projective; - - fn to_icicle(&self) -> Self::IcicleType { - Self::IcicleType { - x: self.x().to_icicle(), - y: self.y().to_icicle(), - z: self.z().to_icicle(), - } - } - - fn from_icicle(icicle: Self::IcicleType) -> Result { - Ok(Self::new([ - FieldElement::::from_icicle(icicle.x).unwrap(), - FieldElement::::from_icicle(icicle.y).unwrap(), - FieldElement::::from_icicle(icicle.z).unwrap(), - ])) - } -} - -/// Performs msm using Icicle GPU, intitiates, allocates, and configures all gpu operations -/// TODO: determining where this setup should occur is an open question -fn msm( - scalars: &[impl ToIcicle], - points: &[impl ToIcicle], -) -> ShortWeierstrassProjectivePoint { - let scalars = HostOrDeviceSlice::Host(&scalars.iter().map(to_icicle()).collect::>()); - let point = HostOrDeviceSlice::Host(&points.iter().map(to_icicle()).collect::>()); - let mut msm_results = HostOrDeviceSlice::cuda_malloc(1).unwrap(); - let stream = CudaStream::create().unwrap(); - let mut cfg = msm::get_default_msm_config(); - cfg.ctx.stream = &stream; - cfg.is_async = true; - msm::msm(&scalars, &points, &cfg, &mut msm_results).unwrap(); - let mut msm_host_result = Vec::new(); - stream.synchronize().unwrap(); - msm_results.copy_to_host(&mut msm_host_result[..]).unwrap(); - stream.destroy().unwrap(); -} - -/// Performs ntt using Icicle GPU, intitiates, allocates, and configures all gpu operations -fn ntt(scalars: &[impl ToIcicle], points: &[impl ToIcicle]) -> FieldElement { - let point = HostOrDeviceSlice::Host(&points.iter().map(to_icicle()).collect::>()); - let mut ntt_results = HostOrDeviceSlice::cuda_malloc(1).unwrap(); - let stream = CudaStream::create().unwrap(); - let mut cfg = msm::get_default_msm_config(); - cfg.ctx.stream = &stream; - cfg.is_async = true; - msm::msm(&scalars, &points, &cfg, &mut msm_results).unwrap(); - let mut ntt_host_result = Vec::new(); - stream.synchronize().unwrap(); - ntt_results.copy_to_host(&mut msm_host_result[..]).unwrap(); - stream.destroy().unwrap(); - - let ctx = get_default_device_context(); -} diff --git a/math/src/gpu/icicle/bls12_381.rs b/math/src/gpu/icicle/bls12_381.rs index bf68f2304..ce1f123b8 100644 --- a/math/src/gpu/icicle/bls12_381.rs +++ b/math/src/gpu/icicle/bls12_381.rs @@ -142,11 +142,7 @@ mod test { let lambda_scalars = vec![eight; LEN]; let lambda_points = (0..LEN).map(|_| point_times_5()).collect::>(); let expected = msm( - &lambda_scalars - .clone() - .into_iter() - .map(|x| x.representative()) - .collect::>(), + &lambda_scalars, &lambda_points, ) .unwrap(); diff --git a/math/src/gpu/icicle/mod.rs b/math/src/gpu/icicle/mod.rs index 3a0698b9f..5a85dec91 100644 --- a/math/src/gpu/icicle/mod.rs +++ b/math/src/gpu/icicle/mod.rs @@ -1,3 +1,177 @@ -pub mod bls12_377; -pub mod bls12_381; -pub mod bn254; +//pub mod bls12_377; +//pub mod bls12_381; +//pub mod bn254; + +use icicle_bls12_381::curve::CurveCfg as IcicleBLS12381Curve; +use icicle_bls12_377::curve::CurveCfg as IcicleBLS12377Curve; +use icicle_bn254::curve::CurveCfg as IcicleBN254Curve; +use icicle_cuda_runtime::{memory::HostOrDeviceSlice, stream::CudaStream}; +use icicle_core::{error::IcicleError, msm, curve::{Curve, Affine, Projective}, traits::FieldImpl}; +use crate::{ + elliptic_curve::{short_weierstrass::{ + curves::{ + bls12_381::curve::BLS12381Curve, + bls12_377::curve::BLS12377Curve, + bn_254::curve::BN254Curve + }, + traits::IsShortWeierstrass, point::ShortWeierstrassProjectivePoint}, traits::IsEllipticCurve, + }, + field::{element::FieldElement, traits::IsField}, + unsigned_integer::element::UnsignedInteger, + cyclic_group::IsGroup, + errors::ByteConversionError, + traits::ByteConversion +}; + +use std::fmt::Debug; + +impl Icicle for BLS12381Curve {} +impl Icicle for BLS12377Curve {} +impl Icicle for BN254Curve {} + +pub trait Icicle +where + FieldElement: ByteConversion +{ + /// Used for searching this field's implementation in other languages, e.g in MSL + /// for executing parallel operations with the Metal API. + fn field_name() -> &'static str { + "" + } + + fn to_icicle_field(element: &FieldElement) -> I::BaseField { + I::BaseField::from_bytes_le(&element.to_bytes_le()) + } + + fn to_icicle_scalar(element: &FieldElement) -> I::ScalarField { + I::ScalarField::from_bytes_le(&element.to_bytes_le()) + } + + fn from_icicle_field(icicle: &I::BaseField) -> Result, ByteConversionError> { + FieldElement::::from_bytes_le(&icicle.to_bytes_le()) + } + + fn to_icicle_affine(point: &ShortWeierstrassProjectivePoint) -> Affine { + let s = ShortWeierstrassProjectivePoint::::to_affine(point); + Affine:: { + x: Self::to_icicle_field(s.x()), + y: Self::to_icicle_field(s.y()), + } + } + + fn from_icicle_projective(icicle: &Projective) -> Result, ByteConversionError> { + Ok(ShortWeierstrassProjectivePoint::::new([ + Self::from_icicle_field(&icicle.x).unwrap(), + Self::from_icicle_field(&icicle.y).unwrap(), + Self::from_icicle_field(&icicle.z).unwrap(), + ])) + } + +} + +pub fn icicle_msm>( + scalars: &[FieldElement], + points: &[ShortWeierstrassProjectivePoint] + ) -> Result, IcicleError> +where + C: Icicle, + FieldElement<::BaseField>: ByteConversion +{ + let mut cfg = msm::MSMConfig::default(); + let scalars = HostOrDeviceSlice::Host( + scalars + .iter() + .map(|scalar| C::to_icicle_scalar(&scalar)) + .collect::>(), + ); + let points = HostOrDeviceSlice::Host( + points + .iter() + .map(|point| C::to_icicle_affine(&point)) + .collect::>(), + ); + let mut msm_results = HostOrDeviceSlice::cuda_malloc(1).unwrap(); + let stream = CudaStream::create().unwrap(); + cfg.ctx.stream = &stream; + cfg.is_async = true; + msm::msm(&scalars, &points, &cfg, &mut msm_results).unwrap(); + let mut msm_host_result = vec![Projective::::zero(); 1]; + stream.synchronize().unwrap(); + msm_results.copy_to_host(&mut msm_host_result[..]).unwrap(); + stream.destroy().unwrap(); + let res = + C::from_icicle_projective(&msm_host_result[0]).unwrap(); + Ok(res) +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{ + elliptic_curve::{ + short_weierstrass::curves::bls12_381::curve::BLS12381FieldElement, + traits::IsEllipticCurve, + }, + field::element::FieldElement, + msm::pippenger::msm, + }; + + impl ShortWeierstrassProjectivePoint { + fn from_icicle_affine( + icicle: &curve::G1Affine, + ) -> Result, ByteConversionError> { + Ok(Self::new([ + FieldElement::::from_icicle(&icicle.x).unwrap(), + FieldElement::::from_icicle(&icicle.y).unwrap(), + FieldElement::one(), + ])) + } + } + + fn point_times_5() -> ShortWeierstrassProjectivePoint { + let x = BLS12381FieldElement::from_hex_unchecked( + "32bcce7e71eb50384918e0c9809f73bde357027c6bf15092dd849aa0eac274d43af4c68a65fb2cda381734af5eecd5c", + ); + let y = BLS12381FieldElement::from_hex_unchecked( + "11e48467b19458aabe7c8a42dc4b67d7390fdf1e150534caadddc7e6f729d8890b68a5ea6885a21b555186452b954d88", + ); + BLS12381Curve::create_point_from_affine(x, y).unwrap() + } + + #[test] + fn to_from_icicle() { + // convert value of 5 to icicle and back again and that icicle 5 matches + let point = point_times_5(); + let icicle_point = point.to_icicle(); + let res = + ShortWeierstrassProjectivePoint::::from_icicle_affine(&icicle_point) + .unwrap(); + assert_eq!(point, res) + } + + #[test] + fn to_from_icicle_generator() { + // Convert generator and see that it matches + let point = BLS12381Curve::generator(); + let icicle_point = point.to_icicle(); + let res = + ShortWeierstrassProjectivePoint::::from_icicle_affine(&icicle_point) + .unwrap(); + assert_eq!(point, res) + } + + #[test] + fn icicle_g1_msm() { + const LEN: usize = 20; + let eight: BLS12381FieldElement = FieldElement::from(8); + let lambda_scalars = vec![eight; LEN]; + let lambda_points = (0..LEN).map(|_| point_times_5()).collect::>(); + let expected = msm( + &lambda_scalars, + &lambda_points, + ) + .unwrap(); + let res = bls12_381_g1_msm(&lambda_scalars, &lambda_points, None).unwrap(); + assert_eq!(res, expected); + } +} diff --git a/math/src/msm/naive.rs b/math/src/msm/naive.rs index fec773073..985d08079 100644 --- a/math/src/msm/naive.rs +++ b/math/src/msm/naive.rs @@ -2,10 +2,13 @@ use core::fmt::Display; use crate::cyclic_group::IsGroup; use crate::unsigned_integer::traits::IsUnsignedInteger; +#[cfg(feature = "icicle")] +use icicle_core::error::IcicleError; #[derive(Debug)] pub enum MSMError { LengthMismatch(usize, usize), + Icicle(IcicleError) } impl Display for MSMError { diff --git a/math/src/msm/pippenger.rs b/math/src/msm/pippenger.rs index 1ccfce9c3..0f813e9a7 100644 --- a/math/src/msm/pippenger.rs +++ b/math/src/msm/pippenger.rs @@ -1,4 +1,4 @@ -use crate::{cyclic_group::IsGroup, unsigned_integer::element::UnsignedInteger}; +use crate::{cyclic_group::IsGroup, unsigned_integer::element::UnsignedInteger, field::{traits::IsField, element::FieldElement}, gpu::icicle::icicle_msm}; use super::naive::MSMError; @@ -15,8 +15,8 @@ use alloc::vec; /// If `points` and `cs` are empty, then `msm` returns the zero element of the group. /// /// Panics if `cs` and `points` have different lengths. -pub fn msm( - cs: &[UnsignedInteger], +pub fn msm>>( + cs: &[FieldElement], points: &[G], ) -> Result where @@ -26,9 +26,30 @@ where return Err(MSMError::LengthMismatch(cs.len(), points.len())); } - let window_size = optimum_window_size(cs.len()); + #[cfg(feature = "icicle")] + { + icicle_msm(cs, points).map_err(|e| MSMError::Icicle(e)) + /* + if !F::field_name().is_empty() { + icicle_msm(cs, points) + } else { + println!( + "Icicle msm failed for field {}. Program will fallback to CPU.", + core::any::type_name::() + ); + let window_size = optimum_window_size(cs.len()); + let cs = cs.into_iter().map(|cs| *cs.calue()).collect::>(); + Ok(msm_with(cs, points, window_size)) + } + */ + } - Ok(msm_with(cs, points, window_size)) + #[cfg(not(feature = "icicle"))] + { + let window_size = optimum_window_size(cs.len()); + let cs = cs.into_iter().map(|cs| *cs.value()).collect::>(); + Ok(msm_with(cs, points, window_size)) + } } fn optimum_window_size(data_length: usize) -> usize { diff --git a/provers/groth16/src/prover.rs b/provers/groth16/src/prover.rs index b0dbeb579..a007b3a67 100644 --- a/provers/groth16/src/prover.rs +++ b/provers/groth16/src/prover.rs @@ -67,15 +67,7 @@ pub struct Prover; impl Prover { pub fn prove(w: &[FrElement], qap: &QuadraticArithmeticProgram, pk: &ProvingKey) -> Proof { let h_coefficients = qap - .calculate_h_coefficients(w) - .iter() - .map(|elem| elem.representative()) - .collect::>(); - - let w = w - .iter() - .map(|elem| elem.representative()) - .collect::>(); + .calculate_h_coefficients(w); // Sample randomness for hiding let r = sample_fr_elem(); diff --git a/provers/groth16/src/verifier.rs b/provers/groth16/src/verifier.rs index b83c4b215..f3343c45f 100644 --- a/provers/groth16/src/verifier.rs +++ b/provers/groth16/src/verifier.rs @@ -7,10 +7,7 @@ use crate::setup::VerifyingKey; pub fn verify(vk: &VerifyingKey, proof: &Proof, pub_inputs: &[FrElement]) -> bool { // [γ^{-1} * (β*l(τ) + α*r(τ) + o(τ))]_1 let k_tau_assigned_verifier_g1 = msm( - &pub_inputs - .iter() - .map(|elem| elem.representative()) - .collect::>(), + &pub_inputs, &vk.verifier_k_tau_g1, ) .unwrap();