From 67d806ec97534f1ab01e1f5adc2e9b2a09431140 Mon Sep 17 00:00:00 2001 From: Davidson Souza Date: Mon, 29 Jul 2024 13:23:05 -0300 Subject: [PATCH] perf: remove some major bottlenecks in deletion Our deleting code had some major bottlenecks, like calling calculate hashes twice (one for the old roots and one for updated ones) and filtering the nodes that are in the proof, which was causing a linear search on proof_positins **for each element in nodes**. This commit fixes that by: - Creating a calculate_nodes_delete function that computes in parallel both updated and current roots, so we don't need to loop twice. In the future we could even use vectorized computation of sha512_256 two further speed things up. - Don't filter proof positions from the nodes vector in calculate_nodes_delete. This will waste a little memory as nodes in the proof are returned, but won't cause a new allocation as we are returning nodes as is. - calculate_hashes now won't delete, so it gets simplified to only work when checking or updating proofs. After running the new code for two days with floresta, the speedup is clear, with calculate_nodes_delete taking less than 2% of the CPU time for block validation. Before this path it was using >40%. --- src/accumulator/proof.rs | 178 +++++++++++++++++++++++++++++++++++---- src/accumulator/stump.rs | 56 ++++++------ 2 files changed, 191 insertions(+), 43 deletions(-) diff --git a/src/accumulator/proof.rs b/src/accumulator/proof.rs index 29607bd..86771c3 100644 --- a/src/accumulator/proof.rs +++ b/src/accumulator/proof.rs @@ -94,10 +94,18 @@ pub struct Proof { /// ``` pub hashes: Vec, } -/// We often need to return the targets paired with hashes, and the proof position. -/// Even not using full qualifications, it gets long and complex, and clippy doesn't like -/// it. This type alias helps with that. -pub(crate) type EnumeratedTargetsAndHashPosition = (Vec<(u64, NodeHash)>, Vec); + +// We often need to return the targets paired with hashes, and the proof position. +// Even not using full qualifications, it gets long and complex, and clippy doesn't like +// it. These type alias helps with that. + +/// This alias is used when we need to return the nodes and roots for a proof +/// if we are not concerned with deleting those elements. +pub(crate) type NodesAndRootsCurrent = (Vec<(u64, NodeHash)>, Vec); +/// This is used when we need to return the nodes and roots for a proof +/// if we are concerned with deleting those elements. The difference is that +/// we need to retun the old and updatated roots in the accumulator. +pub(crate) type NodesAndRootsOldNew = (Vec<(u64, NodeHash)>, Vec<(NodeHash, NodeHash)>); impl Proof { /// Creates a proof from a vector of target and hashes. @@ -340,17 +348,106 @@ impl Proof { self.targets.len() } + /// This function computes a set of roots from the proof + /// + /// It will compute all roots that contains elements in the proof, by hasing the nodes + /// in the path to the root. This function returns the calculated roots and the hashes + /// that were calculated in the process. + /// This function is used for updating the accumulator **and** verifying proofs. It returns + /// the roots computed from the proof (that should be equal to some roots in the present + /// accumulator) and the hashes for a accumulator where the proof elements are removed. + /// If at least one returned element doesn't exist in the accumulator, the proof is invalid. + pub(crate) fn calculate_hashes_delete( + &self, + del_hashes: &[(NodeHash, NodeHash)], + num_leaves: u64, + ) -> Result { + // Where all the root hashes that we've calculated will go to. + let total_rows = util::tree_rows(num_leaves); + + // Where all the parent hashes we've calculated in a given row will go to. + let mut calculated_root_hashes = + Vec::<(NodeHash, NodeHash)>::with_capacity(util::num_roots(num_leaves)); + + // the positions that should be passed as a proof + let proof_positions = get_proof_positions(&self.targets, num_leaves, total_rows); + + // As we calculate nodes upwards, it accumulates here + let mut nodes: Vec<_> = self + .targets + .iter() + .copied() + .zip(del_hashes.to_owned()) + .collect(); + + // add the proof positions to the nodes + nodes.extend( + proof_positions + .iter() + .copied() + .zip(self.hashes.iter().copied().map(|hash| (hash, hash))), + ); + + // Nodes must be sorted for finding siblings during hashing + nodes.sort(); + let mut computed = Vec::with_capacity(nodes.len() * 2); + let mut computed_index = 0; + let mut provided_index = 0; + loop { + let Some((next_pos, (next_hash_old, next_hash_new))) = + Self::get_next(&computed, &nodes, &mut computed_index, &mut provided_index) + else { + break; + }; + + if util::is_root_position(next_pos, num_leaves, total_rows) { + calculated_root_hashes.push((next_hash_old, next_hash_new)); + continue; + } + + let sibling = next_pos | 1; + let (sibling_pos, (sibling_hash_old, sibling_hash_new)) = + Self::get_next(&computed, &nodes, &mut computed_index, &mut provided_index) + .ok_or(format!("Missing sibling for {}", next_pos))?; + + if sibling_pos != sibling { + return Err(format!("Missing sibling for {}", next_pos)); + } + + let parent_hash = match (next_hash_new.is_empty(), sibling_hash_new.is_empty()) { + (true, true) => NodeHash::empty(), + (true, false) => sibling_hash_new, + (false, true) => next_hash_new, + (false, false) => NodeHash::parent_hash(&next_hash_new, &sibling_hash_new), + }; + + let parent = util::parent(next_pos, total_rows); + let old_parent_hash = NodeHash::parent_hash(&next_hash_old, &sibling_hash_old); + computed.push((parent, (old_parent_hash, parent_hash))); + } + + // we shouldn't return the hashes in the proof + nodes.extend(computed); + let nodes = nodes + .into_iter() + .map(|(pos, (_, new_hash))| (pos, new_hash)) + .collect(); + Ok((nodes, calculated_root_hashes)) + } + /// This function computes a set of roots from a proof. - /// If some target's hashes are null, then it computes the roots after - /// those targets are deleted. In this context null means [NodeHash::default]. /// - /// It's the caller's responsibility to null out the targets if desired by - /// passing a `NodeHash::empty()` instead of the actual hash. + /// Using the proof, we should be able to calculate a subset of the roots, by hashing the + /// nodes in the path to the root. This function returns the calculated roots and the + /// hashes that were calculated in the process. + /// This differs from `calculate_hashes_delelte` as this one is only used for verifying + /// proofs, it doesn't compute the roots after the deletion, only the roots that are + /// needed for verification (i.e. the current accumulator). pub(crate) fn calculate_hashes( &self, del_hashes: &[NodeHash], num_leaves: u64, - ) -> Result { + ) -> Result { // Where all the root hashes that we've calculated will go to. let total_rows = util::tree_rows(num_leaves); @@ -403,20 +500,14 @@ impl Proof { return Err(format!("Missing sibling for {}", next_pos)); } - let parent_hash = match (next_hash.is_empty(), sibling_hash.is_empty()) { - (true, true) => NodeHash::empty(), - (true, false) => sibling_hash, - (false, true) => next_hash, - (false, false) => NodeHash::parent_hash(&next_hash, &sibling_hash), - }; - + let parent_hash = NodeHash::parent_hash(&next_hash, &sibling_hash); let parent = util::parent(next_pos, total_rows); computed.push((parent, parent_hash)); } // we shouldn't return the hashes in the proof nodes.extend(computed); - nodes.retain(|(pos, _)| !proof_positions.contains(pos)); + nodes.retain(|(pos, _)| proof_positions.binary_search(pos).is_err()); Ok((nodes, calculated_root_hashes)) } @@ -1129,6 +1220,59 @@ mod tests { } } } + + #[test] + fn test_calculate_hashes_delete() { + let preimages = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let hashes = preimages.into_iter().map(hash_from_u8).collect::>(); + + let del_hashes = vec![hashes[0]]; + let proof = vec![ + "4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a", + "9576f4ade6e9bc3a6458b506ce3e4e890df29cb14cb5d3d887672aef55647a2b", + "29590a14c1b09384b94a2c0e94bf821ca75b62eacebc47893397ca88e3bbcbd7", + ]; + + let proof_hashes = proof + .into_iter() + .map(|hash| NodeHash::from_str(hash).unwrap()) + .collect(); + + let p = Proof::new(vec![0], proof_hashes); + let del_hashes = del_hashes + .into_iter() + .map(|hash| (hash, NodeHash::empty())) + .collect::>(); + + let (computed, roots) = p.calculate_hashes_delete(&del_hashes, 8).unwrap(); + let expected_root_old = + NodeHash::from_str("b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42") + .unwrap(); + let expected_root_new = + NodeHash::from_str("726fdd3b432cc59e68487d126e70f0db74a236267f8daeae30b31839a4e7ebed") + .unwrap(); + + let computed_positions = [0_u64, 1, 9, 13, 8, 12, 14].to_vec(); + let computed_hashes = [ + "0000000000000000000000000000000000000000000000000000000000000000", + "4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a", + "9576f4ade6e9bc3a6458b506ce3e4e890df29cb14cb5d3d887672aef55647a2b", + "29590a14c1b09384b94a2c0e94bf821ca75b62eacebc47893397ca88e3bbcbd7", + "4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a", + "2b77298feac78ab51bc5079099a074c6d789bd350442f5079fcba2b3402694e5", + "726fdd3b432cc59e68487d126e70f0db74a236267f8daeae30b31839a4e7ebed", + ] + .iter() + .map(|hash| NodeHash::from_str(hash).unwrap()) + .collect::>(); + let expected_computed: Vec<_> = computed_positions + .into_iter() + .zip(computed_hashes) + .collect(); + assert_eq!(roots, vec![(expected_root_old, expected_root_new)]); + assert_eq!(computed, expected_computed); + } + #[test] fn test_serialize_rtt() { // Tests if the serialized proof can be deserialized again diff --git a/src/accumulator/stump.rs b/src/accumulator/stump.rs index b3dbab7..46613cf 100644 --- a/src/accumulator/stump.rs +++ b/src/accumulator/stump.rs @@ -36,7 +36,7 @@ use serde::Deserialize; use serde::Serialize; use super::node_hash::NodeHash; -use super::proof::EnumeratedTargetsAndHashPosition; +use super::proof::NodesAndRootsOldNew; use super::proof::Proof; use super::util; @@ -103,31 +103,27 @@ impl Stump { del_hashes: &[NodeHash], proof: &Proof, ) -> Result<(Stump, UpdateData), String> { - let mut root_candidates = proof - .calculate_hashes(del_hashes, self.leaves)? - .1 - .into_iter() - .rev() - .peekable(); - - let (intermediate, computed_roots) = self.remove(del_hashes, proof)?; - let mut computed_roots = computed_roots.into_iter().rev(); - + let (intermediate, mut computed_roots) = self.remove(del_hashes, proof)?; let mut new_roots = vec![]; for root in self.roots.iter() { - if let Some(root_candidate) = root_candidates.peek() { - if *root_candidate == *root { - if let Some(new_root) = computed_roots.next() { - new_roots.push(new_root); - root_candidates.next(); - continue; - } + if let Some(pos) = computed_roots.iter().position(|(old, _new)| old == root) { + let (old_root, new_root) = computed_roots.remove(pos); + if old_root == *root { + new_roots.push(new_root); + continue; } } new_roots.push(*root); } + + // If there are still roots to be added, it means that the proof is invalid + // as we should have consumed all the roots. + if !computed_roots.is_empty() { + return Err("Invalid proof".to_string()); + } + let (roots, updated, destroyed) = Stump::add(new_roots, utxos, self.leaves); let new_stump = Stump { @@ -237,14 +233,21 @@ impl Stump { &self, del_hashes: &[NodeHash], proof: &Proof, - ) -> Result { + ) -> Result { if del_hashes.is_empty() { - return Ok((vec![], self.roots.clone())); + return Ok(( + vec![], + self.roots.iter().map(|root| (*root, *root)).collect(), + )); } - let del_hashes = vec![NodeHash::empty(); proof.targets()]; - proof.calculate_hashes(&del_hashes, self.leaves) + let del_hashes = del_hashes + .iter() + .map(|hash| (*hash, NodeHash::empty())) + .collect::>(); + proof.calculate_hashes_delete(&del_hashes, self.leaves) } + /// Adds new leaves into the root fn add( mut roots: Vec, @@ -413,7 +416,9 @@ mod test { assert_eq!(updated.prev_num_leaves, data.leaves); assert_eq!(updated.to_destroy, data.to_destroy); assert_eq!(updated.new_add, new_add); - assert_eq!(updated.new_del, new_del); + for del in new_del.iter() { + assert!(updated.new_del.contains(del)); + } } } #[test] @@ -466,10 +471,10 @@ mod test { let target_hashes = case .target_values - .clone() + .as_ref() .unwrap() .into_iter() - .map(|target| hash_from_u8(target as u8)) + .map(|target| hash_from_u8(*target as u8)) .collect::>(); let proof_hashes = case @@ -491,7 +496,6 @@ mod test { .modify(&leaf_hashes, &[], &Proof::default()) .expect("This stump is valid"); let (stump, _) = stump.modify(&[], &target_hashes, &proof).unwrap(); - assert_eq!(stump.roots, roots); }