From 7a427aacfd9241276c17f38ecbfd5e6dfae4e032 Mon Sep 17 00:00:00 2001 From: Nicole Graus Date: Tue, 1 Oct 2024 16:34:56 -0300 Subject: [PATCH] Optimize Mersenne31 Field (#921) * optimize add * save changes. Add, sub and mul checked * fix tests * add new inv * add mult by powers of two * replace inverse * test new inv * modify old algorithm for inv * fix tests extension * add mul for degree 4 extension * add fp4 isField and isSubField operations and benchmarks * new version for fp4 mul based on the paper * add mul of a fp2e by non-residue * change inv using mul_fp2_by_non_resiude * save work * wip fp2 test * add fp2 tests * add 2 * a^2 - 1 function * use karatsuba in fp4 mul version 1 * clean up * fix Fp as subfield of Fp2. Tests Fp plus Fp4 is now correct * fix inv * fix comments * fix comments * fixes * fix clippy * fix cargo check no-std * fix clippy * change zero function of isField to rust default * fix two_square_minus_one function and optimize inv function * fix clippy --------- Co-authored-by: Nicole Co-authored-by: Joaquin Carletti Co-authored-by: diegokingston Co-authored-by: Diego K <43053772+diegokingston@users.noreply.github.com> Co-authored-by: Joaquin Carletti <56092489+ColoCarletti@users.noreply.github.com> --- math/benches/criterion_field.rs | 4 +- math/benches/fields/mersenne31.rs | 64 +- math/src/field/fields/mersenne31/extension.rs | 304 --------- .../src/field/fields/mersenne31/extensions.rs | 627 ++++++++++++++++++ math/src/field/fields/mersenne31/field.rs | 272 ++++---- math/src/field/fields/mersenne31/mod.rs | 2 +- .../fields/p448_goldilocks_prime_field.rs | 2 +- math/src/field/traits.rs | 8 +- math/src/unsigned_integer/element.rs | 8 + 9 files changed, 863 insertions(+), 428 deletions(-) delete mode 100644 math/src/field/fields/mersenne31/extension.rs create mode 100644 math/src/field/fields/mersenne31/extensions.rs diff --git a/math/benches/criterion_field.rs b/math/benches/criterion_field.rs index 1c21822de..6e41cbb4b 100644 --- a/math/benches/criterion_field.rs +++ b/math/benches/criterion_field.rs @@ -2,7 +2,7 @@ use criterion::{criterion_group, criterion_main, Criterion}; use pprof::criterion::{Output, PProfProfiler}; mod fields; -use fields::mersenne31::mersenne31_ops_benchmarks; +use fields::mersenne31::{mersenne31_extension_ops_benchmarks, mersenne31_ops_benchmarks}; use fields::mersenne31_montgomery::mersenne31_mont_ops_benchmarks; use fields::{ stark252::starkfield_ops_benchmarks, u64_goldilocks::u64_goldilocks_ops_benchmarks, @@ -12,6 +12,6 @@ use fields::{ criterion_group!( name = field_benches; config = Criterion::default().with_profiler(PProfProfiler::new(100, Output::Flamegraph(None))); - targets = starkfield_ops_benchmarks, mersenne31_ops_benchmarks, mersenne31_mont_ops_benchmarks, u64_goldilocks_ops_benchmarks, u64_goldilocks_montgomery_ops_benchmarks + targets = mersenne31_ops_benchmarks, mersenne31_extension_ops_benchmarks, mersenne31_mont_ops_benchmarks, starkfield_ops_benchmarks, u64_goldilocks_ops_benchmarks, u64_goldilocks_montgomery_ops_benchmarks ); criterion_main!(field_benches); diff --git a/math/benches/fields/mersenne31.rs b/math/benches/fields/mersenne31.rs index 99e3921a5..e8d99d1c2 100644 --- a/math/benches/fields/mersenne31.rs +++ b/math/benches/fields/mersenne31.rs @@ -1,10 +1,18 @@ use std::hint::black_box; use criterion::Criterion; -use lambdaworks_math::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}; +use lambdaworks_math::field::{ + element::FieldElement, + fields::mersenne31::{ + extensions::{Degree2ExtensionField, Degree4ExtensionField}, + field::Mersenne31Field, + }, +}; use rand::random; pub type F = FieldElement; +pub type Fp2E = FieldElement; +pub type Fp4E = FieldElement; #[inline(never)] #[no_mangle] @@ -17,6 +25,60 @@ pub fn rand_field_elements(num: usize) -> Vec<(F, F)> { result } +//TODO: Check if this is the correct way to bench. +pub fn rand_fp4e(num: usize) -> Vec<(Fp4E, Fp4E)> { + let mut result = Vec::with_capacity(num); + for _ in 0..result.capacity() { + result.push(( + Fp4E::new([ + Fp2E::new([F::new(random()), F::new(random())]), + Fp2E::new([F::new(random()), F::new(random())]), + ]), + Fp4E::new([ + Fp2E::new([F::new(random()), F::new(random())]), + Fp2E::new([F::new(random()), F::new(random())]), + ]), + )); + } + result +} + +pub fn mersenne31_extension_ops_benchmarks(c: &mut Criterion) { + let input: Vec> = [1000000].into_iter().map(rand_fp4e).collect::>(); + + let mut group = c.benchmark_group("Mersenne31 Fp4 operations"); + + for i in input.clone().into_iter() { + group.bench_with_input(format!("Mul of Fp4 {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(x) * black_box(y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("Square of Fp4 {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).square()); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("Inv of Fp4 {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).inv().unwrap()); + } + }); + }); + } +} + pub fn mersenne31_ops_benchmarks(c: &mut Criterion) { let input: Vec> = [1, 10, 100, 1000, 10000, 100000, 1000000] .into_iter() diff --git a/math/src/field/fields/mersenne31/extension.rs b/math/src/field/fields/mersenne31/extension.rs deleted file mode 100644 index 3c89a2147..000000000 --- a/math/src/field/fields/mersenne31/extension.rs +++ /dev/null @@ -1,304 +0,0 @@ -use crate::field::{ - element::FieldElement, - errors::FieldError, - extensions::{ - cubic::{CubicExtensionField, HasCubicNonResidue}, - quadratic::{HasQuadraticNonResidue, QuadraticExtensionField}, - }, - traits::IsField, -}; - -use super::field::Mersenne31Field; - -//Note: The inverse calculation in mersenne31/plonky3 differs from the default quadratic extension so I implemented the complex extension. -////////////////// -#[derive(Clone, Debug)] -pub struct Mersenne31Complex; - -impl IsField for Mersenne31Complex { - //Elements represents a[0] = real, a[1] = imaginary - type BaseType = [FieldElement; 2]; - - /// Returns the component wise addition of `a` and `b` - fn add(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { - [a[0] + b[0], a[1] + b[1]] - } - - //NOTE: THIS uses Gauss algorithm. Bench this against plonky 3 implementation to see what is faster. - /// Returns the multiplication of `a` and `b` using the following - /// equation: - /// (a0 + a1 * t) * (b0 + b1 * t) = a0 * b0 + a1 * b1 * Self::residue() + (a0 * b1 + a1 * b0) * t - /// where `t.pow(2)` equals `Q::residue()`. - fn mul(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { - let a0b0 = a[0] * b[0]; - let a1b1 = a[1] * b[1]; - let z = (a[0] + a[1]) * (b[0] + b[1]); - [a0b0 - a1b1, z - a0b0 - a1b1] - } - - fn square(a: &Self::BaseType) -> Self::BaseType { - let [a0, a1] = a; - let v0 = a0 * a1; - let c0 = (a0 + a1) * (a0 - a1); - let c1 = v0 + v0; - [c0, c1] - } - /// Returns the component wise subtraction of `a` and `b` - fn sub(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { - [a[0] - b[0], a[1] - b[1]] - } - - /// Returns the component wise negation of `a` - fn neg(a: &Self::BaseType) -> Self::BaseType { - [-a[0], -a[1]] - } - - /// Returns the multiplicative inverse of `a` - fn inv(a: &Self::BaseType) -> Result { - let inv_norm = (a[0].pow(2_u64) + a[1].pow(2_u64)).inv()?; - Ok([a[0] * inv_norm, -a[1] * inv_norm]) - } - - /// Returns the division of `a` and `b` - fn div(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { - Self::mul(a, &Self::inv(b).unwrap()) - } - - /// Returns a boolean indicating whether `a` and `b` are equal component wise. - fn eq(a: &Self::BaseType, b: &Self::BaseType) -> bool { - a[0] == b[0] && a[1] == b[1] - } - - /// Returns the additive neutral element of the field extension. - fn zero() -> Self::BaseType { - [FieldElement::zero(), FieldElement::zero()] - } - - /// Returns the multiplicative neutral element of the field extension. - fn one() -> Self::BaseType { - [FieldElement::one(), FieldElement::zero()] - } - - /// Returns the element `x * 1` where 1 is the multiplicative neutral element. - fn from_u64(x: u64) -> Self::BaseType { - [FieldElement::from(x), FieldElement::zero()] - } - - /// Takes as input an element of BaseType and returns the internal representation - /// of that element in the field. - /// Note: for this case this is simply the identity, because the components - /// already have correct representations. - fn from_base_type(x: Self::BaseType) -> Self::BaseType { - x - } -} - -pub type Mersenne31ComplexQuadraticExtensionField = - QuadraticExtensionField; - -//TODO: Check this should be for complex and not base field -impl HasQuadraticNonResidue for Mersenne31Complex { - // Verifiable in Sage with - // ```sage - // p = 2**31 - 1 # Mersenne31 - // F = GF(p) # The base field GF(p) - // R. = F[] # The polynomial ring over F - // K. = F.extension(x^2 + 1) # The complex extension field - // R2. = K[] - // f2 = y^2 - i - 2 - // assert f2.is_irreducible() - // ``` - fn residue() -> FieldElement { - FieldElement::from(&Mersenne31Complex::from_base_type([ - FieldElement::::from(2), - FieldElement::::one(), - ])) - } -} - -pub type Mersenne31ComplexCubicExtensionField = - CubicExtensionField; - -impl HasCubicNonResidue for Mersenne31Complex { - // Verifiable in Sage with - // ```sage - // p = 2**31 - 1 # Mersenne31 - // F = GF(p) # The base field GF(p) - // R. = F[] # The polynomial ring over F - // K. = F.extension(x^2 + 1) # The complex extension field - // R2. = K[] - // f2 = y^3 - 5*i - // assert f2.is_irreducible() - // ``` - fn residue() -> FieldElement { - FieldElement::from(&Mersenne31Complex::from_base_type([ - FieldElement::::zero(), - FieldElement::::from(5), - ])) - } -} - -#[cfg(test)] -mod tests { - use crate::field::fields::mersenne31::field::MERSENNE_31_PRIME_FIELD_ORDER; - - use super::*; - - type Fi = Mersenne31Complex; - type F = FieldElement; - - //NOTE: from_u64 reflects from_real - //NOTE: for imag use from_base_type - - #[test] - fn add_real_one_plus_one_is_two() { - assert_eq!(Fi::add(&Fi::one(), &Fi::one()), Fi::from_u64(2)) - } - - #[test] - fn add_real_neg_one_plus_one_is_zero() { - assert_eq!(Fi::add(&Fi::neg(&Fi::one()), &Fi::one()), Fi::zero()) - } - - #[test] - fn add_real_neg_one_plus_two_is_one() { - assert_eq!(Fi::add(&Fi::neg(&Fi::one()), &Fi::from_u64(2)), Fi::one()) - } - - #[test] - fn add_real_neg_one_plus_neg_one_is_order_sub_two() { - assert_eq!( - Fi::add(&Fi::neg(&Fi::one()), &Fi::neg(&Fi::one())), - Fi::from_u64((MERSENNE_31_PRIME_FIELD_ORDER - 2).into()) - ) - } - - #[test] - fn add_complex_one_plus_one_two() { - //Manually declare the complex part to one - let one = Fi::from_base_type([F::zero(), F::one()]); - let two = Fi::from_base_type([F::zero(), F::from(2)]); - assert_eq!(Fi::add(&one, &one), two) - } - - #[test] - fn add_complex_neg_one_plus_one_is_zero() { - //Manually declare the complex part to one - let neg_one = Fi::from_base_type([F::zero(), -F::one()]); - let one = Fi::from_base_type([F::zero(), F::one()]); - assert_eq!(Fi::add(&neg_one, &one), Fi::zero()) - } - - #[test] - fn add_complex_neg_one_plus_two_is_one() { - let neg_one = Fi::from_base_type([F::zero(), -F::one()]); - let two = Fi::from_base_type([F::zero(), F::from(2)]); - let one = Fi::from_base_type([F::zero(), F::one()]); - assert_eq!(Fi::add(&neg_one, &two), one) - } - - #[test] - fn add_complex_neg_one_plus_neg_one_imag_is_order_sub_two() { - let neg_one = Fi::from_base_type([F::zero(), -F::one()]); - assert_eq!( - Fi::add(&neg_one, &neg_one)[1], - F::new(MERSENNE_31_PRIME_FIELD_ORDER - 2) - ) - } - - #[test] - fn add_order() { - let a = Fi::from_base_type([-F::one(), F::one()]); - let b = Fi::from_base_type([F::from(2), F::new(MERSENNE_31_PRIME_FIELD_ORDER - 2)]); - let c = Fi::from_base_type([F::one(), -F::one()]); - assert_eq!(Fi::add(&a, &b), c) - } - - #[test] - fn add_equal_zero() { - let a = Fi::from_base_type([-F::one(), -F::one()]); - let b = Fi::from_base_type([F::one(), F::one()]); - assert_eq!(Fi::add(&a, &b), Fi::zero()) - } - - #[test] - fn add_plus_one() { - let a = Fi::from_base_type([F::one(), F::from(2)]); - let b = Fi::from_base_type([F::one(), F::one()]); - let c = Fi::from_base_type([F::from(2), F::from(3)]); - assert_eq!(Fi::add(&a, &b), c) - } - - #[test] - fn sub_real_one_sub_one_is_zero() { - assert_eq!(Fi::sub(&Fi::one(), &Fi::one()), Fi::zero()) - } - - #[test] - fn sub_real_two_sub_two_is_zero() { - assert_eq!( - Fi::sub(&Fi::from_u64(2u64), &Fi::from_u64(2u64)), - Fi::zero() - ) - } - - #[test] - fn sub_real_neg_one_sub_neg_one_is_zero() { - assert_eq!( - Fi::sub(&Fi::neg(&Fi::one()), &Fi::neg(&Fi::one())), - Fi::zero() - ) - } - - #[test] - fn sub_real_two_sub_one_is_one() { - assert_eq!(Fi::sub(&Fi::from_u64(2), &Fi::one()), Fi::one()) - } - - #[test] - fn sub_real_neg_one_sub_zero_is_neg_one() { - assert_eq!( - Fi::sub(&Fi::neg(&Fi::one()), &Fi::zero()), - Fi::neg(&Fi::one()) - ) - } - - #[test] - fn sub_complex_one_sub_one_is_zero() { - let one = Fi::from_base_type([F::zero(), F::one()]); - assert_eq!(Fi::sub(&one, &one), Fi::zero()) - } - - #[test] - fn sub_complex_two_sub_two_is_zero() { - let two = Fi::from_base_type([F::zero(), F::from(2)]); - assert_eq!(Fi::sub(&two, &two), Fi::zero()) - } - - #[test] - fn sub_complex_neg_one_sub_neg_one_is_zero() { - let neg_one = Fi::from_base_type([F::zero(), -F::one()]); - assert_eq!(Fi::sub(&neg_one, &neg_one), Fi::zero()) - } - - #[test] - fn sub_complex_two_sub_one_is_one() { - let two = Fi::from_base_type([F::zero(), F::from(2)]); - let one = Fi::from_base_type([F::zero(), F::one()]); - assert_eq!(Fi::sub(&two, &one), one) - } - - #[test] - fn sub_complex_neg_one_sub_zero_is_neg_one() { - let neg_one = Fi::from_base_type([F::zero(), -F::one()]); - assert_eq!(Fi::sub(&neg_one, &Fi::zero()), neg_one) - } - - #[test] - fn mul() { - let a = Fi::from_base_type([F::from(2), F::from(2)]); - let b = Fi::from_base_type([F::from(4), F::from(5)]); - let c = Fi::from_base_type([-F::from(2), F::from(18)]); - assert_eq!(Fi::mul(&a, &b), c) - } -} diff --git a/math/src/field/fields/mersenne31/extensions.rs b/math/src/field/fields/mersenne31/extensions.rs new file mode 100644 index 000000000..27c2ab118 --- /dev/null +++ b/math/src/field/fields/mersenne31/extensions.rs @@ -0,0 +1,627 @@ +use super::field::Mersenne31Field; +use crate::field::{ + element::FieldElement, + errors::FieldError, + traits::{IsField, IsSubFieldOf}, +}; +#[cfg(feature = "alloc")] +use alloc::vec::Vec; + +type FpE = FieldElement; + +#[derive(Clone, Debug)] +pub struct Degree2ExtensionField; + +impl Degree2ExtensionField { + pub fn mul_fp2_by_nonresidue(a: &Fp2E) -> Fp2E { + Fp2E::new([ + a.value()[0].double() - a.value()[1], + a.value()[1].double() + a.value()[0], + ]) + } +} + +impl IsField for Degree2ExtensionField { + //Element representation: a[0] = real part, a[1] = imaginary part + type BaseType = [FpE; 2]; + + /// Returns the component wise addition of `a` and `b` + fn add(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + [a[0] + b[0], a[1] + b[1]] + } + + /// Returns the multiplication of `a` and `b`. + fn mul(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + let a0b0 = a[0] * b[0]; + let a1b1 = a[1] * b[1]; + let z = (a[0] + a[1]) * (b[0] + b[1]); + [a0b0 - a1b1, z - a0b0 - a1b1] + } + + fn square(a: &Self::BaseType) -> Self::BaseType { + let [a0, a1] = a; + let v0 = a0 * a1; + let c0 = (a0 + a1) * (a0 - a1); + let c1 = v0.double(); + [c0, c1] + } + /// Returns the component wise subtraction of `a` and `b` + fn sub(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + [a[0] - b[0], a[1] - b[1]] + } + + /// Returns the component wise negation of `a` + fn neg(a: &Self::BaseType) -> Self::BaseType { + [-a[0], -a[1]] + } + + /// Returns the multiplicative inverse of `a` + fn inv(a: &Self::BaseType) -> Result { + let inv_norm = (a[0].square() + a[1].square()).inv()?; + Ok([a[0] * inv_norm, -a[1] * inv_norm]) + } + + /// Returns the division of `a` and `b` + fn div(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + ::mul(a, &Self::inv(b).unwrap()) + } + + /// Returns a boolean indicating whether `a` and `b` are equal component wise. + fn eq(a: &Self::BaseType, b: &Self::BaseType) -> bool { + a[0] == b[0] && a[1] == b[1] + } + + /// Returns the multiplicative neutral element of the field extension. + fn one() -> Self::BaseType { + [FpE::one(), FpE::zero()] + } + + /// Returns the element `x * 1` where 1 is the multiplicative neutral element. + fn from_u64(x: u64) -> Self::BaseType { + [FpE::from(x), FpE::zero()] + } + + /// Takes as input an element of BaseType and returns the internal representation + /// of that element in the field. + /// Note: for this case this is simply the identity, because the components + /// already have correct representations. + fn from_base_type(x: Self::BaseType) -> Self::BaseType { + x + } +} + +impl IsSubFieldOf for Mersenne31Field { + fn add( + a: &Self::BaseType, + b: &::BaseType, + ) -> ::BaseType { + [FpE::from(a) + b[0], b[1]] + } + + fn sub( + a: &Self::BaseType, + b: &::BaseType, + ) -> ::BaseType { + [FpE::from(a) - b[0], -b[1]] + } + + fn mul( + a: &Self::BaseType, + b: &::BaseType, + ) -> ::BaseType { + [FpE::from(a) * b[0], FpE::from(a) * b[1]] + } + + fn div( + a: &Self::BaseType, + b: &::BaseType, + ) -> ::BaseType { + let b_inv = Degree2ExtensionField::inv(b).unwrap(); + >::mul(a, &b_inv) + } + + fn embed(a: Self::BaseType) -> ::BaseType { + [FieldElement::from_raw(a), FieldElement::zero()] + } + + #[cfg(feature = "alloc")] + fn to_subfield_vec( + b: ::BaseType, + ) -> alloc::vec::Vec { + b.into_iter().map(|x| x.to_raw()).collect() + } +} + +type Fp2E = FieldElement; + +#[derive(Clone, Debug)] +pub struct Degree4ExtensionField; + +impl IsField for Degree4ExtensionField { + type BaseType = [Fp2E; 2]; + + fn add(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + [&a[0] + &b[0], &a[1] + &b[1]] + } + + fn sub(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + [&a[0] - &b[0], &a[1] - &b[1]] + } + + fn neg(a: &Self::BaseType) -> Self::BaseType { + [-&a[0], -&a[1]] + } + + fn mul(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + // Algorithm from: https://github.com/ingonyama-zk/papers/blob/main/Mersenne31_polynomial_arithmetic.pdf (page 5): + let a0b0 = &a[0] * &b[0]; + let a1b1 = &a[1] * &b[1]; + [ + &a0b0 + Degree2ExtensionField::mul_fp2_by_nonresidue(&a1b1), + (&a[0] + &a[1]) * (&b[0] + &b[1]) - a0b0 - a1b1, + ] + } + + fn square(a: &Self::BaseType) -> Self::BaseType { + let a0_square = &a[0].square(); + let a1_square = &a[1].square(); + [ + a0_square + Degree2ExtensionField::mul_fp2_by_nonresidue(a1_square), + (&a[0] + &a[1]).square() - a0_square - a1_square, + ] + } + + fn inv(a: &Self::BaseType) -> Result { + let inv_norm = + (a[0].square() - Degree2ExtensionField::mul_fp2_by_nonresidue(&a[1].square())).inv()?; + Ok([&a[0] * &inv_norm, -&a[1] * &inv_norm]) + } + + fn div(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + ::mul(a, &Self::inv(b).unwrap()) + } + + fn eq(a: &Self::BaseType, b: &Self::BaseType) -> bool { + a[0] == b[0] && a[1] == b[1] + } + + fn zero() -> Self::BaseType { + [Fp2E::zero(), Fp2E::zero()] + } + + fn one() -> Self::BaseType { + [Fp2E::one(), Fp2E::zero()] + } + + fn from_u64(x: u64) -> Self::BaseType { + [Fp2E::from(x), Fp2E::zero()] + } + + fn from_base_type(x: Self::BaseType) -> Self::BaseType { + x + } +} + +impl IsSubFieldOf for Mersenne31Field { + fn add( + a: &Self::BaseType, + b: &::BaseType, + ) -> ::BaseType { + [FpE::from(a) + &b[0], b[1].clone()] + } + + fn sub( + a: &Self::BaseType, + b: &::BaseType, + ) -> ::BaseType { + [FpE::from(a) - &b[0], -&b[1]] + } + + fn mul( + a: &Self::BaseType, + b: &::BaseType, + ) -> ::BaseType { + let c0 = FpE::from(a) * &b[0]; + let c1 = FpE::from(a) * &b[1]; + [c0, c1] + } + + fn div( + a: &Self::BaseType, + b: &::BaseType, + ) -> ::BaseType { + let b_inv = Degree4ExtensionField::inv(b).unwrap(); + >::mul(a, &b_inv) + } + + fn embed(a: Self::BaseType) -> ::BaseType { + [ + Fp2E::from_raw(>::embed(a)), + Fp2E::zero(), + ] + } + + #[cfg(feature = "alloc")] + fn to_subfield_vec( + b: ::BaseType, + ) -> alloc::vec::Vec { + // TODO: Repace this for with a map similarly to this: + // b.into_iter().map(|x| x.to_raw()).collect() + let mut result = Vec::new(); + for fp2e in b { + result.push(fp2e.value()[0].to_raw()); + result.push(fp2e.value()[1].to_raw()); + } + result + } +} + +#[cfg(test)] +mod tests { + use core::ops::Neg; + + use crate::field::fields::mersenne31::field::MERSENNE_31_PRIME_FIELD_ORDER; + + use super::*; + + type FpE = FieldElement; + type Fp2E = FieldElement; + type Fp4E = FieldElement; + + #[test] + fn add_real_one_plus_one_is_two() { + assert_eq!(Fp2E::one() + Fp2E::one(), Fp2E::from(2)) + } + + #[test] + fn add_real_neg_one_plus_one_is_zero() { + assert_eq!(Fp2E::one() + Fp2E::one().neg(), Fp2E::zero()) + } + + #[test] + fn add_real_neg_one_plus_two_is_one() { + assert_eq!(Fp2E::one().neg() + Fp2E::from(2), Fp2E::one()) + } + + #[test] + fn add_real_neg_one_plus_neg_one_is_order_sub_two() { + assert_eq!( + Fp2E::one().neg() + Fp2E::one().neg(), + Fp2E::new([FpE::from(&(MERSENNE_31_PRIME_FIELD_ORDER - 2)), FpE::zero()]) + ) + } + + #[test] + fn add_complex_one_plus_one_two() { + let one_i = Fp2E::new([FpE::zero(), FpE::one()]); + let two_i = Fp2E::new([FpE::zero(), FpE::from(2)]); + assert_eq!(&one_i + &one_i, two_i) + } + + #[test] + fn add_complex_neg_one_plus_one_is_zero() { + //Manually declare the complex part to one + let neg_one_i = Fp2E::new([FpE::zero(), -FpE::one()]); + let one_i = Fp2E::new([FpE::zero(), FpE::one()]); + assert_eq!(neg_one_i + one_i, Fp2E::zero()) + } + + #[test] + fn add_complex_neg_one_plus_two_is_one() { + let neg_one_i = Fp2E::new([FpE::zero(), -FpE::one()]); + let two_i = Fp2E::new([FpE::zero(), FpE::from(2)]); + let one_i = Fp2E::new([FpE::zero(), FpE::one()]); + assert_eq!(&neg_one_i + &two_i, one_i) + } + + #[test] + fn add_complex_neg_one_plus_neg_one_imag_is_order_sub_two() { + let neg_one_i = Fp2E::new([FpE::zero(), -FpE::one()]); + assert_eq!( + (&neg_one_i + &neg_one_i).value()[1], + FpE::from(&(MERSENNE_31_PRIME_FIELD_ORDER - 2)) + ) + } + + #[test] + fn add_order() { + let a = Fp2E::new([-FpE::one(), FpE::one()]); + let b = Fp2E::new([ + FpE::from(2), + FpE::from(&(MERSENNE_31_PRIME_FIELD_ORDER - 2)), + ]); + let c = Fp2E::new([FpE::one(), -FpE::one()]); + assert_eq!(&a + &b, c) + } + + #[test] + fn add_equal_zero() { + let a = Fp2E::new([-FpE::one(), -FpE::one()]); + let b = Fp2E::new([FpE::one(), FpE::one()]); + assert_eq!(&a + &b, Fp2E::zero()) + } + + #[test] + fn add_plus_one() { + let a = Fp2E::new([FpE::one(), FpE::from(2)]); + let b = Fp2E::new([FpE::one(), FpE::one()]); + let c = Fp2E::new([FpE::from(2), FpE::from(3)]); + assert_eq!(&a + &b, c) + } + + #[test] + fn sub_real_one_sub_one_is_zero() { + assert_eq!(&Fp2E::one() - &Fp2E::one(), Fp2E::zero()) + } + + #[test] + fn sub_real_two_sub_two_is_zero() { + assert_eq!(&Fp2E::from(2) - &Fp2E::from(2), Fp2E::zero()) + } + + #[test] + fn sub_real_neg_one_sub_neg_one_is_zero() { + assert_eq!(Fp2E::one().neg() - Fp2E::one().neg(), Fp2E::zero()) + } + + #[test] + fn sub_real_two_sub_one_is_one() { + assert_eq!(Fp2E::from(2) - Fp2E::one(), Fp2E::one()) + } + + #[test] + fn sub_real_neg_one_sub_zero_is_neg_one() { + assert_eq!(Fp2E::one().neg() - Fp2E::zero(), Fp2E::one().neg()) + } + + #[test] + fn sub_complex_one_sub_one_is_zero() { + let one = Fp2E::new([FpE::zero(), FpE::one()]); + assert_eq!(&one - &one, Fp2E::zero()) + } + + #[test] + fn sub_complex_two_sub_two_is_zero() { + let two = Fp2E::new([FpE::zero(), FpE::from(2)]); + assert_eq!(&two - &two, Fp2E::zero()) + } + + #[test] + fn sub_complex_neg_one_sub_neg_one_is_zero() { + let neg_one = Fp2E::new([FpE::zero(), -FpE::one()]); + assert_eq!(&neg_one - &neg_one, Fp2E::zero()) + } + + #[test] + fn sub_complex_two_sub_one_is_one() { + let two = Fp2E::new([FpE::zero(), FpE::from(2)]); + let one = Fp2E::new([FpE::zero(), FpE::one()]); + assert_eq!(&two - &one, one) + } + + #[test] + fn sub_complex_neg_one_sub_zero_is_neg_one() { + let neg_one = Fp2E::new([FpE::zero(), -FpE::one()]); + assert_eq!(&neg_one - &Fp2E::zero(), neg_one) + } + + #[test] + fn mul_fp2_is_correct() { + let a = Fp2E::new([FpE::from(2), FpE::from(2)]); + let b = Fp2E::new([FpE::from(4), FpE::from(5)]); + let c = Fp2E::new([-FpE::from(2), FpE::from(18)]); + assert_eq!(&a * &b, c) + } + + #[test] + fn square_equals_mul_by_itself() { + let a = Fp2E::new([FpE::from(2), FpE::from(3)]); + assert_eq!(a.square(), &a * &a) + } + + #[test] + fn test_fp2_add() { + let a = Fp2E::new([FpE::from(0), FpE::from(3)]); + let b = Fp2E::new([-FpE::from(2), FpE::from(8)]); + let expected_result = Fp2E::new([FpE::from(0) - FpE::from(2), FpE::from(3) + FpE::from(8)]); + assert_eq!(a + b, expected_result); + } + + #[test] + fn test_fp2_add_2() { + let a = Fp2E::new([FpE::from(2), FpE::from(4)]); + let b = Fp2E::new([-FpE::from(2), -FpE::from(4)]); + let expected_result = Fp2E::new([FpE::from(2) - FpE::from(2), FpE::from(4) - FpE::from(4)]); + assert_eq!(a + b, expected_result); + } + + #[test] + fn test_fp2_add_3() { + let a = Fp2E::new([FpE::from(&MERSENNE_31_PRIME_FIELD_ORDER), FpE::from(1)]); + let b = Fp2E::new([FpE::from(1), FpE::from(&MERSENNE_31_PRIME_FIELD_ORDER)]); + let expected_result = Fp2E::new([FpE::from(1), FpE::from(1)]); + assert_eq!(a + b, expected_result); + } + + #[test] + fn test_fp2_sub() { + let a = Fp2E::new([FpE::from(0), FpE::from(3)]); + let b = Fp2E::new([-FpE::from(2), FpE::from(8)]); + let expected_result = Fp2E::new([FpE::from(0) + FpE::from(2), FpE::from(3) - FpE::from(8)]); + assert_eq!(a - b, expected_result); + } + + #[test] + fn test_fp2_sub_2() { + let a = Fp2E::new([FpE::zero(), FpE::from(&MERSENNE_31_PRIME_FIELD_ORDER)]); + let b = Fp2E::new([FpE::one(), -FpE::one()]); + let expected_result = + Fp2E::new([FpE::from(&(MERSENNE_31_PRIME_FIELD_ORDER - 1)), FpE::one()]); + assert_eq!(a - b, expected_result); + } + + #[test] + fn test_fp2_sub_3() { + let a = Fp2E::new([FpE::from(5), FpE::from(&MERSENNE_31_PRIME_FIELD_ORDER)]); + let b = Fp2E::new([FpE::from(5), FpE::from(&MERSENNE_31_PRIME_FIELD_ORDER)]); + let expected_result = Fp2E::new([FpE::zero(), FpE::zero()]); + assert_eq!(a - b, expected_result); + } + + #[test] + fn test_fp2_mul() { + let a = Fp2E::new([FpE::from(12), FpE::from(5)]); + let b = Fp2E::new([-FpE::from(4), FpE::from(2)]); + let expected_result = Fp2E::new([-FpE::from(58), FpE::new(4)]); + assert_eq!(a * b, expected_result); + } + + #[test] + fn test_fp2_mul_2() { + let a = Fp2E::new([FpE::one(), FpE::zero()]); + let b = Fp2E::new([FpE::from(12), -FpE::from(8)]); + let expected_result = Fp2E::new([FpE::from(12), -FpE::new(8)]); + assert_eq!(a * b, expected_result); + } + + #[test] + fn test_fp2_mul_3() { + let a = Fp2E::new([FpE::zero(), FpE::zero()]); + let b = Fp2E::new([FpE::from(2), FpE::from(7)]); + let expected_result = Fp2E::new([FpE::zero(), FpE::zero()]); + assert_eq!(a * b, expected_result); + } + + #[test] + fn test_fp2_mul_4() { + let a = Fp2E::new([FpE::from(2), FpE::from(7)]); + let b = Fp2E::new([FpE::zero(), FpE::zero()]); + let expected_result = Fp2E::new([FpE::zero(), FpE::zero()]); + assert_eq!(a * b, expected_result); + } + + #[test] + fn test_fp2_mul_5() { + let a = Fp2E::new([FpE::from(&MERSENNE_31_PRIME_FIELD_ORDER), FpE::one()]); + let b = Fp2E::new([FpE::from(2), FpE::from(&MERSENNE_31_PRIME_FIELD_ORDER)]); + let expected_result = Fp2E::new([FpE::zero(), FpE::from(2)]); + assert_eq!(a * b, expected_result); + } + + #[test] + fn test_fp2_inv() { + let a = Fp2E::new([FpE::one(), FpE::zero()]); + let expected_result = Fp2E::new([FpE::one(), FpE::zero()]); + assert_eq!(a.inv().unwrap(), expected_result); + } + + #[test] + fn test_fp2_inv_2() { + let a = Fp2E::new([FpE::from(&(MERSENNE_31_PRIME_FIELD_ORDER - 1)), FpE::one()]); + let expected_result = Fp2E::new([FpE::from(1073741823), FpE::from(1073741823)]); + assert_eq!(a.inv().unwrap(), expected_result); + } + + #[test] + fn test_fp2_inv_3() { + let a = Fp2E::new([FpE::from(2063384121), FpE::from(1232183486)]); + let expected_result = Fp2E::new([FpE::from(1244288232), FpE::from(1321511038)]); + assert_eq!(a.inv().unwrap(), expected_result); + } + + #[test] + fn test_fp2_mul_inv() { + let a = Fp2E::new([FpE::from(12), FpE::from(5)]); + let b = a.inv().unwrap(); + let expected_result = Fp2E::new([FpE::one(), FpE::zero()]); + assert_eq!(a * b, expected_result); + } + + #[test] + fn test_fp2_div() { + let a = Fp2E::new([FpE::from(12), FpE::from(5)]); + let b = Fp2E::new([FpE::from(4), FpE::from(2)]); + let expected_result = Fp2E::new([FpE::from(644245097), FpE::from(1288490188)]); + assert_eq!(a / b, expected_result); + } + + #[test] + fn test_fp2_div_2() { + let a = Fp2E::new([FpE::from(4), FpE::from(7)]); + let b = Fp2E::new([FpE::one(), FpE::zero()]); + let expected_result = Fp2E::new([FpE::from(4), FpE::from(7)]); + assert_eq!(a / b, expected_result); + } + + #[test] + fn test_fp2_div_3() { + let a = Fp2E::new([FpE::zero(), FpE::zero()]); + let b = Fp2E::new([FpE::from(3), FpE::from(12)]); + let expected_result = Fp2E::new([FpE::zero(), FpE::zero()]); + assert_eq!(a / b, expected_result); + } + + #[test] + fn mul_fp4_by_zero_is_zero() { + let a = Fp4E::new([ + Fp2E::new([FpE::from(2), FpE::from(3)]), + Fp2E::new([FpE::from(4), FpE::from(5)]), + ]); + assert_eq!(Fp4E::zero(), a * Fp4E::zero()) + } + + #[test] + fn mul_fp4_by_one_is_identity() { + let a = Fp4E::new([ + Fp2E::new([FpE::from(2), FpE::from(3)]), + Fp2E::new([FpE::from(4), FpE::from(5)]), + ]); + assert_eq!(a, a.clone() * Fp4E::one()) + } + + #[test] + fn square_fp4_equals_mul_two_times() { + let a = Fp4E::new([ + Fp2E::new([FpE::from(3), FpE::from(4)]), + Fp2E::new([FpE::from(5), FpE::from(6)]), + ]); + + assert_eq!(a.square(), &a * &a) + } + + #[test] + fn fp4_mul_by_inv_is_one() { + let a = Fp4E::new([ + Fp2E::new([FpE::from(2147483647), FpE::from(2147483648)]), + Fp2E::new([FpE::from(2147483649), FpE::from(2147483650)]), + ]); + + assert_eq!(&a * a.inv().unwrap(), Fp4E::one()) + } + + #[test] + fn embed_fp_with_fp4() { + let a = FpE::from(3); + let a_extension = Fp4E::from(3); + assert_eq!(a.to_extension::(), a_extension); + } + + #[test] + fn add_fp_and_fp4() { + let a = FpE::from(3); + let a_extension = Fp4E::from(3); + let b = Fp4E::from(2); + assert_eq!(a + &b, a_extension + b); + } + + #[test] + fn mul_fp_by_fp4() { + let a = FpE::from(30000000000); + let a_extension = a.to_extension::(); + let b = Fp4E::new([ + Fp2E::new([FpE::from(1), FpE::from(2)]), + Fp2E::new([FpE::from(3), FpE::from(4)]), + ]); + assert_eq!(a * &b, a_extension * b); + } +} diff --git a/math/src/field/fields/mersenne31/field.rs b/math/src/field/fields/mersenne31/field.rs index e4abfab0f..1c8b2dc58 100644 --- a/math/src/field/fields/mersenne31/field.rs +++ b/math/src/field/fields/mersenne31/field.rs @@ -42,6 +42,29 @@ impl Mersenne31Field { // Delayed reduction Self::from_u64(iter.map(|x| (x as u64)).sum::()) } + + /// Computes a * 2^k, with 0 < k < 31 + pub fn mul_power_two(a: u32, k: u32) -> u32 { + let msb = (a & (u32::MAX << (31 - k))) >> (31 - k); // The k + 1 msf shifted right . + let lsb = (a & (u32::MAX >> (k + 1))) << k; // The 31 - k lsb shifted left. + Self::weak_reduce(msb + lsb) + } + + pub fn pow_2(a: &u32, order: u32) -> u32 { + let mut res = *a; + (0..order).for_each(|_| res = Self::square(&res)); + res + } + + /// TODO: See if we can optimize this function. + /// Computes 2a^2 - 1 + pub fn two_square_minus_one(a: &u32) -> u32 { + if *a == 0 { + MERSENNE_31_PRIME_FIELD_ORDER - 1 + } else { + Self::from_u64(((u64::from(*a) * u64::from(*a)) << 1) - 1) + } + } } pub const MERSENNE_31_PRIME_FIELD_ORDER: u32 = (1 << 31) - 1; @@ -54,18 +77,9 @@ impl IsField for Mersenne31Field { /// Returns the sum of `a` and `b`. fn add(a: &u32, b: &u32) -> u32 { - // Avoids conditional https://github.com/Plonky3/Plonky3/blob/6049a30c3b1f5351c3eb0f7c994dc97e8f68d10d/mersenne-31/src/lib.rs#L249 - // Working with i32 means we get a flag which informs us if overflow happens - let (sum_i32, over) = (*a as i32).overflowing_add(*b as i32); - let sum_u32 = sum_i32 as u32; - let sum_corr = sum_u32.wrapping_sub(MERSENNE_31_PRIME_FIELD_ORDER); - - //assert 31 bit clear - // If self + rhs did not overflow, return it. - // If self + rhs overflowed, sum_corr = self + rhs - (2**31 - 1). - let sum = if over { sum_corr } else { sum_u32 }; - debug_assert!((sum >> 31) == 0); - Self::as_representative(&sum) + // We are using that if a and b are field elements of Mersenne31, then + // a + b has at most 32 bits, so we can use the weak_reduce function to take mudulus p. + Self::weak_reduce(a + b) } /// Returns the multiplication of `a` and `b`. @@ -75,13 +89,7 @@ impl IsField for Mersenne31Field { } fn sub(a: &u32, b: &u32) -> u32 { - let (mut sub, over) = a.overflowing_sub(*b); - - // If we didn't overflow we have the correct value. - // Otherwise we have added 2**32 = 2**31 + 1 mod 2**31 - 1. - // Hence we need to remove the most significant bit and subtract 1. - sub -= over as u32; - sub & MERSENNE_31_PRIME_FIELD_ORDER + Self::weak_reduce(a + MERSENNE_31_PRIME_FIELD_ORDER - b) } /// Returns the additive inverse of `a`. @@ -91,20 +99,20 @@ impl IsField for Mersenne31Field { } /// Returns the multiplicative inverse of `a`. - fn inv(a: &u32) -> Result { - if *a == Self::zero() || *a == MERSENNE_31_PRIME_FIELD_ORDER { + fn inv(x: &u32) -> Result { + if *x == Self::zero() || *x == MERSENNE_31_PRIME_FIELD_ORDER { return Err(FieldError::InvZeroError); } - let p101 = Self::mul(&Self::pow(a, 4u32), a); + let p101 = Self::mul(&Self::pow_2(x, 2), x); let p1111 = Self::mul(&Self::square(&p101), &p101); - let p11111111 = Self::mul(&Self::pow(&p1111, 16u32), &p1111); - let p111111110000 = Self::pow(&p11111111, 16u32); + let p11111111 = Self::mul(&Self::pow_2(&p1111, 4u32), &p1111); + let p111111110000 = Self::pow_2(&p11111111, 4u32); let p111111111111 = Self::mul(&p111111110000, &p1111); - let p1111111111111111 = Self::mul(&Self::pow(&p111111110000, 16u32), &p11111111); + let p1111111111111111 = Self::mul(&Self::pow_2(&p111111110000, 4u32), &p11111111); let p1111111111111111111111111111 = - Self::mul(&Self::pow(&p1111111111111111, 4096u32), &p111111111111); + Self::mul(&Self::pow_2(&p1111111111111111, 12u32), &p111111111111); let p1111111111111111111111111111101 = - Self::mul(&Self::pow(&p1111111111111111111111111111, 8u32), &p101); + Self::mul(&Self::pow_2(&p1111111111111111111111111111, 3u32), &p101); Ok(p1111111111111111111111111111101) } @@ -120,7 +128,7 @@ impl IsField for Mersenne31Field { } /// Returns the additive neutral element. - fn zero() -> Self::BaseType { + fn zero() -> u32 { 0u32 } @@ -131,16 +139,7 @@ impl IsField for Mersenne31Field { /// Returns the element `x * 1` where 1 is the multiplicative neutral element. fn from_u64(x: u64) -> u32 { - let (lo, hi) = (x as u32 as u64, x >> 32); - // 2^32 = 2 (mod Mersenne 31 bit prime) - // t <= (2^32 - 1) + 2 * (2^32 - 1) = 3 * 2^32 - 3 = 6 * 2^31 - 3 - let t = lo + 2 * hi; - - const MASK: u64 = (1 << 31) - 1; - let (lo, hi) = ((t & MASK) as u32, (t >> 31) as u32); - // 2^31 = 1 mod Mersenne31 - // lo < 2^31, hi < 6, so lo + hi < 2^32. - Self::weak_reduce(lo + hi) + (((((x >> 31) + x + 1) >> 31) + x) & (MERSENNE_31_PRIME_FIELD_ORDER as u64)) as u32 } /// Takes as input an element of BaseType and returns the internal representation @@ -148,6 +147,9 @@ impl IsField for Mersenne31Field { fn from_base_type(x: u32) -> u32 { Self::weak_reduce(x) } + fn double(a: &u32) -> u32 { + Self::weak_reduce(a << 1) + } } impl IsPrimeField for Mersenne31Field { @@ -205,12 +207,45 @@ impl Display for FieldElement { mod tests { use super::*; type F = Mersenne31Field; + type FE = FieldElement; + + #[test] + fn mul_power_two_is_correct() { + let a = 3u32; + let k = 2; + let expected_result = FE::from(&a) * FE::from(2).pow(k as u16); + let result = F::mul_power_two(a, k); + assert_eq!(FE::from(&result), expected_result) + } + + #[test] + fn mul_power_two_is_correct_2() { + let a = 229287u32; + let k = 4; + let expected_result = FE::from(&a) * FE::from(2).pow(k as u16); + let result = F::mul_power_two(a, k); + assert_eq!(FE::from(&result), expected_result) + } + + #[test] + fn pow_2_is_correct() { + let a = 3u32; + let order = 12; + let result = F::pow_2(&a, order); + let expected_result = FE::pow(&FE::from(&a), 4096u32); + assert_eq!(FE::from(&result), expected_result) + } #[test] fn from_hex_for_b_is_11() { assert_eq!(F::from_hex("B").unwrap(), 11); } + #[test] + fn from_hex_for_b_is_11_v2() { + assert_eq!(FE::from_hex("B").unwrap(), FE::from(11)); + } + #[test] fn sum_delayed_reduction() { let up_to = u32::pow(2, 16); @@ -236,190 +271,195 @@ mod tests { #[test] fn one_plus_1_is_2() { - let a = F::one(); - let b = F::one(); - let c = F::add(&a, &b); - assert_eq!(c, 2u32); + assert_eq!(FE::one() + FE::one(), FE::from(&2u32)); } #[test] fn neg_1_plus_1_is_0() { - let a = F::neg(&F::one()); - let b = F::one(); - let c = F::add(&a, &b); - assert_eq!(c, F::zero()); + assert_eq!(-FE::one() + FE::one(), FE::zero()); } #[test] fn neg_1_plus_2_is_1() { - let a = F::neg(&F::one()); - let b = F::from_base_type(2u32); - let c = F::add(&a, &b); - assert_eq!(c, F::one()); + assert_eq!(-FE::one() + FE::from(&2u32), FE::one()); } #[test] fn max_order_plus_1_is_0() { - let a = F::from_base_type(MERSENNE_31_PRIME_FIELD_ORDER - 1); - let b = F::one(); - let c = F::add(&a, &b); - assert_eq!(c, F::zero()); + assert_eq!( + FE::from(&(MERSENNE_31_PRIME_FIELD_ORDER - 1)) + FE::from(1), + FE::from(0) + ); } #[test] fn comparing_13_and_13_are_equal() { - let a = F::from_base_type(13); - let b = F::from_base_type(13); - assert_eq!(a, b); + assert_eq!(FE::from(&13u32), FE::from(13)); } #[test] fn comparing_13_and_8_they_are_not_equal() { - let a = F::from_base_type(13); - let b = F::from_base_type(8); - assert_ne!(a, b); + assert_ne!(FE::from(&13u32), FE::from(8)); } #[test] fn one_sub_1_is_0() { - let a = F::one(); - let b = F::one(); - let c = F::sub(&a, &b); - assert_eq!(c, F::zero()); + assert_eq!(FE::one() - FE::one(), FE::zero()); } #[test] fn zero_sub_1_is_order_minus_1() { - let a = F::zero(); - let b = F::one(); - let c = F::sub(&a, &b); - assert_eq!(c, MERSENNE_31_PRIME_FIELD_ORDER - 1); + assert_eq!( + FE::zero() - FE::one(), + FE::from(&(MERSENNE_31_PRIME_FIELD_ORDER - 1)) + ); } #[test] fn neg_1_sub_neg_1_is_0() { - let a = F::neg(&F::one()); - let b = F::neg(&F::one()); - let c = F::sub(&a, &b); - assert_eq!(c, F::zero()); + assert_eq!(-FE::one() - (-FE::one()), FE::zero()); } #[test] - fn neg_1_sub_1_is_neg_1() { - let a = F::neg(&F::one()); - let b = F::zero(); - let c = F::sub(&a, &b); - assert_eq!(c, F::neg(&F::one())); + fn neg_1_sub_0_is_neg_1() { + assert_eq!(-FE::one() - FE::zero(), -FE::one()); } #[test] fn mul_neutral_element() { - let a = F::from_base_type(1); - let b = F::from_base_type(2); - let c = F::mul(&a, &b); - assert_eq!(c, F::from_base_type(2)); + assert_eq!(FE::one() * FE::from(&2u32), FE::from(&2u32)); } #[test] fn mul_2_3_is_6() { - let a = F::from_base_type(2); - let b = F::from_base_type(3); - assert_eq!(a * b, F::from_base_type(6)); + assert_eq!(FE::from(&2u32) * FE::from(&3u32), FE::from(&6u32)); } #[test] fn mul_order_neg_1() { - let a = F::from_base_type(MERSENNE_31_PRIME_FIELD_ORDER - 1); - let b = F::from_base_type(MERSENNE_31_PRIME_FIELD_ORDER - 1); - let c = F::mul(&a, &b); - assert_eq!(c, F::from_base_type(1)); + assert_eq!( + FE::from(MERSENNE_31_PRIME_FIELD_ORDER as u64 - 1) + * FE::from(MERSENNE_31_PRIME_FIELD_ORDER as u64 - 1), + FE::one() + ); } #[test] fn pow_p_neg_1() { assert_eq!( - F::pow(&F::from_base_type(2), MERSENNE_31_PRIME_FIELD_ORDER - 1), - F::one() + FE::pow(&FE::from(&2u32), MERSENNE_31_PRIME_FIELD_ORDER - 1), + FE::one() ) } #[test] fn inv_0_error() { - let result = F::inv(&F::zero()); + let result = FE::inv(&FE::zero()); assert!(matches!(result, Err(FieldError::InvZeroError))); } #[test] fn inv_2() { - let result = F::inv(&F::from_base_type(2u32)).unwrap(); + let result = FE::inv(&FE::from(&2u32)).unwrap(); // sage: 1 / F(2) = 1073741824 - assert_eq!(result, 1073741824); + assert_eq!(result, FE::from(1073741824)); } #[test] fn pow_2_3() { - assert_eq!(F::pow(&F::from_base_type(2), 3_u64), 8) + assert_eq!(FE::pow(&FE::from(&2u32), 3u64), FE::from(8)); } #[test] fn div_1() { - assert_eq!(F::div(&F::from_base_type(2), &F::from_base_type(1)), 2) + assert_eq!(FE::from(&2u32) / FE::from(&1u32), FE::from(&2u32)); } #[test] fn div_4_2() { - assert_eq!(F::div(&F::from_base_type(4), &F::from_base_type(2)), 2) + assert_eq!(FE::from(&4u32) / FE::from(&2u32), FE::from(&2u32)); } - // 1431655766 #[test] fn div_4_3() { // sage: F(4) / F(3) = 1431655766 - assert_eq!( - F::div(&F::from_base_type(4), &F::from_base_type(3)), - 1431655766 - ) + assert_eq!(FE::from(&4u32) / FE::from(&3u32), FE::from(1431655766)); } #[test] fn two_plus_its_additive_inv_is_0() { - let two = F::from_base_type(2); - - assert_eq!(F::add(&two, &F::neg(&two)), F::zero()) + assert_eq!(FE::from(&2u32) + (-FE::from(&2u32)), FE::zero()); } #[test] fn from_u64_test() { - let num = F::from_u64(1u64); - assert_eq!(num, F::one()); + assert_eq!(FE::from(1u64), FE::one()); } #[test] fn creating_a_field_element_from_its_representative_returns_the_same_element_1() { - let change = 1; - let f1 = F::from_base_type(MERSENNE_31_PRIME_FIELD_ORDER + change); - let f2 = F::from_base_type(Mersenne31Field::representative(&f1)); + let change: u32 = MERSENNE_31_PRIME_FIELD_ORDER + 1; + let f1 = FE::from(&change); + let f2 = FE::from(&FE::representative(&f1)); assert_eq!(f1, f2); } #[test] fn creating_a_field_element_from_its_representative_returns_the_same_element_2() { - let change = 8; - let f1 = F::from_base_type(MERSENNE_31_PRIME_FIELD_ORDER + change); - let f2 = F::from_base_type(Mersenne31Field::representative(&f1)); + let change: u32 = MERSENNE_31_PRIME_FIELD_ORDER + 8; + let f1 = FE::from(&change); + let f2 = FE::from(&FE::representative(&f1)); assert_eq!(f1, f2); } #[test] fn from_base_type_test() { - let b = F::from_base_type(1u32); - assert_eq!(b, F::one()); + assert_eq!(FE::from(&1u32), FE::one()); } #[cfg(feature = "std")] #[test] fn to_hex_test() { - let num = F::from_hex("B").unwrap(); - assert_eq!(F::to_hex(&num), "B"); + let num = FE::from_hex("B").unwrap(); + assert_eq!(FE::to_hex(&num), "B"); + } + + #[test] + fn double_equals_add_itself() { + let a = FE::from(1234); + assert_eq!(a + a, a.double()) + } + + #[test] + fn two_square_minus_one_is_correct() { + let a = FE::from(2147483650); + assert_eq!( + FE::from(&F::two_square_minus_one(a.value())), + a.square().double() - FE::one() + ) + } + + #[test] + fn two_square_zero_minus_one_is_minus_one() { + let a = FE::from(0); + assert_eq!( + FE::from(&F::two_square_minus_one(a.value())), + a.square().double() - FE::one() + ) + } + + #[test] + fn two_square_p_minus_one_is_minus_one() { + let a = FE::from(&MERSENNE_31_PRIME_FIELD_ORDER); + assert_eq!( + FE::from(&F::two_square_minus_one(a.value())), + a.square().double() - FE::one() + ) + } + + #[test] + fn mul_by_inv() { + let x = 3476715743_u32; + assert_eq!(FE::from(&x).inv().unwrap() * FE::from(&x), FE::one()); } } diff --git a/math/src/field/fields/mersenne31/mod.rs b/math/src/field/fields/mersenne31/mod.rs index 4bfd3daf1..2272e7d5e 100644 --- a/math/src/field/fields/mersenne31/mod.rs +++ b/math/src/field/fields/mersenne31/mod.rs @@ -1,2 +1,2 @@ -pub mod extension; +pub mod extensions; pub mod field; diff --git a/math/src/field/fields/p448_goldilocks_prime_field.rs b/math/src/field/fields/p448_goldilocks_prime_field.rs index 9f67fd1d2..eadd61541 100644 --- a/math/src/field/fields/p448_goldilocks_prime_field.rs +++ b/math/src/field/fields/p448_goldilocks_prime_field.rs @@ -16,7 +16,7 @@ pub const P448_GOLDILOCKS_PRIME_FIELD_ORDER: U448 = /// 448-bit unsigned integer represented as /// a size 8 `u64` array `limbs` of 56-bit words. /// The least significant word is in the left most position. -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)] pub struct U56x8 { limbs: [u64; 8], } diff --git a/math/src/field/traits.rs b/math/src/field/traits.rs index a9a39d2dd..320531f42 100644 --- a/math/src/field/traits.rs +++ b/math/src/field/traits.rs @@ -99,9 +99,9 @@ pub trait IsField: Debug + Clone { /// The underlying base type for representing elements from the field. // TODO: Relax Unpin for non cuda usage #[cfg(feature = "lambdaworks-serde-binary")] - type BaseType: Clone + Debug + Unpin + ByteConversion; + type BaseType: Clone + Debug + Unpin + ByteConversion + Default; #[cfg(not(feature = "lambdaworks-serde-binary"))] - type BaseType: Clone + Debug + Unpin; + type BaseType: Clone + Debug + Unpin + Default; /// Returns the sum of `a` and `b`. fn add(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType; @@ -173,7 +173,9 @@ pub trait IsField: Debug + Clone { fn eq(a: &Self::BaseType, b: &Self::BaseType) -> bool; /// Returns the additive neutral element. - fn zero() -> Self::BaseType; + fn zero() -> Self::BaseType { + Self::BaseType::default() + } /// Returns the multiplicative neutral element. fn one() -> Self::BaseType; diff --git a/math/src/unsigned_integer/element.rs b/math/src/unsigned_integer/element.rs index 021afd9a3..1613e192b 100644 --- a/math/src/unsigned_integer/element.rs +++ b/math/src/unsigned_integer/element.rs @@ -36,6 +36,14 @@ pub struct UnsignedInteger { pub limbs: [u64; NUM_LIMBS], } +impl Default for UnsignedInteger { + fn default() -> Self { + Self { + limbs: [0; NUM_LIMBS], + } + } +} + // NOTE: manually implementing `PartialOrd` may seem unorthodox, but the // derived implementation had terrible performance. #[allow(clippy::non_canonical_partial_ord_impl)]