diff --git a/math/src/elliptic_curve/short_weierstrass/curves/bls12_381/field_extension.rs b/math/src/elliptic_curve/short_weierstrass/curves/bls12_381/field_extension.rs index e8cff6196..26c34e173 100644 --- a/math/src/elliptic_curve/short_weierstrass/curves/bls12_381/field_extension.rs +++ b/math/src/elliptic_curve/short_weierstrass/curves/bls12_381/field_extension.rs @@ -6,7 +6,7 @@ use crate::field::{ quadratic::{HasQuadraticNonResidue, QuadraticExtensionField}, }, fields::montgomery_backed_prime_fields::{IsModulus, MontgomeryBackendPrimeField}, - traits::IsField, + traits::{IsField, IsSubFieldOf}, }; use crate::traits::ByteConversion; use crate::unsigned_integer::element::U384; @@ -71,7 +71,7 @@ impl IsField for Degree2ExtensionField { /// Returns the division of `a` and `b` fn div(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { - Self::mul(a, &Self::inv(b).unwrap()) + ::mul(a, &Self::inv(b).unwrap()) } /// Returns a boolean indicating whether `a` and `b` are equal component wise. @@ -103,6 +103,47 @@ impl IsField for Degree2ExtensionField { } } +impl IsSubFieldOf for BLS12381PrimeField { + fn mul( + a: &Self::BaseType, + b: &::BaseType, + ) -> ::BaseType { + let c0 = FieldElement::from_raw(::mul(a, b[0].value())); + let c1 = FieldElement::from_raw(::mul(a, b[1].value())); + [c0, c1] + } + + fn add( + a: &Self::BaseType, + b: &::BaseType, + ) -> ::BaseType { + let c0 = FieldElement::from_raw(::add(a, b[0].value())); + let c1 = FieldElement::from_raw(*b[1].value()); + [c0, c1] + } + + fn div( + a: &Self::BaseType, + b: &::BaseType, + ) -> ::BaseType { + let b_inv = Degree2ExtensionField::inv(b).unwrap(); + >::mul(a, &b_inv) + } + + fn sub( + a: &Self::BaseType, + b: &::BaseType, + ) -> ::BaseType { + let c0 = FieldElement::from_raw(::sub(a, b[0].value())); + let c1 = FieldElement::from_raw(::neg(b[1].value())); + [c0, c1] + } + + fn embed(a: Self::BaseType) -> ::BaseType { + [FieldElement::from_raw(a), FieldElement::zero()] + } +} + impl ByteConversion for FieldElement { #[cfg(feature = "std")] fn to_bytes_be(&self) -> Vec { @@ -328,4 +369,43 @@ mod tests { assert_eq!(g_to_fp12_x, expectedx); assert_eq!(g_to_fp12_y, expectedy); } + + #[test] + fn add_base_field_with_degree_2_extension() { + let a = FieldElement::::from(3); + let a_extension = FieldElement::::from(3); + let b = FieldElement::::from(2); + assert_eq!(a + &b, a_extension + b); + } + + #[test] + fn mul_base_field_with_degree_2_extension() { + let a = FieldElement::::from(3); + let a_extension = FieldElement::::from(3); + let b = FieldElement::::from(2); + assert_eq!(a * &b, a_extension * b); + } + + #[test] + fn sub_base_field_with_degree_2_extension() { + let a = FieldElement::::from(3); + let a_extension = FieldElement::::from(3); + let b = FieldElement::::from(2); + assert_eq!(a - &b, a_extension - b); + } + + #[test] + fn div_base_field_with_degree_2_extension() { + let a = FieldElement::::from(3); + let a_extension = FieldElement::::from(3); + let b = FieldElement::::from(2); + assert_eq!(a / &b, a_extension / b); + } + + #[test] + fn embed_base_field_with_degree_2_extension() { + let a = FieldElement::::from(3); + let a_extension = FieldElement::::from(3); + assert_eq!(a.to_extension::(), a_extension); + } } diff --git a/math/src/elliptic_curve/short_weierstrass/curves/bls12_381/pairing.rs b/math/src/elliptic_curve/short_weierstrass/curves/bls12_381/pairing.rs index 25bfb579e..0e509dee8 100644 --- a/math/src/elliptic_curve/short_weierstrass/curves/bls12_381/pairing.rs +++ b/math/src/elliptic_curve/short_weierstrass/curves/bls12_381/pairing.rs @@ -1,6 +1,6 @@ use super::{ curve::{BLS12381Curve, MILLER_LOOP_CONSTANT}, - field_extension::{Degree12ExtensionField, Degree2ExtensionField}, + field_extension::{BLS12381PrimeField, Degree12ExtensionField, Degree2ExtensionField}, twist::BLS12381TwistCurve, }; use crate::{ @@ -56,22 +56,23 @@ fn double_accumulate_line( let [px, py, _] = p.coordinates(); let residue = LevelTwoResidue::residue(); let two_inv = FieldElement::::new_base("d0088f51cbff34d258dd3db21a5d66bb23ba5c279c2895fb39869507b587b120f55ffff58a9ffffdcff7fffffffd556"); + let three = FieldElement::::from(3); let a = &two_inv * x1 * y1; let b = y1.square(); let c = z1.square(); - let d = FieldElement::from(3) * &c; + let d = &three * &c; let e = BLS12381TwistCurve::b() * d; - let f = FieldElement::from(3) * &e; + let f = &three * &e; let g = two_inv * (&b + &f); let h = (y1 + z1).square() - (&b + &c); let x3 = &a * (&b - &f); - let y3 = g.square() - (FieldElement::from(3) * e.square()); + let y3 = g.square() - (&three * e.square()); let z3 = &b * &h; let [h0, h1] = h.value(); - let x1_sq_3 = FieldElement::from(3) * x1.square(); + let x1_sq_3 = three * x1.square(); let [x1_sq_30, x1_sq_31] = x1_sq_3.value(); t.0.value = [x3, y3, z3]; @@ -120,7 +121,7 @@ fn add_accumulate_line( let e = &lambda * &d; let f = z1 * c; let g = x1 * d; - let h = &e + f - FieldElement::from(2) * &g; + let h = &e + f - FieldElement::::from(2) * &g; let i = y1 * &e; let x3 = &lambda * &h; @@ -195,7 +196,7 @@ fn frobenius_square( let f0 = FieldElement::new([a0.clone(), a1 * &omega_3, a2 * &omega_3_squared]); let f1 = FieldElement::new([b0.clone(), b1 * omega_3, b2 * omega_3_squared]); - FieldElement::new([f0, f1 * w_raised_to_p_squared_minus_one]) + FieldElement::new([f0, w_raised_to_p_squared_minus_one * f1]) } // To understand more about how to reduce the final exponentiation diff --git a/math/src/fft/gpu/metal/ops.rs b/math/src/fft/gpu/metal/ops.rs index c9d7131f2..c9a347dfc 100644 --- a/math/src/fft/gpu/metal/ops.rs +++ b/math/src/fft/gpu/metal/ops.rs @@ -55,7 +55,7 @@ pub fn fft( let result = MetalState::retrieve_contents(&input_buffer); let result = bitrev_permutation::(&result, state)?; - Ok(result.iter().map(FieldElement::from_raw).collect()) + Ok(result.into_iter().map(FieldElement::from_raw).collect()) } /// Generates 2^{`order-1`} twiddle factors in parallel, with a certain `config`, in Metal. @@ -89,7 +89,7 @@ pub fn gen_twiddles( let (command_buffer, command_encoder) = state.setup_command(&pipeline, Some(&[(0, &result_buffer)])); - let root = F::get_primitive_root_of_unity::(order).unwrap(); + let root = F::get_primitive_root_of_unity(order).unwrap(); command_encoder.set_bytes(1, mem::size_of::() as u64, void_ptr(&root)); let grid_size = MTLSize::new(len as u64, 1, 1); @@ -103,7 +103,7 @@ pub fn gen_twiddles( }); let result = MetalState::retrieve_contents(&result_buffer); - Ok(result.iter().map(FieldElement::from_raw).collect()) + Ok(result.into_iter().map(FieldElement::from_raw).collect()) } /// Executes a parallel bit-reverse permutation with the elements of `input`, in Metal. diff --git a/math/src/field/element.rs b/math/src/field/element.rs index 3c1f86484..67707b791 100644 --- a/math/src/field/element.rs +++ b/math/src/field/element.rs @@ -32,7 +32,7 @@ use serde::ser::{Serialize, SerializeStruct, Serializer}; use serde::Deserialize; use super::fields::montgomery_backed_prime_fields::{IsModulus, MontgomeryBackendPrimeField}; -use super::traits::{IsPrimeField, LegendreSymbol}; +use super::traits::{IsPrimeField, IsSubFieldOf, LegendreSymbol}; /// A field element with operations algorithms defined in `F` #[allow(clippy::derived_hash_with_manual_eq)] @@ -95,10 +95,8 @@ where F::BaseType: Clone, F: IsField, { - pub fn from_raw(value: &F::BaseType) -> Self { - Self { - value: value.clone(), - } + pub fn from_raw(value: F::BaseType) -> Self { + Self { value } } pub const fn const_from_raw(value: F::BaseType) -> Self { @@ -119,59 +117,64 @@ where impl Eq for FieldElement where F: IsField {} /// Addition operator overloading for field elements -impl Add<&FieldElement> for &FieldElement +impl Add<&FieldElement> for &FieldElement where - F: IsField, + F: IsSubFieldOf, + L: IsField, { - type Output = FieldElement; + type Output = FieldElement; - fn add(self, rhs: &FieldElement) -> Self::Output { + fn add(self, rhs: &FieldElement) -> Self::Output { Self::Output { - value: F::add(&self.value, &rhs.value), + value: >::add(&self.value, &rhs.value), } } } -impl Add> for FieldElement +impl Add> for FieldElement where - F: IsField, + F: IsSubFieldOf, + L: IsField, { - type Output = FieldElement; + type Output = FieldElement; - fn add(self, rhs: FieldElement) -> Self::Output { + fn add(self, rhs: FieldElement) -> Self::Output { &self + &rhs } } -impl Add<&FieldElement> for FieldElement +impl Add<&FieldElement> for FieldElement where - F: IsField, + F: IsSubFieldOf, + L: IsField, { - type Output = FieldElement; + type Output = FieldElement; - fn add(self, rhs: &FieldElement) -> Self::Output { + fn add(self, rhs: &FieldElement) -> Self::Output { &self + rhs } } -impl Add> for &FieldElement +impl Add> for &FieldElement where - F: IsField, + F: IsSubFieldOf, + L: IsField, { - type Output = FieldElement; + type Output = FieldElement; - fn add(self, rhs: FieldElement) -> Self::Output { + fn add(self, rhs: FieldElement) -> Self::Output { self + &rhs } } /// AddAssign operator overloading for field elements -impl AddAssign> for FieldElement +impl AddAssign> for FieldElement where - F: IsField, + F: IsSubFieldOf, + L: IsField, { fn add_assign(&mut self, rhs: FieldElement) { - self.value = F::add(&self.value, &rhs.value); + self.value = >::add(&rhs.value, &self.value); } } @@ -186,142 +189,154 @@ where } /// Subtraction operator overloading for field elements*/ -impl Sub<&FieldElement> for &FieldElement +impl Sub<&FieldElement> for &FieldElement where - F: IsField, + F: IsSubFieldOf, + L: IsField, { - type Output = FieldElement; + type Output = FieldElement; - fn sub(self, rhs: &FieldElement) -> Self::Output { + fn sub(self, rhs: &FieldElement) -> Self::Output { Self::Output { - value: F::sub(&self.value, &rhs.value), + value: >::sub(&self.value, &rhs.value), } } } -impl Sub> for FieldElement +impl Sub> for FieldElement where - F: IsField, + F: IsSubFieldOf, + L: IsField, { - type Output = FieldElement; + type Output = FieldElement; - fn sub(self, rhs: FieldElement) -> Self::Output { + fn sub(self, rhs: FieldElement) -> Self::Output { &self - &rhs } } -impl Sub<&FieldElement> for FieldElement +impl Sub<&FieldElement> for FieldElement where - F: IsField, + F: IsSubFieldOf, + L: IsField, { - type Output = FieldElement; + type Output = FieldElement; - fn sub(self, rhs: &FieldElement) -> Self::Output { + fn sub(self, rhs: &FieldElement) -> Self::Output { &self - rhs } } -impl Sub> for &FieldElement +impl Sub> for &FieldElement where - F: IsField, + F: IsSubFieldOf, + L: IsField, { - type Output = FieldElement; + type Output = FieldElement; - fn sub(self, rhs: FieldElement) -> Self::Output { + fn sub(self, rhs: FieldElement) -> Self::Output { self - &rhs } } /// Multiplication operator overloading for field elements*/ -impl Mul<&FieldElement> for &FieldElement +impl Mul<&FieldElement> for &FieldElement where - F: IsField, + F: IsSubFieldOf, + L: IsField, { - type Output = FieldElement; + type Output = FieldElement; - fn mul(self, rhs: &FieldElement) -> Self::Output { + fn mul(self, rhs: &FieldElement) -> Self::Output { Self::Output { - value: F::mul(&self.value, &rhs.value), + value: >::mul(&self.value, &rhs.value), } } } -impl Mul> for FieldElement +impl Mul> for FieldElement where - F: IsField, + F: IsSubFieldOf, + L: IsField, { - type Output = FieldElement; + type Output = FieldElement; - fn mul(self, rhs: FieldElement) -> Self::Output { + fn mul(self, rhs: FieldElement) -> Self::Output { &self * &rhs } } -impl Mul<&FieldElement> for FieldElement +impl Mul<&FieldElement> for FieldElement where - F: IsField, + F: IsSubFieldOf, + L: IsField, { - type Output = FieldElement; + type Output = FieldElement; - fn mul(self, rhs: &FieldElement) -> Self::Output { + fn mul(self, rhs: &FieldElement) -> Self::Output { &self * rhs } } -impl Mul> for &FieldElement +impl Mul> for &FieldElement where - F: IsField, + F: IsSubFieldOf, + L: IsField, { - type Output = FieldElement; + type Output = FieldElement; - fn mul(self, rhs: FieldElement) -> Self::Output { + fn mul(self, rhs: FieldElement) -> Self::Output { self * &rhs } } /// Division operator overloading for field elements*/ -impl Div<&FieldElement> for &FieldElement +impl Div<&FieldElement> for &FieldElement where - F: IsField, + F: IsSubFieldOf, + L: IsField, { - type Output = FieldElement; + type Output = FieldElement; - fn div(self, rhs: &FieldElement) -> Self::Output { + fn div(self, rhs: &FieldElement) -> Self::Output { Self::Output { - value: F::div(&self.value, &rhs.value), + value: >::div(&self.value, &rhs.value), } } } -impl Div> for FieldElement +impl Div> for FieldElement where - F: IsField, + F: IsSubFieldOf, + L: IsField, { - type Output = FieldElement; + type Output = FieldElement; - fn div(self, rhs: FieldElement) -> Self::Output { + fn div(self, rhs: FieldElement) -> Self::Output { &self / &rhs } } -impl Div<&FieldElement> for FieldElement +impl Div<&FieldElement> for FieldElement where - F: IsField, + F: IsSubFieldOf, + L: IsField, { - type Output = FieldElement; + type Output = FieldElement; - fn div(self, rhs: &FieldElement) -> Self::Output { + fn div(self, rhs: &FieldElement) -> Self::Output { &self / rhs } } -impl Div> for &FieldElement +impl Div> for &FieldElement where - F: IsField, + F: IsSubFieldOf, + L: IsField, { - type Output = FieldElement; + type Output = FieldElement; - fn div(self, rhs: FieldElement) -> Self::Output { + fn div(self, rhs: FieldElement) -> Self::Output { self / &rhs } } @@ -418,6 +433,16 @@ where pub fn zero() -> Self { Self { value: F::zero() } } + + #[inline(always)] + pub fn to_extension(self) -> FieldElement + where + F: IsSubFieldOf, + { + FieldElement { + value: >::embed(self.value), + } + } } impl FieldElement { @@ -450,7 +475,11 @@ impl FieldElement { } #[cfg(feature = "lambdaworks-serde-binary")] -impl Serialize for FieldElement { +impl Serialize for FieldElement +where + F: IsField, + F::BaseType: ByteConversion, +{ fn serialize(&self, serializer: S) -> Result where S: Serializer, @@ -478,7 +507,11 @@ impl Serialize for FieldElement { } #[cfg(feature = "lambdaworks-serde-binary")] -impl<'de, F: IsPrimeField> Deserialize<'de> for FieldElement { +impl<'de, F> Deserialize<'de> for FieldElement +where + F: IsField, + F::BaseType: ByteConversion, +{ fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, @@ -491,7 +524,7 @@ impl<'de, F: IsPrimeField> Deserialize<'de> for FieldElement { struct FieldElementVisitor(PhantomData F>); - impl<'de, F: IsPrimeField> Visitor<'de> for FieldElementVisitor { + impl<'de, F: IsField> Visitor<'de> for FieldElementVisitor { type Value = FieldElement; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { @@ -515,7 +548,7 @@ impl<'de, F: IsPrimeField> Deserialize<'de> for FieldElement { } let value = value.ok_or_else(|| de::Error::missing_field("value"))?; let val = F::BaseType::from_bytes_be(&value).unwrap(); - Ok(FieldElement::from_raw(&val)) + Ok(FieldElement::from_raw(val)) } fn visit_seq(self, mut seq: S) -> Result, S::Error> @@ -531,7 +564,7 @@ impl<'de, F: IsPrimeField> Deserialize<'de> for FieldElement { } let value = value.ok_or_else(|| de::Error::missing_field("value"))?; let val = F::BaseType::from_bytes_be(&value).unwrap(); - Ok(FieldElement::from_raw(&val)) + Ok(FieldElement::from_raw(val)) } } diff --git a/math/src/field/extensions/cubic.rs b/math/src/field/extensions/cubic.rs index ac57d7dfe..447c790f3 100644 --- a/math/src/field/extensions/cubic.rs +++ b/math/src/field/extensions/cubic.rs @@ -106,7 +106,7 @@ where fn inv( a: &[FieldElement; 3], ) -> Result<[FieldElement; 3], FieldError> { - let three = FieldElement::from(3_u64); + let three = FieldElement::::from(3_u64); let d = a[0].pow(3_u64) + a[1].pow(3_u64) * Q::residue() diff --git a/math/src/field/traits.rs b/math/src/field/traits.rs index e30bc66a7..8441b7082 100644 --- a/math/src/field/traits.rs +++ b/math/src/field/traits.rs @@ -14,6 +14,44 @@ pub enum RootsConfig { BitReverseInversed, // same as above but exponents are negated. } +pub trait IsSubFieldOf: IsField { + fn mul(a: &Self::BaseType, b: &F::BaseType) -> F::BaseType; + fn add(a: &Self::BaseType, b: &F::BaseType) -> F::BaseType; + fn div(a: &Self::BaseType, b: &F::BaseType) -> F::BaseType; + fn sub(a: &Self::BaseType, b: &F::BaseType) -> F::BaseType; + fn embed(a: Self::BaseType) -> F::BaseType; +} + +impl IsSubFieldOf for F +where + F: IsField, +{ + #[inline(always)] + fn mul(a: &Self::BaseType, b: &F::BaseType) -> F::BaseType { + F::mul(a, b) + } + + #[inline(always)] + fn add(a: &Self::BaseType, b: &F::BaseType) -> F::BaseType { + F::add(a, b) + } + + #[inline(always)] + fn sub(a: &Self::BaseType, b: &F::BaseType) -> F::BaseType { + F::sub(a, b) + } + + #[inline(always)] + fn div(a: &Self::BaseType, b: &F::BaseType) -> F::BaseType { + F::div(a, b) + } + + #[inline(always)] + fn embed(a: Self::BaseType) -> F::BaseType { + a + } +} + /// Trait to define necessary parameters for FFT-friendly Fields. /// Two-Adic fields are ones whose order is of the form $2^n k + 1$. /// Here $n$ is usually called the *two-adicity* of the field. The @@ -22,7 +60,7 @@ pub enum RootsConfig { /// A two-adic primitive root of unity is a number w that satisfies w^(2^n) = 1 /// and w^(j) != 1 for every j below 2^n. With this primitive root we can generate /// any other root of unity we need to perform FFT. -pub trait IsFFTField: IsPrimeField { +pub trait IsFFTField: IsField { const TWO_ADICITY: u64; const TWO_ADIC_PRIMITVE_ROOT_OF_UNITY: Self::BaseType; @@ -33,18 +71,16 @@ pub trait IsFFTField: IsPrimeField { } /// Returns a primitive root of unity of order $2^{order}$. - fn get_primitive_root_of_unity( - order: u64, - ) -> Result, FieldError> { + fn get_primitive_root_of_unity(order: u64) -> Result, FieldError> { let two_adic_primitive_root_of_unity = - FieldElement::new(F::TWO_ADIC_PRIMITVE_ROOT_OF_UNITY); + FieldElement::new(Self::TWO_ADIC_PRIMITVE_ROOT_OF_UNITY); if order == 0 { return Ok(FieldElement::one()); } - if order > F::TWO_ADICITY { + if order > Self::TWO_ADICITY { return Err(FieldError::RootOfUnityError(order)); } - let log_power = F::TWO_ADICITY - order; + let log_power = Self::TWO_ADICITY - order; let root = (0..log_power).fold(two_adic_primitive_root_of_unity, |acc, _| acc.square()); Ok(root) } diff --git a/provers/stark/src/debug.rs b/provers/stark/src/debug.rs index d1abd51d2..1b02007a2 100644 --- a/provers/stark/src/debug.rs +++ b/provers/stark/src/debug.rs @@ -51,7 +51,7 @@ pub fn validate_trace>( if &boundary_value != trace_value { ret = false; - error!("Boundary constraint inconsistency - Expected value {} in step {} and column {}, found: {}", boundary_value.representative(), step, col, trace_value.representative()); + error!("Boundary constraint inconsistency - Expected value {:?} in step {} and column {}, found: {:?}", boundary_value, step, col, trace_value); } }); @@ -84,10 +84,8 @@ pub fn validate_trace>( if step < exemption_steps[i] && eval != &FieldElement::::zero() { ret = false; error!( - "Inconsistent evaluation of transition {} in step {} - expected 0, got {}", - i, - step, - eval.representative() + "Inconsistent evaluation of transition {} in step {} - expected 0, got {:?}", + i, step, eval ); } }) diff --git a/provers/stark/src/fri/fri_decommit.rs b/provers/stark/src/fri/fri_decommit.rs index 63da9ec5f..8bc378ac8 100644 --- a/provers/stark/src/fri/fri_decommit.rs +++ b/provers/stark/src/fri/fri_decommit.rs @@ -2,12 +2,12 @@ pub use lambdaworks_crypto::fiat_shamir::transcript::Transcript; use lambdaworks_crypto::merkle_tree::proof::Proof; use lambdaworks_math::field::element::FieldElement; -use lambdaworks_math::field::traits::IsPrimeField; +use lambdaworks_math::field::traits::IsField; use crate::config::Commitment; #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -pub struct FriDecommitment { +pub struct FriDecommitment { pub layers_auth_paths: Vec>, pub layers_evaluations_sym: Vec>, } diff --git a/provers/stark/src/prover.rs b/provers/stark/src/prover.rs index 2b21455f5..bc8cd9449 100644 --- a/provers/stark/src/prover.rs +++ b/provers/stark/src/prover.rs @@ -964,7 +964,7 @@ mod tests { for i in 0..(trace_length * blowup_factor) { assert_eq!( domain.lde_roots_of_unity_coset[i], - FieldElement::from(coset_offset) * primitive_root.pow(i) + primitive_root.pow(i) * FieldElement::from(coset_offset) ); } } diff --git a/provers/stark/src/transcript.rs b/provers/stark/src/transcript.rs index b6952b70b..4349b3092 100644 --- a/provers/stark/src/transcript.rs +++ b/provers/stark/src/transcript.rs @@ -138,7 +138,7 @@ impl IsStarkTranscript for StoneProverTranscript { while result >= Self::MODULUS_MAX_MULTIPLE { result = self.sample_big_int(); } - FieldElement::new(result) * FieldElement::new(Self::R_INV) + FieldElement::::new(result) * FieldElement::new(Self::R_INV) } fn sample_u64(&mut self, upper_bound: u64) -> u64 {