diff --git a/src/accumulator/proof.rs b/src/accumulator/proof.rs index c0d8836..86771c3 100644 --- a/src/accumulator/proof.rs +++ b/src/accumulator/proof.rs @@ -94,10 +94,18 @@ pub struct Proof { /// ``` pub hashes: Vec, } -/// We often need to return the targets paired with hashes, and the proof position. -/// Even not using full qualifications, it gets long and complex, and clippy doesn't like -/// it. This type alias helps with that. -pub(crate) type EnumeratedTargetsAndHashPosition = (Vec<(u64, NodeHash)>, Vec); + +// We often need to return the targets paired with hashes, and the proof position. +// Even not using full qualifications, it gets long and complex, and clippy doesn't like +// it. These type alias helps with that. + +/// This alias is used when we need to return the nodes and roots for a proof +/// if we are not concerned with deleting those elements. +pub(crate) type NodesAndRootsCurrent = (Vec<(u64, NodeHash)>, Vec); +/// This is used when we need to return the nodes and roots for a proof +/// if we are concerned with deleting those elements. The difference is that +/// we need to retun the old and updatated roots in the accumulator. +pub(crate) type NodesAndRootsOldNew = (Vec<(u64, NodeHash)>, Vec<(NodeHash, NodeHash)>); impl Proof { /// Creates a proof from a vector of target and hashes. @@ -340,17 +348,106 @@ impl Proof { self.targets.len() } + /// This function computes a set of roots from the proof + /// + /// It will compute all roots that contains elements in the proof, by hasing the nodes + /// in the path to the root. This function returns the calculated roots and the hashes + /// that were calculated in the process. + /// This function is used for updating the accumulator **and** verifying proofs. It returns + /// the roots computed from the proof (that should be equal to some roots in the present + /// accumulator) and the hashes for a accumulator where the proof elements are removed. + /// If at least one returned element doesn't exist in the accumulator, the proof is invalid. + pub(crate) fn calculate_hashes_delete( + &self, + del_hashes: &[(NodeHash, NodeHash)], + num_leaves: u64, + ) -> Result { + // Where all the root hashes that we've calculated will go to. + let total_rows = util::tree_rows(num_leaves); + + // Where all the parent hashes we've calculated in a given row will go to. + let mut calculated_root_hashes = + Vec::<(NodeHash, NodeHash)>::with_capacity(util::num_roots(num_leaves)); + + // the positions that should be passed as a proof + let proof_positions = get_proof_positions(&self.targets, num_leaves, total_rows); + + // As we calculate nodes upwards, it accumulates here + let mut nodes: Vec<_> = self + .targets + .iter() + .copied() + .zip(del_hashes.to_owned()) + .collect(); + + // add the proof positions to the nodes + nodes.extend( + proof_positions + .iter() + .copied() + .zip(self.hashes.iter().copied().map(|hash| (hash, hash))), + ); + + // Nodes must be sorted for finding siblings during hashing + nodes.sort(); + let mut computed = Vec::with_capacity(nodes.len() * 2); + let mut computed_index = 0; + let mut provided_index = 0; + loop { + let Some((next_pos, (next_hash_old, next_hash_new))) = + Self::get_next(&computed, &nodes, &mut computed_index, &mut provided_index) + else { + break; + }; + + if util::is_root_position(next_pos, num_leaves, total_rows) { + calculated_root_hashes.push((next_hash_old, next_hash_new)); + continue; + } + + let sibling = next_pos | 1; + let (sibling_pos, (sibling_hash_old, sibling_hash_new)) = + Self::get_next(&computed, &nodes, &mut computed_index, &mut provided_index) + .ok_or(format!("Missing sibling for {}", next_pos))?; + + if sibling_pos != sibling { + return Err(format!("Missing sibling for {}", next_pos)); + } + + let parent_hash = match (next_hash_new.is_empty(), sibling_hash_new.is_empty()) { + (true, true) => NodeHash::empty(), + (true, false) => sibling_hash_new, + (false, true) => next_hash_new, + (false, false) => NodeHash::parent_hash(&next_hash_new, &sibling_hash_new), + }; + + let parent = util::parent(next_pos, total_rows); + let old_parent_hash = NodeHash::parent_hash(&next_hash_old, &sibling_hash_old); + computed.push((parent, (old_parent_hash, parent_hash))); + } + + // we shouldn't return the hashes in the proof + nodes.extend(computed); + let nodes = nodes + .into_iter() + .map(|(pos, (_, new_hash))| (pos, new_hash)) + .collect(); + Ok((nodes, calculated_root_hashes)) + } + /// This function computes a set of roots from a proof. - /// If some target's hashes are null, then it computes the roots after - /// those targets are deleted. In this context null means [NodeHash::default]. /// - /// It's the caller's responsibility to null out the targets if desired by - /// passing a `NodeHash::empty()` instead of the actual hash. + /// Using the proof, we should be able to calculate a subset of the roots, by hashing the + /// nodes in the path to the root. This function returns the calculated roots and the + /// hashes that were calculated in the process. + /// This differs from `calculate_hashes_delelte` as this one is only used for verifying + /// proofs, it doesn't compute the roots after the deletion, only the roots that are + /// needed for verification (i.e. the current accumulator). pub(crate) fn calculate_hashes( &self, del_hashes: &[NodeHash], num_leaves: u64, - ) -> Result { + ) -> Result { // Where all the root hashes that we've calculated will go to. let total_rows = util::tree_rows(num_leaves); @@ -379,51 +476,71 @@ impl Proof { // Nodes must be sorted for finding siblings during hashing nodes.sort(); - let mut i = 0; - while i < nodes.len() { - let (pos1, hash1) = nodes[i]; - let next_to_prove = util::parent(pos1, total_rows); - - // If the current position is a root, we add that to our result and don't go any further - if util::is_root_position(pos1, num_leaves, total_rows) { - calculated_root_hashes.push(hash1); - i += 1; + let mut computed = Vec::with_capacity(nodes.len() * 2); + let mut computed_index = 0; + let mut provided_index = 0; + loop { + let Some((next_pos, next_hash)) = + Self::get_next(&computed, &nodes, &mut computed_index, &mut provided_index) + else { + break; + }; + + if util::is_root_position(next_pos, num_leaves, total_rows) { + calculated_root_hashes.push(next_hash); continue; } - let Some((pos2, hash2)) = nodes.get(i + 1) else { - return Err(format!( - "Proof is too short. Expected at least {} elements, got {}", - i + 1, - nodes.len() - )); - }; + let sibling = next_pos | 1; + let (sibling_pos, sibling_hash) = + Self::get_next(&computed, &nodes, &mut computed_index, &mut provided_index) + .ok_or(format!("Missing sibling for {}", next_pos))?; - if pos1 != util::left_sibling(*pos2) { - return Err(format!( - "Invalid proof. Expected left sibling of {} to be {}, got {}", - pos2, - util::left_sibling(*pos2), - pos1 - )); + if sibling_pos != sibling { + return Err(format!("Missing sibling for {}", next_pos)); } - let parent_hash = match (hash1.is_empty(), hash2.is_empty()) { - (true, true) => NodeHash::empty(), - (true, false) => *hash2, - (false, true) => hash1, - (false, false) => NodeHash::parent_hash(&hash1, hash2), - }; - - Self::sorted_push(&mut nodes, (next_to_prove, parent_hash)); - i += 2; + let parent_hash = NodeHash::parent_hash(&next_hash, &sibling_hash); + let parent = util::parent(next_pos, total_rows); + computed.push((parent, parent_hash)); } // we shouldn't return the hashes in the proof - nodes.retain(|(pos, _)| !proof_positions.contains(pos)); - + nodes.extend(computed); + nodes.retain(|(pos, _)| proof_positions.binary_search(pos).is_err()); Ok((nodes, calculated_root_hashes)) } + + fn get_next( + computed: &[(u64, T)], + provided: &[(u64, T)], + computed_pos: &mut usize, + provided_pos: &mut usize, + ) -> Option<(u64, T)> { + let last_computed = computed.get(*computed_pos); + let last_provided = provided.get(*provided_pos); + + match (last_computed, last_provided) { + (Some((pos1, hashes1)), Some((pos2, hashes2))) => { + if pos1 < pos2 { + *computed_pos += 1; + Some((*pos1, *hashes1)) + } else { + *provided_pos += 1; + Some((*pos2, *hashes2)) + } + } + (Some(node), None) => { + *computed_pos += 1; + Some(*node) + } + (None, Some(node)) => { + *provided_pos += 1; + Some(*node) + } + (None, None) => None, + } + } /// Uses the data passed in to update a proof, creating a valid proof for a given /// set of targets, after an update. This is useful for caching UTXOs. You grab a proof /// for it once and then keep updating it every block, yielding an always valid proof @@ -699,12 +816,6 @@ impl Proof { new_positions.sort(); Ok(new_positions) } - fn sorted_push(nodes: &mut Vec<(u64, NodeHash)>, to_add: (u64, NodeHash)) { - let pos = nodes - .binary_search_by(|(pos, _)| pos.cmp(&to_add.0)) - .unwrap_or_else(|x| x); - nodes.insert(pos, to_add); - } } #[cfg(test)] @@ -854,6 +965,37 @@ mod tests { assert_eq!(cached_hashes, expected_cached_hashes); } } + + #[test] + fn test_get_next() { + use super::Proof; + let computed = vec![(1, NodeHash::empty()), (3, NodeHash::empty())]; + let provided = vec![(2, NodeHash::empty()), (4, NodeHash::empty())]; + let mut computed_pos = 0; + let mut provided_pos = 0; + + assert_eq!( + Proof::get_next(&computed, &provided, &mut computed_pos, &mut provided_pos), + Some((1, NodeHash::empty())) + ); + assert_eq!( + Proof::get_next(&computed, &provided, &mut computed_pos, &mut provided_pos), + Some((2, NodeHash::empty())) + ); + assert_eq!( + Proof::get_next(&computed, &provided, &mut computed_pos, &mut provided_pos), + Some((3, NodeHash::empty())) + ); + assert_eq!( + Proof::get_next(&computed, &provided, &mut computed_pos, &mut provided_pos), + Some((4, NodeHash::empty())) + ); + assert_eq!( + Proof::get_next(&computed, &provided, &mut computed_pos, &mut provided_pos), + None + ); + } + #[test] fn test_calc_next_positions() { use super::Proof; @@ -1078,6 +1220,59 @@ mod tests { } } } + + #[test] + fn test_calculate_hashes_delete() { + let preimages = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let hashes = preimages.into_iter().map(hash_from_u8).collect::>(); + + let del_hashes = vec![hashes[0]]; + let proof = vec![ + "4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a", + "9576f4ade6e9bc3a6458b506ce3e4e890df29cb14cb5d3d887672aef55647a2b", + "29590a14c1b09384b94a2c0e94bf821ca75b62eacebc47893397ca88e3bbcbd7", + ]; + + let proof_hashes = proof + .into_iter() + .map(|hash| NodeHash::from_str(hash).unwrap()) + .collect(); + + let p = Proof::new(vec![0], proof_hashes); + let del_hashes = del_hashes + .into_iter() + .map(|hash| (hash, NodeHash::empty())) + .collect::>(); + + let (computed, roots) = p.calculate_hashes_delete(&del_hashes, 8).unwrap(); + let expected_root_old = + NodeHash::from_str("b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42") + .unwrap(); + let expected_root_new = + NodeHash::from_str("726fdd3b432cc59e68487d126e70f0db74a236267f8daeae30b31839a4e7ebed") + .unwrap(); + + let computed_positions = [0_u64, 1, 9, 13, 8, 12, 14].to_vec(); + let computed_hashes = [ + "0000000000000000000000000000000000000000000000000000000000000000", + "4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a", + "9576f4ade6e9bc3a6458b506ce3e4e890df29cb14cb5d3d887672aef55647a2b", + "29590a14c1b09384b94a2c0e94bf821ca75b62eacebc47893397ca88e3bbcbd7", + "4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a", + "2b77298feac78ab51bc5079099a074c6d789bd350442f5079fcba2b3402694e5", + "726fdd3b432cc59e68487d126e70f0db74a236267f8daeae30b31839a4e7ebed", + ] + .iter() + .map(|hash| NodeHash::from_str(hash).unwrap()) + .collect::>(); + let expected_computed: Vec<_> = computed_positions + .into_iter() + .zip(computed_hashes) + .collect(); + assert_eq!(roots, vec![(expected_root_old, expected_root_new)]); + assert_eq!(computed, expected_computed); + } + #[test] fn test_serialize_rtt() { // Tests if the serialized proof can be deserialized again diff --git a/src/accumulator/stump.rs b/src/accumulator/stump.rs index b3dbab7..46613cf 100644 --- a/src/accumulator/stump.rs +++ b/src/accumulator/stump.rs @@ -36,7 +36,7 @@ use serde::Deserialize; use serde::Serialize; use super::node_hash::NodeHash; -use super::proof::EnumeratedTargetsAndHashPosition; +use super::proof::NodesAndRootsOldNew; use super::proof::Proof; use super::util; @@ -103,31 +103,27 @@ impl Stump { del_hashes: &[NodeHash], proof: &Proof, ) -> Result<(Stump, UpdateData), String> { - let mut root_candidates = proof - .calculate_hashes(del_hashes, self.leaves)? - .1 - .into_iter() - .rev() - .peekable(); - - let (intermediate, computed_roots) = self.remove(del_hashes, proof)?; - let mut computed_roots = computed_roots.into_iter().rev(); - + let (intermediate, mut computed_roots) = self.remove(del_hashes, proof)?; let mut new_roots = vec![]; for root in self.roots.iter() { - if let Some(root_candidate) = root_candidates.peek() { - if *root_candidate == *root { - if let Some(new_root) = computed_roots.next() { - new_roots.push(new_root); - root_candidates.next(); - continue; - } + if let Some(pos) = computed_roots.iter().position(|(old, _new)| old == root) { + let (old_root, new_root) = computed_roots.remove(pos); + if old_root == *root { + new_roots.push(new_root); + continue; } } new_roots.push(*root); } + + // If there are still roots to be added, it means that the proof is invalid + // as we should have consumed all the roots. + if !computed_roots.is_empty() { + return Err("Invalid proof".to_string()); + } + let (roots, updated, destroyed) = Stump::add(new_roots, utxos, self.leaves); let new_stump = Stump { @@ -237,14 +233,21 @@ impl Stump { &self, del_hashes: &[NodeHash], proof: &Proof, - ) -> Result { + ) -> Result { if del_hashes.is_empty() { - return Ok((vec![], self.roots.clone())); + return Ok(( + vec![], + self.roots.iter().map(|root| (*root, *root)).collect(), + )); } - let del_hashes = vec![NodeHash::empty(); proof.targets()]; - proof.calculate_hashes(&del_hashes, self.leaves) + let del_hashes = del_hashes + .iter() + .map(|hash| (*hash, NodeHash::empty())) + .collect::>(); + proof.calculate_hashes_delete(&del_hashes, self.leaves) } + /// Adds new leaves into the root fn add( mut roots: Vec, @@ -413,7 +416,9 @@ mod test { assert_eq!(updated.prev_num_leaves, data.leaves); assert_eq!(updated.to_destroy, data.to_destroy); assert_eq!(updated.new_add, new_add); - assert_eq!(updated.new_del, new_del); + for del in new_del.iter() { + assert!(updated.new_del.contains(del)); + } } } #[test] @@ -466,10 +471,10 @@ mod test { let target_hashes = case .target_values - .clone() + .as_ref() .unwrap() .into_iter() - .map(|target| hash_from_u8(target as u8)) + .map(|target| hash_from_u8(*target as u8)) .collect::>(); let proof_hashes = case @@ -491,7 +496,6 @@ mod test { .modify(&leaf_hashes, &[], &Proof::default()) .expect("This stump is valid"); let (stump, _) = stump.modify(&[], &target_hashes, &proof).unwrap(); - assert_eq!(stump.roots, roots); } diff --git a/src/accumulator/util.rs b/src/accumulator/util.rs index 4848248..b9a90b9 100644 --- a/src/accumulator/util.rs +++ b/src/accumulator/util.rs @@ -275,7 +275,7 @@ pub fn get_proof_positions(targets: &[u64], num_leaves: u64, forest_rows: u8) -> for row in 0..=forest_rows { let mut row_targets = computed_positions .iter() - .copied() + .cloned() .filter(|x| super::util::detect_row(*x, forest_rows) == row) .collect::>() .into_iter() @@ -283,11 +283,9 @@ pub fn get_proof_positions(targets: &[u64], num_leaves: u64, forest_rows: u8) -> while let Some(node) = row_targets.next() { if is_root_position(node, num_leaves, forest_rows) { - let idx = computed_positions.iter().position(|x| node == *x).unwrap(); - - computed_positions.remove(idx); continue; } + if let Some(next) = row_targets.peek() { if !is_sibling(node, *next) { proof_positions.push(node ^ 1); @@ -299,8 +297,9 @@ pub fn get_proof_positions(targets: &[u64], num_leaves: u64, forest_rows: u8) -> } computed_positions.push(parent(node, forest_rows)); - computed_positions.sort(); } + + computed_positions.sort(); } proof_positions