Skip to content

Commit

Permalink
Merge pull request #54 from Davidson-Souza/feat/calculate-hashes-delete
Browse files Browse the repository at this point in the history
perf: give a major rework on how our deleting works
  • Loading branch information
Davidson-Souza authored Jul 30, 2024
2 parents aebaf2f + 3e2564d commit e140c8a
Show file tree
Hide file tree
Showing 3 changed files with 278 additions and 80 deletions.
293 changes: 244 additions & 49 deletions src/accumulator/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,18 @@ pub struct Proof {
/// ```
pub hashes: Vec<NodeHash>,
}
/// We often need to return the targets paired with hashes, and the proof position.
/// Even not using full qualifications, it gets long and complex, and clippy doesn't like
/// it. This type alias helps with that.
pub(crate) type EnumeratedTargetsAndHashPosition = (Vec<(u64, NodeHash)>, Vec<NodeHash>);

// We often need to return the targets paired with hashes, and the proof position.
// Even not using full qualifications, it gets long and complex, and clippy doesn't like
// it. These type alias helps with that.

/// This alias is used when we need to return the nodes and roots for a proof
/// if we are not concerned with deleting those elements.
pub(crate) type NodesAndRootsCurrent = (Vec<(u64, NodeHash)>, Vec<NodeHash>);
/// This is used when we need to return the nodes and roots for a proof
/// if we are concerned with deleting those elements. The difference is that
/// we need to retun the old and updatated roots in the accumulator.
pub(crate) type NodesAndRootsOldNew = (Vec<(u64, NodeHash)>, Vec<(NodeHash, NodeHash)>);

impl Proof {
/// Creates a proof from a vector of target and hashes.
Expand Down Expand Up @@ -340,17 +348,106 @@ impl Proof {
self.targets.len()
}

/// This function computes a set of roots from the proof
///
/// It will compute all roots that contains elements in the proof, by hasing the nodes
/// in the path to the root. This function returns the calculated roots and the hashes
/// that were calculated in the process.
/// This function is used for updating the accumulator **and** verifying proofs. It returns
/// the roots computed from the proof (that should be equal to some roots in the present
/// accumulator) and the hashes for a accumulator where the proof elements are removed.
/// If at least one returned element doesn't exist in the accumulator, the proof is invalid.
pub(crate) fn calculate_hashes_delete(
&self,
del_hashes: &[(NodeHash, NodeHash)],
num_leaves: u64,
) -> Result<NodesAndRootsOldNew, String> {
// Where all the root hashes that we've calculated will go to.
let total_rows = util::tree_rows(num_leaves);

// Where all the parent hashes we've calculated in a given row will go to.
let mut calculated_root_hashes =
Vec::<(NodeHash, NodeHash)>::with_capacity(util::num_roots(num_leaves));

// the positions that should be passed as a proof
let proof_positions = get_proof_positions(&self.targets, num_leaves, total_rows);

// As we calculate nodes upwards, it accumulates here
let mut nodes: Vec<_> = self
.targets
.iter()
.copied()
.zip(del_hashes.to_owned())
.collect();

// add the proof positions to the nodes
nodes.extend(
proof_positions
.iter()
.copied()
.zip(self.hashes.iter().copied().map(|hash| (hash, hash))),
);

// Nodes must be sorted for finding siblings during hashing
nodes.sort();
let mut computed = Vec::with_capacity(nodes.len() * 2);
let mut computed_index = 0;
let mut provided_index = 0;
loop {
let Some((next_pos, (next_hash_old, next_hash_new))) =
Self::get_next(&computed, &nodes, &mut computed_index, &mut provided_index)
else {
break;
};

if util::is_root_position(next_pos, num_leaves, total_rows) {
calculated_root_hashes.push((next_hash_old, next_hash_new));
continue;
}

let sibling = next_pos | 1;
let (sibling_pos, (sibling_hash_old, sibling_hash_new)) =
Self::get_next(&computed, &nodes, &mut computed_index, &mut provided_index)
.ok_or(format!("Missing sibling for {}", next_pos))?;

if sibling_pos != sibling {
return Err(format!("Missing sibling for {}", next_pos));
}

let parent_hash = match (next_hash_new.is_empty(), sibling_hash_new.is_empty()) {
(true, true) => NodeHash::empty(),
(true, false) => sibling_hash_new,
(false, true) => next_hash_new,
(false, false) => NodeHash::parent_hash(&next_hash_new, &sibling_hash_new),
};

let parent = util::parent(next_pos, total_rows);
let old_parent_hash = NodeHash::parent_hash(&next_hash_old, &sibling_hash_old);
computed.push((parent, (old_parent_hash, parent_hash)));
}

// we shouldn't return the hashes in the proof
nodes.extend(computed);
let nodes = nodes
.into_iter()
.map(|(pos, (_, new_hash))| (pos, new_hash))
.collect();
Ok((nodes, calculated_root_hashes))
}

/// This function computes a set of roots from a proof.
/// If some target's hashes are null, then it computes the roots after
/// those targets are deleted. In this context null means [NodeHash::default].
///
/// It's the caller's responsibility to null out the targets if desired by
/// passing a `NodeHash::empty()` instead of the actual hash.
/// Using the proof, we should be able to calculate a subset of the roots, by hashing the
/// nodes in the path to the root. This function returns the calculated roots and the
/// hashes that were calculated in the process.
/// This differs from `calculate_hashes_delelte` as this one is only used for verifying
/// proofs, it doesn't compute the roots after the deletion, only the roots that are
/// needed for verification (i.e. the current accumulator).
pub(crate) fn calculate_hashes(
&self,
del_hashes: &[NodeHash],
num_leaves: u64,
) -> Result<EnumeratedTargetsAndHashPosition, String> {
) -> Result<NodesAndRootsCurrent, String> {
// Where all the root hashes that we've calculated will go to.
let total_rows = util::tree_rows(num_leaves);

Expand Down Expand Up @@ -379,51 +476,71 @@ impl Proof {

// Nodes must be sorted for finding siblings during hashing
nodes.sort();
let mut i = 0;
while i < nodes.len() {
let (pos1, hash1) = nodes[i];
let next_to_prove = util::parent(pos1, total_rows);

// If the current position is a root, we add that to our result and don't go any further
if util::is_root_position(pos1, num_leaves, total_rows) {
calculated_root_hashes.push(hash1);
i += 1;
let mut computed = Vec::with_capacity(nodes.len() * 2);
let mut computed_index = 0;
let mut provided_index = 0;
loop {
let Some((next_pos, next_hash)) =
Self::get_next(&computed, &nodes, &mut computed_index, &mut provided_index)
else {
break;
};

if util::is_root_position(next_pos, num_leaves, total_rows) {
calculated_root_hashes.push(next_hash);
continue;
}

let Some((pos2, hash2)) = nodes.get(i + 1) else {
return Err(format!(
"Proof is too short. Expected at least {} elements, got {}",
i + 1,
nodes.len()
));
};
let sibling = next_pos | 1;
let (sibling_pos, sibling_hash) =
Self::get_next(&computed, &nodes, &mut computed_index, &mut provided_index)
.ok_or(format!("Missing sibling for {}", next_pos))?;

if pos1 != util::left_sibling(*pos2) {
return Err(format!(
"Invalid proof. Expected left sibling of {} to be {}, got {}",
pos2,
util::left_sibling(*pos2),
pos1
));
if sibling_pos != sibling {
return Err(format!("Missing sibling for {}", next_pos));
}

let parent_hash = match (hash1.is_empty(), hash2.is_empty()) {
(true, true) => NodeHash::empty(),
(true, false) => *hash2,
(false, true) => hash1,
(false, false) => NodeHash::parent_hash(&hash1, hash2),
};

Self::sorted_push(&mut nodes, (next_to_prove, parent_hash));
i += 2;
let parent_hash = NodeHash::parent_hash(&next_hash, &sibling_hash);
let parent = util::parent(next_pos, total_rows);
computed.push((parent, parent_hash));
}

// we shouldn't return the hashes in the proof
nodes.retain(|(pos, _)| !proof_positions.contains(pos));

nodes.extend(computed);
nodes.retain(|(pos, _)| proof_positions.binary_search(pos).is_err());
Ok((nodes, calculated_root_hashes))
}

fn get_next<T: Copy>(
computed: &[(u64, T)],
provided: &[(u64, T)],
computed_pos: &mut usize,
provided_pos: &mut usize,
) -> Option<(u64, T)> {
let last_computed = computed.get(*computed_pos);
let last_provided = provided.get(*provided_pos);

match (last_computed, last_provided) {
(Some((pos1, hashes1)), Some((pos2, hashes2))) => {
if pos1 < pos2 {
*computed_pos += 1;
Some((*pos1, *hashes1))
} else {
*provided_pos += 1;
Some((*pos2, *hashes2))
}
}
(Some(node), None) => {
*computed_pos += 1;
Some(*node)
}
(None, Some(node)) => {
*provided_pos += 1;
Some(*node)
}
(None, None) => None,
}
}
/// Uses the data passed in to update a proof, creating a valid proof for a given
/// set of targets, after an update. This is useful for caching UTXOs. You grab a proof
/// for it once and then keep updating it every block, yielding an always valid proof
Expand Down Expand Up @@ -699,12 +816,6 @@ impl Proof {
new_positions.sort();
Ok(new_positions)
}
fn sorted_push(nodes: &mut Vec<(u64, NodeHash)>, to_add: (u64, NodeHash)) {
let pos = nodes
.binary_search_by(|(pos, _)| pos.cmp(&to_add.0))
.unwrap_or_else(|x| x);
nodes.insert(pos, to_add);
}
}

#[cfg(test)]
Expand Down Expand Up @@ -854,6 +965,37 @@ mod tests {
assert_eq!(cached_hashes, expected_cached_hashes);
}
}

#[test]
fn test_get_next() {
use super::Proof;
let computed = vec![(1, NodeHash::empty()), (3, NodeHash::empty())];
let provided = vec![(2, NodeHash::empty()), (4, NodeHash::empty())];
let mut computed_pos = 0;
let mut provided_pos = 0;

assert_eq!(
Proof::get_next(&computed, &provided, &mut computed_pos, &mut provided_pos),
Some((1, NodeHash::empty()))
);
assert_eq!(
Proof::get_next(&computed, &provided, &mut computed_pos, &mut provided_pos),
Some((2, NodeHash::empty()))
);
assert_eq!(
Proof::get_next(&computed, &provided, &mut computed_pos, &mut provided_pos),
Some((3, NodeHash::empty()))
);
assert_eq!(
Proof::get_next(&computed, &provided, &mut computed_pos, &mut provided_pos),
Some((4, NodeHash::empty()))
);
assert_eq!(
Proof::get_next(&computed, &provided, &mut computed_pos, &mut provided_pos),
None
);
}

#[test]
fn test_calc_next_positions() {
use super::Proof;
Expand Down Expand Up @@ -1078,6 +1220,59 @@ mod tests {
}
}
}

#[test]
fn test_calculate_hashes_delete() {
let preimages = vec![0, 1, 2, 3, 4, 5, 6, 7];
let hashes = preimages.into_iter().map(hash_from_u8).collect::<Vec<_>>();

let del_hashes = vec![hashes[0]];
let proof = vec![
"4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a",
"9576f4ade6e9bc3a6458b506ce3e4e890df29cb14cb5d3d887672aef55647a2b",
"29590a14c1b09384b94a2c0e94bf821ca75b62eacebc47893397ca88e3bbcbd7",
];

let proof_hashes = proof
.into_iter()
.map(|hash| NodeHash::from_str(hash).unwrap())
.collect();

let p = Proof::new(vec![0], proof_hashes);
let del_hashes = del_hashes
.into_iter()
.map(|hash| (hash, NodeHash::empty()))
.collect::<Vec<_>>();

let (computed, roots) = p.calculate_hashes_delete(&del_hashes, 8).unwrap();
let expected_root_old =
NodeHash::from_str("b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42")
.unwrap();
let expected_root_new =
NodeHash::from_str("726fdd3b432cc59e68487d126e70f0db74a236267f8daeae30b31839a4e7ebed")
.unwrap();

let computed_positions = [0_u64, 1, 9, 13, 8, 12, 14].to_vec();
let computed_hashes = [
"0000000000000000000000000000000000000000000000000000000000000000",
"4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a",
"9576f4ade6e9bc3a6458b506ce3e4e890df29cb14cb5d3d887672aef55647a2b",
"29590a14c1b09384b94a2c0e94bf821ca75b62eacebc47893397ca88e3bbcbd7",
"4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a",
"2b77298feac78ab51bc5079099a074c6d789bd350442f5079fcba2b3402694e5",
"726fdd3b432cc59e68487d126e70f0db74a236267f8daeae30b31839a4e7ebed",
]
.iter()
.map(|hash| NodeHash::from_str(hash).unwrap())
.collect::<Vec<_>>();
let expected_computed: Vec<_> = computed_positions
.into_iter()
.zip(computed_hashes)
.collect();
assert_eq!(roots, vec![(expected_root_old, expected_root_new)]);
assert_eq!(computed, expected_computed);
}

#[test]
fn test_serialize_rtt() {
// Tests if the serialized proof can be deserialized again
Expand Down
Loading

0 comments on commit e140c8a

Please sign in to comment.