From 774756c297bb3a83ab620bceb5535dba1776a917 Mon Sep 17 00:00:00 2001 From: irfan Date: Thu, 7 Mar 2024 00:25:10 +0300 Subject: [PATCH] Delayed reduction for M31 summations - from Plonky3 (#810) * add delayed reduction * change monolith to use delayed reduction instead of naive .reduce to sum elements --- crypto/src/hash/monolith/utils.rs | 6 +----- math/src/field/fields/mersenne31/field.rs | 26 +++++++++++++++++------ 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/crypto/src/hash/monolith/utils.rs b/crypto/src/hash/monolith/utils.rs index 378200a40..a10968acd 100644 --- a/crypto/src/hash/monolith/utils.rs +++ b/crypto/src/hash/monolith/utils.rs @@ -24,11 +24,7 @@ fn random_field_element(shake: &mut Shake128Reader) -> u32 { } pub fn dot_product(u: &[u32], v: &[u32]) -> u32 { - u.iter() - .zip(v) - .map(|(x, y)| F::mul(x, y)) - .reduce(|a, b| F::add(&a, &b)) - .unwrap() + Mersenne31Field::sum(u.iter().zip(v).map(|(x, y)| F::mul(x, y))) } pub fn get_random_y_i( diff --git a/math/src/field/fields/mersenne31/field.rs b/math/src/field/fields/mersenne31/field.rs index 0d85ba146..00bb72203 100644 --- a/math/src/field/fields/mersenne31/field.rs +++ b/math/src/field/fields/mersenne31/field.rs @@ -34,6 +34,14 @@ impl Mersenne31Field { *n } } + + #[inline] + pub fn sum::BaseType>>( + iter: I, + ) -> ::BaseType { + // Delayed reduction + Self::from_u64(iter.map(|x| (x as u64)).sum::()) + } } pub const MERSENNE_31_PRIME_FIELD_ORDER: u32 = (1 << 31) - 1; @@ -63,12 +71,7 @@ impl IsField for Mersenne31Field { /// Returns the multiplication of `a` and `b`. // Note: for powers of 2 we can perform bit shifting this would involve overriding the trait implementation fn mul(a: &u32, b: &u32) -> u32 { - let prod = u64::from(*a) * u64::from(*b); - let prod_lo = (prod as u32) & ((1 << 31) - 1); - let prod_hi = (prod >> 31) as u32; - //assert prod_hi and prod_lo 31 bit clear - debug_assert!((prod_lo >> 31) == 0 && (prod_hi >> 31) == 0); - Self::add(&prod_lo, &prod_hi) + Self::from_u64(u64::from(*a) * u64::from(*b)) } fn sub(a: &u32, b: &u32) -> u32 { @@ -201,7 +204,6 @@ impl Display for FieldElement { #[cfg(test)] mod tests { use super::*; - type F = Mersenne31Field; #[test] @@ -209,6 +211,16 @@ mod tests { assert_eq!(F::from_hex("B").unwrap(), 11); } + #[test] + fn sum_delayed_reduction() { + let up_to = u32::pow(2, 23); + let pow = u64::pow(2, 60); + + let iter = (0..up_to).map(F::weak_reduce).map(|e| F::pow(&e, pow)); + + assert_eq!(F::from_u64(1314320703), F::sum(iter)); + } + #[test] fn from_hex_for_0x1_a_is_26() { assert_eq!(F::from_hex("0x1a").unwrap(), 26);