Skip to content

Commit

Permalink
cleaner
Browse files Browse the repository at this point in the history
  • Loading branch information
benr-ml committed Sep 27, 2023
1 parent ed4d320 commit 6920913
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 72 deletions.
107 changes: 35 additions & 72 deletions fastcrypto-tbls/src/polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//

use crate::types::{IndexedValue, ShareIndex};
use fastcrypto::error::FastCryptoError;
use fastcrypto::error::{FastCryptoError, FastCryptoResult};
use fastcrypto::groups::{GroupElement, MultiScalarMul, Scalar};
use fastcrypto::traits::AllowedRng;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -82,13 +82,13 @@ impl<C: GroupElement> Poly<C> {
}
}

/// Given at least `t` polynomial evaluations, it will recover the polynomial's
/// constant term
pub fn recover_c0(t: u32, shares: &[Eval<C>]) -> Result<C, FastCryptoError> {
fn get_lagrange_coefficients(
t: u32,
shares: &[Eval<C>],
) -> FastCryptoResult<Vec<C::ScalarType>> {
if shares.len() < t.try_into().unwrap() {
return Err(FastCryptoError::InvalidInput);
}

// Check for duplicates.
let mut ids_set = HashSet::new();
shares.iter().map(|s| &s.index).for_each(|id| {
Expand All @@ -98,33 +98,36 @@ impl<C: GroupElement> Poly<C> {
return Err(FastCryptoError::InvalidInput);
}

// Iterate over all indices and for each multiply the lagrange basis
// with the value of the share.
let mut acc = C::zero();
for IndexedValue {
index: i,
value: share_i,
} in shares
{
let mut num = C::ScalarType::generator();
let mut den = C::ScalarType::generator();

for IndexedValue { index: j, value: _ } in shares {
if i == j {
continue;
};
// j - 0
num = num * C::ScalarType::from(j.get() as u64);
// 1 / (j - i)
den = den
* (C::ScalarType::from(j.get() as u64) - C::ScalarType::from(i.get() as u64));
}
// Next line is safe since i != j.
let inv = (C::ScalarType::generator() / den).unwrap();
acc += *share_i * num * inv;
let indices = shares
.iter()
.map(|s| C::ScalarType::from(s.index.get() as u64))
.collect::<Vec<_>>();

let full_numerator = indices
.iter()
.fold(C::ScalarType::generator(), |acc, i| acc * i);
let mut coeffs = Vec::new();
for i in &indices {
let denominator = indices
.iter()
.filter(|j| *j != i)
.fold(*i, |acc, j| acc * (*j - i));
let coeff = full_numerator / denominator;
coeffs.push(coeff.expect("safe since i != j"));
}
Ok(coeffs)
}

Ok(acc)
/// Given at least `t` polynomial evaluations, it will recover the polynomial's
/// constant term
pub fn recover_c0(t: u32, shares: &[Eval<C>]) -> Result<C, FastCryptoError> {
let coeffs = Self::get_lagrange_coefficients(t, shares)?;
let plain_shares = shares.iter().map(|s| s.value).collect::<Vec<_>>();
let res = coeffs
.iter()
.zip(plain_shares.iter())
.fold(C::zero(), |acc, (c, s)| acc + (*s * *c));
Ok(res)
}

/// Checks if a given share is valid.
Expand Down Expand Up @@ -177,48 +180,8 @@ impl<C: GroupElement + MultiScalarMul> Poly<C> {
/// Given at least `t` polynomial evaluations, it will recover the polynomial's
/// constant term
pub fn recover_c0_msm(t: u32, shares: &[Eval<C>]) -> Result<C, FastCryptoError> {
if shares.len() < t.try_into().unwrap() {
return Err(FastCryptoError::InvalidInput);
}

// Check for duplicates.
let mut ids_set = HashSet::new();
shares.iter().map(|s| &s.index).for_each(|id| {
ids_set.insert(id);
});
if ids_set.len() != t as usize {
return Err(FastCryptoError::InvalidInput);
}

// Iterate over all indices and for each multiply the lagrange basis
// with the value of the share.
let mut coeffs = Vec::new();
let mut plain_shares = Vec::new();
for IndexedValue {
index: i,
value: share_i,
} in shares
{
let mut num = C::ScalarType::generator();
let mut den = C::ScalarType::generator();

for IndexedValue { index: j, value: _ } in shares {
if i == j {
continue;
};
// j - 0
num = num * C::ScalarType::from(j.get() as u64); //opt

// 1 / (j - i)
den = den
* (C::ScalarType::from(j.get() as u64) - C::ScalarType::from(i.get() as u64));
//opt
}
// Next line is safe since i != j.
let inv = (C::ScalarType::generator() / den).unwrap();
coeffs.push(num * inv);
plain_shares.push(*share_i);
}
let coeffs = Self::get_lagrange_coefficients(t, shares)?;
let plain_shares = shares.iter().map(|s| s.value).collect::<Vec<_>>();
let res = C::multi_scalar_mul(&coeffs, &plain_shares).expect("sizes match");
Ok(res)
}
Expand Down
2 changes: 2 additions & 0 deletions fastcrypto-tbls/src/tests/polynomial_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,5 @@ fn interpolation_insufficient_shares() {

Poly::<RistrettoScalar>::recover_c0(threshold, &shares).unwrap_err();
}

// TODO: test recover_msm

0 comments on commit 6920913

Please sign in to comment.