diff --git a/fastcrypto-tbls/src/polynomial.rs b/fastcrypto-tbls/src/polynomial.rs index 93e61b22ef..95c5563ebe 100644 --- a/fastcrypto-tbls/src/polynomial.rs +++ b/fastcrypto-tbls/src/polynomial.rs @@ -6,7 +6,7 @@ // use crate::types::{IndexedValue, ShareIndex}; -use fastcrypto::error::FastCryptoError; +use fastcrypto::error::{FastCryptoError, FastCryptoResult}; use fastcrypto::groups::{GroupElement, MultiScalarMul, Scalar}; use fastcrypto::traits::AllowedRng; use serde::{Deserialize, Serialize}; @@ -82,13 +82,13 @@ impl Poly { } } - /// Given at least `t` polynomial evaluations, it will recover the polynomial's - /// constant term - pub fn recover_c0(t: u32, shares: &[Eval]) -> Result { + fn get_lagrange_coefficients( + t: u32, + shares: &[Eval], + ) -> FastCryptoResult> { if shares.len() < t.try_into().unwrap() { return Err(FastCryptoError::InvalidInput); } - // Check for duplicates. let mut ids_set = HashSet::new(); shares.iter().map(|s| &s.index).for_each(|id| { @@ -98,33 +98,36 @@ impl Poly { return Err(FastCryptoError::InvalidInput); } - // Iterate over all indices and for each multiply the lagrange basis - // with the value of the share. - let mut acc = C::zero(); - for IndexedValue { - index: i, - value: share_i, - } in shares - { - let mut num = C::ScalarType::generator(); - let mut den = C::ScalarType::generator(); - - for IndexedValue { index: j, value: _ } in shares { - if i == j { - continue; - }; - // j - 0 - num = num * C::ScalarType::from(j.get() as u64); - // 1 / (j - i) - den = den - * (C::ScalarType::from(j.get() as u64) - C::ScalarType::from(i.get() as u64)); - } - // Next line is safe since i != j. - let inv = (C::ScalarType::generator() / den).unwrap(); - acc += *share_i * num * inv; + let indices = shares + .iter() + .map(|s| C::ScalarType::from(s.index.get() as u64)) + .collect::>(); + + let full_numerator = indices + .iter() + .fold(C::ScalarType::generator(), |acc, i| acc * i); + let mut coeffs = Vec::new(); + for i in &indices { + let denominator = indices + .iter() + .filter(|j| *j != i) + .fold(*i, |acc, j| acc * (*j - i)); + let coeff = full_numerator / denominator; + coeffs.push(coeff.expect("safe since i != j")); } + Ok(coeffs) + } - Ok(acc) + /// Given at least `t` polynomial evaluations, it will recover the polynomial's + /// constant term + pub fn recover_c0(t: u32, shares: &[Eval]) -> Result { + let coeffs = Self::get_lagrange_coefficients(t, shares)?; + let plain_shares = shares.iter().map(|s| s.value).collect::>(); + let res = coeffs + .iter() + .zip(plain_shares.iter()) + .fold(C::zero(), |acc, (c, s)| acc + (*s * *c)); + Ok(res) } /// Checks if a given share is valid. @@ -177,48 +180,8 @@ impl Poly { /// Given at least `t` polynomial evaluations, it will recover the polynomial's /// constant term pub fn recover_c0_msm(t: u32, shares: &[Eval]) -> Result { - if shares.len() < t.try_into().unwrap() { - return Err(FastCryptoError::InvalidInput); - } - - // Check for duplicates. - let mut ids_set = HashSet::new(); - shares.iter().map(|s| &s.index).for_each(|id| { - ids_set.insert(id); - }); - if ids_set.len() != t as usize { - return Err(FastCryptoError::InvalidInput); - } - - // Iterate over all indices and for each multiply the lagrange basis - // with the value of the share. - let mut coeffs = Vec::new(); - let mut plain_shares = Vec::new(); - for IndexedValue { - index: i, - value: share_i, - } in shares - { - let mut num = C::ScalarType::generator(); - let mut den = C::ScalarType::generator(); - - for IndexedValue { index: j, value: _ } in shares { - if i == j { - continue; - }; - // j - 0 - num = num * C::ScalarType::from(j.get() as u64); //opt - - // 1 / (j - i) - den = den - * (C::ScalarType::from(j.get() as u64) - C::ScalarType::from(i.get() as u64)); - //opt - } - // Next line is safe since i != j. - let inv = (C::ScalarType::generator() / den).unwrap(); - coeffs.push(num * inv); - plain_shares.push(*share_i); - } + let coeffs = Self::get_lagrange_coefficients(t, shares)?; + let plain_shares = shares.iter().map(|s| s.value).collect::>(); let res = C::multi_scalar_mul(&coeffs, &plain_shares).expect("sizes match"); Ok(res) } diff --git a/fastcrypto-tbls/src/tests/polynomial_tests.rs b/fastcrypto-tbls/src/tests/polynomial_tests.rs index f056466fc7..bd8fd6c858 100644 --- a/fastcrypto-tbls/src/tests/polynomial_tests.rs +++ b/fastcrypto-tbls/src/tests/polynomial_tests.rs @@ -79,3 +79,5 @@ fn interpolation_insufficient_shares() { Poly::::recover_c0(threshold, &shares).unwrap_err(); } + +// TODO: test recover_msm