diff --git a/examples/custom-hash-type.rs b/examples/custom-hash-type.rs index ceefaaa..e696b07 100644 --- a/examples/custom-hash-type.rs +++ b/examples/custom-hash-type.rs @@ -9,17 +9,17 @@ //! for zero-knowledge proofs, and is used in projects like ZCash and StarkNet. //! If you want to work with utreexo proofs in zero-knowledge you may want to use this instead //! of our usual sha512-256 that we use by default, since that will give you smaller circuits. -//! This example shows how to use both the [Pollard](crate::accumulator::pollard::Pollard) and +//! This example shows how to use both the [MemForest](crate::accumulator::MemForest::MemForest) and //! proofs with a custom hash type. The code here should be pretty much all you need to do to //! use your custom hashes, just tweak the implementation of //! [NodeHash](crate::accumulator::node_hash::NodeHash) for your hash type. +use rustreexo::accumulator::mem_forest::MemForest; use rustreexo::accumulator::node_hash::AccumulatorHash; -use rustreexo::accumulator::pollard::Pollard; use starknet_crypto::poseidon_hash_many; use starknet_crypto::Felt; -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] /// We need a stateful wrapper around the actual hash, this is because we use those different /// values inside our accumulator. Here we use an enum to represent the different states, you /// may want to use a struct with more data, depending on your needs. @@ -34,6 +34,8 @@ enum PoseidonHash { /// returns sane values (that is, if we call [NodeHash::placeholder] calling [NodeHash::is_placeholder] /// on the result should return true). Placeholder, + + #[default] /// This is an empty value, it represents a node that was deleted from the accumulator. /// /// Same as the placeholder, you can implement this the way you want, just make sure that @@ -123,17 +125,17 @@ impl AccumulatorHash for PoseidonHash { } fn main() { - // Create a vector with two utxos that will be added to the Pollard + // Create a vector with two utxos that will be added to the MemForest let elements = vec![ PoseidonHash::Hash(Felt::from(1)), PoseidonHash::Hash(Felt::from(2)), ]; - // Create a new Pollard, and add the utxos to it - let mut p = Pollard::::new_with_hash(); + // Create a new MemForest, and add the utxos to it + let mut p = MemForest::::new_with_hash(); p.modify(&elements, &[]).unwrap(); - // Create a proof that the first utxo is in the Pollard + // Create a proof that the first utxo is in the MemForest let proof = p.prove(&[elements[0]]).unwrap(); // check that the proof has exactly one target diff --git a/examples/full-accumulator.rs b/examples/full-accumulator.rs index 2d895d2..12160c0 100644 --- a/examples/full-accumulator.rs +++ b/examples/full-accumulator.rs @@ -4,8 +4,8 @@ use std::str::FromStr; +use rustreexo::accumulator::mem_forest::MemForest; use rustreexo::accumulator::node_hash::BitcoinNodeHash; -use rustreexo::accumulator::pollard::Pollard; use rustreexo::accumulator::proof::Proof; use rustreexo::accumulator::stump::Stump; @@ -20,11 +20,11 @@ fn main() { ) .unwrap(), ]; - // Create a new Pollard, and add the utxos to it - let mut p = Pollard::new(); + // Create a new MemForest, and add the utxos to it + let mut p = MemForest::new(); p.modify(&elements, &[]).unwrap(); - // Create a proof that the first utxo is in the Pollard + // Create a proof that the first utxo is in the MemForest let proof = p.prove(&[elements[0]]).unwrap(); // Verify the proof. Notice how we use the del_hashes returned by `prove` here. let s = Stump::new() @@ -32,7 +32,7 @@ fn main() { .unwrap() .0; assert_eq!(s.verify(&proof, &[elements[0]]), Ok(true)); - // Now we want to update the Pollard, by removing the first utxo, and adding a new one. + // Now we want to update the MemForest, by removing the first utxo, and adding a new one. // This would be in case we received a new block with a transaction spending the first utxo, // and creating a new one. let new_utxo = BitcoinNodeHash::from_str( @@ -41,6 +41,6 @@ fn main() { .unwrap(); p.modify(&[new_utxo], &[elements[0]]).unwrap(); - // Now we can prove that the new utxo is in the Pollard. + // Now we can prove that the new utxo is in the MemForest. let _ = p.prove(&[new_utxo]).unwrap(); } diff --git a/src/accumulator/mem_forest.rs b/src/accumulator/mem_forest.rs new file mode 100644 index 0000000..ff324fa --- /dev/null +++ b/src/accumulator/mem_forest.rs @@ -0,0 +1,1088 @@ +//! A full MemForest accumulator implementation. This is a simple version of the forest, +//! that keeps every node in memory. This is may require more memory, but is faster +//! to update, prove and verify. +//! +//! # Example +//! ``` +//! use rustreexo::accumulator::mem_forest::MemForest; +//! use rustreexo::accumulator::node_hash::AccumulatorHash; +//! use rustreexo::accumulator::node_hash::BitcoinNodeHash; +//! +//! let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; +//! let hashes: Vec = values +//! .into_iter() +//! .map(|i| BitcoinNodeHash::from([i; 32])) +//! .collect(); +//! +//! let mut p = MemForest::::new(); +//! +//! p.modify(&hashes, &[]).expect("MemForest should not fail"); +//! assert_eq!(p.get_roots().len(), 1); +//! +//! p.modify(&[], &hashes).expect("Still should not fail"); // Remove leaves from the accumulator +//! +//! assert_eq!(p.get_roots().len(), 1); +//! assert_eq!(p.get_roots()[0].get_data(), BitcoinNodeHash::default()); +//! ``` + +use std::cell::Cell; +use std::cell::RefCell; +use std::collections::HashMap; +use std::fmt::Debug; +use std::fmt::Display; +use std::fmt::Formatter; +use std::io::Read; +use std::io::Write; +use std::rc::Rc; +use std::rc::Weak; + +use super::node_hash::AccumulatorHash; +use super::node_hash::BitcoinNodeHash; +use super::proof::Proof; +use super::util::detect_offset; +use super::util::get_proof_positions; +use super::util::is_left_niece; +use super::util::is_root_populated; +use super::util::left_child; +use super::util::max_position_at_row; +use super::util::right_child; +use super::util::root_position; +use super::util::tree_rows; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum NodeType { + Branch, + Leaf, +} + +// A few type aliases to improve readability. + +/// A weak reference to a node in the forest. We use this when referencing to a node's +/// parent, as we don't want to create a reference cycle. +type Parent = RefCell>>; +/// A reference to a node's children. We use this to store the left and right children +/// of a node. This will be the only long-lived strong reference to a node. If this gets +/// dropped, the node will be dropped as well. +type Children = RefCell>>; + +/// A forest node that can either be a leaf or a branch. +#[derive(Clone)] +pub struct Node { + /// The type of this node. + ty: NodeType, + /// The hash of the stored in this node. + data: Cell, + /// The parent of this node, if any. + parent: Parent>, + /// The left and right children of this node, if any. + left: Children>, + /// The left and right children of this node, if any. + right: Children>, +} + +impl Node { + /// Writes one node to the writer, this method will recursively write all children. + /// The primary use of this method is to serialize the accumulator. In this case, + /// you should call this method on each root in the forest. + pub fn write_one(&self, writer: &mut W) -> std::io::Result<()> { + match self.ty { + NodeType::Branch => writer.write_all(&0_u64.to_le_bytes())?, + NodeType::Leaf => writer.write_all(&1_u64.to_le_bytes())?, + } + self.data.get().write(writer)?; + self.left + .borrow() + .as_ref() + .map(|l| l.write_one(writer)) + .transpose()?; + + self.right + .borrow() + .as_ref() + .map(|r| r.write_one(writer)) + .transpose()?; + Ok(()) + } + /// Recomputes the hash of all nodes, up to the root. + fn recompute_hashes(&self) { + let left = self.left.borrow(); + let right = self.right.borrow(); + + if let (Some(left), Some(right)) = (left.as_deref(), right.as_deref()) { + self.data + .replace(Hash::parent_hash(&left.data.get(), &right.data.get())); + } + if let Some(ref parent) = *self.parent.borrow() { + if let Some(p) = parent.upgrade() { + p.recompute_hashes(); + } + } + } + + /// Reads one node from the reader, this method will recursively read all children. + /// The primary use of this method is to deserialize the accumulator. In this case, + /// you should call this method on each root in the forest, assuming you know how + /// many roots there are. + #[allow(clippy::type_complexity)] + pub fn read_one( + reader: &mut R, + ) -> std::io::Result<(Rc>, HashMap>>)> { + fn _read_one( + ancestor: Option>>, + reader: &mut R, + index: &mut HashMap>>, + ) -> std::io::Result>> { + let mut ty = [0u8; 8]; + reader.read_exact(&mut ty)?; + let data = Hash::read(reader)?; + + let ty = match u64::from_le_bytes(ty) { + 0 => NodeType::Branch, + 1 => NodeType::Leaf, + _ => panic!("Invalid node type"), + }; + if ty == NodeType::Leaf { + let leaf = Rc::new(Node { + ty, + data: Cell::new(data), + parent: RefCell::new(ancestor.map(|a| Rc::downgrade(&a))), + left: RefCell::new(None), + right: RefCell::new(None), + }); + index.insert(leaf.data.get(), Rc::downgrade(&leaf)); + return Ok(leaf); + } + let node = Rc::new(Node { + ty: NodeType::Branch, + data: Cell::new(data), + parent: RefCell::new(ancestor.map(|a| Rc::downgrade(&a))), + left: RefCell::new(None), + right: RefCell::new(None), + }); + if !data.is_empty() { + let left = _read_one(Some(node.clone()), reader, index)?; + let right = _read_one(Some(node.clone()), reader, index)?; + node.left.replace(Some(left)); + node.right.replace(Some(right)); + } + node.left + .borrow() + .as_ref() + .map(|l| l.parent.replace(Some(Rc::downgrade(&node)))); + node.right + .borrow() + .as_ref() + .map(|r| r.parent.replace(Some(Rc::downgrade(&node)))); + + Ok(node) + } + let mut index = HashMap::new(); + let root = _read_one(None, reader, &mut index)?; + Ok((root, index)) + } + + /// Returns the data associated with this node. + pub fn get_data(&self) -> Hash { + self.data.get() + } +} + +/// The actual MemForest accumulator, it implements all methods required to update the forest +/// and to prove/verify membership. +#[derive(Default, Clone)] +pub struct MemForest { + /// The roots of the forest, all leaves are children of these roots, and therefore + /// owned by them. + roots: Vec>>, + /// The number of leaves in the forest. Actually, this is the number of leaves we ever + /// added to the forest. + pub leaves: u64, + /// A map of all nodes in the forest, indexed by their hash, this is used to lookup + /// leaves when proving membership. + map: HashMap>>, +} + +impl MemForest { + /// Creates a new empty [MemForest] with the default hash function. + /// + /// This will create an empty MemForest, using [BitcoinNodeHash] as the hash function. If you + /// want to use a different hash function, you can use [MemForest::new_with_hash]. + /// # Example + /// ``` + /// use rustreexo::accumulator::mem_forest::MemForest; + /// let mut mem_forest = MemForest::new(); + /// ``` + pub fn new() -> MemForest { + MemForest { + map: HashMap::new(), + roots: Vec::new(), + leaves: 0, + } + } +} + +impl MemForest { + /// Creates a new empty [MemForest] with a custom hash function. + /// # Example + /// ``` + /// use rustreexo::accumulator::mem_forest::MemForest; + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; + /// let mut MemForest = MemForest::::new(); + /// ``` + pub fn new_with_hash() -> MemForest { + MemForest { + map: HashMap::new(), + roots: Vec::new(), + leaves: 0, + } + } + + /// Writes the MemForest to a writer. Used to send the accumulator over the wire + /// or to disk. + /// # Example + /// ``` + /// use rustreexo::accumulator::mem_forest::MemForest; + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; + /// + /// let mut mem_forest = MemForest::::new(); + /// let mut serialized = Vec::new(); + /// mem_forest.serialize(&mut serialized).unwrap(); + /// + /// assert_eq!( + /// serialized, + /// vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + /// ); + /// ``` + pub fn serialize(&self, mut writer: W) -> std::io::Result<()> { + writer.write_all(&self.leaves.to_le_bytes())?; + writer.write_all(&self.roots.len().to_le_bytes())?; + + for root in &self.roots { + root.write_one(&mut writer).unwrap(); + } + + Ok(()) + } + + /// Deserializes a MemForest from a reader. + /// # Example + /// ``` + /// use std::io::Cursor; + /// + /// use rustreexo::accumulator::mem_forest::MemForest; + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; + /// let mut serialized = Cursor::new(vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); + /// let MemForest = MemForest::::deserialize(&mut serialized).unwrap(); + /// assert_eq!(MemForest.leaves, 0); + /// assert_eq!(MemForest.get_roots().len(), 0); + /// ``` + pub fn deserialize(mut reader: R) -> std::io::Result> { + fn read_u64(reader: &mut R) -> std::io::Result { + let mut buf = [0u8; 8]; + reader.read_exact(&mut buf)?; + Ok(u64::from_le_bytes(buf)) + } + let leaves = read_u64(&mut reader)?; + let roots_len = read_u64(&mut reader)?; + let mut roots = Vec::new(); + let mut map = HashMap::new(); + for _ in 0..roots_len { + let (root, _map) = Node::read_one(&mut reader)?; + map.extend(_map); + roots.push(root); + } + Ok(MemForest { roots, leaves, map }) + } + + /// Returns the hash of a given position in the tree. + fn get_hash(&self, pos: u64) -> Result { + let (node, _, _) = self.grab_node(pos)?; + Ok(node.data.get()) + } + + /// Proves that a given set of hashes is in the accumulator. It returns a proof + /// and the hashes that we what to prove, but sorted by position in the tree. + /// # Example + /// ``` + /// use rustreexo::accumulator::mem_forest::MemForest; + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; + /// let mut mem_forest = MemForest::::new(); + /// let hashes = vec![0, 1, 2, 3, 4, 5, 6, 7] + /// .iter() + /// .map(|n| BitcoinNodeHash::from([*n; 32])) + /// .collect::>(); + /// mem_forest.modify(&hashes, &[]).unwrap(); + /// // We want to prove that the first two hashes are in the accumulator. + /// let proof = mem_forest.prove(&[hashes[1], hashes[0]]).unwrap(); + /// //TODO: Verify the proof + /// ``` + pub fn prove(&self, targets: &[Hash]) -> Result, String> { + let mut positions = Vec::new(); + for target in targets { + let node = self.map.get(target).ok_or("Could not find node")?; + let position = self.get_pos(node); + positions.push(position); + } + let needed = get_proof_positions(&positions, self.leaves, tree_rows(self.leaves)); + let proof = needed + .iter() + .map(|pos| self.get_hash(*pos).unwrap()) + .collect::>(); + + Ok(Proof::new_with_hash(positions, proof)) + } + + /// Returns a reference to the roots in this MemForest. + pub fn get_roots(&self) -> &[Rc>] { + &self.roots + } + + /// Modify is the main API to a [MemForest]. Because order matters, you can only `modify` + /// a [MemForest], and internally it'll add and delete, in the correct order. + /// + /// This method accepts two vectors as parameter, a vec of [Hash] and a vec of [u64]. The + /// first one is a vec of leaf hashes for the newly created UTXOs. The second one is the position + /// for the UTXOs being spent in this block as inputs. + /// + /// # Example + /// ``` + /// use bitcoin_hashes::sha256::Hash as Data; + /// use bitcoin_hashes::Hash; + /// use bitcoin_hashes::HashEngine; + /// use rustreexo::accumulator::mem_forest::MemForest; + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; + /// use rustreexo::accumulator::node_hash::NodeHash; + /// let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + /// let hashes = values + /// .into_iter() + /// .map(|val| { + /// let mut engine = Data::engine(); + /// engine.input(&[val]); + /// BitcoinNodeHash::from(Data::from_engine(engine).as_byte_array()) + /// }) + /// .collect::>(); + /// // Add 8 leaves to the MemForest + /// let mut p = MemForest::::new(); + /// p.modify(&hashes, &[]).expect("MemForest should not fail"); + /// + /// assert_eq!( + /// p.get_roots()[0].get_data().to_string(), + /// String::from("b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42") + /// ); + /// ``` + pub fn modify(&mut self, add: &[Hash], del: &[Hash]) -> Result<(), String> { + self.del(del)?; + self.add(add); + Ok(()) + } + + #[allow(clippy::type_complexity)] + pub fn grab_node( + &self, + pos: u64, + ) -> Result<(Rc>, Rc>, Rc>), String> { + let (tree, branch_len, bits) = detect_offset(pos, self.leaves); + let mut n = Some(self.roots[tree as usize].clone()); + let mut sibling = Some(self.roots[tree as usize].clone()); + let mut parent = sibling.clone(); + + for row in (0..(branch_len)).rev() { + // Parent is the sibling of the current node as each of the + // nodes point to their nieces. + parent.clone_from(&sibling); + + // Figure out which node we need to follow. + let niece_pos = ((bits >> row) & 1) as u8; + + #[allow(clippy::assigning_clones)] + if let Some(node) = n { + if is_left_niece(niece_pos as u64) { + n = node.right.borrow().clone(); + sibling.clone_from(&*node.left.borrow()); + } else { + n = node.left.borrow().clone(); + sibling.clone_from(&*node.right.borrow()); + } + } else { + sibling = None; + } + } + if let (Some(node), Some(sibling), Some(parent)) = (n, sibling, parent) { + return Ok((node, sibling, parent)); + } + Err(format!("node {} not found", pos)) + } + + fn del(&mut self, targets: &[Hash]) -> Result<(), String> { + let mut pos = targets + .iter() + .flat_map(|target| self.map.get(target)) + .flat_map(|target| target.upgrade()) + .map(|target| { + ( + self.get_pos(self.map.get(&target.data.get()).unwrap()), + target.data.get(), + ) + }) + .collect::>(); + + pos.sort(); + let (_, targets): (Vec, Vec) = pos.into_iter().unzip(); + for target in targets { + match self.map.remove(&target) { + Some(target) => { + self.del_single(&target.upgrade().unwrap()); + } + None => { + return Err(format!("node {} not in the forest", target)); + } + } + } + Ok(()) + } + + pub fn verify(&self, proof: &Proof, del_hashes: &[Hash]) -> Result { + let roots = self + .roots + .iter() + .map(|root| root.get_data()) + .collect::>(); + proof.verify(del_hashes, &roots, self.leaves) + } + + fn get_pos(&self, node: &Weak>) -> u64 { + // This indicates whether the node is a left or right child at each level + // When we go down the tree, we can use the indicator to know which + // child to take. + let mut left_child_indicator = 0_u64; + let mut rows_to_top = 0; + let mut node = node.upgrade().unwrap(); + while let Some(parent) = node.parent.clone().into_inner() { + let parent_left = parent + .upgrade() + .and_then(|parent| parent.left.clone().into_inner()) + .unwrap() + .clone(); + + // If the current node is a left child, we left-shift the indicator + // and leave the LSB as 0 + if parent_left.get_data() == node.get_data() { + left_child_indicator <<= 1; + } else { + // If the current node is a right child, we left-shift the indicator + // and set the LSB to 1 + left_child_indicator <<= 1; + left_child_indicator |= 1; + } + rows_to_top += 1; + node = parent.upgrade().unwrap(); + } + let mut root_idx = self.roots.len() - 1; + let forest_rows = tree_rows(self.leaves); + let mut root_row = 0; + // Find the root of the tree that the node belongs to + for row in 0..forest_rows { + if is_root_populated(row, self.leaves) { + let root = &self.roots[root_idx]; + if root.get_data() == node.get_data() { + root_row = row; + break; + } + root_idx -= 1; + } + } + let mut pos = root_position(self.leaves, root_row, forest_rows); + for _ in 0..rows_to_top { + // If LSB is 0, go left, otherwise go right + match left_child_indicator & 1 { + 0 => { + pos = left_child(pos, forest_rows); + } + 1 => { + pos = right_child(pos, forest_rows); + } + _ => unreachable!(), + } + left_child_indicator >>= 1; + } + pos + } + + fn del_single(&mut self, node: &Node) -> Option<()> { + let parent = node.parent.borrow(); + // Deleting a root + let parent = match *parent { + Some(ref node) => node.upgrade()?, + None => { + let pos = self.roots.iter().position(|x| x.data == node.data).unwrap(); + self.roots[pos] = Rc::new(Node { + ty: NodeType::Branch, + parent: RefCell::new(None), + data: Cell::new(Hash::empty()), + left: RefCell::new(None), + right: RefCell::new(None), + }); + return None; + } + }; + + let me = parent.left.borrow(); + // Can unwrap because we know the sibling exists + let sibling = if me.as_deref()?.data == node.data { + parent.right.borrow().clone() + } else { + parent.left.borrow().clone() + }; + if let Some(ref sibling) = sibling { + let grandparent = parent.parent.borrow().clone(); + sibling.parent.replace(grandparent.clone()); + + if let Some(ref grandparent) = grandparent.and_then(|g| g.upgrade()) { + if grandparent.left.borrow().clone().as_ref().unwrap().data == parent.data { + grandparent.left.replace(Some(sibling.clone())); + } else { + grandparent.right.replace(Some(sibling.clone())); + } + sibling.recompute_hashes(); + } else { + let pos = self + .roots + .iter() + .position(|x| x.data == parent.data) + .unwrap(); + self.roots[pos] = sibling.clone(); + } + }; + + Some(()) + } + + fn add_single(&mut self, value: Hash) { + let mut node: Rc> = Rc::new(Node { + ty: NodeType::Leaf, + parent: RefCell::new(None), + data: Cell::new(value), + left: RefCell::new(None), + right: RefCell::new(None), + }); + self.map.insert(value, Rc::downgrade(&node)); + let mut leaves = self.leaves; + while leaves & 1 != 0 { + let root = self.roots.pop().unwrap(); + if root.get_data() == AccumulatorHash::empty() { + leaves >>= 1; + continue; + } + let new_node = Rc::new(Node { + ty: NodeType::Branch, + parent: RefCell::new(None), + data: Cell::new(AccumulatorHash::parent_hash( + &root.data.get(), + &node.data.get(), + )), + left: RefCell::new(Some(root.clone())), + right: RefCell::new(Some(node.clone())), + }); + root.parent.replace(Some(Rc::downgrade(&new_node))); + node.parent.replace(Some(Rc::downgrade(&new_node))); + + node = new_node; + leaves >>= 1; + } + self.roots.push(node); + self.leaves += 1; + } + + fn add(&mut self, values: &[Hash]) { + for value in values { + self.add_single(*value); + } + } + + /// to_string returns the full MemForest in a string for all forests less than 6 rows. + fn string(&self) -> String { + if self.leaves == 0 { + return "empty".to_owned(); + } + let fh = tree_rows(self.leaves); + // The accumulator should be less than 6 rows. + if fh > 6 { + let s = format!("Can't print {} leaves. roots: \n", self.leaves); + return self.get_roots().iter().fold(s, |mut a, b| { + a.extend(format!("{}\n", b.get_data()).chars()); + a + }); + } + let mut output = vec!["".to_string(); (fh as usize * 2) + 1]; + let mut pos: u8 = 0; + for h in 0..=fh { + let row_len = 1 << (fh - h); + for _ in 0..row_len { + let max = max_position_at_row(h, fh, self.leaves).unwrap(); + if max >= pos as u64 { + match self.get_hash(pos as u64) { + Ok(val) => { + if pos >= 100 { + output[h as usize * 2].push_str( + format!("{:#02x}:{} ", pos, &val.to_string()[..2]).as_str(), + ); + } else { + output[h as usize * 2].push_str( + format!("{:0>2}:{} ", pos, &val.to_string()[..4]).as_str(), + ); + } + } + Err(_) => { + output[h as usize * 2].push_str(" "); + } + } + } + + if h > 0 { + output[(h as usize * 2) - 1].push_str("|-------"); + + for _ in 0..((1 << h) - 1) / 2 { + output[(h as usize * 2) - 1].push_str("--------"); + } + output[(h as usize * 2) - 1].push_str("\\ "); + + for _ in 0..((1 << h) - 1) / 2 { + output[(h as usize * 2) - 1].push_str(" "); + } + + for _ in 0..(1 << h) - 1 { + output[h as usize * 2].push_str(" "); + } + } + pos += 1; + } + } + + output.iter().rev().fold(String::new(), |mut a, b| { + a.push_str(b); + a.push('\n'); + a + }) + } +} + +impl Debug for MemForest { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { + write!(f, "{}", self.string()) + } +} + +impl Display for MemForest { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { + write!(f, "{}", self.string()) + } +} + +#[cfg(test)] +mod test { + use std::convert::TryFrom; + use std::rc::Rc; + use std::str::FromStr; + use std::vec; + + use bitcoin_hashes::sha256::Hash as Data; + use bitcoin_hashes::Hash; + use bitcoin_hashes::HashEngine; + use serde::Deserialize; + + use super::MemForest; + use crate::accumulator::mem_forest::Node; + use crate::accumulator::node_hash::AccumulatorHash; + use crate::accumulator::node_hash::BitcoinNodeHash; + use crate::accumulator::proof::Proof; + + fn hash_from_u8(value: u8) -> BitcoinNodeHash { + let mut engine = Data::engine(); + + engine.input(&[value]); + + BitcoinNodeHash::from(Data::from_engine(engine).as_byte_array()) + } + + #[test] + fn test_grab_node() { + let values = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]; + let hashes = values.into_iter().map(hash_from_u8).collect::>(); + + let mut p = MemForest::new(); + p.modify(&hashes, &[]).expect("MemForest should not fail"); + let (found_target, found_sibling, _) = p.grab_node(4).unwrap(); + let target = BitcoinNodeHash::try_from( + "e52d9c508c502347344d8c07ad91cbd6068afc75ff6292f062a09ca381c89e71", + ) + .unwrap(); + let sibling = BitcoinNodeHash::try_from( + "e77b9a9ae9e30b0dbdb6f510a264ef9de781501d7b6b92ae89eb059c5ab743db", + ) + .unwrap(); + + assert_eq!(target, found_target.data.get()); + assert_eq!(sibling, found_sibling.data.get()); + } + + #[test] + fn test_delete() { + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let hashes = values.into_iter().map(hash_from_u8).collect::>(); + + let mut p = MemForest::new(); + p.modify(&hashes, &[]).expect("MemForest should not fail"); + p.modify(&[], &[hashes[0]]).expect("msg"); + + let (node, _, _) = p.grab_node(8).unwrap(); + assert_eq!( + String::from("4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a"), + node.data.get().to_string() + ); + } + + #[test] + fn test_proof_verify() { + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let hashes = values.into_iter().map(hash_from_u8).collect::>(); + let mut p = MemForest::new(); + p.modify(&hashes, &[]).unwrap(); + + let proof = p.prove(&[hashes[0], hashes[1]]).unwrap(); + assert!(p.verify(&proof, &[hashes[0], hashes[1]]).unwrap()); + } + + #[test] + fn test_add() { + let values = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]; + let hashes = values.into_iter().map(hash_from_u8).collect::>(); + + let mut acc = MemForest::new(); + acc.add(&hashes); + + assert_eq!( + "b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42", + acc.roots[0].data.get().to_string().as_str(), + ); + assert_eq!( + "9c053db406c1a077112189469a3aca0573d3481bef09fa3d2eda3304d7d44be8", + acc.roots[1].data.get().to_string().as_str(), + ); + assert_eq!( + "55d0a0ef8f5c25a9da266b36c0c5f4b31008ece82df2512c8966bddcc27a66a0", + acc.roots[2].data.get().to_string().as_str(), + ); + assert_eq!( + "4d7b3ef7300acf70c892d8327db8272f54434adbc61a4e130a563cb59a0d0f47", + acc.roots[3].data.get().to_string().as_str(), + ); + } + + #[test] + fn test_delete_roots_child() { + // Assuming the following tree: + // + // 02 + // |---\ + // 00 01 + // If I delete `01`, then `00` will become a root, moving it's hash to `02` + let values = vec![0, 1]; + let hashes: Vec = values.into_iter().map(hash_from_u8).collect(); + + let mut p = MemForest::new(); + p.modify(&hashes, &[]).expect("MemForest should not fail"); + p.del_single(&p.grab_node(1).unwrap().0); + assert_eq!(p.get_roots().len(), 1); + + let root = p.get_roots()[0].clone(); + assert_eq!(root.data.get(), hashes[0]); + } + + #[test] + fn test_delete_root() { + // Assuming the following tree: + // + // 02 + // |---\ + // 00 01 + // If I delete `02`, then `02` will become an empty root, it'll point to nothing + // and its data will be Data::default() + let values = vec![0, 1]; + let hashes: Vec = values.into_iter().map(hash_from_u8).collect(); + + let mut p = MemForest::new(); + p.modify(&hashes, &[]).expect("MemForest should not fail"); + p.del_single(&p.grab_node(2).unwrap().0); + assert_eq!(p.get_roots().len(), 1); + let root = p.get_roots()[0].clone(); + assert_eq!(root.data.get(), BitcoinNodeHash::default()); + } + + #[test] + fn test_delete_non_root() { + // Assuming this tree, if we delete `01`, 00 will move up to 08's position + // 14 + // |-----------------\ + // 12 13 + // |-------\ |--------\ + // 08 09 10 11 + // |----\ |----\ |----\ |----\ + // 00 01 02 03 04 05 06 07 + + // 14 + // |-----------------\ + // 12 13 + // |-------\ |--------\ + // 08 09 10 11 + // |----\ |----\ |----\ |----\ + // 00 01 02 03 04 05 06 07 + + // Where 08's data is just 00's + + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let hashes: Vec = values.into_iter().map(hash_from_u8).collect(); + + let mut p = MemForest::new(); + p.modify(&hashes, &[]).expect("MemForest should not fail"); + p.modify(&[], &[hashes[1]]).expect("Still should not fail"); + + assert_eq!(p.roots.len(), 1); + let (node, _, _) = p.grab_node(8).expect("This tree should have pos 8"); + assert_eq!(node.data.get(), hashes[0]); + } + + #[derive(Debug, Deserialize)] + struct TestCase { + leaf_preimages: Vec, + target_values: Option>, + expected_roots: Vec, + } + + fn run_single_addition_case(case: TestCase) { + let hashes = case + .leaf_preimages + .iter() + .map(|preimage| hash_from_u8(*preimage)) + .collect::>(); + let mut p = MemForest::new(); + p.modify(&hashes, &[]).expect("Test mem_forests are valid"); + assert_eq!(p.get_roots().len(), case.expected_roots.len()); + let expected_roots = case + .expected_roots + .iter() + .map(|root| BitcoinNodeHash::from_str(root).unwrap()) + .collect::>(); + let roots = p + .get_roots() + .iter() + .map(|root| root.data.get()) + .collect::>(); + assert_eq!(expected_roots, roots, "Test case failed {:?}", case); + } + + fn run_case_with_deletion(case: TestCase) { + let hashes = case + .leaf_preimages + .iter() + .map(|preimage| hash_from_u8(*preimage)) + .collect::>(); + let dels = case + .target_values + .clone() + .unwrap() + .iter() + .map(|pos| hashes[*pos as usize]) + .collect::>(); + let mut p = MemForest::new(); + p.modify(&hashes, &[]).expect("Test mem_forests are valid"); + p.modify(&[], &dels).expect("still should be valid"); + + assert_eq!(p.get_roots().len(), case.expected_roots.len()); + let expected_roots = case + .expected_roots + .iter() + .map(|root| BitcoinNodeHash::from_str(root).unwrap()) + .collect::>(); + let roots = p + .get_roots() + .iter() + .map(|root| root.data.get()) + .collect::>(); + assert_eq!(expected_roots, roots, "Test case failed {:?}", case); + } + + #[test] + fn run_tests_from_cases() { + #[derive(Deserialize)] + struct TestsJSON { + insertion_tests: Vec, + deletion_tests: Vec, + } + + let contents = std::fs::read_to_string("test_values/test_cases.json") + .expect("Something went wrong reading the file"); + + let tests = serde_json::from_str::(contents.as_str()) + .expect("JSON deserialization error"); + + for i in tests.insertion_tests { + run_single_addition_case(i); + } + for i in tests.deletion_tests { + run_case_with_deletion(i); + } + } + + #[test] + fn test_to_string() { + let hashes = get_hash_vec_of(&(0..255).collect::>()); + let mut p = MemForest::new(); + p.modify(&hashes, &[]).expect("Test mem_forests are valid"); + assert_eq!( + Some("Can't print 255 leaves. roots:"), + p.to_string().get(0..30) + ); + } + + #[test] + fn test_get_pos() { + macro_rules! test_get_pos { + ($p:ident, $pos:literal) => { + assert_eq!( + $p.get_pos(&Rc::downgrade(&$p.grab_node($pos).unwrap().0)), + $pos + ); + }; + } + let hashes = get_hash_vec_of(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); + let mut p = MemForest::new(); + p.modify(&hashes, &[]).expect("Test mem_forests are valid"); + test_get_pos!(p, 0); + test_get_pos!(p, 1); + test_get_pos!(p, 2); + test_get_pos!(p, 3); + test_get_pos!(p, 4); + test_get_pos!(p, 5); + test_get_pos!(p, 6); + test_get_pos!(p, 7); + test_get_pos!(p, 8); + test_get_pos!(p, 9); + test_get_pos!(p, 10); + test_get_pos!(p, 11); + test_get_pos!(p, 12); + + assert_eq!(p.get_pos(&Rc::downgrade(&p.get_roots()[0])), 28); + assert_eq!( + p.get_pos(&Rc::downgrade( + p.get_roots()[0].left.borrow().as_ref().unwrap() + )), + 24 + ); + assert_eq!( + p.get_pos(&Rc::downgrade( + p.get_roots()[0].right.borrow().as_ref().unwrap() + )), + 25 + ); + } + + #[test] + fn test_serialize_one() { + let hashes = get_hash_vec_of(&[0, 1, 2, 3, 4, 5, 6, 7]); + let mut p = MemForest::new(); + p.modify(&hashes, &[]).expect("Test mem_forests are valid"); + p.modify(&[], &[hashes[0]]).expect("can remove 0"); + let mut writer = std::io::Cursor::new(Vec::new()); + p.get_roots()[0].write_one(&mut writer).unwrap(); + let (deserialized, _) = + Node::::read_one(&mut std::io::Cursor::new(writer.into_inner())) + .unwrap(); + assert_eq!(deserialized.get_data(), p.get_roots()[0].get_data()); + } + + #[test] + fn test_serialization() { + let hashes = get_hash_vec_of(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]); + let mut p = MemForest::new(); + p.modify(&hashes, &[]).expect("Test mem_forests are valid"); + p.modify(&[], &[hashes[0]]).expect("can remove 0"); + let mut writer = std::io::Cursor::new(Vec::new()); + p.serialize(&mut writer).unwrap(); + let deserialized = MemForest::::deserialize(&mut std::io::Cursor::new( + writer.into_inner(), + )) + .unwrap(); + assert_eq!( + deserialized.get_roots()[0].get_data(), + p.get_roots()[0].get_data() + ); + assert_eq!(deserialized.leaves, p.leaves); + assert_eq!(deserialized.map.len(), p.map.len()); + } + + #[test] + fn test_proof() { + let hashes = get_hash_vec_of(&[0, 1, 2, 3, 4, 5, 6, 7]); + let del_hashes = [hashes[2], hashes[1], hashes[4], hashes[6]]; + + let mut p = MemForest::new(); + p.modify(&hashes, &[]).expect("Test mem_forests are valid"); + + let proof = p.prove(&del_hashes).expect("Should be able to prove"); + + let expected_proof = Proof::new( + [2, 1, 4, 6].to_vec(), + vec![ + "6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d" + .parse() + .unwrap(), + "084fed08b978af4d7d196a7446a86b58009e636b611db16211b65a9aadff29c5" + .parse() + .unwrap(), + "e77b9a9ae9e30b0dbdb6f510a264ef9de781501d7b6b92ae89eb059c5ab743db" + .parse() + .unwrap(), + "ca358758f6d27e6cf45272937977a748fd88391db679ceda7dc7bf1f005ee879" + .parse() + .unwrap(), + ], + ); + assert_eq!(proof, expected_proof); + assert!(p.verify(&proof, &del_hashes).unwrap()); + } + + fn get_hash_vec_of(elements: &[u8]) -> Vec { + elements.iter().map(|el| hash_from_u8(*el)).collect() + } + + #[test] + fn test_display_empty() { + let p = MemForest::new(); + let _ = p.to_string(); + } + + #[test] + fn test_serialization_roundtrip() { + let mut p = MemForest::::new(); + let values = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + let hashes: Vec = values + .into_iter() + .map(|i| BitcoinNodeHash::from([i; 32])) + .collect(); + p.modify(&hashes, &[]).expect("modify should work"); + assert_eq!(p.get_roots().len(), 1); + assert!(!p.get_roots()[0].get_data().is_empty()); + assert_eq!(p.leaves, 16); + p.modify(&[], &hashes).expect("modify should work"); + assert_eq!(p.get_roots().len(), 1); + assert!(p.get_roots()[0].get_data().is_empty()); + assert_eq!(p.leaves, 16); + let mut serialized = Vec::::new(); + p.serialize(&mut serialized).expect("serialize should work"); + let deserialized = MemForest::::deserialize(&*serialized) + .expect("deserialize should work"); + assert_eq!(deserialized.get_roots().len(), 1); + assert!(deserialized.get_roots()[0].get_data().is_empty()); + assert_eq!(deserialized.leaves, 16); + } +} diff --git a/src/accumulator/mod.rs b/src/accumulator/mod.rs index 08c3624..9dcf76c 100644 --- a/src/accumulator/mod.rs +++ b/src/accumulator/mod.rs @@ -7,6 +7,7 @@ //! a lightweight implementation of utreexo that only stores roots. Although the [stump::Stump] //! only keeps the accumulator's roots, it still trustlessly update this state, not requiring //! a trusted third party to learn about the current state. +pub mod mem_forest; pub mod node_hash; pub mod pollard; pub mod proof; diff --git a/src/accumulator/node_hash.rs b/src/accumulator/node_hash.rs index bb5b440..d64f020 100644 --- a/src/accumulator/node_hash.rs +++ b/src/accumulator/node_hash.rs @@ -62,7 +62,7 @@ use serde::Deserialize; use serde::Serialize; pub trait AccumulatorHash: - Copy + Clone + Ord + Debug + Display + std::hash::Hash + 'static + Copy + Clone + Ord + Debug + Display + std::hash::Hash + Default + 'static { fn is_empty(&self) -> bool; fn empty() -> Self; diff --git a/src/accumulator/pollard.rs b/src/accumulator/pollard.rs index d7a522f..4c9f38e 100644 --- a/src/accumulator/pollard.rs +++ b/src/accumulator/pollard.rs @@ -1,604 +1,673 @@ -//! A full Pollard accumulator implementation. This is a simple version of the forest, -//! that keeps every node in memory. This is may require more memory, but is faster -//! to update, prove and verify. -//! -//! # Example -//! ``` -//! use rustreexo::accumulator::node_hash::AccumulatorHash; -//! use rustreexo::accumulator::node_hash::BitcoinNodeHash; -//! use rustreexo::accumulator::pollard::Pollard; -//! -//! let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; -//! let hashes: Vec = values -//! .into_iter() -//! .map(|i| BitcoinNodeHash::from([i; 32])) -//! .collect(); -//! -//! let mut p = Pollard::::new(); -//! -//! p.modify(&hashes, &[]).expect("Pollard should not fail"); -//! assert_eq!(p.get_roots().len(), 1); -//! -//! p.modify(&[], &hashes).expect("Still should not fail"); // Remove leaves from the accumulator -//! -//! assert_eq!(p.get_roots().len(), 1); -//! assert_eq!(p.get_roots()[0].get_data(), BitcoinNodeHash::default()); -//! ``` - +/// Pollard is an efficient implementation of the accumulator for keeping track of a subset of the +/// whole tree. Instead of storing a proof for some leaves, it is more efficient to hold them in a +/// tree structure, and add/remove elements as needed. The main use-case for a Pollard is to keep +/// track of unconfirmed transactions' proof, in the mempool. As you get new transactions through +/// the p2p network, you check the proofs and add them to the Pollard. When a block is mined, we +/// can remove the confirmed transactions from the Pollard, and keep the unconfirmed ones. We can +/// also serve proofs for specific transactions as requested, allowing efficient transaction relay. +/// +/// This implementation is close to the one in `MemForest`, but it is specialized in keeping track +/// of subsets of the whole tree, allowing you to cache and uncache elements as needed. While the +/// MemForest keeps everything in the accumulator, and may take a lot of memory. +/// +/// Nodes are kept in memory, and they hold their hashes, a reference to their **aunt** (not +/// parent!), and their nieces (not children!). We do this to allow for proof generation, while +/// prunning as much as possible. In a merkle proof, we only need the sibling of the path to the +/// root, the parent is always computed on the fly as we walk up the tree. Some there's no need to +/// keep the parent. But we need the aunt (the sibling of the parent) to generate the proof. +/// +/// Every node is owned by exactly one other node, the ancestor - With the only exception being the +/// roots, which are owned by the Pollard itself. This almost garantees that we can't have a memory +/// leak, as deleting one node will delete all of its descendants. The only way to have a memory +/// leak is if we have a cycle in the tree, which we avoid by only allowing Weak references everywhere, +/// except for the owner of the node. Things are kept in a [Rc] to allow for multiple references to +/// the same node, as we may need to operate on it, and also to allow the nieces to have a reference +/// to their aunt. It could be done with pointers, but it would be more complex and error-prone. The +/// [Rc]s live inside a [RefCell], to allow for interior mutability, as we may need to change the +/// values inside a node. Make sure to avoid leaking a reference to the inner [RefCell] to the outside +/// world, as it may cause race conditions and panics. Every time we use a reference to the inner +/// [RefCell], we make sure to drop it as soon as possible, and that we are the only ones operating +/// on it at that time. For this reason, a [Pollard] is not [Sync], and you'll need to use a [Mutex] +/// or something similar to share it between threads. But it is [Send], as it is safe to send it to +/// another thread - everything is owned by the Pollard and lives on the heap. +/// +/// ## Usage +/// +/// //TODO: Add usage examples use std::cell::Cell; use std::cell::RefCell; -use std::collections::HashMap; use std::fmt::Debug; use std::fmt::Display; -use std::fmt::Formatter; -use std::io::Read; -use std::io::Write; use std::rc::Rc; use std::rc::Weak; use super::node_hash::AccumulatorHash; -use super::node_hash::BitcoinNodeHash; use super::proof::Proof; -use super::util::detect_offset; +use super::util::detect_row; +use super::util::detwin; use super::util::get_proof_positions; -use super::util::is_left_niece; -use super::util::is_root_populated; -use super::util::left_child; +use super::util::is_root_position; use super::util::max_position_at_row; -use super::util::right_child; +use super::util::parent; use super::util::root_position; use super::util::tree_rows; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum NodeType { - Branch, - Leaf, +#[derive(Default, Clone)] +/// A node in the Pollard tree +struct PollardNode { + /// Whether we should remember this node or not + /// + /// If this is set, we keep this node in memory, as well as all of its ancestors needed to + /// generate a proof for it. If this is not set, we can delete this node and all of its + /// descendants, as we don't need them anymore. For internal nodes, remember is based on + /// whether any of the nieces have remember set. For leaves, the user sets this value. + remember: bool, + /// The hash of this node + /// + /// This is the hash used in the merkle proof. For leaves, this is the hash of the value + /// committed to. For internal nodes, this is the hash of the concatenation of the hashes of + /// the children. This value is stored in a [Cell] to allow for interior mutability, as we may + /// need to change it if some descendant is deleted. + hash: Cell, + /// This node's aunt + /// + /// The aunt is the sibling of the parent. This is the only node that is not owned by this + /// node, as it is owned by some ancestor. This is a [Weak] reference to avoid cycles in the tree. + /// If a node is a root, this value is `None`, as it doesn't have an aunt. If this node's + /// parent is a root, then it actually points to its parent, as the parent is a root, and + /// there's no aunt. + aunt: RefCell>>>, + /// This node's left niece + /// + /// The left niece is the left child of this node's sibling. We use an actual [Rc] here, to + /// make this node own the niece. This is the only place where an [Rc] can be stored past some + /// function's scope, as it may create cycles in the tree. This is a [RefCell] because we may + /// need to either prune the nieces, or swap them if this node is a root. If this node is a + /// leaf, this value is `None`, as it doesn't have any descendants. + left_niece: RefCell>>>, + /// This node's right niece + /// + /// The right niece is the right child of this node's sibling. We use an actual [Rc] here, to + /// make this node own the niece. This is the only place where an [Rc] can be stored past some + /// function's scope, as it may create cycles in the tree. This is a [RefCell] because we may + /// need to either prune the nieces, or swap them if this node is a root. If this node is a + /// leaf, this value is `None`, as it doesn't have any descendants. + right_niece: RefCell>>>, } -// A few type aliases to improve readability. - -/// A weak reference to a node in the forest. We use this when referencing to a node's -/// parent, as we don't want to create a reference cycle. -type Parent = RefCell>>; -/// A reference to a node's children. We use this to store the left and right children -/// of a node. This will be the only long-lived strong reference to a node. If this gets -/// dropped, the node will be dropped as well. -type Children = RefCell>>; +impl Debug for PollardNode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.hash().to_string()) + } +} -/// A forest node that can either be a leaf or a branch. -#[derive(Clone)] -pub struct Node { - /// The type of this node. - ty: NodeType, - /// The hash of the stored in this node. - data: Cell, - /// The parent of this node, if any. - parent: Parent>, - /// The left and right children of this node, if any. - left: Children>, - /// The left and right children of this node, if any. - right: Children>, +impl PartialEq for PollardNode { + fn eq(&self, other: &Self) -> bool { + self.hash() == other.hash() + } } -impl Node { - /// Writes one node to the writer, this method will recursively write all children. - /// The primary use of this method is to serialize the accumulator. In this case, - /// you should call this method on each root in the forest. - pub fn write_one(&self, writer: &mut W) -> std::io::Result<()> { - match self.ty { - NodeType::Branch => writer.write_all(&0_u64.to_le_bytes())?, - NodeType::Leaf => writer.write_all(&1_u64.to_le_bytes())?, - } - self.data.get().write(writer)?; - self.left - .borrow() - .as_ref() - .map(|l| l.write_one(writer)) - .transpose()?; +impl Eq for PollardNode {} + +impl PollardNode { + /// Creates a new PollardNode with the given hash and remember value + fn new(hash: Hash, remember: bool) -> Rc> { + Rc::new(PollardNode { + remember, + hash: Cell::new(hash), + aunt: RefCell::new(None), + left_niece: RefCell::new(None), + right_niece: RefCell::new(None), + }) + } - self.right - .borrow() - .as_ref() - .map(|r| r.write_one(writer)) - .transpose()?; - Ok(()) + /// Returns the hash of this node + fn hash(&self) -> Hash { + self.hash.get() } - /// Recomputes the hash of all nodes, up to the root. - fn recompute_hashes(&self) { - let left = self.left.borrow(); - let right = self.right.borrow(); - - if let (Some(left), Some(right)) = (left.as_deref(), right.as_deref()) { - self.data - .replace(Hash::parent_hash(&left.data.get(), &right.data.get())); - } - if let Some(ref parent) = *self.parent.borrow() { - if let Some(p) = parent.upgrade() { - p.recompute_hashes(); - } + + /// Whether we should remember this node or not + fn should_remember(&self) -> bool { + let left = self.left_niece(); + let right = self.right_niece(); + + match (left, right) { + (Some(left), Some(right)) => left.should_remember() || right.should_remember(), + (Some(left), None) => left.should_remember(), + (None, Some(right)) => right.should_remember(), + (None, None) => self.remember, } } - /// Reads one node from the reader, this method will recursively read all children. - /// The primary use of this method is to deserialize the accumulator. In this case, - /// you should call this method on each root in the forest, assuming you know how - /// many roots there are. - #[allow(clippy::type_complexity)] - pub fn read_one( - reader: &mut R, - ) -> std::io::Result<(Rc>, HashMap>>)> { - fn _read_one( - ancestor: Option>>, - reader: &mut R, - index: &mut HashMap>>, - ) -> std::io::Result>> { - let mut ty = [0u8; 8]; - reader.read_exact(&mut ty)?; - let data = Hash::read(reader)?; - - let ty = match u64::from_le_bytes(ty) { - 0 => NodeType::Branch, - 1 => NodeType::Leaf, - _ => panic!("Invalid node type"), - }; - if ty == NodeType::Leaf { - let leaf = Rc::new(Node { - ty, - data: Cell::new(data), - parent: RefCell::new(ancestor.map(|a| Rc::downgrade(&a))), - left: RefCell::new(None), - right: RefCell::new(None), - }); - index.insert(leaf.data.get(), Rc::downgrade(&leaf)); - return Ok(leaf); - } - let node = Rc::new(Node { - ty: NodeType::Branch, - data: Cell::new(data), - parent: RefCell::new(ancestor.map(|a| Rc::downgrade(&a))), - left: RefCell::new(None), - right: RefCell::new(None), - }); - if !data.is_empty() { - let left = _read_one(Some(node.clone()), reader, index)?; - let right = _read_one(Some(node.clone()), reader, index)?; - node.left.replace(Some(left)); - node.right.replace(Some(right)); - } - node.left - .borrow() - .as_ref() - .map(|l| l.parent.replace(Some(Rc::downgrade(&node)))); - node.right - .borrow() - .as_ref() - .map(|r| r.parent.replace(Some(Rc::downgrade(&node)))); - - Ok(node) + fn children(&self) -> Option> { + if self.aunt().is_none() { + return Some((self.left_niece()?, self.right_niece()?)); } - let mut index = HashMap::new(); - let root = _read_one(None, reader, &mut index)?; - Ok((root, index)) - } - /// Returns the data associated with this node. - pub fn get_data(&self) -> Hash { - self.data.get() - } -} + let sibling = self.sibling().unwrap(); -/// The actual Pollard accumulator, it implements all methods required to update the forest -/// and to prove/verify membership. -#[derive(Default, Clone)] -pub struct Pollard { - /// The roots of the forest, all leaves are children of these roots, and therefore - /// owned by them. - roots: Vec>>, - /// The number of leaves in the forest. Actually, this is the number of leaves we ever - /// added to the forest. - pub leaves: u64, - /// A map of all nodes in the forest, indexed by their hash, this is used to lookup - /// leaves when proving membership. - map: HashMap>>, -} + Some((sibling.left_niece()?, sibling.right_niece()?)) + } -impl Pollard { - /// Creates a new empty [Pollard] with the default hash function. + /// Returns this node's sibling /// - /// This will create an empty Pollard, using [BitcoinNodeHash] as the hash function. If you - /// want to use a different hash function, you can use [Pollard::new_with_hash]. - /// # Example - /// ``` - /// use rustreexo::accumulator::pollard::Pollard; - /// let mut pollard = Pollard::new(); - /// ``` - pub fn new() -> Pollard { - Pollard { - map: HashMap::new(), - roots: Vec::new(), - leaves: 0, + /// This function should return an [Rc] containing the sibling of this node. If this node is a + /// root, it should return `None`, as roots don't have siblings. + fn sibling(&self) -> Option>> { + let aunt = self.aunt()?; + if aunt.left_niece()?.hash() == self.hash() { + aunt.right_niece() + } else { + aunt.left_niece() } } -} -impl Pollard { - /// Creates a new empty [Pollard] with a custom hash function. - /// # Example - /// ``` - /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; - /// use rustreexo::accumulator::pollard::Pollard; - /// let mut pollard = Pollard::::new(); - /// ``` - pub fn new_with_hash() -> Pollard { - Pollard { - map: HashMap::new(), - roots: Vec::new(), - leaves: 0, - } + /// Returns this node's aunt + /// + /// This function should return an [Rc] containing the aunt of this node. If this node is a + /// root, it should return `None`, as roots don't have aunts. + fn aunt(&self) -> Option>> { + self.aunt.borrow().as_ref()?.upgrade() } - /// Writes the Pollard to a writer. Used to send the accumulator over the wire - /// or to disk. - /// # Example - /// ``` - /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; - /// use rustreexo::accumulator::pollard::Pollard; + /// Returns this node's grandparent /// - /// let mut pollard = Pollard::::new(); - /// let mut serialized = Vec::new(); - /// pollard.serialize(&mut serialized).unwrap(); + /// This function should return an [Rc] containing the grandparent of this node (i.e. the + /// parent of this node's parent). If this node is a root, it should return `None`, as roots + /// don't have grandparents. + fn grandparent(&self) -> Option>> { + self.aunt()?.aunt() + } + + /// Recomputes the hashes of this node and all of its ancestors /// - /// assert_eq!( - /// serialized, - /// vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - /// ); - /// ``` - pub fn serialize(&self, mut writer: W) -> std::io::Result<()> { - writer.write_all(&self.leaves.to_le_bytes())?; - writer.write_all(&self.roots.len().to_le_bytes())?; + /// This function will walk up the tree and recompute the hashes for each node. We may need + /// this if we delete a node, and we need to update the hashes of the ancestors. + fn recompute_hashes(&self) -> Option<()> { + if let Some((left, right)) = self.children() { + let new_hash = Hash::parent_hash(&left.hash(), &right.hash()); + self.hash.set(new_hash); + } + + if let Some(aunt) = self.aunt() { + if let Some(parent) = aunt.sibling() { + return parent.recompute_hashes(); + } - for root in &self.roots { - root.write_one(&mut writer).unwrap(); + return aunt.recompute_hashes(); } - Ok(()) + Some(()) } - /// Deserializes a pollard from a reader. - /// # Example + fn recompute_hashes_down(&self) -> Option<()> { + let left = self.left_niece()?; + let right = self.right_niece()?; + let new_hash = Hash::parent_hash(&left.hash(), &right.hash()); + self.hash.set(new_hash); + left.recompute_hashes_down()?; + right.recompute_hashes_down()?; + Some(()) + } + + /// Migrates this node up the tree + /// + /// The deletion algorithm for utreexo works like this: let's say we have the following tree: + /// + /// ```! + /// 06 + /// |---------\ + /// 04 05 + /// |-----\ |-----\ + /// 00 01 02 03 /// ``` - /// use std::io::Cursor; /// - /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; - /// use rustreexo::accumulator::pollard::Pollard; - /// let mut serialized = Cursor::new(vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); - /// let pollard = Pollard::::deserialize(&mut serialized).unwrap(); - /// assert_eq!(pollard.leaves, 0); - /// assert_eq!(pollard.get_roots().len(), 0); + /// to delete `03`, we simply move `02` up to `09`'s position, so now we have: + /// ```! + /// 06 + /// |---------\ + /// 04 02 + /// |-----\ |-----\ + /// 00 01 -- -- /// ``` - pub fn deserialize(mut reader: R) -> std::io::Result> { - fn read_u64(reader: &mut R) -> std::io::Result { - let mut buf = [0u8; 8]; - reader.read_exact(&mut buf)?; - Ok(u64::from_le_bytes(buf)) - } - let leaves = read_u64(&mut reader)?; - let roots_len = read_u64(&mut reader)?; - let mut roots = Vec::new(); - let mut map = HashMap::new(); - for _ in 0..roots_len { - let (root, _map) = Node::read_one(&mut reader)?; - map.extend(_map); - roots.push(root); + /// + /// This function does exactly that. It moves this node up the tree, and updates the hashes + /// of the ancestors to reflect the new subtree (in the example above, the hash of `06` would + /// be updated to the hash of 04 and 02). + fn migrate_up(&self) -> Option<()> { + let aunt = self.aunt().unwrap(); + let grandparent = aunt.aunt()?; + let parent = aunt.sibling()?; + + let _self = if aunt.left_niece()?.hash() == self.hash() { + aunt.left_niece()? + } else { + aunt.right_niece()? + }; + + let (left_niece, right_niece) = if grandparent.left_niece()?.hash() == aunt.hash() { + (grandparent.left_niece()?, _self.clone()) + } else { + (_self.clone(), grandparent.right_niece()?) + }; + + // place myself and my parent's sibling as my grandancestor's nieces + grandparent.set_niece(Some(left_niece), Some(right_niece)); + + // update my own aunt + self.set_aunt(Rc::downgrade(&grandparent)); + + aunt.prune(); + // I'm now my aunt's sibling, so I should have their children. + // Update my nieces's aunt to be me + if let Some(x) = parent.left_niece() { + x.set_aunt(Rc::downgrade(&_self)) + }; + if let Some(x) = parent.right_niece() { + x.set_aunt(Rc::downgrade(&_self)) } - Ok(Pollard { roots, leaves, map }) + + // take my parent's nieces, as they are still needed + self.swap_nieces(&parent); + + _self.recompute_hashes(); + Some(()) } - /// Returns the hash of a given position in the tree. - fn get_hash(&self, pos: u64) -> Result { - let (node, _, _) = self.grab_node(pos)?; - Ok(node.data.get()) + /// Sets the nieces of this nodes to the provided values + fn set_niece(&self, left: Option>>, right: Option>>) { + *self.left_niece.borrow_mut() = left; + *self.right_niece.borrow_mut() = right; } - /// Proves that a given set of hashes is in the accumulator. It returns a proof - /// and the hashes that we what to prove, but sorted by position in the tree. - /// # Example - /// ``` - /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; - /// use rustreexo::accumulator::pollard::Pollard; - /// let mut pollard = Pollard::::new(); - /// let hashes = vec![0, 1, 2, 3, 4, 5, 6, 7] - /// .iter() - /// .map(|n| BitcoinNodeHash::from([*n; 32])) - /// .collect::>(); - /// pollard.modify(&hashes, &[]).unwrap(); - /// // We want to prove that the first two hashes are in the accumulator. - /// let proof = pollard.prove(&[hashes[1], hashes[0]]).unwrap(); - /// //TODO: Verify the proof - /// ``` - pub fn prove(&self, targets: &[Hash]) -> Result, String> { - let mut positions = Vec::new(); - for target in targets { - let node = self.map.get(target).ok_or("Could not find node")?; - let position = self.get_pos(node); - positions.push(position); - } - let needed = get_proof_positions(&positions, self.leaves, tree_rows(self.leaves)); - let proof = needed - .iter() - .map(|pos| self.get_hash(*pos).unwrap()) - .collect::>(); + /// Sets the aunt of this node to the provided value + fn set_aunt(&self, aunt: Weak>) { + *self.aunt.borrow_mut() = Some(aunt); + } - Ok(Proof::new_with_hash(positions, proof)) + fn prune(&self) { + self.left_niece.replace(None); + self.right_niece.replace(None); } - /// Returns a reference to the roots in this Pollard. - pub fn get_roots(&self) -> &[Rc>] { - &self.roots + /// Swaps the nieces of this node with the nieces of the provided node + /// + /// We use this function during addition (or undoing an addition) because roots points to their + /// children, but when we add another node on top of that root, it now should point to the new + /// node's children. This function swaps the nieces of this node with the nieces of the provided + /// node. + fn swap_nieces(&self, other: &PollardNode) { + std::mem::swap( + &mut *self.left_niece.borrow_mut(), + &mut *other.left_niece.borrow_mut(), + ); + std::mem::swap( + &mut *self.right_niece.borrow_mut(), + &mut *other.right_niece.borrow_mut(), + ); } - /// Modify is the main API to a [Pollard]. Because order matters, you can only `modify` - /// a [Pollard], and internally it'll add and delete, in the correct order. + /// Returns the left niece of this node /// - /// This method accepts two vectors as parameter, a vec of [Hash] and a vec of [u64]. The - /// first one is a vec of leaf hashes for the newly created UTXOs. The second one is the position - /// for the UTXOs being spent in this block as inputs. + /// If this node is a leaf, this function should return `None`, as leaves don't have nieces. + fn left_niece(&self) -> Option>> { + self.left_niece.borrow().clone() + } + + /// Returns the right niece of this node /// - /// # Example - /// ``` - /// use bitcoin_hashes::sha256::Hash as Data; - /// use bitcoin_hashes::Hash; - /// use bitcoin_hashes::HashEngine; - /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; - /// use rustreexo::accumulator::pollard::Pollard; - /// let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; - /// let hashes = values - /// .into_iter() - /// .map(|val| { - /// let mut engine = Data::engine(); - /// engine.input(&[val]); - /// BitcoinNodeHash::from(Data::from_engine(engine).as_byte_array()) - /// }) - /// .collect::>(); - /// // Add 8 leaves to the pollard - /// let mut p = Pollard::::new(); - /// p.modify(&hashes, &[]).expect("Pollard should not fail"); + /// If this node is a leaf, this function should return `None`, as leaves don't have nieces. + fn right_niece(&self) -> Option>> { + self.right_niece.borrow().clone() + } +} + +#[derive(Clone, Copy)] +/// A new node to be added to the [Pollard] +/// +/// This node contains the data that should be added to the accumulator. It contains the hash of +/// the node, and whether we should remember this node or not. If remember is set, we keep this +/// node in our forest and we can generate proofs for it. If remember is not set, we can delete +/// this node and all of its descendants, as we don't need them anymore. +pub struct PollardAddition { + /// The hash of the node to be added + pub hash: Hash, + /// Whether we should remember this node or not + pub remember: bool, +} + +#[derive(Clone)] +pub struct Pollard { + /// The roots of our [Pollard]. They are the top nodes of the tree, and they are the only nodes + /// that are owned by the [Pollard] itself. All other nodes are owned by their ancestors. /// - /// assert_eq!( - /// p.get_roots()[0].get_data().to_string(), - /// String::from("b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42") - /// ); - /// ``` - pub fn modify(&mut self, add: &[Hash], del: &[Hash]) -> Result<(), String> { - self.del(del)?; - self.add(add); - Ok(()) + /// The roots are stored in an array, where the index is the row of the tree where the root is + /// located. The first root is at index 0, and so on. The roots are stored in the array in the + /// stack to make it more efficent to access and move them around. At any given time, a row may + /// or may not have a root. If a row doesn't have a root, the value at that index is `None`. + roots: [Option>>; 64], + /// How many leaves have been added to the tree + /// + /// We use this value all the time, since everything about the structure of the tree is + /// reflected in the number of leaves. This value is how many leaves we ever added, so if we + /// add 5 leaves and delete 4, this value will still be 5. Moreover, the position of a leaf is + /// the number of leaves when it was added, so we can always find a leaf by it's position. + leaves: u64, +} + +impl PartialEq for Pollard { + fn eq(&self, other: &Self) -> bool { + self.roots + .iter() + .zip(other.roots.iter()) + .all(|(a, b)| match (a, b) { + (Some(a), Some(b)) => a.hash() == b.hash(), + (None, None) => true, + _ => false, + }) } +} - #[allow(clippy::type_complexity)] - pub fn grab_node( - &self, - pos: u64, - ) -> Result<(Rc>, Rc>, Rc>), String> { - let (tree, branch_len, bits) = detect_offset(pos, self.leaves); - let mut n = Some(self.roots[tree as usize].clone()); - let mut sibling = Some(self.roots[tree as usize].clone()); - let mut parent = sibling.clone(); - - for row in (0..(branch_len)).rev() { - // Parent is the sibling of the current node as each of the - // nodes point to their nieces. - parent.clone_from(&sibling); - - // Figure out which node we need to follow. - let niece_pos = ((bits >> row) & 1) as u8; - - #[allow(clippy::assigning_clones)] - if let Some(node) = n { - if is_left_niece(niece_pos as u64) { - n = node.right.borrow().clone(); - sibling.clone_from(&*node.left.borrow()); - } else { - n = node.left.borrow().clone(); - sibling.clone_from(&*node.right.borrow()); +impl Eq for Pollard {} + +impl Debug for Pollard { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.string()) + } +} + +impl Display for Pollard { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.string()) + } +} + +impl Default for Pollard { + fn default() -> Self { + Self::new() + } +} + +// public methods + +impl Pollard { + /// Return how many leaves are in the [Pollard] + pub fn leaves(&self) -> u64 { + self.leaves + } + + /// Ingests a proof into the [Pollard], caching the nodes in the proof + /// + /// This function takes a proof and a list of hashes for the nodes in that proof. It will + /// take all the nodes in the proof and add them to the [Pollard], so we can generate proofs + /// for them later. This function doesn't check the validity of the proof, so you should do + /// that before calling this function. If the proof is not valid, this function will return an + /// error. + pub fn ingest_proof( + &mut self, + proof: Proof, + del_hashes: &[Hash], + remembers: &[u64], + ) -> Result<(), String> { + self.do_ingest_proof(proof, del_hashes, remembers, false) + } + + pub fn verify_and_ingest( + &mut self, + proof: Proof, + del_hashes: &[Hash], + remembers: &[u64], + ) -> Result<(), String> { + let roots = self.roots(); + proof + .verify(del_hashes, &roots, self.leaves) + .map(|valid| { + if !valid { + return Err("Proof is not valid".to_owned()); } - } else { - sibling = None; - } - } - if let (Some(node), Some(sibling), Some(parent)) = (n, sibling, parent) { - return Ok((node, sibling, parent)); + + Ok(()) + })??; + + self.do_ingest_proof(proof, del_hashes, remembers, false) + } + + pub fn prune(&self, positions: &[u64]) -> Result<(), &'static str> { + let positions = detwin(positions.to_vec(), tree_rows(self.leaves)); + let nodes = positions + .into_iter() + .map(|pos| self.grab_position(pos)) + .collect::>(); + + for node in nodes { + let node = node.ok_or("Position not found")?; + node.0.prune(); } - Err(format!("node {} not found", pos)) + + Ok(()) } - fn del(&mut self, targets: &[Hash]) -> Result<(), String> { - let mut pos = targets + /// Returns the hash of all roots in the [Pollard] + /// + /// The returned array contains all roots, in ascending order. You can see the row that each + /// root occupies by looking at which bits are set in the number of leaves in the [Pollard]. + pub fn roots(&self) -> Vec { + self.roots .iter() - .flat_map(|target| self.map.get(target)) - .flat_map(|target| target.upgrade()) - .map(|target| { - ( - self.get_pos(self.map.get(&target.data.get()).unwrap()), - target.data.get(), - ) + .filter_map(|x| x.as_ref().map(|x| x.hash())) + .collect() + } + + /// Proves the inclusion of the nodes at the given positions + /// + /// This function takes a list of positions and returns a list of proofs for each position. + pub fn batch_proof(&self, targets: &[u64]) -> Result, &'static str> { + let targets = detwin(targets.to_vec(), tree_rows(self.leaves)); + let positions = get_proof_positions(&targets, self.leaves, tree_rows(self.leaves)); + let mut proof_hashes = Vec::new(); + + for pos in positions.iter() { + let hash = self + .grab_position(*pos) + .ok_or("Position not found")? + .0 + .hash(); + + proof_hashes.push(hash); + } + + Ok(Proof:: { + hashes: proof_hashes, + targets: positions, + }) + } + + pub fn prove_single(&self, pos: u64) -> Result, &'static str> { + let hashes = self.prove_single_inner(pos)?; + let targets = vec![pos]; + + Ok(Proof { hashes, targets }) + } + + /// Applies the changes to the [Pollard] for a new block + /// + /// Since the order of the operations is important, the API can't expose adding and deleting + /// directly. Instead, the user should call this function with the additions and deletions they + /// want to make. You should pass in the additions as [PollardAddition]s, telling what should + /// be added to the accumulator, and whether it should be remembered or not. + /// The deletions should be passed as a list of target positions, telling which nodes should be + /// deleted from the accumulator. Positions that are not cached will be ignored. You should check + /// the validity of the proof before calling this function, as it will blindly apply the changes + /// to the [Pollard] without validating anything. + pub fn modify( + &mut self, + adds: &[PollardAddition], + del_hashes: &[Hash], + proof: Proof, + ) -> Result<(), String> { + let targets = proof.targets.clone(); + self.ingest_proof(proof.clone(), del_hashes, &targets) + .unwrap(); + let targets = detwin(targets, tree_rows(self.leaves)); + let targets = targets + .into_iter() + .map(|pos| { + self.grab_position(pos) + .ok_or(format!("Position {pos} not found")) }) .collect::>(); - pos.sort(); - let (_, targets): (Vec, Vec) = pos.into_iter().unzip(); - for target in targets { - match self.map.remove(&target) { - Some(target) => { - self.del_single(&target.upgrade().unwrap()); - } - None => { - return Err(format!("node {} not in the forest", target)); - } - } + for del in targets { + self.delete_single(del?.0)? + } + + let mut add_nodes = Vec::new(); + let mut roots_destroyed = Vec::new(); + for node in adds { + let (_new_nodes, _roots_destroyed) = self.add_single(*node)?; + add_nodes.extend(_new_nodes); + roots_destroyed.extend(_roots_destroyed); } + Ok(()) } - pub fn verify(&self, proof: &Proof, del_hashes: &[Hash]) -> Result { - let roots = self - .roots - .iter() - .map(|root| root.get_data()) - .collect::>(); - proof.verify(del_hashes, &roots, self.leaves) + /// Creates a new empty [Pollard] + pub fn new() -> Pollard { + let roots: [Option>>; 64] = [const { None }; 64]; + Pollard:: { roots, leaves: 0 } } +} + +// private methods + +/// The result from add_single +type AddSingleResult = (Vec<(u64, T)>, Vec); +type ChildrenTuple = (Rc>, Rc>); - fn get_pos(&self, node: &Weak>) -> u64 { - // This indicates whether the node is a left or right child at each level - // When we go down the tree, we can use the indicator to know which - // child to take. - let mut left_child_indicator = 0_u64; - let mut rows_to_top = 0; - let mut node = node.upgrade().unwrap(); - while let Some(parent) = node.parent.clone().into_inner() { - let parent_left = parent - .upgrade() - .and_then(|parent| parent.left.clone().into_inner()) - .unwrap() - .clone(); - - // If the current node is a left child, we left-shift the indicator - // and leave the LSB as 0 - if parent_left.get_data() == node.get_data() { - left_child_indicator <<= 1; +impl Pollard { + fn grab_position(&self, pos: u64) -> Option> { + let (root, depth, bits) = Self::detect_offset(pos, self.leaves); + let mut node = self.roots[root as usize].clone()?; + + if depth == 0 { + return Some((node.clone(), node)); + } + + for row in 0..(depth - 1) { + let next = if pos >> (depth - row - 1) & 1 == 1 { + node.left_niece()? } else { - // If the current node is a right child, we left-shift the indicator - // and set the LSB to 1 - left_child_indicator <<= 1; - left_child_indicator |= 1; - } - rows_to_top += 1; - node = parent.upgrade().unwrap(); + node.right_niece()? + }; + node = next; } - let mut root_idx = self.roots.len() - 1; + + Some(if bits & 1 == 0 { + (node.left_niece()?, node.right_niece()?) + } else { + (node.right_niece()?, node.left_niece()?) + }) + } + + fn ingest_positions( + &mut self, + mut iter: impl Iterator, + ) -> Result<(), String> { let forest_rows = tree_rows(self.leaves); - let mut root_row = 0; - // Find the root of the tree that the node belongs to - for row in 0..forest_rows { - if is_root_populated(row, self.leaves) { - let root = &self.roots[root_idx]; - if root.get_data() == node.get_data() { - root_row = row; - break; - } - root_idx -= 1; + while let Some((pos1, hash1)) = iter.next() { + if is_root_position(pos1, self.leaves, forest_rows) { + let root = detect_row(pos1, forest_rows); + self.roots[root as usize] = Some(PollardNode::new(hash1, true)); + continue; } - } - let mut pos = root_position(self.leaves, root_row, forest_rows); - for _ in 0..rows_to_top { - // If LSB is 0, go left, otherwise go right - match left_child_indicator & 1 { - 0 => { - pos = left_child(pos, forest_rows); - } - 1 => { - pos = right_child(pos, forest_rows); - } - _ => unreachable!(), + + let (pos2, hash2) = iter.next().ok_or("Proof is not valid")?; + if pos1 != (pos2 ^ 1) { + return Err(format!("Proof is not valid, missing pos {}", pos2 ^ 1)); + } + + let aunt = parent(pos1, forest_rows); + let aunt = self + .grab_position(aunt) + .ok_or(format!("can't find aunt for {pos1} {self:?}"))? + .1; + + if aunt.left_niece().is_some() { + continue; } - left_child_indicator >>= 1; + + let new_node = PollardNode::new(hash1, true); + let new_sibling = PollardNode::new(hash2, true); + + new_node.set_aunt(Rc::downgrade(&aunt)); + new_sibling.set_aunt(Rc::downgrade(&aunt)); + + aunt.set_niece(Some(new_sibling), Some(new_node)); } - pos + + Ok(()) } - fn del_single(&mut self, node: &Node) -> Option<()> { - let parent = node.parent.borrow(); - // Deleting a root - let parent = match *parent { - Some(ref node) => node.upgrade()?, - None => { - let pos = self.roots.iter().position(|x| x.data == node.data).unwrap(); - self.roots[pos] = Rc::new(Node { - ty: NodeType::Branch, - parent: RefCell::new(None), - data: Cell::new(Hash::empty()), - left: RefCell::new(None), - right: RefCell::new(None), - }); - return None; - } - }; + fn do_ingest_proof( + &mut self, + proof: Proof, + del_hashes: &[Hash], + remembers: &[u64], + recompute: bool, + ) -> Result<(), String> { + let forest_rows = tree_rows(self.leaves); + let (mut all_nodes, _) = proof.calculate_hashes(del_hashes, self.leaves)?; + let proof_positions = get_proof_positions(&proof.targets, self.leaves, forest_rows); - let me = parent.left.borrow(); - // Can unwrap because we know the sibling exists - let sibling = if me.as_deref()?.data == node.data { - parent.right.borrow().clone() - } else { - parent.left.borrow().clone() - }; - if let Some(ref sibling) = sibling { - let grandparent = parent.parent.borrow().clone(); - sibling.parent.replace(grandparent.clone()); + all_nodes.extend(proof_positions.into_iter().zip(proof.hashes.clone())); + all_nodes.sort(); + let iter = all_nodes.into_iter().rev(); + self.ingest_positions(iter)?; - if let Some(ref grandparent) = grandparent.and_then(|g| g.upgrade()) { - if grandparent.left.borrow().clone().as_ref().unwrap().data == parent.data { - grandparent.left.replace(Some(sibling.clone())); - } else { - grandparent.right.replace(Some(sibling.clone())); - } - sibling.recompute_hashes(); - } else { - let pos = self - .roots - .iter() - .position(|x| x.data == parent.data) - .unwrap(); - self.roots[pos] = sibling.clone(); + let pruned = proof + .targets + .iter() + .filter(|x| !remembers.contains(x)) + .copied() + .collect::>(); + + self.prune(&pruned)?; + + if recompute { + for root in self.roots.iter().filter_map(|x| x.as_ref()) { + root.recompute_hashes_down(); } - }; + } - Some(()) + Ok(()) } - fn add_single(&mut self, value: Hash) { - let mut node: Rc> = Rc::new(Node { - ty: NodeType::Leaf, - parent: RefCell::new(None), - data: Cell::new(value), - left: RefCell::new(None), - right: RefCell::new(None), - }); - self.map.insert(value, Rc::downgrade(&node)); - let mut leaves = self.leaves; - while leaves & 1 != 0 { - let root = self.roots.pop().unwrap(); - if root.get_data() == AccumulatorHash::empty() { - leaves >>= 1; - continue; - } - let new_node = Rc::new(Node { - ty: NodeType::Branch, - parent: RefCell::new(None), - data: Cell::new(AccumulatorHash::parent_hash( - &root.data.get(), - &node.data.get(), - )), - left: RefCell::new(Some(root.clone())), - right: RefCell::new(Some(node.clone())), - }); - root.parent.replace(Some(Rc::downgrade(&new_node))); - node.parent.replace(Some(Rc::downgrade(&new_node))); + fn detect_offset(pos: u64, num_leaves: u64) -> (u8, u8, u64) { + let mut tr = tree_rows(num_leaves); + let nr = detect_row(pos, tr); + + let mut bigger_trees = tr; + let mut marker = pos; - node = new_node; - leaves >>= 1; + while ((marker << nr) & ((2 << tr) - 1)) >= ((1 << tr) & num_leaves) { + let tree_size = (1 << tr) & num_leaves; + marker -= tree_size; + bigger_trees -= 1; + + tr -= 1; } - self.roots.push(node); - self.leaves += 1; + (bigger_trees, (tr - nr), marker) } - fn add(&mut self, values: &[Hash]) { - for value in values { - self.add_single(*value); + fn get_hash(&self, pos: u64) -> Result { + match self.grab_position(pos) { + Some(node) => Ok(node.0.hash()), + None => Err("Position not found"), } } - /// to_string returns the full pollard in a string for all forests less than 6 rows. + /// to_string returns the full mem_forest in a string for all forests less than 6 rows. fn string(&self) -> String { if self.leaves == 0 { return "empty".to_owned(); @@ -607,11 +676,17 @@ impl Pollard { // The accumulator should be less than 6 rows. if fh > 6 { let s = format!("Can't print {} leaves. roots: \n", self.leaves); - return self.get_roots().iter().fold(s, |mut a, b| { - a.extend(format!("{}\n", b.get_data()).chars()); + return self.roots.iter().fold(s, |mut a, b| { + a.push_str( + &b.as_ref() + .map(|b| b.hash()) + .unwrap_or(Hash::empty()) + .to_string(), + ); a }); } + let mut output = vec!["".to_string(); (fh as usize * 2) + 1]; let mut pos: u8 = 0; for h in 0..=fh { @@ -663,219 +738,233 @@ impl Pollard { a }) } -} -impl Debug for Pollard { - fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { - write!(f, "{}", self.string()) - } -} + fn prove_single_inner(&self, pos: u64) -> Result, &'static str> { + let (node, sibling) = self.grab_position(pos).ok_or("Position not found")?; + let mut proof = vec![sibling.hash()]; + let mut current = node; + + while let Some(aunt) = current.aunt() { + // don't push roots + if aunt.aunt().is_some() { + proof.push(aunt.hash()); + } + current = aunt; + } -impl Display for Pollard { - fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { - write!(f, "{}", self.string()) + Ok(proof) } -} -#[cfg(test)] -mod test { - use std::convert::TryFrom; - use std::rc::Rc; - use std::str::FromStr; - use std::vec; + fn add_single(&mut self, node: PollardAddition) -> Result, String> { + let mut row = 0; + let mut new_node = PollardNode::new(node.hash, node.remember); - use bitcoin_hashes::sha256::Hash as Data; - use bitcoin_hashes::Hash; - use bitcoin_hashes::HashEngine; - use serde::Deserialize; + let mut add_positions = Vec::new(); + let mut roots_to_destroy = Vec::new(); - use super::Pollard; - use crate::accumulator::node_hash::AccumulatorHash; - use crate::accumulator::node_hash::BitcoinNodeHash; - use crate::accumulator::pollard::Node; - use crate::accumulator::proof::Proof; + while self.leaves >> row & 1 == 1 { + let old_root = std::mem::take(&mut self.roots[row as usize]).expect("Root not found"); + let pos = root_position(self.leaves(), row, tree_rows(self.leaves())); - fn hash_from_u8(value: u8) -> BitcoinNodeHash { - let mut engine = Data::engine(); + add_positions.push((pos, old_root.hash())); - engine.input(&[value]); + if old_root.hash().is_empty() { + let pos = row as usize; + self.roots[pos] = None; + roots_to_destroy.push(pos); + row += 1; + continue; + } - BitcoinNodeHash::from(Data::from_engine(engine).as_byte_array()) - } + let new_root_hash = Hash::parent_hash(&old_root.hash.get(), &new_node.hash.get()); + let new_root_rc = Rc::new(PollardNode { + remember: old_root.should_remember() || new_node.should_remember(), + hash: Cell::new(new_root_hash), + aunt: RefCell::new(None), + left_niece: RefCell::new(None), + right_niece: RefCell::new(None), + }); - #[test] - fn test_grab_node() { - let values = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]; - let hashes = values.into_iter().map(hash_from_u8).collect::>(); - - let mut p = Pollard::new(); - p.modify(&hashes, &[]).expect("Pollard should not fail"); - let (found_target, found_sibling, _) = p.grab_node(4).unwrap(); - let target = BitcoinNodeHash::try_from( - "e52d9c508c502347344d8c07ad91cbd6068afc75ff6292f062a09ca381c89e71", - ) - .unwrap(); - let sibling = BitcoinNodeHash::try_from( - "e77b9a9ae9e30b0dbdb6f510a264ef9de781501d7b6b92ae89eb059c5ab743db", - ) - .unwrap(); - - assert_eq!(target, found_target.data.get()); - assert_eq!(sibling, found_sibling.data.get()); - } + // swap nieces + new_node.swap_nieces(&old_root); - #[test] - fn test_delete() { - let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; - let hashes = values.into_iter().map(hash_from_u8).collect::>(); + //FIXME: This should be a method in PollardNode + if let Some(x) = new_node.left_niece() { + x.set_aunt(Rc::downgrade(&new_node)) + } + if let Some(x) = new_node.right_niece() { + x.set_aunt(Rc::downgrade(&new_node)) + } - let mut p = Pollard::new(); - p.modify(&hashes, &[]).expect("Pollard should not fail"); - p.modify(&[], &[hashes[0]]).expect("msg"); + if let Some(x) = old_root.left_niece() { + x.set_aunt(Rc::downgrade(&old_root)) + } + if let Some(x) = old_root.right_niece() { + x.set_aunt(Rc::downgrade(&old_root)) + } - let (node, _, _) = p.grab_node(8).unwrap(); - assert_eq!( - String::from("4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a"), - node.data.get().to_string() - ); + // update aunts for the old nodes + let new_root_weak = Rc::downgrade(&new_root_rc); + old_root.set_aunt(new_root_weak.clone()); + new_node.set_aunt(new_root_weak); + + // update nieces for the new root + let (left_niece, right_niece) = + if old_root.should_remember() || new_node.should_remember() { + (Some(old_root), Some(new_node)) + } else { + (None, None) + }; + + new_root_rc.set_niece(left_niece, right_niece); + + // keep doing this until we find a row with an empty spot + new_node = new_root_rc; + row += 1; + } + + self.roots[row as usize] = Some(new_node); + self.leaves += 1; + + Ok((add_positions, roots_to_destroy)) } - #[test] - fn test_proof_verify() { - let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; - let hashes = values.into_iter().map(hash_from_u8).collect::>(); - let mut p = Pollard::new(); - p.modify(&hashes, &[]).unwrap(); - - let proof = p.prove(&[hashes[0], hashes[1]]).unwrap(); - assert!(p.verify(&proof, &[hashes[0], hashes[1]]).unwrap()); + fn delete_single(&mut self, node: Rc>) -> Result<(), String> { + // we are deleting a root, just write an empty hash where it was + if node.aunt.borrow().is_none() { + for i in 0..64 { + if self.roots[i].eq(&Some(node.clone())) { + self.roots[i] = Some(Rc::new(PollardNode::default())); + return Ok(()); + } + } + + return Err("Root not found".to_string()); + } + + let sibling = node + .sibling() + .ok_or(format!("Sibling for {} not found", node.hash()))?; + if node.grandparent().is_none() { + // my parent is a root, I'm a root now + for i in 0..64 { + let aunt = node.aunt().unwrap(); + let Some(root) = self.roots[i].as_ref() else { + continue; + }; + if root.hash() == aunt.hash() { + self.roots[i] = Some(sibling); + return Ok(()); + } + } + + return Err("Root not found".to_string()); + }; + sibling.migrate_up().unwrap(); + Ok(()) } +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use serde::Deserialize; + + use super::*; + use crate::accumulator::node_hash::BitcoinNodeHash; + use crate::accumulator::util::hash_from_u8; #[test] fn test_add() { let values = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]; - let hashes = values.into_iter().map(hash_from_u8).collect::>(); + let hashes = values + .into_iter() + .map(|preimage| { + let hash = hash_from_u8(preimage); + PollardAddition { + hash, + remember: true, + } + }) + .collect::>(); - let mut acc = Pollard::new(); - acc.add(&hashes); + let mut acc = Pollard::::new(); + acc.modify(&hashes, &[], Proof::default()).unwrap(); assert_eq!( "b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42", - acc.roots[0].data.get().to_string().as_str(), + acc.roots[3].as_ref().unwrap().hash().to_string(), ); assert_eq!( "9c053db406c1a077112189469a3aca0573d3481bef09fa3d2eda3304d7d44be8", - acc.roots[1].data.get().to_string().as_str(), + acc.roots[2].as_ref().unwrap().hash().to_string(), ); assert_eq!( "55d0a0ef8f5c25a9da266b36c0c5f4b31008ece82df2512c8966bddcc27a66a0", - acc.roots[2].data.get().to_string().as_str(), + acc.roots[1].as_ref().unwrap().hash().to_string() ); assert_eq!( "4d7b3ef7300acf70c892d8327db8272f54434adbc61a4e130a563cb59a0d0f47", - acc.roots[3].data.get().to_string().as_str(), + acc.roots[0].as_ref().unwrap().hash().to_string() ); } - #[test] - fn test_delete_roots_child() { - // Assuming the following tree: - // - // 02 - // |---\ - // 00 01 - // If I delete `01`, then `00` will become a root, moving it's hash to `02` - let values = vec![0, 1]; - let hashes: Vec = values.into_iter().map(hash_from_u8).collect(); - - let mut p = Pollard::new(); - p.modify(&hashes, &[]).expect("Pollard should not fail"); - p.del_single(&p.grab_node(1).unwrap().0); - assert_eq!(p.get_roots().len(), 1); - - let root = p.get_roots()[0].clone(); - assert_eq!(root.data.get(), hashes[0]); - } - - #[test] - fn test_delete_root() { - // Assuming the following tree: - // - // 02 - // |---\ - // 00 01 - // If I delete `02`, then `02` will become an empty root, it'll point to nothing - // and its data will be Data::default() - let values = vec![0, 1]; - let hashes: Vec = values.into_iter().map(hash_from_u8).collect(); - - let mut p = Pollard::new(); - p.modify(&hashes, &[]).expect("Pollard should not fail"); - p.del_single(&p.grab_node(2).unwrap().0); - assert_eq!(p.get_roots().len(), 1); - let root = p.get_roots()[0].clone(); - assert_eq!(root.data.get(), BitcoinNodeHash::default()); - } - - #[test] - fn test_delete_non_root() { - // Assuming this tree, if we delete `01`, 00 will move up to 08's position - // 14 - // |-----------------\ - // 12 13 - // |-------\ |--------\ - // 08 09 10 11 - // |----\ |----\ |----\ |----\ - // 00 01 02 03 04 05 06 07 - - // 14 - // |-----------------\ - // 12 13 - // |-------\ |--------\ - // 08 09 10 11 - // |----\ |----\ |----\ |----\ - // 00 01 02 03 04 05 06 07 - - // Where 08's data is just 00's - - let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; - let hashes: Vec = values.into_iter().map(hash_from_u8).collect(); - - let mut p = Pollard::new(); - p.modify(&hashes, &[]).expect("Pollard should not fail"); - p.modify(&[], &[hashes[1]]).expect("Still should not fail"); - - assert_eq!(p.roots.len(), 1); - let (node, _, _) = p.grab_node(8).expect("This tree should have pos 8"); - assert_eq!(node.data.get(), hashes[0]); - } - #[derive(Debug, Deserialize)] struct TestCase { leaf_preimages: Vec, target_values: Option>, expected_roots: Vec, + proofhashes: Option>, + } + + #[test] + fn run_tests_from_cases() { + #[derive(Deserialize)] + struct TestsJSON { + insertion_tests: Vec, + deletion_tests: Vec, + } + + let contents = std::fs::read_to_string("test_values/test_cases.json") + .expect("Something went wrong reading the file"); + + let tests = serde_json::from_str::(contents.as_str()) + .expect("JSON deserialization error"); + + for i in tests.insertion_tests { + run_single_addition_case(i); + } + + for i in tests.deletion_tests { + run_case_with_deletion(i); + } } fn run_single_addition_case(case: TestCase) { let hashes = case .leaf_preimages .iter() - .map(|preimage| hash_from_u8(*preimage)) + .map(|preimage| { + let hash = hash_from_u8(*preimage); + PollardAddition { + hash, + remember: true, + } + }) .collect::>(); - let mut p = Pollard::new(); - p.modify(&hashes, &[]).expect("Test pollards are valid"); - assert_eq!(p.get_roots().len(), case.expected_roots.len()); + + let mut p = Pollard::::new(); + p.modify(&hashes, &[], Proof::default()).unwrap(); + let expected_roots = case .expected_roots .iter() .map(|root| BitcoinNodeHash::from_str(root).unwrap()) .collect::>(); - let roots = p - .get_roots() - .iter() - .map(|root| root.data.get()) - .collect::>(); + let roots = p.roots().iter().copied().rev().collect::>(); + + assert_eq!(roots.len(), case.expected_roots.len()); assert_eq!(expected_roots, roots, "Test case failed {:?}", case); } @@ -883,204 +972,144 @@ mod test { let hashes = case .leaf_preimages .iter() - .map(|preimage| hash_from_u8(*preimage)) + .map(|preimage| { + let hash = hash_from_u8(*preimage); + PollardAddition { + hash, + remember: false, + } + }) .collect::>(); - let dels = case + + let target_hashes = case .target_values - .clone() + .as_ref() .unwrap() .iter() - .map(|pos| hashes[*pos as usize]) + .map(|target| hash_from_u8(*target as u8)) .collect::>(); - let mut p = Pollard::new(); - p.modify(&hashes, &[]).expect("Test pollards are valid"); - p.modify(&[], &dels).expect("still should be valid"); - assert_eq!(p.get_roots().len(), case.expected_roots.len()); + let proof_hashes = case + .proofhashes + .clone() + .unwrap_or_default() + .into_iter() + .map(|hash| { + BitcoinNodeHash::from_str(hash.as_str()).expect("Test case hashes are valid") + }) + .collect::>(); + + let proof = Proof::new(case.target_values.clone().unwrap(), proof_hashes); + + let mut p = Pollard::::new(); + p.modify(&hashes, &[], Proof::default()).unwrap(); + p.modify(&[], &target_hashes, proof).unwrap(); + let expected_roots = case .expected_roots .iter() .map(|root| BitcoinNodeHash::from_str(root).unwrap()) .collect::>(); - let roots = p - .get_roots() - .iter() - .map(|root| root.data.get()) - .collect::>(); + + let roots = p.roots().iter().copied().rev().collect::>(); + assert_eq!(roots.len(), case.expected_roots.len()); assert_eq!(expected_roots, roots, "Test case failed {:?}", case); } #[test] - fn run_tests_from_cases() { - #[derive(Deserialize)] - struct TestsJSON { - insertion_tests: Vec, - deletion_tests: Vec, - } - - let contents = std::fs::read_to_string("test_values/test_cases.json") - .expect("Something went wrong reading the file"); - - let tests = serde_json::from_str::(contents.as_str()) - .expect("JSON deserialization error"); + fn test_delete_roots_child() { + // Assuming the following tree: + // + // 02 + // |---\ + // 00 01 + // If I delete `01`, then `00` will become a root, moving it's hash to `02` + let values = vec![0, 1]; + let hashes: Vec<_> = values + .into_iter() + .map(|preimage| { + let hash = hash_from_u8(preimage); + PollardAddition { + hash, + remember: true, + } + }) + .collect(); - for i in tests.insertion_tests { - run_single_addition_case(i); - } - for i in tests.deletion_tests { - run_case_with_deletion(i); - } - } + let mut p = Pollard::::new(); + p.modify(&hashes, &[], Proof::default()).unwrap(); + p.delete_single(p.grab_position(1).unwrap().0) + .expect("Failed to delete"); - #[test] - fn test_to_string() { - let hashes = get_hash_vec_of(&(0..255).collect::>()); - let mut p = Pollard::new(); - p.modify(&hashes, &[]).expect("Test pollards are valid"); - assert_eq!( - Some("Can't print 255 leaves. roots:"), - p.to_string().get(0..30) - ); + let root = p.roots[1].clone(); + assert_eq!(root.unwrap().hash(), hashes[0].hash); } #[test] - fn test_get_pos() { - macro_rules! test_get_pos { - ($p:ident, $pos:literal) => { - assert_eq!( - $p.get_pos(&Rc::downgrade(&$p.grab_node($pos).unwrap().0)), - $pos - ); - }; - } - let hashes = get_hash_vec_of(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); - let mut p = Pollard::new(); - p.modify(&hashes, &[]).expect("Test pollards are valid"); - test_get_pos!(p, 0); - test_get_pos!(p, 1); - test_get_pos!(p, 2); - test_get_pos!(p, 3); - test_get_pos!(p, 4); - test_get_pos!(p, 5); - test_get_pos!(p, 6); - test_get_pos!(p, 7); - test_get_pos!(p, 8); - test_get_pos!(p, 9); - test_get_pos!(p, 10); - test_get_pos!(p, 11); - test_get_pos!(p, 12); - - assert_eq!(p.get_pos(&Rc::downgrade(&p.get_roots()[0])), 28); - assert_eq!( - p.get_pos(&Rc::downgrade( - p.get_roots()[0].left.borrow().as_ref().unwrap() - )), - 24 - ); - assert_eq!( - p.get_pos(&Rc::downgrade( - p.get_roots()[0].right.borrow().as_ref().unwrap() - )), - 25 - ); - } + fn test_prove_single() { + let values = vec![0, 1, 2, 3, 4, 5]; + let hashes: Vec<_> = values + .into_iter() + .map(|preimage| { + let hash = hash_from_u8(preimage); + let remember = true; + PollardAddition { hash, remember } + }) + .collect(); - #[test] - fn test_serialize_one() { - let hashes = get_hash_vec_of(&[0, 1, 2, 3, 4, 5, 6, 7]); - let mut p = Pollard::new(); - p.modify(&hashes, &[]).expect("Test pollards are valid"); - p.modify(&[], &[hashes[0]]).expect("can remove 0"); - let mut writer = std::io::Cursor::new(Vec::new()); - p.get_roots()[0].write_one(&mut writer).unwrap(); - let (deserialized, _) = - Node::::read_one(&mut std::io::Cursor::new(writer.into_inner())) - .unwrap(); - assert_eq!(deserialized.get_data(), p.get_roots()[0].get_data()); - } + let mut acc = Pollard::::new(); + acc.modify(&hashes, &[], Proof::default()).unwrap(); + + let proof = acc.prove_single(3).unwrap(); + let expected_hashes = [ + "dbc1b4c900ffe48d575b5da5c638040125f65db0fe3e24494b76ea986457d986", + "02242b37d8e851f1e86f46790298c7097df06893d6226b7c1453c213e91717de", + ] + .iter() + .map(|x| BitcoinNodeHash::from_str(x).unwrap()) + .collect::>(); + + let expected_proof = Proof { + hashes: expected_hashes, + targets: vec![3], + }; - #[test] - fn test_serialization() { - let hashes = get_hash_vec_of(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]); - let mut p = Pollard::new(); - p.modify(&hashes, &[]).expect("Test pollards are valid"); - p.modify(&[], &[hashes[0]]).expect("can remove 0"); - let mut writer = std::io::Cursor::new(Vec::new()); - p.serialize(&mut writer).unwrap(); - let deserialized = - Pollard::::deserialize(&mut std::io::Cursor::new(writer.into_inner())) - .unwrap(); - assert_eq!( - deserialized.get_roots()[0].get_data(), - p.get_roots()[0].get_data() - ); - assert_eq!(deserialized.leaves, p.leaves); - assert_eq!(deserialized.map.len(), p.map.len()); + assert_eq!(proof, expected_proof); } #[test] - fn test_proof() { - let hashes = get_hash_vec_of(&[0, 1, 2, 3, 4, 5, 6, 7]); - let del_hashes = [hashes[2], hashes[1], hashes[4], hashes[6]]; - - let mut p = Pollard::new(); - p.modify(&hashes, &[]).expect("Test pollards are valid"); - - let proof = p.prove(&del_hashes).expect("Should be able to prove"); - - let expected_proof = Proof::new( - [2, 1, 4, 6].to_vec(), - vec![ - "6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d" - .parse() - .unwrap(), - "084fed08b978af4d7d196a7446a86b58009e636b611db16211b65a9aadff29c5" - .parse() - .unwrap(), - "e77b9a9ae9e30b0dbdb6f510a264ef9de781501d7b6b92ae89eb059c5ab743db" - .parse() - .unwrap(), - "ca358758f6d27e6cf45272937977a748fd88391db679ceda7dc7bf1f005ee879" - .parse() - .unwrap(), - ], - ); - assert_eq!(proof, expected_proof); - assert!(p.verify(&proof, &del_hashes).unwrap()); - } + fn test_ingest_proof() { + let values = [0, 1, 2, 3, 4, 5, 6, 7] + .iter() + .map(|pos| { + let hash = hash_from_u8(*pos); + PollardAddition { + hash, + remember: false, + } + }) + .collect::>(); - fn get_hash_vec_of(elements: &[u8]) -> Vec { - elements.iter().map(|el| hash_from_u8(*el)).collect() - } + let proof = Proof { + targets: [3].to_vec(), + hashes: [ + "dbc1b4c900ffe48d575b5da5c638040125f65db0fe3e24494b76ea986457d986", + "02242b37d8e851f1e86f46790298c7097df06893d6226b7c1453c213e91717de", + "29590a14c1b09384b94a2c0e94bf821ca75b62eacebc47893397ca88e3bbcbd7", + ] + .iter() + .map(|x| BitcoinNodeHash::from_str(x).unwrap()) + .collect::>(), + }; - #[test] - fn test_display_empty() { - let p = Pollard::new(); - let _ = p.to_string(); - } + let mut acc = Pollard::::new(); + acc.modify(&values, &[], Proof::default()).unwrap(); + acc.ingest_proof(proof.clone(), &[hash_from_u8(3)], &[3]) + .unwrap(); + let new_proof = acc.prove_single(3).unwrap(); + assert_eq!(new_proof, proof); - #[test] - fn test_serialization_roundtrip() { - let mut p = Pollard::::new(); - let values = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; - let hashes: Vec = values - .into_iter() - .map(|i| BitcoinNodeHash::from([i; 32])) - .collect(); - p.modify(&hashes, &[]).expect("modify should work"); - assert_eq!(p.get_roots().len(), 1); - assert!(!p.get_roots()[0].get_data().is_empty()); - assert_eq!(p.leaves, 16); - p.modify(&[], &hashes).expect("modify should work"); - assert_eq!(p.get_roots().len(), 1); - assert!(p.get_roots()[0].get_data().is_empty()); - assert_eq!(p.leaves, 16); - let mut serialized = Vec::::new(); - p.serialize(&mut serialized).expect("serialize should work"); - let deserialized = - Pollard::::deserialize(&*serialized).expect("deserialize should work"); - assert_eq!(deserialized.get_roots().len(), 1); - assert!(deserialized.get_roots()[0].get_data().is_empty()); - assert_eq!(deserialized.leaves, 16); + let node = acc.grab_position(3).unwrap().0; + assert_eq!(node.hash(), hash_from_u8(3)); } } diff --git a/src/accumulator/stump.rs b/src/accumulator/stump.rs index 211a2db..631739c 100644 --- a/src/accumulator/stump.rs +++ b/src/accumulator/stump.rs @@ -371,7 +371,7 @@ mod test { #[test] fn test_custom_hash_type() { - #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)] + #[derive(Debug, Default, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)] struct CustomHash([u8; 32]); impl Display for CustomHash { @@ -581,7 +581,7 @@ mod test { .target_values .as_ref() .unwrap() - .into_iter() + .iter() .map(|target| hash_from_u8(*target as u8)) .collect::>(); diff --git a/src/accumulator/util.rs b/src/accumulator/util.rs index 5a62809..18e7660 100644 --- a/src/accumulator/util.rs +++ b/src/accumulator/util.rs @@ -237,7 +237,8 @@ pub fn parent_many(pos: u64, rise: u8, forest_rows: u8) -> Result { rise, forest_rows )); } - let mask: u64 = (2_u64 << forest_rows) - 1_u64; + + let mask = (2_u64 << forest_rows) - 1; Ok((pos >> rise | (mask << (forest_rows - (rise - 1)) as u64)) & mask) }