Skip to content

Commit

Permalink
Delayed reduction for M31 summations - from Plonky3 (#810)
Browse files Browse the repository at this point in the history
* add delayed reduction

* change monolith to use delayed reduction instead of naive .reduce to sum elements
  • Loading branch information
irfanbozkurt authored Mar 6, 2024
1 parent 601ab67 commit 774756c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
6 changes: 1 addition & 5 deletions crypto/src/hash/monolith/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@ fn random_field_element(shake: &mut Shake128Reader) -> u32 {
}

pub fn dot_product(u: &[u32], v: &[u32]) -> u32 {
u.iter()
.zip(v)
.map(|(x, y)| F::mul(x, y))
.reduce(|a, b| F::add(&a, &b))
.unwrap()
Mersenne31Field::sum(u.iter().zip(v).map(|(x, y)| F::mul(x, y)))
}

pub fn get_random_y_i(
Expand Down
26 changes: 19 additions & 7 deletions math/src/field/fields/mersenne31/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ impl Mersenne31Field {
*n
}
}

#[inline]
pub fn sum<I: Iterator<Item = <Self as IsField>::BaseType>>(
iter: I,
) -> <Self as IsField>::BaseType {
// Delayed reduction
Self::from_u64(iter.map(|x| (x as u64)).sum::<u64>())
}
}

pub const MERSENNE_31_PRIME_FIELD_ORDER: u32 = (1 << 31) - 1;
Expand Down Expand Up @@ -63,12 +71,7 @@ impl IsField for Mersenne31Field {
/// Returns the multiplication of `a` and `b`.
// Note: for powers of 2 we can perform bit shifting this would involve overriding the trait implementation
fn mul(a: &u32, b: &u32) -> u32 {
let prod = u64::from(*a) * u64::from(*b);
let prod_lo = (prod as u32) & ((1 << 31) - 1);
let prod_hi = (prod >> 31) as u32;
//assert prod_hi and prod_lo 31 bit clear
debug_assert!((prod_lo >> 31) == 0 && (prod_hi >> 31) == 0);
Self::add(&prod_lo, &prod_hi)
Self::from_u64(u64::from(*a) * u64::from(*b))
}

fn sub(a: &u32, b: &u32) -> u32 {
Expand Down Expand Up @@ -201,14 +204,23 @@ impl Display for FieldElement<Mersenne31Field> {
#[cfg(test)]
mod tests {
use super::*;

type F = Mersenne31Field;

#[test]
fn from_hex_for_b_is_11() {
assert_eq!(F::from_hex("B").unwrap(), 11);
}

#[test]
fn sum_delayed_reduction() {
let up_to = u32::pow(2, 23);
let pow = u64::pow(2, 60);

let iter = (0..up_to).map(F::weak_reduce).map(|e| F::pow(&e, pow));

assert_eq!(F::from_u64(1314320703), F::sum(iter));
}

#[test]
fn from_hex_for_0x1_a_is_26() {
assert_eq!(F::from_hex("0x1a").unwrap(), 26);
Expand Down

0 comments on commit 774756c

Please sign in to comment.