Skip to content

Commit

Permalink
perf: remove some major bottlenecks in deletion
Browse files Browse the repository at this point in the history
Our deleting code had some major bottlenecks, like calling calculate
hashes twice (one for the old roots and one for updated ones) and
filtering the nodes that are in the proof, which was causing a linear
search on proof_positins **for each element in nodes**.

This commit fixes that by:
  - Creating a calculate_nodes_delete function that computes in parallel
    both updated and current roots, so we don't need to loop twice. In
    the future we could even use vectorized computation of sha512_256
    two further speed things up.
  - Don't filter proof positions from the nodes vector in
    calculate_nodes_delete. This will waste a little memory as nodes in
    the proof are returned, but won't cause a new allocation as we are
    returning nodes as is.
  - calculate_hashes now won't delete, so it gets simplified to only
    work when checking or updating proofs.

After running the new code for two days with floresta, the speedup is
clear, with calculate_nodes_delete taking less than 2% of the CPU time
for block validation. Before this path it was using >40%.
  • Loading branch information
Davidson-Souza committed Jul 29, 2024
1 parent eb22a17 commit ccf72e9
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 45 deletions.
204 changes: 184 additions & 20 deletions src/accumulator/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,18 @@ pub struct Proof {
/// ```
pub hashes: Vec<NodeHash>,
}
/// We often need to return the targets paired with hashes, and the proof position.
/// Even not using full qualifications, it gets long and complex, and clippy doesn't like
/// it. This type alias helps with that.
pub(crate) type EnumeratedTargetsAndHashPosition = (Vec<(u64, NodeHash)>, Vec<NodeHash>);

// We often need to return the targets paired with hashes, and the proof position.
// Even not using full qualifications, it gets long and complex, and clippy doesn't like
// it. These type alias helps with that.

/// This alias is used when we need to return the nodes and roots for a proof
/// if we are not concerned with deleting those elements.
pub(crate) type NodesAndRootsCurrent = (Vec<(u64, NodeHash)>, Vec<NodeHash>);
/// This is used when we need to return the nodes and roots for a proof
/// if we are concerned with deleting those elements. The difference is that
/// we need to retun the old and updatated roots in the accumulator.
pub(crate) type NodesAndRootsOldNew = (Vec<(u64, NodeHash)>, Vec<(NodeHash, NodeHash)>);

impl Proof {
/// Creates a proof from a vector of target and hashes.
Expand Down Expand Up @@ -340,17 +348,125 @@ impl Proof {
self.targets.len()
}

/// This function computes a set of roots from the proof
///
/// It will compute all roots that contains elements in the proof, by hasing the nodes
/// in the path to the root. This function returns the calculated roots and the hashes
/// that were calculated in the process.
/// This function is used for updating the accumulator **and** verifying proofs. It returns
/// the roots computed from the proof (that should be equal to some roots in the present
/// accumulator) and the hashes for a accumulator where the proof elements are removed.
/// If at least one returned element doesn't exist in the accumulator, the proof is invalid.
pub(crate) fn calculate_hashes_delete(
&self,
del_hashes: &[(NodeHash, NodeHash)],
num_leaves: u64,
) -> Result<NodesAndRootsOldNew, String> {
// Where all the root hashes that we've calculated will go to.
let total_rows = util::tree_rows(num_leaves);

// Where all the parent hashes we've calculated in a given row will go to.
let mut calculated_root_hashes =
Vec::<(NodeHash, NodeHash)>::with_capacity(util::num_roots(num_leaves));

// the positions that should be passed as a proof
let proof_positions = get_proof_positions(&self.targets, num_leaves, total_rows);

// As we calculate nodes upwards, it accumulates here
let mut nodes: Vec<_> = self
.targets
.iter()
.copied()
.zip(del_hashes.to_owned())
.collect();

// add the proof positions to the nodes
nodes.extend(
proof_positions
.iter()
.copied()
.zip(self.hashes.iter().copied().map(|hash| (hash, hash))),
);

// Nodes must be sorted for finding siblings during hashing
nodes.sort();
let mut computed = Vec::with_capacity(nodes.len() * 2);
let mut computed_index = 0;
let mut provided_index = 0;
loop {
let Some(next_vec) = Self::get_next(&computed, &nodes, computed_index, provided_index)
else {
break;
};

let (next_pos, (next_hash_old, next_hash_new)) = match next_vec {
true => {
computed_index += 1;
computed[computed_index - 1]
}
false => {
provided_index += 1;
nodes[provided_index - 1]
}
};

if util::is_root_position(next_pos, num_leaves, total_rows) {
calculated_root_hashes.push((next_hash_old, next_hash_new));
continue;
}

let sibling = next_pos | 1;
let sibling_vec = Self::get_next(&computed, &nodes, computed_index, provided_index)
.ok_or(format!("Missing sibling for {}", next_pos))?;
let (sibling_pos, (sibling_hash_old, sibling_hash_new)) = match sibling_vec {
true => {
computed_index += 1;
computed[computed_index - 1]
}
false => {
provided_index += 1;
nodes[provided_index - 1]
}
};

if sibling_pos != sibling {
return Err(format!("Missing sibling for {}", next_pos));
}

let parent_hash = match (next_hash_new.is_empty(), sibling_hash_new.is_empty()) {
(true, true) => NodeHash::empty(),
(true, false) => sibling_hash_new,
(false, true) => next_hash_new,
(false, false) => NodeHash::parent_hash(&next_hash_new, &sibling_hash_new),
};

let parent = util::parent(next_pos, total_rows);
let old_parent_hash = NodeHash::parent_hash(&next_hash_old, &sibling_hash_old);
computed.push((parent, (old_parent_hash, parent_hash)));
}

// we shouldn't return the hashes in the proof
nodes.extend(computed);
let nodes = nodes
.into_iter()
.map(|(pos, (_, new_hash))| (pos, new_hash))
.collect();
Ok((nodes, calculated_root_hashes))
}

/// This function computes a set of roots from a proof.
/// If some target's hashes are null, then it computes the roots after
/// those targets are deleted. In this context null means [NodeHash::default].
///
/// It's the caller's responsibility to null out the targets if desired by
/// passing a `NodeHash::empty()` instead of the actual hash.
/// Using the proof, we should be able to calculate a subset of the roots, by hashing the
/// nodes in the path to the root. This function returns the calculated roots and the
/// hashes that were calculated in the process.
/// This differs from `calculate_hashes_delelte` as this one is only used for verifying
/// proofs, it doesn't compute the roots after the deletion, only the roots that are
/// needed for verification (i.e. the current accumulator).
pub(crate) fn calculate_hashes(
&self,
del_hashes: &[NodeHash],
num_leaves: u64,
) -> Result<EnumeratedTargetsAndHashPosition, String> {
) -> Result<NodesAndRootsCurrent, String> {
// Where all the root hashes that we've calculated will go to.
let total_rows = util::tree_rows(num_leaves);

Expand Down Expand Up @@ -407,6 +523,7 @@ impl Proof {
let sibling = next_pos | 1;
let sibling_vec = Self::get_next(&computed, &nodes, computed_index, provided_index)
.ok_or(format!("Missing sibling for {}", next_pos))?;

let (sibling_pos, sibling_hash) = match sibling_vec {
true => {
computed_index += 1;
Expand All @@ -422,26 +539,20 @@ impl Proof {
return Err(format!("Missing sibling for {}", next_pos));
}

let parent_hash = match (next_hash.is_empty(), sibling_hash.is_empty()) {
(true, true) => NodeHash::empty(),
(true, false) => sibling_hash,
(false, true) => next_hash,
(false, false) => NodeHash::parent_hash(&next_hash, &sibling_hash),
};

let parent_hash = NodeHash::parent_hash(&next_hash, &sibling_hash);
let parent = util::parent(next_pos, total_rows);
computed.push((parent, parent_hash));
}

// we shouldn't return the hashes in the proof
nodes.extend(computed);
nodes.retain(|(pos, _)| !proof_positions.contains(pos));
nodes.retain(|(pos, _)| proof_positions.binary_search(pos).is_err());
Ok((nodes, calculated_root_hashes))
}

fn get_next(
computed: &[(u64, NodeHash)],
provided: &[(u64, NodeHash)],
fn get_next<T>(
computed: &[(u64, T)],
provided: &[(u64, T)],
computed_pos: usize,
provided_pos: usize,
) -> Option<bool> {
Expand Down Expand Up @@ -1123,6 +1234,59 @@ mod tests {
}
}
}

#[test]
fn test_calculate_hashes_delete() {
let preimages = vec![0, 1, 2, 3, 4, 5, 6, 7];
let hashes = preimages.into_iter().map(hash_from_u8).collect::<Vec<_>>();

let del_hashes = vec![hashes[0]];
let proof = vec![
"4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a",
"9576f4ade6e9bc3a6458b506ce3e4e890df29cb14cb5d3d887672aef55647a2b",
"29590a14c1b09384b94a2c0e94bf821ca75b62eacebc47893397ca88e3bbcbd7",
];

let proof_hashes = proof
.into_iter()
.map(|hash| NodeHash::from_str(hash).unwrap())
.collect();

let p = Proof::new(vec![0], proof_hashes);
let del_hashes = del_hashes
.into_iter()
.map(|hash| (hash, NodeHash::empty()))
.collect::<Vec<_>>();

let (computed, roots) = p.calculate_hashes_delete(&del_hashes, 8).unwrap();
let expected_root_old =
NodeHash::from_str("b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42")
.unwrap();
let expected_root_new =
NodeHash::from_str("726fdd3b432cc59e68487d126e70f0db74a236267f8daeae30b31839a4e7ebed")
.unwrap();

let computed_positions = [0_u64, 1, 9, 13, 8, 12, 14].to_vec();
let computed_hashes = [
"0000000000000000000000000000000000000000000000000000000000000000",
"4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a",
"9576f4ade6e9bc3a6458b506ce3e4e890df29cb14cb5d3d887672aef55647a2b",
"29590a14c1b09384b94a2c0e94bf821ca75b62eacebc47893397ca88e3bbcbd7",
"4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a",
"2b77298feac78ab51bc5079099a074c6d789bd350442f5079fcba2b3402694e5",
"726fdd3b432cc59e68487d126e70f0db74a236267f8daeae30b31839a4e7ebed",
]
.iter()
.map(|hash| NodeHash::from_str(hash).unwrap())
.collect::<Vec<_>>();
let expected_computed: Vec<_> = computed_positions
.into_iter()
.zip(computed_hashes)
.collect();
assert_eq!(roots, vec![(expected_root_old, expected_root_new)]);
assert_eq!(computed, expected_computed);
}

#[test]
fn test_serialize_rtt() {
// Tests if the serialized proof can be deserialized again
Expand Down
60 changes: 35 additions & 25 deletions src/accumulator/stump.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use serde::Deserialize;
use serde::Serialize;

use super::node_hash::NodeHash;
use super::proof::EnumeratedTargetsAndHashPosition;
use super::proof::NodesAndRootsOldNew;
use super::proof::Proof;
use super::util;

Expand Down Expand Up @@ -103,31 +103,27 @@ impl Stump {
del_hashes: &[NodeHash],
proof: &Proof,
) -> Result<(Stump, UpdateData), String> {
let mut root_candidates = proof
.calculate_hashes(del_hashes, self.leaves)?
.1
.into_iter()
.rev()
.peekable();

let (intermediate, computed_roots) = self.remove(del_hashes, proof)?;
let mut computed_roots = computed_roots.into_iter().rev();

let (intermediate, mut computed_roots) = self.remove(del_hashes, proof)?;
let mut new_roots = vec![];

for root in self.roots.iter() {
if let Some(root_candidate) = root_candidates.peek() {
if *root_candidate == *root {
if let Some(new_root) = computed_roots.next() {
new_roots.push(new_root);
root_candidates.next();
continue;
}
if let Some(pos) = computed_roots.iter().position(|(old, _new)| old == root) {
let (old_root, new_root) = computed_roots.remove(pos);
if old_root == *root {
new_roots.push(new_root);
continue;
}
}

new_roots.push(*root);
}

// If there are still roots to be added, it means that the proof is invalid
// as we should have consumed all the roots.
if !computed_roots.is_empty() {
return Err("Invalid proof".to_string());
}

let (roots, updated, destroyed) = Stump::add(new_roots, utxos, self.leaves);

let new_stump = Stump {
Expand Down Expand Up @@ -237,14 +233,21 @@ impl Stump {
&self,
del_hashes: &[NodeHash],
proof: &Proof,
) -> Result<EnumeratedTargetsAndHashPosition, String> {
) -> Result<NodesAndRootsOldNew, String> {
if del_hashes.is_empty() {
return Ok((vec![], self.roots.clone()));
return Ok((
vec![],
self.roots.iter().map(|root| (*root, *root)).collect(),
));
}

let del_hashes = vec![NodeHash::empty(); proof.targets()];
proof.calculate_hashes(&del_hashes, self.leaves)
let del_hashes = del_hashes
.iter()
.map(|hash| (*hash, NodeHash::empty()))
.collect::<Vec<_>>();
proof.calculate_hashes_delete(&del_hashes, self.leaves)
}

/// Adds new leaves into the root
fn add(
mut roots: Vec<NodeHash>,
Expand Down Expand Up @@ -413,7 +416,9 @@ mod test {
assert_eq!(updated.prev_num_leaves, data.leaves);
assert_eq!(updated.to_destroy, data.to_destroy);
assert_eq!(updated.new_add, new_add);
assert_eq!(updated.new_del, new_del);
for del in new_del.iter() {
assert!(updated.new_del.contains(del));
}
}
}
#[test]
Expand Down Expand Up @@ -460,6 +465,7 @@ mod test {
fn run_case_with_deletion(case: TestCase) {
let leaf_hashes = case
.leaf_preimages
.clone()
.into_iter()
.map(hash_from_u8)
.collect::<Vec<_>>();
Expand All @@ -474,23 +480,27 @@ mod test {

let proof_hashes = case
.proofhashes
.clone()
.unwrap_or_default()
.into_iter()
.map(|hash| NodeHash::from_str(hash.as_str()).expect("Test case hashes are valid"))
.collect::<Vec<_>>();

let proof = Proof::new(case.target_values.unwrap(), proof_hashes);
let proof = Proof::new(case.target_values.clone().unwrap(), proof_hashes);

let roots = case
.expected_roots
.clone()
.into_iter()
.map(|hash| NodeHash::from_str(hash.as_str()).expect("Test case hashes are valid"))
.collect::<Vec<NodeHash>>();

let (stump, _) = Stump::new()
.modify(&leaf_hashes, &[], &Proof::default())
.expect("This stump is valid");
let (stump, _) = stump.modify(&[], &target_hashes, &proof).unwrap();
let (stump, _) = stump
.modify(&[], &target_hashes, &proof)
.expect(&format!("case {:?}", case));

assert_eq!(stump.roots, roots);
}
Expand Down

0 comments on commit ccf72e9

Please sign in to comment.