diff --git a/fastcrypto-tbls/src/nodes.rs b/fastcrypto-tbls/src/nodes.rs index 149f02eebb..71e01caec3 100644 --- a/fastcrypto-tbls/src/nodes.rs +++ b/fastcrypto-tbls/src/nodes.rs @@ -8,6 +8,7 @@ use fastcrypto::error::{FastCryptoError, FastCryptoResult}; use fastcrypto::groups::GroupElement; use fastcrypto::hash::{Blake2b256, Digest, HashFunction}; use serde::{Deserialize, Serialize}; +use tracing::debug; pub type PartyId = u16; @@ -145,14 +146,28 @@ impl Nodes { /// - The precision loss, counted as the sum of the remainders of the division by d, is at most /// the allowed delta /// In practice, allowed delta will be the extra liveness we would assume above 2f+1. - pub fn reduce(&self, t: u16, allowed_delta: u16) -> (Self, u16) { + /// total_weight_lower_bound allows limiting the level of reduction (e.g., in benchmarks). To get the best results, + /// set it to 1. + pub fn reduce(&self, t: u16, allowed_delta: u16, total_weight_lower_bound: u32) -> (Self, u16) { + assert!(total_weight_lower_bound <= self.total_weight && total_weight_lower_bound > 0); let mut max_d = 1; for d in 2..=40 { - let sum = self.nodes.iter().map(|n| n.weight % d).sum::(); - if sum <= allowed_delta { + // Break if we reached the lower bound. + let new_total_weight = self.nodes.iter().map(|n| n.weight / d).sum::(); + if new_total_weight < total_weight_lower_bound as u16 { + break; + } + // Compute the precision loss. + let delta = self.nodes.iter().map(|n| n.weight % d).sum::(); + if delta <= allowed_delta { max_d = d; } } + debug!( + "Nodes::reduce reducing from {} with max_d {}, allowed_delta {}, total_weight_lower_bound {}", + self.total_weight, max_d, allowed_delta, total_weight_lower_bound + ); + let nodes = self .nodes .iter() diff --git a/fastcrypto-tbls/src/tests/nodes_tests.rs b/fastcrypto-tbls/src/tests/nodes_tests.rs index 11b2c97d0b..90958633dd 100644 --- a/fastcrypto-tbls/src/tests/nodes_tests.rs +++ b/fastcrypto-tbls/src/tests/nodes_tests.rs @@ -188,12 +188,12 @@ fn test_reduce() { let t = (nodes.total_weight() / 3) as u16; // No extra gap, should return the inputs - let (new_nodes, new_t) = nodes.reduce(t, 1); + let (new_nodes, new_t) = nodes.reduce(t, 1, 1); assert_eq!(nodes, new_nodes); assert_eq!(t, new_t); // 10% gap - let (new_nodes, _new_t) = nodes.reduce(t, (nodes.total_weight() / 10) as u16); + let (new_nodes, _new_t) = nodes.reduce(t, (nodes.total_weight() / 10) as u16, 1); // Estimate the real factor d let d = nodes.iter().last().unwrap().weight / new_nodes.iter().last().unwrap().weight; // The loss per node is on average (d - 1) / 2 @@ -201,3 +201,27 @@ fn test_reduce() { assert!((d - 1) / 2 * number_of_nodes < ((nodes.total_weight() / 9) as u16)); } } + +#[test] +fn test_reduce_with_lower_bounds() { + let number_of_nodes = 100; + let node_vec = get_nodes::(number_of_nodes); + let nodes = Nodes::new(node_vec).unwrap(); + let t = (nodes.total_weight() / 3) as u16; + + // No extra gap, should return the inputs + let (new_nodes, new_t) = nodes.reduce(t, 1, 1); + assert_eq!(nodes, new_nodes); + assert_eq!(t, new_t); + + // 10% gap + let (new_nodes1, _new_t1) = nodes.reduce(t, (nodes.total_weight() / 20) as u16, 1); + let (new_nodes2, _new_t2) = nodes.reduce( + t, + (nodes.total_weight() / 20) as u16, + nodes.total_weight() / 3, + ); + assert!(new_nodes1.total_weight() < new_nodes2.total_weight()); + assert!(new_nodes2.total_weight() >= nodes.total_weight() / 3); + assert!(new_nodes2.total_weight() < nodes.total_weight()); +}