From c37ab75f2257c6884e056ed79b62c369c475646d Mon Sep 17 00:00:00 2001 From: Davidson Souza Date: Mon, 7 Oct 2024 18:01:21 -0300 Subject: [PATCH] feat: make the hash type generic Currently, we only support using our internal hash type, that is the 32-bytes representation of a sha512-256 hash digest. However, we may want to give callers the possibility to use other hash types, or store them in other ways. This commit modifies `Pollard`, `Proof` and `Stump` to take in a generic parameter `Hash` that implements the trait `NodeHash` and defaults to `BitcoinNodeHash`, the one used by Bitcoin consensus as defined by the utreexo spec. This is part of a project to support a Cairo prover for Bitcoin. To reduce the circuit size, we need to use an algebraic hash function like Poseidon, which tends to reduce the circuit significantly. With this commit the caller can use our data structures with Poseidon without needing to change anything in rustreexo. --- Cargo.toml | 4 + examples/custom-hash-type.rs | 140 ++++++++++++++ examples/full-accumulator.rs | 21 +- examples/proof-update.rs | 12 +- examples/simple-stump-update.rs | 21 +- src/accumulator/node_hash.rs | 198 +++++++++++-------- src/accumulator/pollard.rs | 245 +++++++++++++++--------- src/accumulator/proof.rs | 326 +++++++++++++++++++++----------- src/accumulator/stump.rs | 315 ++++++++++++++++++++---------- src/accumulator/util.rs | 30 ++- 10 files changed, 912 insertions(+), 400 deletions(-) create mode 100644 examples/custom-hash-type.rs diff --git a/Cargo.toml b/Cargo.toml index b7f2637..95c8044 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ serde = { version = "1.0", features = ["derive"], optional = true } [dev-dependencies] serde = { version = "1.0", features = ["derive"] } serde_json = "1.0.81" +starknet-crypto = "0.7.2" [features] with-serde = ["serde"] @@ -27,3 +28,6 @@ name = "simple-stump-update" [[example]] name = "proof-update" + +[[example]] +name = "custom-hash-type" diff --git a/examples/custom-hash-type.rs b/examples/custom-hash-type.rs new file mode 100644 index 0000000..91614dc --- /dev/null +++ b/examples/custom-hash-type.rs @@ -0,0 +1,140 @@ +//! All data structures in this library are generic over the hash type used, defaulting to +//! [BitcoinNodeHash](crate::accumulator::node_hash::BitcoinNodeHash), the one used by Bitcoin +//! as defined by the utreexo spec. However, if you need to use a different hash type, you can +//! implement the [NodeHash](crate::accumulator::node_hash::NodeHash) trait for it, and use it +//! with the accumulator data structures. +//! +//! This example shows how to use a custom hash type based on the Poseidon hash function. The +//! [Poseidon Hash](https://eprint.iacr.org/2019/458.pdf) is a hash function that is optmized +//! for zero-knowledge proofs, and is used in projects like ZCash and StarkNet. +//! If you want to work with utreexo proofs in zero-knowledge you may want to use this instead +//! of our usual sha512-256 that we use by default, since that will give you smaller circuits. +//! This example shows how to use both the [Pollard](crate::accumulator::pollard::Pollard) and +//! proofs with a custom hash type. The code here should be pretty much all you need to do to +//! use your custom hashes, just tweak the implementation of +//! [NodeHash](crate::accumulator::node_hash::NodeHash) for your hash type. + +use rustreexo::accumulator::{node_hash::NodeHash, pollard::Pollard}; +use starknet_crypto::{poseidon_hash_many, Felt}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +/// We need a stateful wrapper around the actual hash, this is because we use those different +/// values inside our accumulator. Here we use an enum to represent the different states, you +/// may want to use a struct with more data, depending on your needs. +enum PoseidonHash { + /// This means this holds an actual value + /// + /// It usually represents a node in the accumulator that haven't been deleted. + Hash(Felt), + /// Placeholder is a value that haven't been deleted, but we don't have the actual value. + /// The only thing that matters about it is that it's not empty. You can implement this + /// the way you want, just make sure that [NodeHash::is_placeholder] and [NodeHash::placeholder] + /// returns sane values (that is, if we call [NodeHash::placeholder] calling [NodeHash::is_placeholder] + /// on the result should return true). + Placeholder, + /// This is an empty value, it represents a node that was deleted from the accumulator. + /// + /// Same as the placeholder, you can implement this the way you want, just make sure that + /// [NodeHash::is_empty] and [NodeHash::empty] returns sane values. + Empty, +} + +// you'll need to implement Display for your hash type, so you can print it. +impl std::fmt::Display for PoseidonHash { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PoseidonHash::Hash(h) => write!(f, "Hash({})", h), + PoseidonHash::Placeholder => write!(f, "Placeholder"), + PoseidonHash::Empty => write!(f, "Empty"), + } + } +} + +// this is the implementation of the NodeHash trait for our custom hash type. And it's the only +// thing you need to do to use your custom hash type with the accumulator data structures. +impl NodeHash for PoseidonHash { + // returns a new placeholder type such that is_placeholder returns true + fn placeholder() -> Self { + PoseidonHash::Placeholder + } + + // returns an empty hash such that is_empty returns true + fn empty() -> Self { + PoseidonHash::Empty + } + + // returns true if this is a placeholder. This should be true iff this type was created by + // calling placeholder. + fn is_placeholder(&self) -> bool { + matches!(self, PoseidonHash::Placeholder) + } + + // returns true if this is an empty hash. This should be true iff this type was created by + // calling empty. + fn is_empty(&self) -> bool { + matches!(self, PoseidonHash::Empty) + } + + // used for serialization, writes the hash to the writer + // + // if you don't want to use serialization, you can just return an error here. + fn write(&self, writer: &mut W) -> std::io::Result<()> + where + W: std::io::Write { + match self { + PoseidonHash::Hash(h) => writer.write_all(&h.to_bytes_be()), + PoseidonHash::Placeholder => writer.write_all(&[0u8; 32]), + PoseidonHash::Empty => writer.write_all(&[0u8; 32]), + } + } + + // used for deserialization, reads the hash from the reader + // + // if you don't want to use serialization, you can just return an error here. + fn read(reader: &mut R) -> std::io::Result + where + R: std::io::Read { + let mut buf = [0u8; 32]; + reader.read_exact(&mut buf)?; + if buf.iter().all(|&b| b == 0) { + Ok(PoseidonHash::Empty) + } else { + Ok(PoseidonHash::Hash(Felt::from_bytes_be(&buf))) + } + } + + // the main thing about the hash type, it returns the next node's hash, given it's children. + // The implementation of this method is highly consensus critical, so everywhere should use the + // exact same algorithm to calculate the next hash. Rustreexo won't call this method, unless + // **both** children are not empty. + fn parent_hash(left: &Self, right: &Self) -> Self { + if let (PoseidonHash::Hash(left), PoseidonHash::Hash(right)) = (left, right) { + return PoseidonHash::Hash(poseidon_hash_many(&[*left, *right])); + } + + // This should never happen, since rustreexo won't call this method unless both children + // are not empty. + unreachable!() + } +} + +fn main() { + // Create a vector with two utxos that will be added to the Pollard + let elements = vec![ + PoseidonHash::Hash(Felt::from(1)), + PoseidonHash::Hash(Felt::from(2)), + ]; + + // Create a new Pollard, and add the utxos to it + let mut p = Pollard::::new_with_hash(); + p.modify(&elements, &[]).unwrap(); + + // Create a proof that the first utxo is in the Pollard + let proof = p.prove(&[elements[0]]).unwrap(); + + // check that the proof has exactly one target + assert_eq!(proof.n_targets(), 1); + // check that the proof is what we expect + assert!(p.verify(&proof, &[elements[0]]).unwrap()); + +} diff --git a/examples/full-accumulator.rs b/examples/full-accumulator.rs index 7e67c1f..2d895d2 100644 --- a/examples/full-accumulator.rs +++ b/examples/full-accumulator.rs @@ -4,17 +4,21 @@ use std::str::FromStr; -use rustreexo::accumulator::node_hash::NodeHash; +use rustreexo::accumulator::node_hash::BitcoinNodeHash; use rustreexo::accumulator::pollard::Pollard; use rustreexo::accumulator::proof::Proof; use rustreexo::accumulator::stump::Stump; fn main() { let elements = vec![ - NodeHash::from_str("b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42") - .unwrap(), - NodeHash::from_str("d3bd63d53c5a70050a28612a2f4b2019f40951a653ae70736d93745efb1124fa") - .unwrap(), + BitcoinNodeHash::from_str( + "b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42", + ) + .unwrap(), + BitcoinNodeHash::from_str( + "d3bd63d53c5a70050a28612a2f4b2019f40951a653ae70736d93745efb1124fa", + ) + .unwrap(), ]; // Create a new Pollard, and add the utxos to it let mut p = Pollard::new(); @@ -31,9 +35,10 @@ fn main() { // Now we want to update the Pollard, by removing the first utxo, and adding a new one. // This would be in case we received a new block with a transaction spending the first utxo, // and creating a new one. - let new_utxo = - NodeHash::from_str("cac74661f4944e6e1fed35df40da951c6e151e7b0c8d65c3ee37d6dfd3bc3ef7") - .unwrap(); + let new_utxo = BitcoinNodeHash::from_str( + "cac74661f4944e6e1fed35df40da951c6e151e7b0c8d65c3ee37d6dfd3bc3ef7", + ) + .unwrap(); p.modify(&[new_utxo], &[elements[0]]).unwrap(); // Now we can prove that the new utxo is in the Pollard. diff --git a/examples/proof-update.rs b/examples/proof-update.rs index da4fa6d..cb02cfd 100644 --- a/examples/proof-update.rs +++ b/examples/proof-update.rs @@ -12,7 +12,7 @@ use std::str::FromStr; -use rustreexo::accumulator::node_hash::NodeHash; +use rustreexo::accumulator::node_hash::BitcoinNodeHash; use rustreexo::accumulator::proof::Proof; use rustreexo::accumulator::stump::Stump; @@ -36,7 +36,7 @@ fn main() { .update(vec![], utxos.clone(), vec![], vec![0, 1], update_data) .unwrap(); // This should be a valid proof over 0 and 1. - assert_eq!(p.targets(), 2); + assert_eq!(p.n_targets(), 2); assert_eq!(s.verify(&p, &cached_hashes), Ok(true)); // Get a subset of the proof, for the first UTXO only @@ -65,7 +65,7 @@ fn main() { /// Returns the hashes for UTXOs in the first block in this fictitious example, there's nothing /// special about them, they are just the first 8 integers hashed as u8s. -fn get_utxo_hashes1() -> Vec { +fn get_utxo_hashes1() -> Vec { let hashes = [ "6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d", "4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a", @@ -78,11 +78,11 @@ fn get_utxo_hashes1() -> Vec { ]; hashes .iter() - .map(|h| NodeHash::from_str(h).unwrap()) + .map(|h| BitcoinNodeHash::from_str(h).unwrap()) .collect() } /// Returns the hashes for UTXOs in the second block. -fn get_utxo_hashes2() -> Vec { +fn get_utxo_hashes2() -> Vec { let utxo_hashes = [ "bf4aff60ee0f3b2d82b47b94f6eff3018d1a47d1b0bc5dfbf8d3a95a2836bf5b", "2e6adf10ab3174629fc388772373848bbe277ffee1f72568e6d06e823b39d2dd", @@ -91,6 +91,6 @@ fn get_utxo_hashes2() -> Vec { ]; utxo_hashes .iter() - .map(|h| NodeHash::from_str(h).unwrap()) + .map(|h| BitcoinNodeHash::from_str(h).unwrap()) .collect() } diff --git a/examples/simple-stump-update.rs b/examples/simple-stump-update.rs index 9017724..139b1b9 100644 --- a/examples/simple-stump-update.rs +++ b/examples/simple-stump-update.rs @@ -5,7 +5,7 @@ use std::str::FromStr; use std::vec; -use rustreexo::accumulator::node_hash::NodeHash; +use rustreexo::accumulator::node_hash::BitcoinNodeHash; use rustreexo::accumulator::proof::Proof; use rustreexo::accumulator::stump::Stump; @@ -15,10 +15,14 @@ fn main() { // If we assume this is the very first block, then the Stump is empty, and we can just add // the utxos to it. Assuming a coinbase with two outputs, we would have the following utxos: let utxos = vec![ - NodeHash::from_str("b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42") - .unwrap(), - NodeHash::from_str("d3bd63d53c5a70050a28612a2f4b2019f40951a653ae70736d93745efb1124fa") - .unwrap(), + BitcoinNodeHash::from_str( + "b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42", + ) + .unwrap(), + BitcoinNodeHash::from_str( + "d3bd63d53c5a70050a28612a2f4b2019f40951a653ae70736d93745efb1124fa", + ) + .unwrap(), ]; // Create a new Stump, and add the utxos to it. Notice how we don't use the full return here, // but only the Stump. To understand what is the second return value, see the documentation @@ -34,9 +38,10 @@ fn main() { // Now we want to update the Stump, by removing the first utxo, and adding a new one. // This would be in case we received a new block with a transaction spending the first utxo, // and creating a new one. - let new_utxo = - NodeHash::from_str("d3bd63d53c5a70050a28612a2f4b2019f40951a653ae70736d93745efb1124fa") - .unwrap(); + let new_utxo = BitcoinNodeHash::from_str( + "d3bd63d53c5a70050a28612a2f4b2019f40951a653ae70736d93745efb1124fa", + ) + .unwrap(); let s = s.modify(&[new_utxo], &[utxos[0]], &proof).unwrap().0; // Now we can verify that the new utxo is in the Stump, and the old one is not. let new_proof = Proof::new(vec![2], vec![new_utxo]); diff --git a/src/accumulator/node_hash.rs b/src/accumulator/node_hash.rs index edb5988..75b84e5 100644 --- a/src/accumulator/node_hash.rs +++ b/src/accumulator/node_hash.rs @@ -5,10 +5,11 @@ //! ``` //! use std::str::FromStr; //! -//! use rustreexo::accumulator::node_hash::NodeHash; -//! let hash = -//! NodeHash::from_str("0000000000000000000000000000000000000000000000000000000000000000") -//! .unwrap(); +//! use rustreexo::accumulator::node_hash::BitcoinNodeHash; +//! let hash = BitcoinNodeHash::from_str( +//! "0000000000000000000000000000000000000000000000000000000000000000", +//! ) +//! .unwrap(); //! assert_eq!( //! hash.to_string().as_str(), //! "0000000000000000000000000000000000000000000000000000000000000000" @@ -18,10 +19,10 @@ //! ``` //! use std::str::FromStr; //! -//! use rustreexo::accumulator::node_hash::NodeHash; -//! let hash1 = NodeHash::new([0; 32]); +//! use rustreexo::accumulator::node_hash::BitcoinNodeHash; +//! let hash1 = BitcoinNodeHash::new([0; 32]); //! // ... or ... -//! let hash2 = NodeHash::from([0; 32]); +//! let hash2 = BitcoinNodeHash::from([0; 32]); //! assert_eq!(hash1, hash2); //! assert_eq!( //! hash1.to_string().as_str(), @@ -33,13 +34,15 @@ //! ``` //! use std::str::FromStr; //! +//! use rustreexo::accumulator::node_hash::BitcoinNodeHash; //! use rustreexo::accumulator::node_hash::NodeHash; -//! let left = NodeHash::new([0; 32]); -//! let right = NodeHash::new([1; 32]); -//! let parent = NodeHash::parent_hash(&left, &right); -//! let expected_parent = -//! NodeHash::from_str("34e33ca0c40b7bd33d28932ca9e35170def7309a3bf91ecda5e1ceb067548a12") -//! .unwrap(); +//! let left = BitcoinNodeHash::new([0; 32]); +//! let right = BitcoinNodeHash::new([1; 32]); +//! let parent = BitcoinNodeHash::parent_hash(&left, &right); +//! let expected_parent = BitcoinNodeHash::from_str( +//! "34e33ca0c40b7bd33d28932ca9e35170def7309a3bf91ecda5e1ceb067548a12", +//! ) +//! .unwrap(); //! assert_eq!(parent, expected_parent); //! ``` use std::convert::TryFrom; @@ -58,39 +61,54 @@ use serde::Deserialize; #[cfg(feature = "with-serde")] use serde::Serialize; +pub trait NodeHash: Copy + Clone + Ord + Debug + Display + std::hash::Hash + 'static { + fn is_empty(&self) -> bool; + fn empty() -> Self; + fn is_placeholder(&self) -> bool; + fn placeholder() -> Self; + fn parent_hash(left: &Self, right: &Self) -> Self; + fn write(&self, writer: &mut W) -> std::io::Result<()> + where + W: std::io::Write; + fn read(reader: &mut R) -> std::io::Result + where + R: std::io::Read; +} + #[derive(Eq, PartialEq, Copy, Clone, Hash, PartialOrd, Ord)] #[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))] /// NodeHash is a wrapper around a 32 byte array that represents a hash of a node in the tree. /// # Example /// ``` -/// use rustreexo::accumulator::node_hash::NodeHash; -/// let hash = NodeHash::new([0; 32]); +/// use rustreexo::accumulator::node_hash::BitcoinNodeHash; +/// let hash = BitcoinNodeHash::new([0; 32]); /// assert_eq!( /// hash.to_string().as_str(), /// "0000000000000000000000000000000000000000000000000000000000000000" /// ); /// ``` #[derive(Default)] -pub enum NodeHash { +pub enum BitcoinNodeHash { #[default] Empty, Placeholder, Some([u8; 32]), } -impl Deref for NodeHash { +impl Deref for BitcoinNodeHash { type Target = [u8; 32]; fn deref(&self) -> &Self::Target { match self { - NodeHash::Some(ref inner) => inner, + BitcoinNodeHash::Some(ref inner) => inner, _ => &[0; 32], } } } -impl Display for NodeHash { + +impl Display for BitcoinNodeHash { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { - if let NodeHash::Some(ref inner) = self { + if let BitcoinNodeHash::Some(ref inner) = self { let mut s = String::new(); for byte in inner.iter() { s.push_str(&format!("{:02x}", byte)); @@ -101,137 +119,158 @@ impl Display for NodeHash { } } } -impl Debug for NodeHash { + +impl Debug for BitcoinNodeHash { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { - if let NodeHash::Some(ref inner) = self { - let mut s = String::new(); - for byte in inner.iter() { - s.push_str(&format!("{:02x}", byte)); + match self { + BitcoinNodeHash::Empty => write!(f, "empty"), + BitcoinNodeHash::Placeholder => write!(f, "placeholder"), + BitcoinNodeHash::Some(ref inner) => { + let mut s = String::new(); + for byte in inner.iter() { + s.push_str(&format!("{:02x}", byte)); + } + write!(f, "{}", s) } - write!(f, "{}", s) - } else { - write!(f, "empty") } } } -impl From for NodeHash { + +impl From for BitcoinNodeHash { fn from(hash: sha512_256::Hash) -> Self { - NodeHash::Some(hash.to_byte_array()) + BitcoinNodeHash::Some(hash.to_byte_array()) } } -impl From<[u8; 32]> for NodeHash { + +impl From<[u8; 32]> for BitcoinNodeHash { fn from(hash: [u8; 32]) -> Self { - NodeHash::Some(hash) + BitcoinNodeHash::Some(hash) } } -impl From<&[u8; 32]> for NodeHash { + +impl From<&[u8; 32]> for BitcoinNodeHash { fn from(hash: &[u8; 32]) -> Self { - NodeHash::Some(*hash) + BitcoinNodeHash::Some(*hash) } } + #[cfg(test)] -impl TryFrom<&str> for NodeHash { +impl TryFrom<&str> for BitcoinNodeHash { type Error = hex::HexToArrayError; fn try_from(hash: &str) -> Result { // This implementation is useful for testing, as it allows to create empty hashes // from the string of 64 zeros. Without this, it would be impossible to express this // hash in the test vectors. if hash == "0000000000000000000000000000000000000000000000000000000000000000" { - return Ok(NodeHash::Empty); + return Ok(BitcoinNodeHash::Empty); } + let hash = hex::FromHex::from_hex(hash)?; - Ok(NodeHash::Some(hash)) + Ok(BitcoinNodeHash::Some(hash)) } } #[cfg(not(test))] -impl TryFrom<&str> for NodeHash { +impl TryFrom<&str> for BitcoinNodeHash { type Error = hex::HexToArrayError; fn try_from(hash: &str) -> Result { let inner = hex::FromHex::from_hex(hash)?; - Ok(NodeHash::Some(inner)) + Ok(BitcoinNodeHash::Some(inner)) } } -impl From<&[u8]> for NodeHash { + +impl From<&[u8]> for BitcoinNodeHash { fn from(hash: &[u8]) -> Self { let mut inner = [0; 32]; inner.copy_from_slice(hash); - NodeHash::Some(inner) + BitcoinNodeHash::Some(inner) } } -impl From for NodeHash { +impl From for BitcoinNodeHash { fn from(hash: sha256::Hash) -> Self { - NodeHash::Some(hash.to_byte_array()) + BitcoinNodeHash::Some(hash.to_byte_array()) } } -impl FromStr for NodeHash { + +impl FromStr for BitcoinNodeHash { fn from_str(s: &str) -> Result { - NodeHash::try_from(s) + BitcoinNodeHash::try_from(s) } + type Err = hex::HexToArrayError; } -impl NodeHash { - /// Tells whether this hash is empty. We use empty hashes throughout the code to represent - /// leaves we want to delete. - pub fn is_empty(&self) -> bool { - if let NodeHash::Empty = self { - return true; - } - false - } + +impl BitcoinNodeHash { /// Creates a new NodeHash from a 32 byte array. /// # Example /// ``` - /// use rustreexo::accumulator::node_hash::NodeHash; - /// let hash = NodeHash::new([0; 32]); + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; + /// let hash = BitcoinNodeHash::new([0; 32]); /// assert_eq!( /// hash.to_string().as_str(), /// "0000000000000000000000000000000000000000000000000000000000000000" /// ); /// ``` pub fn new(inner: [u8; 32]) -> Self { - NodeHash::Some(inner) + BitcoinNodeHash::Some(inner) } +} + +impl NodeHash for BitcoinNodeHash { + /// Tells whether this hash is empty. We use empty hashes throughout the code to represent + /// leaves we want to delete. + fn is_empty(&self) -> bool { + matches!(self, BitcoinNodeHash::Empty) + } + /// Creates an empty hash. This is used to represent leaves we want to delete. /// # Example /// ``` + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; /// use rustreexo::accumulator::node_hash::NodeHash; - /// let hash = NodeHash::empty(); + /// let hash = BitcoinNodeHash::empty(); /// assert!(hash.is_empty()); /// ``` - pub fn empty() -> Self { - NodeHash::Empty + fn empty() -> Self { + BitcoinNodeHash::Empty } + /// parent_hash return the merkle parent of the two passed in nodes. /// # Example /// ``` /// use std::str::FromStr; /// + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; /// use rustreexo::accumulator::node_hash::NodeHash; - /// let left = NodeHash::new([0; 32]); - /// let right = NodeHash::new([1; 32]); - /// let parent = NodeHash::parent_hash(&left, &right); - /// let expected_parent = - /// NodeHash::from_str("34e33ca0c40b7bd33d28932ca9e35170def7309a3bf91ecda5e1ceb067548a12") - /// .unwrap(); + /// let left = BitcoinNodeHash::new([0; 32]); + /// let right = BitcoinNodeHash::new([1; 32]); + /// let parent = BitcoinNodeHash::parent_hash(&left, &right); + /// let expected_parent = BitcoinNodeHash::from_str( + /// "34e33ca0c40b7bd33d28932ca9e35170def7309a3bf91ecda5e1ceb067548a12", + /// ) + /// .unwrap(); /// assert_eq!(parent, expected_parent); /// ``` - pub fn parent_hash(left: &NodeHash, right: &NodeHash) -> NodeHash { + fn parent_hash(left: &Self, right: &Self) -> Self { let mut hash = sha512_256::Hash::engine(); hash.input(&**left); hash.input(&**right); sha512_256::Hash::from_engine(hash).into() } + fn is_placeholder(&self) -> bool { + matches!(self, BitcoinNodeHash::Placeholder) + } + /// Returns a arbitrary placeholder hash that is unlikely to collide with any other hash. /// We use this while computing roots to destroy. Don't confuse this with an empty hash. - pub const fn placeholder() -> Self { - NodeHash::Placeholder + fn placeholder() -> Self { + BitcoinNodeHash::Placeholder } /// write to buffer - pub(super) fn write(&self, writer: &mut W) -> std::io::Result<()> + fn write(&self, writer: &mut W) -> std::io::Result<()> where W: std::io::Write, { @@ -246,7 +285,7 @@ impl NodeHash { } /// Read from buffer - pub(super) fn read(reader: &mut R) -> std::io::Result + fn read(reader: &mut R) -> std::io::Result where R: std::io::Read, { @@ -276,6 +315,7 @@ mod test { use std::str::FromStr; use super::NodeHash; + use crate::accumulator::node_hash::BitcoinNodeHash; use crate::accumulator::util::hash_from_u8; #[test] @@ -283,7 +323,7 @@ mod test { let hash1 = hash_from_u8(0); let hash2 = hash_from_u8(1); - let parent_hash = NodeHash::parent_hash(&hash1, &hash2); + let parent_hash = BitcoinNodeHash::parent_hash(&hash1, &hash2); assert_eq!( parent_hash.to_string().as_str(), "02242b37d8e851f1e86f46790298c7097df06893d6226b7c1453c213e91717de" @@ -291,17 +331,19 @@ mod test { } #[test] fn test_hash_from_str() { - let hash = - NodeHash::from_str("6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d") - .unwrap(); + let hash = BitcoinNodeHash::from_str( + "6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d", + ) + .unwrap(); assert_eq!(hash, hash_from_u8(0)); } #[test] fn test_empty_hash() { // Only relevant for tests - let hash = - NodeHash::from_str("0000000000000000000000000000000000000000000000000000000000000000") - .unwrap(); + let hash = BitcoinNodeHash::from_str( + "0000000000000000000000000000000000000000000000000000000000000000", + ) + .unwrap(); assert_eq!(hash, NodeHash::empty()); } } diff --git a/src/accumulator/pollard.rs b/src/accumulator/pollard.rs index 6bc2af0..3c02e58 100644 --- a/src/accumulator/pollard.rs +++ b/src/accumulator/pollard.rs @@ -4,15 +4,17 @@ //! //! # Example //! ``` +//! use rustreexo::accumulator::node_hash::BitcoinNodeHash; //! use rustreexo::accumulator::node_hash::NodeHash; //! use rustreexo::accumulator::pollard::Pollard; +//! //! let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; -//! let hashes: Vec = values +//! let hashes: Vec = values //! .into_iter() -//! .map(|i| NodeHash::from([i; 32])) +//! .map(|i| BitcoinNodeHash::from([i; 32])) //! .collect(); //! -//! let mut p = Pollard::new(); +//! let mut p = Pollard::::new(); //! //! p.modify(&hashes, &[]).expect("Pollard should not fail"); //! assert_eq!(p.get_roots().len(), 1); @@ -20,10 +22,9 @@ //! p.modify(&[], &hashes).expect("Still should not fail"); // Remove leaves from the accumulator //! //! assert_eq!(p.get_roots().len(), 1); -//! assert_eq!(p.get_roots()[0].get_data(), NodeHash::default()); +//! assert_eq!(p.get_roots()[0].get_data(), BitcoinNodeHash::default()); //! ``` -use core::fmt; use std::cell::Cell; use std::cell::RefCell; use std::collections::HashMap; @@ -35,6 +36,7 @@ use std::io::Write; use std::rc::Rc; use std::rc::Weak; +use super::node_hash::BitcoinNodeHash; use super::node_hash::NodeHash; use super::proof::Proof; use super::util::detect_offset; @@ -46,42 +48,39 @@ use super::util::max_position_at_row; use super::util::right_child; use super::util::root_position; use super::util::tree_rows; + #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum NodeType { Branch, Leaf, } +// A few type aliases to improve readability. + +/// A weak reference to a node in the forest. We use this when referencing to a node's +/// parent, as we don't want to create a reference cycle. +type Parent = RefCell>>; +/// A reference to a node's children. We use this to store the left and right children +/// of a node. This will be the only long-lived strong reference to a node. If this gets +/// dropped, the node will be dropped as well. +type Children = RefCell>>; + /// A forest node that can either be a leaf or a branch. #[derive(Clone)] -pub struct Node { +pub struct Node { /// The type of this node. ty: NodeType, /// The hash of the stored in this node. - data: Cell, + data: Cell, /// The parent of this node, if any. - parent: RefCell>>, + parent: Parent>, /// The left and right children of this node, if any. - left: RefCell>>, + left: Children>, /// The left and right children of this node, if any. - right: RefCell>>, + right: Children>, } -impl Node { - /// Recomputes the hash of all nodes, up to the root. - fn recompute_hashes(&self) { - let left = self.left.borrow(); - let right = self.right.borrow(); - if let (Some(left), Some(right)) = (left.as_deref(), right.as_deref()) { - self.data - .replace(NodeHash::parent_hash(&left.data.get(), &right.data.get())); - } - if let Some(ref parent) = *self.parent.borrow() { - if let Some(p) = parent.upgrade() { - p.recompute_hashes(); - } - } - } +impl Node { /// Writes one node to the writer, this method will recursively write all children. /// The primary use of this method is to serialize the accumulator. In this case, /// you should call this method on each root in the forest. @@ -104,6 +103,22 @@ impl Node { .transpose()?; Ok(()) } + /// Recomputes the hash of all nodes, up to the root. + fn recompute_hashes(&self) { + let left = self.left.borrow(); + let right = self.right.borrow(); + + if let (Some(left), Some(right)) = (left.as_deref(), right.as_deref()) { + self.data + .replace(Hash::parent_hash(&left.data.get(), &right.data.get())); + } + if let Some(ref parent) = *self.parent.borrow() { + if let Some(p) = parent.upgrade() { + p.recompute_hashes(); + } + } + } + /// Reads one node from the reader, this method will recursively read all children. /// The primary use of this method is to deserialize the accumulator. In this case, /// you should call this method on each root in the forest, assuming you know how @@ -111,15 +126,15 @@ impl Node { #[allow(clippy::type_complexity)] pub fn read_one( reader: &mut R, - ) -> std::io::Result<(Rc, HashMap>)> { - fn _read_one( - ancestor: Option>, + ) -> std::io::Result<(Rc>, HashMap>>)> { + fn _read_one( + ancestor: Option>>, reader: &mut R, - index: &mut HashMap>, - ) -> std::io::Result> { + index: &mut HashMap>>, + ) -> std::io::Result>> { let mut ty = [0u8; 8]; reader.read_exact(&mut ty)?; - let data = NodeHash::read(reader)?; + let data = Hash::read(reader)?; let ty = match u64::from_le_bytes(ty) { 0 => NodeType::Branch, @@ -165,33 +180,33 @@ impl Node { let root = _read_one(None, reader, &mut index)?; Ok((root, index)) } + /// Returns the data associated with this node. - pub fn get_data(&self) -> NodeHash { + pub fn get_data(&self) -> Hash { self.data.get() } } -impl Debug for Node { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "{:02x}{:02x}", self.data.get()[0], self.data.get()[1]) - } -} /// The actual Pollard accumulator, it implements all methods required to update the forest /// and to prove/verify membership. #[derive(Default, Clone)] -pub struct Pollard { +pub struct Pollard { /// The roots of the forest, all leaves are children of these roots, and therefore /// owned by them. - roots: Vec>, + roots: Vec>>, /// The number of leaves in the forest. Actually, this is the number of leaves we ever /// added to the forest. pub leaves: u64, /// A map of all nodes in the forest, indexed by their hash, this is used to lookup /// leaves when proving membership. - map: HashMap>, + map: HashMap>>, } + impl Pollard { - /// Creates a new empty [Pollard]. + /// Creates a new empty [Pollard] with the default hash function. + /// + /// This will create an empty Pollard, using [BitcoinNodeHash] as the hash function. If you + /// want to use a different hash function, you can use [Pollard::new_with_hash]. /// # Example /// ``` /// use rustreexo::accumulator::pollard::Pollard; @@ -204,13 +219,32 @@ impl Pollard { leaves: 0, } } +} + +impl Pollard { + /// Creates a new empty [Pollard] with a custom hash function. + /// # Example + /// ``` + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; + /// use rustreexo::accumulator::pollard::Pollard; + /// let mut pollard = Pollard::::new(); + /// ``` + pub fn new_with_hash() -> Pollard { + Pollard { + map: HashMap::new(), + roots: Vec::new(), + leaves: 0, + } + } + /// Writes the Pollard to a writer. Used to send the accumulator over the wire /// or to disk. /// # Example /// ``` + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; /// use rustreexo::accumulator::pollard::Pollard; /// - /// let mut pollard = Pollard::new(); + /// let mut pollard = Pollard::::new(); /// let mut serialized = Vec::new(); /// pollard.serialize(&mut serialized).unwrap(); /// @@ -229,14 +263,16 @@ impl Pollard { Ok(()) } + /// Deserializes a pollard from a reader. /// # Example /// ``` /// use std::io::Cursor; /// + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; /// use rustreexo::accumulator::pollard::Pollard; /// let mut serialized = Cursor::new(vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); - /// let pollard = Pollard::deserialize(&mut serialized).unwrap(); + /// let pollard = Pollard::::deserialize(&mut serialized).unwrap(); /// assert_eq!(pollard.leaves, 0); /// assert_eq!(pollard.get_roots().len(), 0); /// ``` @@ -257,28 +293,30 @@ impl Pollard { } Ok(Pollard { roots, leaves, map }) } + /// Returns the hash of a given position in the tree. - fn get_hash(&self, pos: u64) -> Result { + fn get_hash(&self, pos: u64) -> Result { let (node, _, _) = self.grab_node(pos)?; Ok(node.data.get()) } + /// Proves that a given set of hashes is in the accumulator. It returns a proof /// and the hashes that we what to prove, but sorted by position in the tree. /// # Example /// ``` - /// use rustreexo::accumulator::node_hash::NodeHash; + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; /// use rustreexo::accumulator::pollard::Pollard; - /// let mut pollard = Pollard::new(); + /// let mut pollard = Pollard::::new(); /// let hashes = vec![0, 1, 2, 3, 4, 5, 6, 7] /// .iter() - /// .map(|n| NodeHash::from([*n; 32])) + /// .map(|n| BitcoinNodeHash::from([*n; 32])) /// .collect::>(); /// pollard.modify(&hashes, &[]).unwrap(); /// // We want to prove that the first two hashes are in the accumulator. /// let proof = pollard.prove(&[hashes[1], hashes[0]]).unwrap(); /// //TODO: Verify the proof /// ``` - pub fn prove(&self, targets: &[NodeHash]) -> Result { + pub fn prove(&self, targets: &[Hash]) -> Result, String> { let mut positions = Vec::new(); for target in targets { let node = self.map.get(target).ok_or("Could not find node")?; @@ -290,12 +328,15 @@ impl Pollard { .iter() .map(|pos| self.get_hash(*pos).unwrap()) .collect::>(); - Ok(Proof::new(positions, proof)) + + Ok(Proof::new_with_hash(positions, proof)) } + /// Returns a reference to the roots in this Pollard. - pub fn get_roots(&self) -> &[Rc] { + pub fn get_roots(&self) -> &[Rc>] { &self.roots } + /// Modify is the main API to a [Pollard]. Because order matters, you can only `modify` /// a [Pollard], and internally it'll add and delete, in the correct order. /// @@ -308,7 +349,7 @@ impl Pollard { /// use bitcoin_hashes::sha256::Hash as Data; /// use bitcoin_hashes::Hash; /// use bitcoin_hashes::HashEngine; - /// use rustreexo::accumulator::node_hash::NodeHash; + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; /// use rustreexo::accumulator::pollard::Pollard; /// let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; /// let hashes = values @@ -316,11 +357,11 @@ impl Pollard { /// .map(|val| { /// let mut engine = Data::engine(); /// engine.input(&[val]); - /// NodeHash::from(Data::from_engine(engine).as_byte_array()) + /// BitcoinNodeHash::from(Data::from_engine(engine).as_byte_array()) /// }) /// .collect::>(); /// // Add 8 leaves to the pollard - /// let mut p = Pollard::new(); + /// let mut p = Pollard::::new(); /// p.modify(&hashes, &[]).expect("Pollard should not fail"); /// /// assert_eq!( @@ -328,13 +369,17 @@ impl Pollard { /// String::from("b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42") /// ); /// ``` - pub fn modify(&mut self, add: &[NodeHash], del: &[NodeHash]) -> Result<(), String> { + pub fn modify(&mut self, add: &[Hash], del: &[Hash]) -> Result<(), String> { self.del(del)?; self.add(add); Ok(()) } + #[allow(clippy::type_complexity)] - pub fn grab_node(&self, pos: u64) -> Result<(Rc, Rc, Rc), String> { + pub fn grab_node( + &self, + pos: u64, + ) -> Result<(Rc>, Rc>, Rc>), String> { let (tree, branch_len, bits) = detect_offset(pos, self.leaves); let mut n = Some(self.roots[tree as usize].clone()); let mut sibling = Some(self.roots[tree as usize].clone()); @@ -366,7 +411,8 @@ impl Pollard { } Err(format!("node {} not found", pos)) } - fn del(&mut self, targets: &[NodeHash]) -> Result<(), String> { + + fn del(&mut self, targets: &[Hash]) -> Result<(), String> { let mut pos = targets .iter() .flat_map(|target| self.map.get(target)) @@ -380,7 +426,7 @@ impl Pollard { .collect::>(); pos.sort(); - let (_, targets): (Vec, Vec) = pos.into_iter().unzip(); + let (_, targets): (Vec, Vec) = pos.into_iter().unzip(); for target in targets { match self.map.remove(&target) { Some(target) => { @@ -393,7 +439,8 @@ impl Pollard { } Ok(()) } - pub fn verify(&self, proof: &Proof, del_hashes: &[NodeHash]) -> Result { + + pub fn verify(&self, proof: &Proof, del_hashes: &[Hash]) -> Result { let roots = self .roots .iter() @@ -401,7 +448,8 @@ impl Pollard { .collect::>(); proof.verify(del_hashes, &roots, self.leaves) } - fn get_pos(&self, node: &Weak) -> u64 { + + fn get_pos(&self, node: &Weak>) -> u64 { // This indicates whether the node is a left or right child at each level // When we go down the tree, we can use the indicator to know which // child to take. @@ -458,7 +506,8 @@ impl Pollard { } pos } - fn del_single(&mut self, node: &Node) -> Option<()> { + + fn del_single(&mut self, node: &Node) -> Option<()> { let parent = node.parent.borrow(); // Deleting a root let parent = match *parent { @@ -468,7 +517,7 @@ impl Pollard { self.roots[pos] = Rc::new(Node { ty: NodeType::Branch, parent: RefCell::new(None), - data: Cell::new(NodeHash::default()), + data: Cell::new(Hash::empty()), left: RefCell::new(None), right: RefCell::new(None), }); @@ -506,8 +555,9 @@ impl Pollard { Some(()) } - fn add_single(&mut self, value: NodeHash) { - let mut node: Rc = Rc::new(Node { + + fn add_single(&mut self, value: Hash) { + let mut node: Rc> = Rc::new(Node { ty: NodeType::Leaf, parent: RefCell::new(None), data: Cell::new(value), @@ -538,11 +588,13 @@ impl Pollard { self.roots.push(node); self.leaves += 1; } - fn add(&mut self, values: &[NodeHash]) { + + fn add(&mut self, values: &[Hash]) { for value in values { self.add_single(*value); } } + /// to_string returns the full pollard in a string for all forests less than 6 rows. fn string(&self) -> String { if self.leaves == 0 { @@ -615,11 +667,13 @@ impl Debug for Pollard { write!(f, "{}", self.string()) } } + impl Display for Pollard { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { write!(f, "{}", self.string()) } } + #[cfg(test)] mod test { use std::convert::TryFrom; @@ -633,17 +687,19 @@ mod test { use serde::Deserialize; use super::Pollard; + use crate::accumulator::node_hash::BitcoinNodeHash; use crate::accumulator::node_hash::NodeHash; use crate::accumulator::pollard::Node; use crate::accumulator::proof::Proof; - fn hash_from_u8(value: u8) -> NodeHash { + fn hash_from_u8(value: u8) -> BitcoinNodeHash { let mut engine = Data::engine(); engine.input(&[value]); - NodeHash::from(Data::from_engine(engine).as_byte_array()) + BitcoinNodeHash::from(Data::from_engine(engine).as_byte_array()) } + #[test] fn test_grab_node() { let values = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]; @@ -652,12 +708,14 @@ mod test { let mut p = Pollard::new(); p.modify(&hashes, &[]).expect("Pollard should not fail"); let (found_target, found_sibling, _) = p.grab_node(4).unwrap(); - let target = - NodeHash::try_from("e52d9c508c502347344d8c07ad91cbd6068afc75ff6292f062a09ca381c89e71") - .unwrap(); - let sibling = - NodeHash::try_from("e77b9a9ae9e30b0dbdb6f510a264ef9de781501d7b6b92ae89eb059c5ab743db") - .unwrap(); + let target = BitcoinNodeHash::try_from( + "e52d9c508c502347344d8c07ad91cbd6068afc75ff6292f062a09ca381c89e71", + ) + .unwrap(); + let sibling = BitcoinNodeHash::try_from( + "e77b9a9ae9e30b0dbdb6f510a264ef9de781501d7b6b92ae89eb059c5ab743db", + ) + .unwrap(); assert_eq!(target, found_target.data.get()); assert_eq!(sibling, found_sibling.data.get()); @@ -678,6 +736,7 @@ mod test { node.data.get().to_string() ); } + #[test] fn test_proof_verify() { let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; @@ -688,6 +747,7 @@ mod test { let proof = p.prove(&[hashes[0], hashes[1]]).unwrap(); assert!(p.verify(&proof, &[hashes[0], hashes[1]]).unwrap()); } + #[test] fn test_add() { let values = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]; @@ -713,6 +773,7 @@ mod test { acc.roots[3].data.get().to_string().as_str(), ); } + #[test] fn test_delete_roots_child() { // Assuming the following tree: @@ -722,7 +783,7 @@ mod test { // 00 01 // If I delete `01`, then `00` will become a root, moving it's hash to `02` let values = vec![0, 1]; - let hashes: Vec = values.into_iter().map(hash_from_u8).collect(); + let hashes: Vec = values.into_iter().map(hash_from_u8).collect(); let mut p = Pollard::new(); p.modify(&hashes, &[]).expect("Pollard should not fail"); @@ -743,15 +804,16 @@ mod test { // If I delete `02`, then `02` will become an empty root, it'll point to nothing // and its data will be Data::default() let values = vec![0, 1]; - let hashes: Vec = values.into_iter().map(hash_from_u8).collect(); + let hashes: Vec = values.into_iter().map(hash_from_u8).collect(); let mut p = Pollard::new(); p.modify(&hashes, &[]).expect("Pollard should not fail"); p.del_single(&p.grab_node(2).unwrap().0); assert_eq!(p.get_roots().len(), 1); let root = p.get_roots()[0].clone(); - assert_eq!(root.data.get(), NodeHash::default()); + assert_eq!(root.data.get(), BitcoinNodeHash::default()); } + #[test] fn test_delete_non_root() { // Assuming this tree, if we delete `01`, 00 will move up to 08's position @@ -774,7 +836,7 @@ mod test { // Where 08's data is just 00's let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; - let hashes: Vec = values.into_iter().map(hash_from_u8).collect(); + let hashes: Vec = values.into_iter().map(hash_from_u8).collect(); let mut p = Pollard::new(); p.modify(&hashes, &[]).expect("Pollard should not fail"); @@ -784,12 +846,14 @@ mod test { let (node, _, _) = p.grab_node(8).expect("This tree should have pos 8"); assert_eq!(node.data.get(), hashes[0]); } + #[derive(Debug, Deserialize)] struct TestCase { leaf_preimages: Vec, target_values: Option>, expected_roots: Vec, } + fn run_single_addition_case(case: TestCase) { let hashes = case .leaf_preimages @@ -802,7 +866,7 @@ mod test { let expected_roots = case .expected_roots .iter() - .map(|root| NodeHash::from_str(root).unwrap()) + .map(|root| BitcoinNodeHash::from_str(root).unwrap()) .collect::>(); let roots = p .get_roots() @@ -811,6 +875,7 @@ mod test { .collect::>(); assert_eq!(expected_roots, roots, "Test case failed {:?}", case); } + fn run_case_with_deletion(case: TestCase) { let hashes = case .leaf_preimages @@ -832,7 +897,7 @@ mod test { let expected_roots = case .expected_roots .iter() - .map(|root| NodeHash::from_str(root).unwrap()) + .map(|root| BitcoinNodeHash::from_str(root).unwrap()) .collect::>(); let roots = p .get_roots() @@ -841,6 +906,7 @@ mod test { .collect::>(); assert_eq!(expected_roots, roots, "Test case failed {:?}", case); } + #[test] fn run_tests_from_cases() { #[derive(Deserialize)] @@ -862,6 +928,7 @@ mod test { run_case_with_deletion(i); } } + #[test] fn test_to_string() { let hashes = get_hash_vec_of(&(0..255).collect::>()); @@ -872,6 +939,7 @@ mod test { p.to_string().get(0..30) ); } + #[test] fn test_get_pos() { macro_rules! test_get_pos { @@ -913,6 +981,7 @@ mod test { 25 ); } + #[test] fn test_serialize_one() { let hashes = get_hash_vec_of(&[0, 1, 2, 3, 4, 5, 6, 7]); @@ -922,9 +991,11 @@ mod test { let mut writer = std::io::Cursor::new(Vec::new()); p.get_roots()[0].write_one(&mut writer).unwrap(); let (deserialized, _) = - Node::read_one(&mut std::io::Cursor::new(writer.into_inner())).unwrap(); + Node::::read_one(&mut std::io::Cursor::new(writer.into_inner())) + .unwrap(); assert_eq!(deserialized.get_data(), p.get_roots()[0].get_data()); } + #[test] fn test_serialization() { let hashes = get_hash_vec_of(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]); @@ -934,7 +1005,8 @@ mod test { let mut writer = std::io::Cursor::new(Vec::new()); p.serialize(&mut writer).unwrap(); let deserialized = - Pollard::deserialize(&mut std::io::Cursor::new(writer.into_inner())).unwrap(); + Pollard::::deserialize(&mut std::io::Cursor::new(writer.into_inner())) + .unwrap(); assert_eq!( deserialized.get_roots()[0].get_data(), p.get_roots()[0].get_data() @@ -942,6 +1014,7 @@ mod test { assert_eq!(deserialized.leaves, p.leaves); assert_eq!(deserialized.map.len(), p.map.len()); } + #[test] fn test_proof() { let hashes = get_hash_vec_of(&[0, 1, 2, 3, 4, 5, 6, 7]); @@ -972,7 +1045,8 @@ mod test { assert_eq!(proof, expected_proof); assert!(p.verify(&proof, &del_hashes).unwrap()); } - fn get_hash_vec_of(elements: &[u8]) -> Vec { + + fn get_hash_vec_of(elements: &[u8]) -> Vec { elements.iter().map(|el| hash_from_u8(*el)).collect() } @@ -984,11 +1058,11 @@ mod test { #[test] fn test_serialization_roundtrip() { - let mut p = Pollard::new(); + let mut p = Pollard::::new(); let values = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; - let hashes: Vec = values + let hashes: Vec = values .into_iter() - .map(|i| NodeHash::from([i; 32])) + .map(|i| BitcoinNodeHash::from([i; 32])) .collect(); p.modify(&hashes, &[]).expect("modify should work"); assert_eq!(p.get_roots().len(), 1); @@ -1000,7 +1074,8 @@ mod test { assert_eq!(p.leaves, 16); let mut serialized = Vec::::new(); p.serialize(&mut serialized).expect("serialize should work"); - let deserialized = Pollard::deserialize(&*serialized).expect("deserialize should work"); + let deserialized = + Pollard::::deserialize(&*serialized).expect("deserialize should work"); assert_eq!(deserialized.get_roots().len(), 1); assert!(deserialized.get_roots()[0].get_data().is_empty()); assert_eq!(deserialized.leaves, 16); diff --git a/src/accumulator/proof.rs b/src/accumulator/proof.rs index 86771c3..f242677 100644 --- a/src/accumulator/proof.rs +++ b/src/accumulator/proof.rs @@ -9,7 +9,7 @@ //! use bitcoin_hashes::sha256; //! use bitcoin_hashes::Hash; //! use bitcoin_hashes::HashEngine; -//! use rustreexo::accumulator::node_hash::NodeHash; +//! use rustreexo::accumulator::node_hash::BitcoinNodeHash; //! use rustreexo::accumulator::proof::Proof; //! use rustreexo::accumulator::stump::Stump; //! let s = Stump::new(); @@ -22,16 +22,22 @@ //! let mut proof_hashes = Vec::new(); //! //! proof_hashes.push( -//! NodeHash::from_str("4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a") -//! .unwrap(), +//! BitcoinNodeHash::from_str( +//! "4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a", +//! ) +//! .unwrap(), //! ); //! proof_hashes.push( -//! NodeHash::from_str("9576f4ade6e9bc3a6458b506ce3e4e890df29cb14cb5d3d887672aef55647a2b") -//! .unwrap(), +//! BitcoinNodeHash::from_str( +//! "9576f4ade6e9bc3a6458b506ce3e4e890df29cb14cb5d3d887672aef55647a2b", +//! ) +//! .unwrap(), //! ); //! proof_hashes.push( -//! NodeHash::from_str("29590a14c1b09384b94a2c0e94bf821ca75b62eacebc47893397ca88e3bbcbd7") -//! .unwrap(), +//! BitcoinNodeHash::from_str( +//! "29590a14c1b09384b94a2c0e94bf821ca75b62eacebc47893397ca88e3bbcbd7", +//! ) +//! .unwrap(), //! ); //! //! // Hashes of the leaves UTXOs we'll add to the accumulator @@ -57,19 +63,21 @@ use serde::Deserialize; #[cfg(feature = "with-serde")] use serde::Serialize; +use super::node_hash::BitcoinNodeHash; use super::node_hash::NodeHash; use super::stump::UpdateData; +use super::util; use super::util::get_proof_positions; use super::util::read_u64; use super::util::tree_rows; -use super::util::{self}; -#[derive(Clone, Debug, Default, Eq, PartialEq)] + +#[derive(Clone, Debug, Eq, PartialEq)] #[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))] /// A proof is a collection of hashes and positions. Each target position /// points to a leaf to be proven. Hashes are all /// hashes that can't be calculated from the data itself. /// Proofs are generated elsewhere. -pub struct Proof { +pub struct Proof { /// Targets are the i'th of leaf locations to delete and they are the bottommost leaves. /// With the tree below, the Targets can only consist of one of these: 02, 03, 04. ///```! @@ -92,7 +100,17 @@ pub struct Proof { /// // |---\ |---\ /// // 00 01 02 03 /// ``` - pub hashes: Vec, + pub hashes: Vec, +} + +// the default hash type for a proof is BitcoinNodeHash +impl Default for Proof { + fn default() -> Self { + Proof { + targets: Vec::new(), + hashes: Vec::new(), + } + } } // We often need to return the targets paired with hashes, and the proof position. @@ -101,11 +119,11 @@ pub struct Proof { /// This alias is used when we need to return the nodes and roots for a proof /// if we are not concerned with deleting those elements. -pub(crate) type NodesAndRootsCurrent = (Vec<(u64, NodeHash)>, Vec); +pub(crate) type NodesAndRootsCurrent = (Vec<(u64, Hash)>, Vec); /// This is used when we need to return the nodes and roots for a proof /// if we are concerned with deleting those elements. The difference is that /// we need to retun the old and updatated roots in the accumulator. -pub(crate) type NodesAndRootsOldNew = (Vec<(u64, NodeHash)>, Vec<(NodeHash, NodeHash)>); +pub(crate) type NodesAndRootsOldNew = (Vec<(u64, Hash)>, Vec<(Hash, Hash)>); impl Proof { /// Creates a proof from a vector of target and hashes. @@ -132,7 +150,6 @@ impl Proof { /// ``` /// use bitcoin_hashes::Hash; /// use bitcoin_hashes::HashEngine; - /// use rustreexo::accumulator::node_hash::NodeHash; /// use rustreexo::accumulator::proof::Proof; /// let targets = vec![0]; /// @@ -142,9 +159,53 @@ impl Proof { /// // Fill `proof_hashes` up with all hashes /// Proof::new(targets, proof_hashes); /// ``` - pub fn new(targets: Vec, hashes: Vec) -> Self { + pub fn new(targets: Vec, hashes: Vec) -> Proof { Proof { targets, hashes } } +} + +impl Proof { + /// Creates a proof from a vector of target and hashes, using a different hash type. + /// `targets` are u64s and indicates the position of the leaves we are + /// trying to prove. + /// `hashes` are of type `NodeHash` and are all hashes we need for computing the roots. + /// + /// Different from `new`, this function allows for the proof to be created with a different + /// hash type, as long as it implements `NodeHash`. + /// + /// Assuming a tree with leaf values [0, 1, 2, 3, 4, 5, 6, 7], we should see something like this: + ///```! + /// // 14 + /// // |-----------------\ + /// // 12 13 + /// // |---------\ |--------\ + /// // 08 09 10 11 + /// // |----\ |----\ |----\ |----\ + /// // 00 01 02 03 04 05 06 07 + /// ``` + /// If we are proving `00` (i.e. 00 is our target), then we need 01, + /// 09 and 13's hashes, so we can compute 14 by hashing both siblings + /// in each level (00 and 01, 08 and 09 and 12 and 13). Note that + /// some hashes we can compute by ourselves, and are not present in the + /// proof, in this case 00, 08, 12 and 14. + /// # Example + /// ``` + /// use bitcoin_hashes::Hash; + /// use bitcoin_hashes::HashEngine; + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; + /// use rustreexo::accumulator::proof::Proof; + /// let targets = vec![0]; + /// + /// let mut proof_hashes = Vec::new(); + /// let targets = vec![0]; + /// // For proving 0, we need 01, 09 and 13's hashes. 00, 08, 12 and 14 can be calculated + /// // Fill `proof_hashes` up with all hashes + /// Proof::::new_with_hash(targets, proof_hashes); + /// ``` + pub fn new_with_hash(targets: Vec, hashes: Vec) -> Self { + Proof { targets, hashes } + } + /// Public interface for verifying proofs. Returns a result with a bool or an Error /// True means the proof is true given the current stump, false means the proof is /// not valid given the current stump. @@ -155,7 +216,7 @@ impl Proof { /// use bitcoin_hashes::sha256; /// use bitcoin_hashes::Hash; /// use bitcoin_hashes::HashEngine; - /// use rustreexo::accumulator::node_hash::NodeHash; + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; /// use rustreexo::accumulator::proof::Proof; /// use rustreexo::accumulator::stump::Stump; /// let s = Stump::new(); @@ -175,16 +236,22 @@ impl Proof { /// // 00 01 02 03 04 05 06 07 /// // For proving 0, we need 01, 09 and 13's hashes. 00, 08, 12 and 14 can be calculated /// proof_hashes.push( - /// NodeHash::from_str("4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a") - /// .unwrap(), + /// BitcoinNodeHash::from_str( + /// "4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a", + /// ) + /// .unwrap(), /// ); /// proof_hashes.push( - /// NodeHash::from_str("9576f4ade6e9bc3a6458b506ce3e4e890df29cb14cb5d3d887672aef55647a2b") - /// .unwrap(), + /// BitcoinNodeHash::from_str( + /// "9576f4ade6e9bc3a6458b506ce3e4e890df29cb14cb5d3d887672aef55647a2b", + /// ) + /// .unwrap(), /// ); /// proof_hashes.push( - /// NodeHash::from_str("29590a14c1b09384b94a2c0e94bf821ca75b62eacebc47893397ca88e3bbcbd7") - /// .unwrap(), + /// BitcoinNodeHash::from_str( + /// "29590a14c1b09384b94a2c0e94bf821ca75b62eacebc47893397ca88e3bbcbd7", + /// ) + /// .unwrap(), /// ); /// /// let mut hashes = Vec::new(); @@ -199,15 +266,15 @@ impl Proof { /// ``` pub fn verify( &self, - del_hashes: &[NodeHash], - roots: &[NodeHash], + del_hashes: &[Hash], + roots: &[Hash], num_leaves: u64, ) -> Result { if self.targets.is_empty() { return Ok(true); } - let mut calculated_roots: std::iter::Peekable> = self + let mut calculated_roots: std::iter::Peekable> = self .calculate_hashes(del_hashes, num_leaves)? .1 .into_iter() @@ -227,8 +294,10 @@ impl Proof { if calculated_roots.len() != number_matched_roots && calculated_roots.len() != 0 { return Ok(false); } + Ok(true) } + /// Returns the elements needed to prove a subset of targets. For example, a tree with /// 8 leaves, if we cache `[0, 2, 6, 7]`, and we need to prove `[2, 7]` only, we have to remove /// elements for 0 and 7. The original proof is `[1, 3, 10]`, and we can compute `[8, 9, 11, 12, 13, 14]`. @@ -244,10 +313,10 @@ impl Proof { /// ``` pub fn get_proof_subset( &self, - del_hashes: &[NodeHash], + del_hashes: &[Hash], new_targets: &[u64], num_leaves: u64, - ) -> Result { + ) -> Result, String> { let forest_rows = tree_rows(num_leaves); let old_proof_positions = get_proof_positions(&self.targets, num_leaves, forest_rows); let needed_positions = get_proof_positions(new_targets, num_leaves, forest_rows); @@ -257,7 +326,7 @@ impl Proof { .iter() .copied() .zip(self.hashes.iter().copied()) - .collect::>(); + .collect::>(); old_proof.extend(intermediate_positions); @@ -271,7 +340,7 @@ impl Proof { } } new_proof.sort(); - let (_, new_proof): (Vec, Vec) = new_proof.into_iter().unzip(); + let (_, new_proof): (Vec, Vec) = new_proof.into_iter().unzip(); Ok(Proof { targets: new_targets.to_vec(), hashes: new_proof, @@ -286,7 +355,7 @@ impl Proof { /// - hashes (32 bytes) /// # Example /// ``` - /// use rustreexo::accumulator::node_hash::NodeHash; + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; /// use rustreexo::accumulator::proof::Proof; /// use rustreexo::accumulator::stump::Stump; /// @@ -309,20 +378,21 @@ impl Proof { writer.write_all(&self.hashes.len().to_le_bytes())?; for hash in &self.hashes { len += 32; - writer.write_all(&**hash)?; + hash.write(&mut writer)?; } Ok(len) } + /// Deserializes a proof from a byte array. /// # Example /// ``` /// use std::io::Cursor; /// - /// use rustreexo::accumulator::node_hash::NodeHash; + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; /// use rustreexo::accumulator::proof::Proof; /// use rustreexo::accumulator::stump::Stump; /// let proof = Cursor::new(vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); - /// let deserialized_proof = Proof::deserialize(proof).unwrap(); + /// let deserialized_proof = Proof::::deserialize(proof).unwrap(); /// // An empty proof is only 16 bytes of zeros, meaning no targets and no hashes /// assert_eq!(Proof::default(), deserialized_proof); /// ``` @@ -336,15 +406,14 @@ impl Proof { let hashes_len = read_u64(&mut buf)? as usize; let mut hashes = Vec::with_capacity(hashes_len); for _ in 0..hashes_len { - let mut hash = [0u8; 32]; - buf.read_exact(&mut hash) - .map_err(|_| "Failed to read hash")?; - hashes.push(hash.into()); + let hash = Hash::read(&mut buf).map_err(|_| "Failed to parse hash")?; + hashes.push(hash); } Ok(Proof { targets, hashes }) } + /// Returns how many targets this proof has - pub fn targets(&self) -> usize { + pub fn n_targets(&self) -> usize { self.targets.len() } @@ -359,15 +428,15 @@ impl Proof { /// If at least one returned element doesn't exist in the accumulator, the proof is invalid. pub(crate) fn calculate_hashes_delete( &self, - del_hashes: &[(NodeHash, NodeHash)], + del_hashes: &[(Hash, Hash)], num_leaves: u64, - ) -> Result { + ) -> Result, String> { // Where all the root hashes that we've calculated will go to. let total_rows = util::tree_rows(num_leaves); // Where all the parent hashes we've calculated in a given row will go to. let mut calculated_root_hashes = - Vec::<(NodeHash, NodeHash)>::with_capacity(util::num_roots(num_leaves)); + Vec::<(Hash, Hash)>::with_capacity(util::num_roots(num_leaves)); // the positions that should be passed as a proof let proof_positions = get_proof_positions(&self.targets, num_leaves, total_rows); @@ -445,15 +514,14 @@ impl Proof { /// needed for verification (i.e. the current accumulator). pub(crate) fn calculate_hashes( &self, - del_hashes: &[NodeHash], + del_hashes: &[Hash], num_leaves: u64, - ) -> Result { + ) -> Result, String> { // Where all the root hashes that we've calculated will go to. let total_rows = util::tree_rows(num_leaves); // Where all the parent hashes we've calculated in a given row will go to. - let mut calculated_root_hashes = - Vec::::with_capacity(util::num_roots(num_leaves)); + let mut calculated_root_hashes = Vec::::with_capacity(util::num_roots(num_leaves)); // the positions that should be passed as a proof let proof_positions = get_proof_positions(&self.targets, num_leaves, total_rows); @@ -541,18 +609,19 @@ impl Proof { (None, None) => None, } } + /// Uses the data passed in to update a proof, creating a valid proof for a given /// set of targets, after an update. This is useful for caching UTXOs. You grab a proof /// for it once and then keep updating it every block, yielding an always valid proof /// over those UTXOs. pub fn update( self, - cached_hashes: Vec, - add_hashes: Vec, + cached_hashes: Vec, + add_hashes: Vec, block_targets: Vec, remembers: Vec, - update_data: UpdateData, - ) -> Result<(Proof, Vec), String> { + update_data: UpdateData, + ) -> Result<(Proof, Vec), String> { let (proof_after_deletion, cached_hashes) = self.update_proof_remove( block_targets, cached_hashes, @@ -571,17 +640,18 @@ impl Proof { Ok(data_after_addition) } + fn update_proof_add( self, - adds: Vec, - cached_del_hashes: Vec, + adds: Vec, + cached_del_hashes: Vec, remembers: Vec, - new_nodes: Vec<(u64, NodeHash)>, + new_nodes: Vec<(u64, Hash)>, before_num_leaves: u64, to_destroy: Vec, - ) -> Result<(Proof, Vec), String> { + ) -> Result<(Proof, Vec), String> { // Combine the hashes with the targets. - let orig_targets_with_hash: Vec<(u64, NodeHash)> = self + let orig_targets_with_hash: Vec<(u64, Hash)> = self .targets .iter() .copied() @@ -616,7 +686,7 @@ impl Proof { // remembers is an index telling what newly created UTXO should be cached for remember in remembers { - let remember_target: Option<&NodeHash> = adds.get(remember as usize); + let remember_target: Option<&Hash> = adds.get(remember as usize); if let Some(remember_target) = remember_target { let node = new_nodes_iter.find(|(_, hash)| *hash == *remember_target); if let Some((pos, hash)) = node { @@ -655,22 +725,23 @@ impl Proof { } new_proof.sort(); - let (_, hashes): (Vec, Vec) = new_proof.into_iter().unzip(); + let (_, hashes): (Vec, Vec) = new_proof.into_iter().unzip(); Ok(( - Proof { + Proof:: { hashes, targets: new_target_pos, }, target_hashes, )) } + /// maybe_remap remaps the passed in hash and pos if the tree_rows increase after /// adding the new nodes. fn maybe_remap( num_leaves: u64, num_adds: u64, - positions: Vec<(u64, NodeHash)>, - ) -> Vec<(u64, NodeHash)> { + positions: Vec<(u64, Hash)>, + ) -> Vec<(u64, Hash)> { let new_forest_rows = util::tree_rows(num_leaves + num_adds); let old_forest_rows = util::tree_rows(num_leaves); let tree_rows = util::tree_rows(num_leaves); @@ -697,13 +768,13 @@ impl Proof { fn update_proof_remove( self, block_targets: Vec, - cached_hashes: Vec, - updated: Vec<(u64, NodeHash)>, + cached_hashes: Vec, + updated: Vec<(u64, Hash)>, num_leaves: u64, - ) -> Result<(Proof, Vec), String> { + ) -> Result<(Proof, Vec), String> { let total_rows = util::tree_rows(num_leaves); - let targets_with_hash: Vec<(u64, NodeHash)> = self + let targets_with_hash: Vec<(u64, Hash)> = self .targets .iter() .cloned() @@ -767,7 +838,7 @@ impl Proof { proof_elements.sort(); // Grab the hashes for the proof - let (_, hashes): (Vec, Vec) = proof_elements.into_iter().unzip(); + let (_, hashes): (Vec, Vec) = proof_elements.into_iter().unzip(); // Gets all proof targets, but with their new positions after delete let (targets, target_hashes) = Proof::calc_next_positions(&block_targets, &targets_with_hash, num_leaves, true)? @@ -779,10 +850,10 @@ impl Proof { fn calc_next_positions( block_targets: &Vec, - old_positions: &Vec<(u64, NodeHash)>, + old_positions: &Vec<(u64, Hash)>, num_leaves: u64, append_roots: bool, - ) -> Result, String> { + ) -> Result, String> { let total_rows = util::tree_rows(num_leaves); let mut new_positions = vec![]; @@ -825,9 +896,11 @@ mod tests { use serde::Deserialize; use super::Proof; + use crate::accumulator::node_hash::BitcoinNodeHash; use crate::accumulator::node_hash::NodeHash; use crate::accumulator::stump::Stump; use crate::accumulator::util::hash_from_u8; + #[derive(Deserialize)] struct TestCase { numleaves: usize, @@ -837,6 +910,7 @@ mod tests { proofhashes: Vec, expected: bool, } + /// This test checks whether our update proof works for different scenarios. We start /// with a (valid) cached proof, then we receive `blocks`, like we would in normal Bitcoin /// but for this test, block is just random data. For each block we update our Stump and @@ -893,19 +967,19 @@ mod tests { .cached_proof .hashes .iter() - .map(|val| NodeHash::from_str(val).unwrap()) + .map(|val| BitcoinNodeHash::from_str(val).unwrap()) .collect(); let cached_hashes: Vec<_> = case_values .cached_hashes .iter() - .map(|val| NodeHash::from_str(val).unwrap()) + .map(|val| BitcoinNodeHash::from_str(val).unwrap()) .collect(); let cached_proof = Proof::new(case_values.cached_proof.targets, proof_hashes); let roots = case_values .initial_roots .into_iter() - .map(|hash| NodeHash::from_str(&hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(&hash).unwrap()) .collect(); let stump = Stump { @@ -923,14 +997,14 @@ mod tests { .update .del_hashes .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect::>(); let block_proof_hashes = case_values .update .proof .hashes .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect::>(); let block_proof = @@ -951,13 +1025,13 @@ mod tests { let expected_roots: Vec<_> = case_values .expected_roots .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect(); let expected_cached_hashes: Vec<_> = case_values .expected_cached_hashes .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect(); assert_eq!(res, Ok(true)); assert_eq!(cached_proof.targets, case_values.expected_targets); @@ -969,29 +1043,54 @@ mod tests { #[test] fn test_get_next() { use super::Proof; - let computed = vec![(1, NodeHash::empty()), (3, NodeHash::empty())]; - let provided = vec![(2, NodeHash::empty()), (4, NodeHash::empty())]; + let computed = vec![(1, BitcoinNodeHash::empty()), (3, NodeHash::empty())]; + let provided = vec![(2, BitcoinNodeHash::empty()), (4, NodeHash::empty())]; let mut computed_pos = 0; let mut provided_pos = 0; assert_eq!( - Proof::get_next(&computed, &provided, &mut computed_pos, &mut provided_pos), + Proof::::get_next( + &computed, + &provided, + &mut computed_pos, + &mut provided_pos + ), Some((1, NodeHash::empty())) ); assert_eq!( - Proof::get_next(&computed, &provided, &mut computed_pos, &mut provided_pos), + Proof::::get_next( + &computed, + &provided, + &mut computed_pos, + &mut provided_pos + ), Some((2, NodeHash::empty())) ); assert_eq!( - Proof::get_next(&computed, &provided, &mut computed_pos, &mut provided_pos), + Proof::::get_next( + &computed, + &provided, + &mut computed_pos, + &mut provided_pos + ), Some((3, NodeHash::empty())) ); assert_eq!( - Proof::get_next(&computed, &provided, &mut computed_pos, &mut provided_pos), + Proof::::get_next( + &computed, + &provided, + &mut computed_pos, + &mut provided_pos + ), Some((4, NodeHash::empty())) ); assert_eq!( - Proof::get_next(&computed, &provided, &mut computed_pos, &mut provided_pos), + Proof::::get_next( + &computed, + &provided, + &mut computed_pos, + &mut provided_pos + ), None ); } @@ -1004,11 +1103,11 @@ mod tests { struct Test { name: &'static str, block_targets: Vec, - old_positions: Vec<(u64, NodeHash)>, + old_positions: Vec<(u64, BitcoinNodeHash)>, num_leaves: u64, num_adds: u64, append_roots: bool, - expected: Vec<(u64, NodeHash)>, + expected: Vec<(u64, BitcoinNodeHash)>, } let tests = vec![Test { @@ -1017,28 +1116,28 @@ mod tests { old_positions: vec![ ( 1, - NodeHash::from_str( + BitcoinNodeHash::from_str( "4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a", ) .unwrap(), ), ( 13, - NodeHash::from_str( + BitcoinNodeHash::from_str( "9d1e0e2d9459d06523ad13e28a4093c2316baafe7aec5b25f30eba2e113599c4", ) .unwrap(), ), ( 17, - NodeHash::from_str( + BitcoinNodeHash::from_str( "9576f4ade6e9bc3a6458b506ce3e4e890df29cb14cb5d3d887672aef55647a2b", ) .unwrap(), ), ( 25, - NodeHash::from_str( + BitcoinNodeHash::from_str( "29590a14c1b09384b94a2c0e94bf821ca75b62eacebc47893397ca88e3bbcbd7", ) .unwrap(), @@ -1050,28 +1149,28 @@ mod tests { expected: (vec![ ( 1, - NodeHash::from_str( + BitcoinNodeHash::from_str( "4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a", ) .unwrap(), ), ( 17, - NodeHash::from_str( + BitcoinNodeHash::from_str( "9576f4ade6e9bc3a6458b506ce3e4e890df29cb14cb5d3d887672aef55647a2b", ) .unwrap(), ), ( 21, - NodeHash::from_str( + BitcoinNodeHash::from_str( "9d1e0e2d9459d06523ad13e28a4093c2316baafe7aec5b25f30eba2e113599c4", ) .unwrap(), ), ( 25, - NodeHash::from_str( + BitcoinNodeHash::from_str( "29590a14c1b09384b94a2c0e94bf821ca75b62eacebc47893397ca88e3bbcbd7", ) .unwrap(), @@ -1091,6 +1190,7 @@ mod tests { assert_eq!(res, test.expected, "testcase: \"{}\" fail", test.name); } } + #[test] fn test_update_proof_delete() { let preimages = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; @@ -1107,7 +1207,7 @@ mod tests { ]; let proof_hashes = proof_hashes .into_iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect(); let cached_proof_hashes = [ @@ -1117,7 +1217,7 @@ mod tests { ]; let cached_proof_hashes = cached_proof_hashes .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect(); let cached_proof = Proof::new(vec![0, 1, 7], cached_proof_hashes); @@ -1142,6 +1242,7 @@ mod tests { let res = stump.verify(&new_proof, &[hash_from_u8(0), hash_from_u8(7)]); assert_eq!(res, Ok(true)); } + #[test] fn test_calculate_hashes() { // Tests if the calculated roots and nodes are correct. @@ -1164,7 +1265,7 @@ mod tests { ]; let proof_hashes = proof .into_iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect(); let p = Proof::new(vec![0, 2, 4, 6], proof_hashes); @@ -1191,12 +1292,12 @@ mod tests { let expected_roots: Vec<_> = expected_roots .iter() - .map(|root| NodeHash::from_str(root).unwrap()) + .map(|root| BitcoinNodeHash::from_str(root).unwrap()) .collect(); let mut expected_computed = expected_hashes .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .zip(&expected_pos); let calculated = p.calculate_hashes(&del_hashes, s.leaves); @@ -1235,22 +1336,24 @@ mod tests { let proof_hashes = proof .into_iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect(); let p = Proof::new(vec![0], proof_hashes); let del_hashes = del_hashes .into_iter() - .map(|hash| (hash, NodeHash::empty())) + .map(|hash| (hash, BitcoinNodeHash::empty())) .collect::>(); let (computed, roots) = p.calculate_hashes_delete(&del_hashes, 8).unwrap(); - let expected_root_old = - NodeHash::from_str("b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42") - .unwrap(); - let expected_root_new = - NodeHash::from_str("726fdd3b432cc59e68487d126e70f0db74a236267f8daeae30b31839a4e7ebed") - .unwrap(); + let expected_root_old = BitcoinNodeHash::from_str( + "b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42", + ) + .unwrap(); + let expected_root_new = BitcoinNodeHash::from_str( + "726fdd3b432cc59e68487d126e70f0db74a236267f8daeae30b31839a4e7ebed", + ) + .unwrap(); let computed_positions = [0_u64, 1, 9, 13, 8, 12, 14].to_vec(); let computed_hashes = [ @@ -1263,7 +1366,7 @@ mod tests { "726fdd3b432cc59e68487d126e70f0db74a236267f8daeae30b31839a4e7ebed", ] .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect::>(); let expected_computed: Vec<_> = computed_positions .into_iter() @@ -1282,6 +1385,7 @@ mod tests { let deserialized = Proof::deserialize(&mut serialized.as_slice()).unwrap(); assert_eq!(p, deserialized); } + #[test] fn test_get_proof_subset() { // Tests if the calculated roots and nodes are correct. @@ -1304,7 +1408,7 @@ mod tests { ]; let proof_hashes = proof .into_iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect(); let p = Proof::new(vec![0, 2, 4, 6], proof_hashes); @@ -1314,6 +1418,7 @@ mod tests { assert_eq!(s.verify(&subset, &[del_hashes[0]]), Ok(true)); assert_eq!(s.verify(&subset, &[del_hashes[2]]), Ok(false)); } + #[test] #[cfg(feature = "with-serde")] fn test_serde_rtt() { @@ -1324,12 +1429,13 @@ mod tests { serde_json::from_str(&serialized).expect("Deserialization failed"); assert_eq!(proof, deserialized); } + fn run_single_case(case: &serde_json::Value) { let case = serde_json::from_value::(case.clone()).expect("Invalid test case"); let roots = case .roots .into_iter() - .map(|root| NodeHash::from_str(root.as_str()).expect("Test case hash is valid")) + .map(|root| BitcoinNodeHash::from_str(root.as_str()).expect("Test case hash is valid")) .collect(); let s = Stump { @@ -1347,7 +1453,7 @@ mod tests { let proof_hashes = case .proofhashes .into_iter() - .map(|hash| NodeHash::from_str(hash.as_str()).expect("Test case hash is valid")) + .map(|hash| BitcoinNodeHash::from_str(hash.as_str()).expect("Test case hash is valid")) .collect(); let p = Proof::new(targets, proof_hashes); @@ -1357,16 +1463,17 @@ mod tests { assert!(Ok(expected) == res); // Test getting proof subset (only if the original proof is valid) if expected { - let (subset, _) = p.targets.split_at(p.targets() / 2); + let (subset, _) = p.targets.split_at(p.n_targets() / 2); let proof = p.get_proof_subset(&del_hashes, subset, s.leaves).unwrap(); let set_hashes = subset .iter() .map(|preimage| hash_from_u8(*preimage as u8)) - .collect::>(); + .collect::>(); assert_eq!(s.verify(&proof, &set_hashes), Ok(true)); } } + #[test] fn test_proof_verify() { let contents = std::fs::read_to_string("test_values/test_cases.json") @@ -1411,6 +1518,7 @@ mod bench { bencher.iter(|| proof.calculate_hashes(&cached_hashes, stump.leaves)) } + #[bench] fn bench_proof_update(bencher: &mut Bencher) { let preimages = [0_u8, 1, 2, 3, 4, 5]; diff --git a/src/accumulator/stump.rs b/src/accumulator/stump.rs index 46613cf..4ccd7ea 100644 --- a/src/accumulator/stump.rs +++ b/src/accumulator/stump.rs @@ -6,13 +6,13 @@ //! ``` //! use std::str::FromStr; //! -//! use rustreexo::accumulator::node_hash::NodeHash; +//! use rustreexo::accumulator::node_hash::BitcoinNodeHash; //! use rustreexo::accumulator::proof::Proof; //! use rustreexo::accumulator::stump::Stump; //! // Create a new empty Stump //! let s = Stump::new(); //! // The newly create outputs -//! let utxos = vec![NodeHash::from_str( +//! let utxos = vec![BitcoinNodeHash::from_str( //! "b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42", //! ) //! .unwrap()]; @@ -21,7 +21,7 @@ //! // Modify the Stump, adding the new outputs and removing the spent ones, notice how //! // it returns a new Stump, instead of modifying the old one. This is due to the fact //! // that modify is a pure function that doesn't modify the old Stump. -//! let s = s.modify(&utxos, &stxos, &Proof::default()); +//! let s = s.modify(&utxos, &stxos, &Proof::::default()); //! assert!(s.is_ok()); //! assert_eq!(s.unwrap().0.roots, utxos); //! ``` @@ -35,28 +35,35 @@ use serde::Deserialize; #[cfg(feature = "with-serde")] use serde::Serialize; +use super::node_hash::BitcoinNodeHash; use super::node_hash::NodeHash; use super::proof::NodesAndRootsOldNew; use super::proof::Proof; use super::util; -#[derive(Debug, Clone, PartialEq, Default)] +#[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))] -pub struct Stump { +pub struct Stump { pub leaves: u64, - pub roots: Vec, + pub roots: Vec, +} + +impl Default for Stump { + fn default() -> Self { + Stump::new() + } } #[derive(Debug, Clone, Default)] -pub struct UpdateData { +pub struct UpdateData { /// to_destroy is the positions of the empty roots removed after the add. pub(crate) to_destroy: Vec, /// pre_num_leaves is the numLeaves of the stump before the add. pub(crate) prev_num_leaves: u64, /// new_add are the new hashes for the newly created roots after the addition. - pub(crate) new_add: Vec<(u64, NodeHash)>, + pub(crate) new_add: Vec<(u64, Hash)>, /// new_del are the new hashes after the deletion. - pub(crate) new_del: Vec<(u64, NodeHash)>, + pub(crate) new_del: Vec<(u64, Hash)>, } impl Stump { @@ -72,9 +79,87 @@ impl Stump { roots: Vec::new(), } } - pub fn verify(&self, proof: &Proof, del_hashes: &[NodeHash]) -> Result { + + /// Serialize the Stump into a byte array + /// # Example + /// ``` + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; + /// use rustreexo::accumulator::proof::Proof; + /// use rustreexo::accumulator::stump::Stump; + /// let hashes = [0, 1, 2, 3, 4, 5, 6, 7] + /// .iter() + /// .map(|&el| BitcoinNodeHash::from([el; 32])) + /// .collect::>(); + /// let (stump, _) = Stump::new() + /// .modify(&hashes, &[], &Proof::default()) + /// .unwrap(); + /// let mut writer = Vec::new(); + /// stump.serialize(&mut writer).unwrap(); + /// assert_eq!( + /// writer, + /// vec![ + /// 8, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 150, 124, 244, 241, 98, 69, 217, + /// 222, 235, 97, 61, 137, 135, 76, 197, 134, 232, 173, 253, 8, 28, 17, 124, 123, 16, 4, + /// 66, 30, 63, 113, 246, 74, + /// ] + /// ); + /// ``` + pub fn serialize(&self, mut writer: &mut Dst) -> std::io::Result { + let mut len = 8; + writer.write_all(&self.leaves.to_le_bytes())?; + writer.write_all(&self.roots.len().to_le_bytes())?; + for root in self.roots.iter() { + len += 32; + root.write(&mut writer)?; + } + Ok(len) + } + + /// Rewinds old tree state, this should be used in case of reorgs. + /// Takes the ownership over `old_state`. + ///# Example + /// ``` + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; + /// use rustreexo::accumulator::proof::Proof; + /// use rustreexo::accumulator::stump::Stump; + /// + /// let s_old = Stump::new(); + /// let mut s_new = Stump::new(); + /// + /// let s_old = s_old.modify(&vec![], &vec![], &Proof::default()).unwrap().0; + /// s_new = s_old.clone(); + /// s_new = s_new.modify(&vec![], &vec![], &Proof::default()).unwrap().0; + /// + /// // A reorg happened + /// s_new.undo(s_old); + /// ``` + pub fn undo(&mut self, old_state: Stump) { + self.leaves = old_state.leaves; + self.roots = old_state.roots; + } +} + +impl Stump { + /// Verifies the proof against the Stump. The proof is a list of hashes that are used to + /// recompute the root of the accumulator. The del_hashes are the hashes that are being + /// deleted from the accumulator. + /// // TODO: Add example + pub fn verify(&self, proof: &Proof, del_hashes: &[Hash]) -> Result { proof.verify(del_hashes, &self.roots, self.leaves) } + + /// Creates a new Stump with a custom hash type + /// + /// If you need to use a hash type that's not the [BitcoinNodeHash], you can use this + /// function to create a new Stump with the desired hash type. Use [BitcoinNodeHash::new] + /// to create a new Stump with the default hash type. + pub fn new_with_hash() -> Stump { + Stump { + leaves: 0, + roots: Vec::new(), + } + } + /// Modify is the external API to change the accumulator state. Since order /// matters, you can only modify, providing a list of utxos to be added, /// and txos to be removed, along with it's proof. Either may be @@ -83,12 +168,12 @@ impl Stump { /// ``` /// use std::str::FromStr; /// - /// use rustreexo::accumulator::node_hash::NodeHash; + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; /// use rustreexo::accumulator::proof::Proof; /// use rustreexo::accumulator::stump::Stump; /// /// let s = Stump::new(); - /// let utxos = vec![NodeHash::from_str( + /// let utxos = vec![BitcoinNodeHash::from_str( /// "b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42", /// ) /// .unwrap()]; @@ -99,10 +184,10 @@ impl Stump { /// ``` pub fn modify( &self, - utxos: &[NodeHash], - del_hashes: &[NodeHash], - proof: &Proof, - ) -> Result<(Stump, UpdateData), String> { + utxos: &[Hash], + del_hashes: &[Hash], + proof: &Proof, + ) -> Result<(Stump, UpdateData), String> { let (intermediate, mut computed_roots) = self.remove(del_hashes, proof)?; let mut new_roots = vec![]; @@ -139,101 +224,48 @@ impl Stump { Ok((new_stump, update_data)) } - /// Serialize the Stump into a byte array - /// # Example - /// ``` - /// use rustreexo::accumulator::node_hash::NodeHash; - /// use rustreexo::accumulator::proof::Proof; - /// use rustreexo::accumulator::stump::Stump; - /// let hashes = [0, 1, 2, 3, 4, 5, 6, 7] - /// .iter() - /// .map(|&el| NodeHash::from([el; 32])) - /// .collect::>(); - /// let (stump, _) = Stump::new() - /// .modify(&hashes, &[], &Proof::default()) - /// .unwrap(); - /// let mut writer = Vec::new(); - /// stump.serialize(&mut writer).unwrap(); - /// assert_eq!( - /// writer, - /// vec![ - /// 8, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 150, 124, 244, 241, 98, 69, 217, 222, - /// 235, 97, 61, 137, 135, 76, 197, 134, 232, 173, 253, 8, 28, 17, 124, 123, 16, 4, 66, 30, - /// 63, 113, 246, 74, - /// ] - /// ); - /// ``` - pub fn serialize(&self, writer: &mut Dst) -> std::io::Result { - let mut len = 8; - writer.write_all(&self.leaves.to_le_bytes())?; - writer.write_all(&self.roots.len().to_le_bytes())?; - for root in self.roots.iter() { - len += 32; - writer.write_all(&**root)?; - } - Ok(len) - } + /// Deserialize the Stump from a Reader /// # Example /// ``` - /// use rustreexo::accumulator::node_hash::NodeHash; + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; /// use rustreexo::accumulator::proof::Proof; /// use rustreexo::accumulator::stump::Stump; /// let buffer = vec![ - /// 8, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 150, 124, 244, 241, 98, 69, 217, 222, 235, - /// 97, 61, 137, 135, 76, 197, 134, 232, 173, 253, 8, 28, 17, 124, 123, 16, 4, 66, 30, 63, 113, - /// 246, 74, + /// 8, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 150, 124, 244, 241, 98, 69, 217, 222, + /// 235, 97, 61, 137, 135, 76, 197, 134, 232, 173, 253, 8, 28, 17, 124, 123, 16, 4, 66, 30, 63, + /// 113, 246, 74, /// ]; /// let mut buffer = std::io::Cursor::new(buffer); /// let hashes = [0, 1, 2, 3, 4, 5, 6, 7] /// .iter() - /// .map(|&el| NodeHash::from([el; 32])) + /// .map(|&el| BitcoinNodeHash::from([el; 32])) /// .collect::>(); /// let (stump, _) = Stump::new() /// .modify(&hashes, &[], &Proof::default()) /// .unwrap(); - /// assert_eq!(stump, Stump::deserialize(buffer).unwrap()); + /// assert_eq!( + /// stump, + /// Stump::::deserialize(buffer).unwrap() + /// ); /// ``` - pub fn deserialize(data: Source) -> Result { - let mut data = data; + pub fn deserialize(mut data: Source) -> Result { let leaves = util::read_u64(&mut data)?; let roots_len = util::read_u64(&mut data)?; let mut roots = vec![]; for _ in 0..roots_len { - let mut root = [0u8; 32]; - data.read_exact(&mut root).map_err(|e| e.to_string())?; - - roots.push(NodeHash::from(root)); + let root = Hash::read(&mut data).map_err(|e| e.to_string())?; + roots.push(root); } Ok(Stump { leaves, roots }) } - /// Rewinds old tree state, this should be used in case of reorgs. - /// Takes the ownership over `old_state`. - ///# Example - /// ``` - /// use rustreexo::accumulator::proof::Proof; - /// use rustreexo::accumulator::stump::Stump; - /// let s_old = Stump::new(); - /// let mut s_new = Stump::new(); - /// - /// let s_old = s_old.modify(&vec![], &vec![], &Proof::default()).unwrap().0; - /// s_new = s_old.clone(); - /// s_new = s_new.modify(&vec![], &vec![], &Proof::default()).unwrap().0; - /// - /// // A reorg happened - /// s_new.undo(s_old); - /// ``` - pub fn undo(&mut self, old_state: Stump) { - self.leaves = old_state.leaves; - self.roots = old_state.roots; - } fn remove( &self, - del_hashes: &[NodeHash], - proof: &Proof, - ) -> Result { + del_hashes: &[Hash], + proof: &Proof, + ) -> Result, String> { if del_hashes.is_empty() { return Ok(( vec![], @@ -243,19 +275,20 @@ impl Stump { let del_hashes = del_hashes .iter() - .map(|hash| (*hash, NodeHash::empty())) + .map(|hash| (*hash, Hash::empty())) .collect::>(); + proof.calculate_hashes_delete(&del_hashes, self.leaves) } /// Adds new leaves into the root fn add( - mut roots: Vec, - utxos: &[NodeHash], + mut roots: Vec, + utxos: &[Hash], mut leaves: u64, - ) -> (Vec, Vec<(u64, NodeHash)>, Vec) { + ) -> (Vec, Vec<(u64, Hash)>, Vec) { let after_rows = util::tree_rows(leaves + (utxos.len() as u64)); - let mut updated_subtree: Vec<(u64, NodeHash)> = vec![]; + let mut updated_subtree: Vec<(u64, Hash)> = vec![]; let all_deleted = util::roots_to_destroy(utxos.len() as u64, leaves, &roots); for (i, add) in utxos.iter().enumerate() { @@ -306,12 +339,16 @@ impl Stump { #[cfg(test)] mod test { + use std::fmt::Display; + use std::io::Read; + use std::io::Write; use std::str::FromStr; use std::vec; use serde::Deserialize; use super::Stump; + use crate::accumulator::node_hash::BitcoinNodeHash; use crate::accumulator::node_hash::NodeHash; use crate::accumulator::proof::Proof; use crate::accumulator::util::hash_from_u8; @@ -331,6 +368,75 @@ mod test { assert!(s.leaves == 0); assert!(s.roots.is_empty()); } + + #[test] + fn test_custom_hash_type() { + #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)] + struct CustomHash([u8; 32]); + + impl Display for CustomHash { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.0) + } + } + + impl NodeHash for CustomHash { + fn empty() -> Self { + CustomHash([0; 32]) + } + fn is_empty(&self) -> bool { + self.0.iter().all(|&x| x == 0) + } + fn parent_hash(left: &Self, right: &Self) -> Self { + let mut hash = [0; 32]; + for i in 0..32 { + hash[i] = left.0[i] ^ right.0[i]; + } + CustomHash(hash) + } + fn read(reader: &mut R) -> std::io::Result { + let mut hash = [0; 32]; + reader + .read_exact(&mut hash) + .map_err(|e| e.to_string()) + .unwrap(); + Ok(CustomHash(hash)) + } + fn write(&self, writer: &mut W) -> std::io::Result<()> { + writer + .write_all(&self.0) + .map_err(|e| e.to_string()) + .unwrap(); + Ok(()) + } + fn is_placeholder(&self) -> bool { + false + } + fn placeholder() -> Self { + CustomHash([0; 32]) + } + } + + let s = Stump::::new_with_hash(); + assert!(s.leaves == 0); + assert!(s.roots.is_empty()); + + let hashes = [0, 1, 2, 3, 4, 5, 6, 7] + .iter() + .map(|&el| CustomHash([el; 32])) + .collect::>(); + + let (stump, _) = s + .modify( + &hashes, + &[], + &Proof::::new_with_hash(Vec::new(), Vec::new()), + ) + .unwrap(); + assert_eq!(stump.leaves, 8); + assert_eq!(stump.roots.len(), 1); + } + #[test] fn test_updated_data() { /// This test initializes a Stump, with some utxos. Then, we add a couple more utxos @@ -366,7 +472,7 @@ mod test { let roots = data .roots .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect(); let stump = Stump { leaves: data.leaves, @@ -381,12 +487,12 @@ mod test { let del_hashes = data .del_hashes .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect::>(); let proof_hashes = data .proof_hashes .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect::>(); let proof = Proof::new(data.proof_targets, proof_hashes); let (_, updated) = stump.modify(&utxos, &del_hashes, &proof).unwrap(); @@ -394,7 +500,7 @@ mod test { let new_add_hash: Vec<_> = data .new_add_hash .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect(); let new_add: Vec<_> = data .new_add_pos @@ -405,7 +511,7 @@ mod test { let new_del_hash: Vec<_> = data .new_del_hashes .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect(); let new_del: Vec<_> = data .new_del_pos @@ -421,6 +527,7 @@ mod test { } } } + #[test] fn test_update_data_add() { let preimages = vec![0, 1, 2, 3]; @@ -442,7 +549,7 @@ mod test { "df46b17be5f66f0750a4b3efa26d4679db170a72d41eb56c3e4ff75a58c65386", ] .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect(); let positions: Vec<_> = positions.into_iter().zip(hashes).collect(); @@ -462,6 +569,7 @@ mod test { serde_json::from_str(&serialized).expect("Deserialization failed"); assert_eq!(stump, deserialized); } + fn run_case_with_deletion(case: TestCase) { let leaf_hashes = case .leaf_preimages @@ -481,7 +589,9 @@ mod test { .proofhashes .unwrap_or_default() .into_iter() - .map(|hash| NodeHash::from_str(hash.as_str()).expect("Test case hashes are valid")) + .map(|hash| { + BitcoinNodeHash::from_str(hash.as_str()).expect("Test case hashes are valid") + }) .collect::>(); let proof = Proof::new(case.target_values.unwrap(), proof_hashes); @@ -489,8 +599,10 @@ mod test { let roots = case .expected_roots .into_iter() - .map(|hash| NodeHash::from_str(hash.as_str()).expect("Test case hashes are valid")) - .collect::>(); + .map(|hash| { + BitcoinNodeHash::from_str(hash.as_str()).expect("Test case hashes are valid") + }) + .collect::>(); let (stump, _) = Stump::new() .modify(&leaf_hashes, &[], &Proof::default()) @@ -519,6 +631,7 @@ mod test { assert_eq!(root, s.roots[i].to_string()); } } + #[test] fn test_undo() { let mut hashes = vec![]; @@ -555,7 +668,7 @@ mod test { fn test_serialize() { let hashes = [0, 1, 2, 3, 4, 5, 6, 7] .iter() - .map(|&el| NodeHash::from([el; 32])) + .map(|&el| BitcoinNodeHash::from([el; 32])) .collect::>(); let (stump, _) = Stump::new() .modify(&hashes, &[], &Proof::default()) @@ -567,6 +680,7 @@ mod test { let stump2 = Stump::deserialize(&mut reader).unwrap(); assert_eq!(stump, stump2); } + #[test] fn run_test_cases() { #[derive(Deserialize)] @@ -620,6 +734,7 @@ mod bench { let _ = Stump::new().modify(&hash, &[], &proof); }); } + #[bench] fn bench_add_del(bencher: &mut Bencher) { let leaf_preimages = [0_u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; diff --git a/src/accumulator/util.rs b/src/accumulator/util.rs index b9a90b9..8ae6e4c 100644 --- a/src/accumulator/util.rs +++ b/src/accumulator/util.rs @@ -89,7 +89,11 @@ pub fn left_sibling(position: u64) -> u64 { } // roots_to_destroy returns the empty roots that get written over after num_adds // amount of leaves have been added. -pub fn roots_to_destroy(num_adds: u64, mut num_leaves: u64, orig_roots: &[NodeHash]) -> Vec { +pub fn roots_to_destroy( + num_adds: u64, + mut num_leaves: u64, + orig_roots: &[Hash], +) -> Vec { let mut roots = orig_roots.to_vec(); let mut deleted = vec![]; let mut h = 0; @@ -233,7 +237,7 @@ pub fn parent_many(pos: u64, rise: u8, forest_rows: u8) -> Result { rise, forest_rows )); } - let mask = ((2_u64 << forest_rows) - 1) as u64; + let mask: u64 = (2_u64 << forest_rows) - 1_u64; Ok((pos >> rise | (mask << (forest_rows - (rise - 1)) as u64)) & mask) } @@ -304,24 +308,27 @@ pub fn get_proof_positions(targets: &[u64], num_leaves: u64, forest_rows: u8) -> proof_positions } + #[cfg(any(test, bench))] -pub fn hash_from_u8(value: u8) -> NodeHash { +pub fn hash_from_u8(value: u8) -> super::node_hash::BitcoinNodeHash { use bitcoin_hashes::sha256; use bitcoin_hashes::Hash; use bitcoin_hashes::HashEngine; + let mut engine = bitcoin_hashes::sha256::Hash::engine(); engine.input(&[value]); sha256::Hash::from_engine(engine).into() } + #[cfg(test)] mod tests { use std::str::FromStr; use std::vec; use super::roots_to_destroy; - use crate::accumulator::node_hash::NodeHash; + use crate::accumulator::node_hash::BitcoinNodeHash; use crate::accumulator::util::children; use crate::accumulator::util::tree_rows; @@ -338,6 +345,7 @@ mod tests { super::get_proof_positions(&sorted, num_leaves, num_rows) ); } + #[test] fn test_is_sibling() { assert!(super::is_sibling(0, 1)); @@ -345,6 +353,7 @@ mod tests { assert!(!super::is_sibling(1, 2)); assert!(!super::is_sibling(2, 1)); } + #[test] fn test_root_position() { let pos = super::root_position(5, 2, 3); @@ -353,10 +362,12 @@ mod tests { let pos = super::root_position(5, 0, 3); assert_eq!(pos, 4); } + #[test] fn test_is_right_sibling() { assert!(super::is_right_sibling(0, 1)); } + #[test] fn test_roots_to_destroy() { let roots = [ @@ -367,13 +378,14 @@ mod tests { ]; let roots = roots .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect::>(); let deleted = roots_to_destroy(1, 15, &roots); assert_eq!(deleted, vec![22, 28]) } + #[test] fn test_remove_bit() { // This should remove just one bit from the final number @@ -388,6 +400,7 @@ mod tests { let res = super::remove_bit(14, 1); assert_eq!(res, 6); } + #[test] fn test_detwin() { // 14 @@ -413,12 +426,14 @@ mod tests { assert_eq!(super::tree_rows(12), 4); assert_eq!(super::tree_rows(255), 8); } + fn row_offset(row: u8, forest_rows: u8) -> u64 { // 2 << forestRows is 2 more than the max position // to get the correct offset for a given row, // subtract (2 << `row complement of forestRows`) from (2 << forestRows) (2 << forest_rows) - (2 << (forest_rows - row)) } + #[test] fn test_detect_row() { for forest_rows in 1..63 { @@ -437,8 +452,8 @@ mod tests { } } } - #[test] + #[test] fn test_get_proof_positions() { let targets: Vec = vec![4, 5, 7, 8]; let num_leaves = 8; @@ -447,11 +462,13 @@ mod tests { assert_eq!(vec![6, 9], targets); } + #[test] fn test_is_root_position() { let h = super::is_root_position(14, 8, 3); assert!(h); } + #[test] fn test_children_pos() { assert_eq!(children(4, 2), 0); @@ -459,6 +476,7 @@ mod tests { assert_eq!(children(50, 5), 36); assert_eq!(children(44, 5), 24); } + #[test] fn test_calc_next_pos() { let res = super::calc_next_pos(0, 1, 3);