diff --git a/Cargo.toml b/Cargo.toml index b7f2637..95c8044 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ serde = { version = "1.0", features = ["derive"], optional = true } [dev-dependencies] serde = { version = "1.0", features = ["derive"] } serde_json = "1.0.81" +starknet-crypto = "0.7.2" [features] with-serde = ["serde"] @@ -27,3 +28,6 @@ name = "simple-stump-update" [[example]] name = "proof-update" + +[[example]] +name = "custom-hash-type" diff --git a/examples/custom-hash-type.rs b/examples/custom-hash-type.rs new file mode 100644 index 0000000..ceefaaa --- /dev/null +++ b/examples/custom-hash-type.rs @@ -0,0 +1,143 @@ +//! All data structures in this library are generic over the hash type used, defaulting to +//! [BitcoinNodeHash](crate::accumulator::node_hash::BitcoinNodeHash), the one used by Bitcoin +//! as defined by the utreexo spec. However, if you need to use a different hash type, you can +//! implement the [NodeHash](crate::accumulator::node_hash::NodeHash) trait for it, and use it +//! with the accumulator data structures. +//! +//! This example shows how to use a custom hash type based on the Poseidon hash function. The +//! [Poseidon Hash](https://eprint.iacr.org/2019/458.pdf) is a hash function that is optmized +//! for zero-knowledge proofs, and is used in projects like ZCash and StarkNet. +//! If you want to work with utreexo proofs in zero-knowledge you may want to use this instead +//! of our usual sha512-256 that we use by default, since that will give you smaller circuits. +//! This example shows how to use both the [Pollard](crate::accumulator::pollard::Pollard) and +//! proofs with a custom hash type. The code here should be pretty much all you need to do to +//! use your custom hashes, just tweak the implementation of +//! [NodeHash](crate::accumulator::node_hash::NodeHash) for your hash type. + +use rustreexo::accumulator::node_hash::AccumulatorHash; +use rustreexo::accumulator::pollard::Pollard; +use starknet_crypto::poseidon_hash_many; +use starknet_crypto::Felt; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +/// We need a stateful wrapper around the actual hash, this is because we use those different +/// values inside our accumulator. Here we use an enum to represent the different states, you +/// may want to use a struct with more data, depending on your needs. +enum PoseidonHash { + /// This means this holds an actual value + /// + /// It usually represents a node in the accumulator that haven't been deleted. + Hash(Felt), + /// Placeholder is a value that haven't been deleted, but we don't have the actual value. + /// The only thing that matters about it is that it's not empty. You can implement this + /// the way you want, just make sure that [NodeHash::is_placeholder] and [NodeHash::placeholder] + /// returns sane values (that is, if we call [NodeHash::placeholder] calling [NodeHash::is_placeholder] + /// on the result should return true). + Placeholder, + /// This is an empty value, it represents a node that was deleted from the accumulator. + /// + /// Same as the placeholder, you can implement this the way you want, just make sure that + /// [NodeHash::is_empty] and [NodeHash::empty] returns sane values. + Empty, +} + +// you'll need to implement Display for your hash type, so you can print it. +impl std::fmt::Display for PoseidonHash { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PoseidonHash::Hash(h) => write!(f, "Hash({})", h), + PoseidonHash::Placeholder => write!(f, "Placeholder"), + PoseidonHash::Empty => write!(f, "Empty"), + } + } +} + +// this is the implementation of the NodeHash trait for our custom hash type. And it's the only +// thing you need to do to use your custom hash type with the accumulator data structures. +impl AccumulatorHash for PoseidonHash { + // returns a new placeholder type such that is_placeholder returns true + fn placeholder() -> Self { + PoseidonHash::Placeholder + } + + // returns an empty hash such that is_empty returns true + fn empty() -> Self { + PoseidonHash::Empty + } + + // returns true if this is a placeholder. This should be true iff this type was created by + // calling placeholder. + fn is_placeholder(&self) -> bool { + matches!(self, PoseidonHash::Placeholder) + } + + // returns true if this is an empty hash. This should be true iff this type was created by + // calling empty. + fn is_empty(&self) -> bool { + matches!(self, PoseidonHash::Empty) + } + + // used for serialization, writes the hash to the writer + // + // if you don't want to use serialization, you can just return an error here. + fn write(&self, writer: &mut W) -> std::io::Result<()> + where + W: std::io::Write, + { + match self { + PoseidonHash::Hash(h) => writer.write_all(&h.to_bytes_be()), + PoseidonHash::Placeholder => writer.write_all(&[0u8; 32]), + PoseidonHash::Empty => writer.write_all(&[0u8; 32]), + } + } + + // used for deserialization, reads the hash from the reader + // + // if you don't want to use serialization, you can just return an error here. + fn read(reader: &mut R) -> std::io::Result + where + R: std::io::Read, + { + let mut buf = [0u8; 32]; + reader.read_exact(&mut buf)?; + if buf.iter().all(|&b| b == 0) { + Ok(PoseidonHash::Empty) + } else { + Ok(PoseidonHash::Hash(Felt::from_bytes_be(&buf))) + } + } + + // the main thing about the hash type, it returns the next node's hash, given it's children. + // The implementation of this method is highly consensus critical, so everywhere should use the + // exact same algorithm to calculate the next hash. Rustreexo won't call this method, unless + // **both** children are not empty. + fn parent_hash(left: &Self, right: &Self) -> Self { + if let (PoseidonHash::Hash(left), PoseidonHash::Hash(right)) = (left, right) { + return PoseidonHash::Hash(poseidon_hash_many(&[*left, *right])); + } + + // This should never happen, since rustreexo won't call this method unless both children + // are not empty. + unreachable!() + } +} + +fn main() { + // Create a vector with two utxos that will be added to the Pollard + let elements = vec![ + PoseidonHash::Hash(Felt::from(1)), + PoseidonHash::Hash(Felt::from(2)), + ]; + + // Create a new Pollard, and add the utxos to it + let mut p = Pollard::::new_with_hash(); + p.modify(&elements, &[]).unwrap(); + + // Create a proof that the first utxo is in the Pollard + let proof = p.prove(&[elements[0]]).unwrap(); + + // check that the proof has exactly one target + assert_eq!(proof.n_targets(), 1); + // check that the proof is what we expect + assert!(p.verify(&proof, &[elements[0]]).unwrap()); +} diff --git a/examples/full-accumulator.rs b/examples/full-accumulator.rs index 7e67c1f..2d895d2 100644 --- a/examples/full-accumulator.rs +++ b/examples/full-accumulator.rs @@ -4,17 +4,21 @@ use std::str::FromStr; -use rustreexo::accumulator::node_hash::NodeHash; +use rustreexo::accumulator::node_hash::BitcoinNodeHash; use rustreexo::accumulator::pollard::Pollard; use rustreexo::accumulator::proof::Proof; use rustreexo::accumulator::stump::Stump; fn main() { let elements = vec![ - NodeHash::from_str("b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42") - .unwrap(), - NodeHash::from_str("d3bd63d53c5a70050a28612a2f4b2019f40951a653ae70736d93745efb1124fa") - .unwrap(), + BitcoinNodeHash::from_str( + "b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42", + ) + .unwrap(), + BitcoinNodeHash::from_str( + "d3bd63d53c5a70050a28612a2f4b2019f40951a653ae70736d93745efb1124fa", + ) + .unwrap(), ]; // Create a new Pollard, and add the utxos to it let mut p = Pollard::new(); @@ -31,9 +35,10 @@ fn main() { // Now we want to update the Pollard, by removing the first utxo, and adding a new one. // This would be in case we received a new block with a transaction spending the first utxo, // and creating a new one. - let new_utxo = - NodeHash::from_str("cac74661f4944e6e1fed35df40da951c6e151e7b0c8d65c3ee37d6dfd3bc3ef7") - .unwrap(); + let new_utxo = BitcoinNodeHash::from_str( + "cac74661f4944e6e1fed35df40da951c6e151e7b0c8d65c3ee37d6dfd3bc3ef7", + ) + .unwrap(); p.modify(&[new_utxo], &[elements[0]]).unwrap(); // Now we can prove that the new utxo is in the Pollard. diff --git a/examples/proof-update.rs b/examples/proof-update.rs index da4fa6d..cb02cfd 100644 --- a/examples/proof-update.rs +++ b/examples/proof-update.rs @@ -12,7 +12,7 @@ use std::str::FromStr; -use rustreexo::accumulator::node_hash::NodeHash; +use rustreexo::accumulator::node_hash::BitcoinNodeHash; use rustreexo::accumulator::proof::Proof; use rustreexo::accumulator::stump::Stump; @@ -36,7 +36,7 @@ fn main() { .update(vec![], utxos.clone(), vec![], vec![0, 1], update_data) .unwrap(); // This should be a valid proof over 0 and 1. - assert_eq!(p.targets(), 2); + assert_eq!(p.n_targets(), 2); assert_eq!(s.verify(&p, &cached_hashes), Ok(true)); // Get a subset of the proof, for the first UTXO only @@ -65,7 +65,7 @@ fn main() { /// Returns the hashes for UTXOs in the first block in this fictitious example, there's nothing /// special about them, they are just the first 8 integers hashed as u8s. -fn get_utxo_hashes1() -> Vec { +fn get_utxo_hashes1() -> Vec { let hashes = [ "6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d", "4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a", @@ -78,11 +78,11 @@ fn get_utxo_hashes1() -> Vec { ]; hashes .iter() - .map(|h| NodeHash::from_str(h).unwrap()) + .map(|h| BitcoinNodeHash::from_str(h).unwrap()) .collect() } /// Returns the hashes for UTXOs in the second block. -fn get_utxo_hashes2() -> Vec { +fn get_utxo_hashes2() -> Vec { let utxo_hashes = [ "bf4aff60ee0f3b2d82b47b94f6eff3018d1a47d1b0bc5dfbf8d3a95a2836bf5b", "2e6adf10ab3174629fc388772373848bbe277ffee1f72568e6d06e823b39d2dd", @@ -91,6 +91,6 @@ fn get_utxo_hashes2() -> Vec { ]; utxo_hashes .iter() - .map(|h| NodeHash::from_str(h).unwrap()) + .map(|h| BitcoinNodeHash::from_str(h).unwrap()) .collect() } diff --git a/examples/simple-stump-update.rs b/examples/simple-stump-update.rs index 9017724..139b1b9 100644 --- a/examples/simple-stump-update.rs +++ b/examples/simple-stump-update.rs @@ -5,7 +5,7 @@ use std::str::FromStr; use std::vec; -use rustreexo::accumulator::node_hash::NodeHash; +use rustreexo::accumulator::node_hash::BitcoinNodeHash; use rustreexo::accumulator::proof::Proof; use rustreexo::accumulator::stump::Stump; @@ -15,10 +15,14 @@ fn main() { // If we assume this is the very first block, then the Stump is empty, and we can just add // the utxos to it. Assuming a coinbase with two outputs, we would have the following utxos: let utxos = vec![ - NodeHash::from_str("b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42") - .unwrap(), - NodeHash::from_str("d3bd63d53c5a70050a28612a2f4b2019f40951a653ae70736d93745efb1124fa") - .unwrap(), + BitcoinNodeHash::from_str( + "b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42", + ) + .unwrap(), + BitcoinNodeHash::from_str( + "d3bd63d53c5a70050a28612a2f4b2019f40951a653ae70736d93745efb1124fa", + ) + .unwrap(), ]; // Create a new Stump, and add the utxos to it. Notice how we don't use the full return here, // but only the Stump. To understand what is the second return value, see the documentation @@ -34,9 +38,10 @@ fn main() { // Now we want to update the Stump, by removing the first utxo, and adding a new one. // This would be in case we received a new block with a transaction spending the first utxo, // and creating a new one. - let new_utxo = - NodeHash::from_str("d3bd63d53c5a70050a28612a2f4b2019f40951a653ae70736d93745efb1124fa") - .unwrap(); + let new_utxo = BitcoinNodeHash::from_str( + "d3bd63d53c5a70050a28612a2f4b2019f40951a653ae70736d93745efb1124fa", + ) + .unwrap(); let s = s.modify(&[new_utxo], &[utxos[0]], &proof).unwrap().0; // Now we can verify that the new utxo is in the Stump, and the old one is not. let new_proof = Proof::new(vec![2], vec![new_utxo]); diff --git a/src/accumulator/node_hash.rs b/src/accumulator/node_hash.rs index edb5988..bb5b440 100644 --- a/src/accumulator/node_hash.rs +++ b/src/accumulator/node_hash.rs @@ -1,14 +1,15 @@ -//! [NodeHash] is an internal type for representing Hashes in an utreexo accumulator. It's +//! [AccumulatorHash] is an internal type for representing Hashes in an utreexo accumulator. It's //! just a wrapper around [[u8; 32]] but with some useful methods. //! # Examples //! Building from a str //! ``` //! use std::str::FromStr; //! -//! use rustreexo::accumulator::node_hash::NodeHash; -//! let hash = -//! NodeHash::from_str("0000000000000000000000000000000000000000000000000000000000000000") -//! .unwrap(); +//! use rustreexo::accumulator::node_hash::BitcoinNodeHash; +//! let hash = BitcoinNodeHash::from_str( +//! "0000000000000000000000000000000000000000000000000000000000000000", +//! ) +//! .unwrap(); //! assert_eq!( //! hash.to_string().as_str(), //! "0000000000000000000000000000000000000000000000000000000000000000" @@ -18,10 +19,10 @@ //! ``` //! use std::str::FromStr; //! -//! use rustreexo::accumulator::node_hash::NodeHash; -//! let hash1 = NodeHash::new([0; 32]); +//! use rustreexo::accumulator::node_hash::BitcoinNodeHash; +//! let hash1 = BitcoinNodeHash::new([0; 32]); //! // ... or ... -//! let hash2 = NodeHash::from([0; 32]); +//! let hash2 = BitcoinNodeHash::from([0; 32]); //! assert_eq!(hash1, hash2); //! assert_eq!( //! hash1.to_string().as_str(), @@ -33,13 +34,15 @@ //! ``` //! use std::str::FromStr; //! -//! use rustreexo::accumulator::node_hash::NodeHash; -//! let left = NodeHash::new([0; 32]); -//! let right = NodeHash::new([1; 32]); -//! let parent = NodeHash::parent_hash(&left, &right); -//! let expected_parent = -//! NodeHash::from_str("34e33ca0c40b7bd33d28932ca9e35170def7309a3bf91ecda5e1ceb067548a12") -//! .unwrap(); +//! use rustreexo::accumulator::node_hash::AccumulatorHash; +//! use rustreexo::accumulator::node_hash::BitcoinNodeHash; +//! let left = BitcoinNodeHash::new([0; 32]); +//! let right = BitcoinNodeHash::new([1; 32]); +//! let parent = BitcoinNodeHash::parent_hash(&left, &right); +//! let expected_parent = BitcoinNodeHash::from_str( +//! "34e33ca0c40b7bd33d28932ca9e35170def7309a3bf91ecda5e1ceb067548a12", +//! ) +//! .unwrap(); //! assert_eq!(parent, expected_parent); //! ``` use std::convert::TryFrom; @@ -58,39 +61,59 @@ use serde::Deserialize; #[cfg(feature = "with-serde")] use serde::Serialize; +pub trait AccumulatorHash: + Copy + Clone + Ord + Debug + Display + std::hash::Hash + 'static +{ + fn is_empty(&self) -> bool; + fn empty() -> Self; + fn is_placeholder(&self) -> bool; + fn placeholder() -> Self; + fn parent_hash(left: &Self, right: &Self) -> Self; + fn write(&self, writer: &mut W) -> std::io::Result<()> + where + W: std::io::Write; + fn read(reader: &mut R) -> std::io::Result + where + R: std::io::Read; +} + #[derive(Eq, PartialEq, Copy, Clone, Hash, PartialOrd, Ord)] #[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))] -/// NodeHash is a wrapper around a 32 byte array that represents a hash of a node in the tree. +/// AccumulatorHash is a wrapper around a 32 byte array that represents a hash of a node in the tree. /// # Example /// ``` -/// use rustreexo::accumulator::node_hash::NodeHash; -/// let hash = NodeHash::new([0; 32]); +/// use rustreexo::accumulator::node_hash::BitcoinNodeHash; +/// let hash = BitcoinNodeHash::new([0; 32]); /// assert_eq!( /// hash.to_string().as_str(), /// "0000000000000000000000000000000000000000000000000000000000000000" /// ); /// ``` #[derive(Default)] -pub enum NodeHash { +pub enum BitcoinNodeHash { #[default] Empty, Placeholder, Some([u8; 32]), } -impl Deref for NodeHash { +#[deprecated(since = "0.4.0", note = "Please use BitcoinNodeHash instead.")] +pub type NodeHash = BitcoinNodeHash; + +impl Deref for BitcoinNodeHash { type Target = [u8; 32]; fn deref(&self) -> &Self::Target { match self { - NodeHash::Some(ref inner) => inner, + BitcoinNodeHash::Some(ref inner) => inner, _ => &[0; 32], } } } -impl Display for NodeHash { + +impl Display for BitcoinNodeHash { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { - if let NodeHash::Some(ref inner) = self { + if let BitcoinNodeHash::Some(ref inner) = self { let mut s = String::new(); for byte in inner.iter() { s.push_str(&format!("{:02x}", byte)); @@ -101,137 +124,158 @@ impl Display for NodeHash { } } } -impl Debug for NodeHash { + +impl Debug for BitcoinNodeHash { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { - if let NodeHash::Some(ref inner) = self { - let mut s = String::new(); - for byte in inner.iter() { - s.push_str(&format!("{:02x}", byte)); + match self { + BitcoinNodeHash::Empty => write!(f, "empty"), + BitcoinNodeHash::Placeholder => write!(f, "placeholder"), + BitcoinNodeHash::Some(ref inner) => { + let mut s = String::new(); + for byte in inner.iter() { + s.push_str(&format!("{:02x}", byte)); + } + write!(f, "{}", s) } - write!(f, "{}", s) - } else { - write!(f, "empty") } } } -impl From for NodeHash { + +impl From for BitcoinNodeHash { fn from(hash: sha512_256::Hash) -> Self { - NodeHash::Some(hash.to_byte_array()) + BitcoinNodeHash::Some(hash.to_byte_array()) } } -impl From<[u8; 32]> for NodeHash { + +impl From<[u8; 32]> for BitcoinNodeHash { fn from(hash: [u8; 32]) -> Self { - NodeHash::Some(hash) + BitcoinNodeHash::Some(hash) } } -impl From<&[u8; 32]> for NodeHash { + +impl From<&[u8; 32]> for BitcoinNodeHash { fn from(hash: &[u8; 32]) -> Self { - NodeHash::Some(*hash) + BitcoinNodeHash::Some(*hash) } } + #[cfg(test)] -impl TryFrom<&str> for NodeHash { +impl TryFrom<&str> for BitcoinNodeHash { type Error = hex::HexToArrayError; fn try_from(hash: &str) -> Result { // This implementation is useful for testing, as it allows to create empty hashes // from the string of 64 zeros. Without this, it would be impossible to express this // hash in the test vectors. if hash == "0000000000000000000000000000000000000000000000000000000000000000" { - return Ok(NodeHash::Empty); + return Ok(BitcoinNodeHash::Empty); } + let hash = hex::FromHex::from_hex(hash)?; - Ok(NodeHash::Some(hash)) + Ok(BitcoinNodeHash::Some(hash)) } } #[cfg(not(test))] -impl TryFrom<&str> for NodeHash { +impl TryFrom<&str> for BitcoinNodeHash { type Error = hex::HexToArrayError; fn try_from(hash: &str) -> Result { let inner = hex::FromHex::from_hex(hash)?; - Ok(NodeHash::Some(inner)) + Ok(BitcoinNodeHash::Some(inner)) } } -impl From<&[u8]> for NodeHash { + +impl From<&[u8]> for BitcoinNodeHash { fn from(hash: &[u8]) -> Self { let mut inner = [0; 32]; inner.copy_from_slice(hash); - NodeHash::Some(inner) + BitcoinNodeHash::Some(inner) } } -impl From for NodeHash { +impl From for BitcoinNodeHash { fn from(hash: sha256::Hash) -> Self { - NodeHash::Some(hash.to_byte_array()) + BitcoinNodeHash::Some(hash.to_byte_array()) } } -impl FromStr for NodeHash { + +impl FromStr for BitcoinNodeHash { fn from_str(s: &str) -> Result { - NodeHash::try_from(s) + BitcoinNodeHash::try_from(s) } + type Err = hex::HexToArrayError; } -impl NodeHash { - /// Tells whether this hash is empty. We use empty hashes throughout the code to represent - /// leaves we want to delete. - pub fn is_empty(&self) -> bool { - if let NodeHash::Empty = self { - return true; - } - false - } - /// Creates a new NodeHash from a 32 byte array. + +impl BitcoinNodeHash { + /// Creates a new AccumulatorHash from a 32 byte array. /// # Example /// ``` - /// use rustreexo::accumulator::node_hash::NodeHash; - /// let hash = NodeHash::new([0; 32]); + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; + /// let hash = BitcoinNodeHash::new([0; 32]); /// assert_eq!( /// hash.to_string().as_str(), /// "0000000000000000000000000000000000000000000000000000000000000000" /// ); /// ``` pub fn new(inner: [u8; 32]) -> Self { - NodeHash::Some(inner) + BitcoinNodeHash::Some(inner) } +} + +impl AccumulatorHash for BitcoinNodeHash { + /// Tells whether this hash is empty. We use empty hashes throughout the code to represent + /// leaves we want to delete. + fn is_empty(&self) -> bool { + matches!(self, BitcoinNodeHash::Empty) + } + /// Creates an empty hash. This is used to represent leaves we want to delete. /// # Example /// ``` - /// use rustreexo::accumulator::node_hash::NodeHash; - /// let hash = NodeHash::empty(); + /// use rustreexo::accumulator::node_hash::AccumulatorHash; + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; + /// let hash = BitcoinNodeHash::empty(); /// assert!(hash.is_empty()); /// ``` - pub fn empty() -> Self { - NodeHash::Empty + fn empty() -> Self { + BitcoinNodeHash::Empty } + /// parent_hash return the merkle parent of the two passed in nodes. /// # Example /// ``` /// use std::str::FromStr; /// - /// use rustreexo::accumulator::node_hash::NodeHash; - /// let left = NodeHash::new([0; 32]); - /// let right = NodeHash::new([1; 32]); - /// let parent = NodeHash::parent_hash(&left, &right); - /// let expected_parent = - /// NodeHash::from_str("34e33ca0c40b7bd33d28932ca9e35170def7309a3bf91ecda5e1ceb067548a12") - /// .unwrap(); + /// use rustreexo::accumulator::node_hash::AccumulatorHash; + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; + /// let left = BitcoinNodeHash::new([0; 32]); + /// let right = BitcoinNodeHash::new([1; 32]); + /// let parent = BitcoinNodeHash::parent_hash(&left, &right); + /// let expected_parent = BitcoinNodeHash::from_str( + /// "34e33ca0c40b7bd33d28932ca9e35170def7309a3bf91ecda5e1ceb067548a12", + /// ) + /// .unwrap(); /// assert_eq!(parent, expected_parent); /// ``` - pub fn parent_hash(left: &NodeHash, right: &NodeHash) -> NodeHash { + fn parent_hash(left: &Self, right: &Self) -> Self { let mut hash = sha512_256::Hash::engine(); hash.input(&**left); hash.input(&**right); sha512_256::Hash::from_engine(hash).into() } + fn is_placeholder(&self) -> bool { + matches!(self, BitcoinNodeHash::Placeholder) + } + /// Returns a arbitrary placeholder hash that is unlikely to collide with any other hash. /// We use this while computing roots to destroy. Don't confuse this with an empty hash. - pub const fn placeholder() -> Self { - NodeHash::Placeholder + fn placeholder() -> Self { + BitcoinNodeHash::Placeholder } /// write to buffer - pub(super) fn write(&self, writer: &mut W) -> std::io::Result<()> + fn write(&self, writer: &mut W) -> std::io::Result<()> where W: std::io::Write, { @@ -246,7 +290,7 @@ impl NodeHash { } /// Read from buffer - pub(super) fn read(reader: &mut R) -> std::io::Result + fn read(reader: &mut R) -> std::io::Result where R: std::io::Read, { @@ -263,7 +307,7 @@ impl NodeHash { [_] => { let err = std::io::Error::new( std::io::ErrorKind::InvalidData, - "unexpected tag for NodeHash", + "unexpected tag for AccumulatorHash", ); Err(err) } @@ -275,7 +319,8 @@ impl NodeHash { mod test { use std::str::FromStr; - use super::NodeHash; + use super::AccumulatorHash; + use crate::accumulator::node_hash::BitcoinNodeHash; use crate::accumulator::util::hash_from_u8; #[test] @@ -283,7 +328,7 @@ mod test { let hash1 = hash_from_u8(0); let hash2 = hash_from_u8(1); - let parent_hash = NodeHash::parent_hash(&hash1, &hash2); + let parent_hash = BitcoinNodeHash::parent_hash(&hash1, &hash2); assert_eq!( parent_hash.to_string().as_str(), "02242b37d8e851f1e86f46790298c7097df06893d6226b7c1453c213e91717de" @@ -291,17 +336,19 @@ mod test { } #[test] fn test_hash_from_str() { - let hash = - NodeHash::from_str("6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d") - .unwrap(); + let hash = BitcoinNodeHash::from_str( + "6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d", + ) + .unwrap(); assert_eq!(hash, hash_from_u8(0)); } #[test] fn test_empty_hash() { // Only relevant for tests - let hash = - NodeHash::from_str("0000000000000000000000000000000000000000000000000000000000000000") - .unwrap(); - assert_eq!(hash, NodeHash::empty()); + let hash = BitcoinNodeHash::from_str( + "0000000000000000000000000000000000000000000000000000000000000000", + ) + .unwrap(); + assert_eq!(hash, AccumulatorHash::empty()); } } diff --git a/src/accumulator/pollard.rs b/src/accumulator/pollard.rs index 6bc2af0..d7a522f 100644 --- a/src/accumulator/pollard.rs +++ b/src/accumulator/pollard.rs @@ -4,15 +4,17 @@ //! //! # Example //! ``` -//! use rustreexo::accumulator::node_hash::NodeHash; +//! use rustreexo::accumulator::node_hash::AccumulatorHash; +//! use rustreexo::accumulator::node_hash::BitcoinNodeHash; //! use rustreexo::accumulator::pollard::Pollard; +//! //! let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; -//! let hashes: Vec = values +//! let hashes: Vec = values //! .into_iter() -//! .map(|i| NodeHash::from([i; 32])) +//! .map(|i| BitcoinNodeHash::from([i; 32])) //! .collect(); //! -//! let mut p = Pollard::new(); +//! let mut p = Pollard::::new(); //! //! p.modify(&hashes, &[]).expect("Pollard should not fail"); //! assert_eq!(p.get_roots().len(), 1); @@ -20,10 +22,9 @@ //! p.modify(&[], &hashes).expect("Still should not fail"); // Remove leaves from the accumulator //! //! assert_eq!(p.get_roots().len(), 1); -//! assert_eq!(p.get_roots()[0].get_data(), NodeHash::default()); +//! assert_eq!(p.get_roots()[0].get_data(), BitcoinNodeHash::default()); //! ``` -use core::fmt; use std::cell::Cell; use std::cell::RefCell; use std::collections::HashMap; @@ -35,7 +36,8 @@ use std::io::Write; use std::rc::Rc; use std::rc::Weak; -use super::node_hash::NodeHash; +use super::node_hash::AccumulatorHash; +use super::node_hash::BitcoinNodeHash; use super::proof::Proof; use super::util::detect_offset; use super::util::get_proof_positions; @@ -46,42 +48,39 @@ use super::util::max_position_at_row; use super::util::right_child; use super::util::root_position; use super::util::tree_rows; + #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum NodeType { Branch, Leaf, } +// A few type aliases to improve readability. + +/// A weak reference to a node in the forest. We use this when referencing to a node's +/// parent, as we don't want to create a reference cycle. +type Parent = RefCell>>; +/// A reference to a node's children. We use this to store the left and right children +/// of a node. This will be the only long-lived strong reference to a node. If this gets +/// dropped, the node will be dropped as well. +type Children = RefCell>>; + /// A forest node that can either be a leaf or a branch. #[derive(Clone)] -pub struct Node { +pub struct Node { /// The type of this node. ty: NodeType, /// The hash of the stored in this node. - data: Cell, + data: Cell, /// The parent of this node, if any. - parent: RefCell>>, + parent: Parent>, /// The left and right children of this node, if any. - left: RefCell>>, + left: Children>, /// The left and right children of this node, if any. - right: RefCell>>, + right: Children>, } -impl Node { - /// Recomputes the hash of all nodes, up to the root. - fn recompute_hashes(&self) { - let left = self.left.borrow(); - let right = self.right.borrow(); - if let (Some(left), Some(right)) = (left.as_deref(), right.as_deref()) { - self.data - .replace(NodeHash::parent_hash(&left.data.get(), &right.data.get())); - } - if let Some(ref parent) = *self.parent.borrow() { - if let Some(p) = parent.upgrade() { - p.recompute_hashes(); - } - } - } +impl Node { /// Writes one node to the writer, this method will recursively write all children. /// The primary use of this method is to serialize the accumulator. In this case, /// you should call this method on each root in the forest. @@ -104,6 +103,22 @@ impl Node { .transpose()?; Ok(()) } + /// Recomputes the hash of all nodes, up to the root. + fn recompute_hashes(&self) { + let left = self.left.borrow(); + let right = self.right.borrow(); + + if let (Some(left), Some(right)) = (left.as_deref(), right.as_deref()) { + self.data + .replace(Hash::parent_hash(&left.data.get(), &right.data.get())); + } + if let Some(ref parent) = *self.parent.borrow() { + if let Some(p) = parent.upgrade() { + p.recompute_hashes(); + } + } + } + /// Reads one node from the reader, this method will recursively read all children. /// The primary use of this method is to deserialize the accumulator. In this case, /// you should call this method on each root in the forest, assuming you know how @@ -111,15 +126,15 @@ impl Node { #[allow(clippy::type_complexity)] pub fn read_one( reader: &mut R, - ) -> std::io::Result<(Rc, HashMap>)> { - fn _read_one( - ancestor: Option>, + ) -> std::io::Result<(Rc>, HashMap>>)> { + fn _read_one( + ancestor: Option>>, reader: &mut R, - index: &mut HashMap>, - ) -> std::io::Result> { + index: &mut HashMap>>, + ) -> std::io::Result>> { let mut ty = [0u8; 8]; reader.read_exact(&mut ty)?; - let data = NodeHash::read(reader)?; + let data = Hash::read(reader)?; let ty = match u64::from_le_bytes(ty) { 0 => NodeType::Branch, @@ -165,33 +180,33 @@ impl Node { let root = _read_one(None, reader, &mut index)?; Ok((root, index)) } + /// Returns the data associated with this node. - pub fn get_data(&self) -> NodeHash { + pub fn get_data(&self) -> Hash { self.data.get() } } -impl Debug for Node { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "{:02x}{:02x}", self.data.get()[0], self.data.get()[1]) - } -} /// The actual Pollard accumulator, it implements all methods required to update the forest /// and to prove/verify membership. #[derive(Default, Clone)] -pub struct Pollard { +pub struct Pollard { /// The roots of the forest, all leaves are children of these roots, and therefore /// owned by them. - roots: Vec>, + roots: Vec>>, /// The number of leaves in the forest. Actually, this is the number of leaves we ever /// added to the forest. pub leaves: u64, /// A map of all nodes in the forest, indexed by their hash, this is used to lookup /// leaves when proving membership. - map: HashMap>, + map: HashMap>>, } + impl Pollard { - /// Creates a new empty [Pollard]. + /// Creates a new empty [Pollard] with the default hash function. + /// + /// This will create an empty Pollard, using [BitcoinNodeHash] as the hash function. If you + /// want to use a different hash function, you can use [Pollard::new_with_hash]. /// # Example /// ``` /// use rustreexo::accumulator::pollard::Pollard; @@ -204,13 +219,32 @@ impl Pollard { leaves: 0, } } +} + +impl Pollard { + /// Creates a new empty [Pollard] with a custom hash function. + /// # Example + /// ``` + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; + /// use rustreexo::accumulator::pollard::Pollard; + /// let mut pollard = Pollard::::new(); + /// ``` + pub fn new_with_hash() -> Pollard { + Pollard { + map: HashMap::new(), + roots: Vec::new(), + leaves: 0, + } + } + /// Writes the Pollard to a writer. Used to send the accumulator over the wire /// or to disk. /// # Example /// ``` + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; /// use rustreexo::accumulator::pollard::Pollard; /// - /// let mut pollard = Pollard::new(); + /// let mut pollard = Pollard::::new(); /// let mut serialized = Vec::new(); /// pollard.serialize(&mut serialized).unwrap(); /// @@ -229,18 +263,20 @@ impl Pollard { Ok(()) } + /// Deserializes a pollard from a reader. /// # Example /// ``` /// use std::io::Cursor; /// + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; /// use rustreexo::accumulator::pollard::Pollard; /// let mut serialized = Cursor::new(vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); - /// let pollard = Pollard::deserialize(&mut serialized).unwrap(); + /// let pollard = Pollard::::deserialize(&mut serialized).unwrap(); /// assert_eq!(pollard.leaves, 0); /// assert_eq!(pollard.get_roots().len(), 0); /// ``` - pub fn deserialize(mut reader: R) -> std::io::Result { + pub fn deserialize(mut reader: R) -> std::io::Result> { fn read_u64(reader: &mut R) -> std::io::Result { let mut buf = [0u8; 8]; reader.read_exact(&mut buf)?; @@ -257,28 +293,30 @@ impl Pollard { } Ok(Pollard { roots, leaves, map }) } + /// Returns the hash of a given position in the tree. - fn get_hash(&self, pos: u64) -> Result { + fn get_hash(&self, pos: u64) -> Result { let (node, _, _) = self.grab_node(pos)?; Ok(node.data.get()) } + /// Proves that a given set of hashes is in the accumulator. It returns a proof /// and the hashes that we what to prove, but sorted by position in the tree. /// # Example /// ``` - /// use rustreexo::accumulator::node_hash::NodeHash; + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; /// use rustreexo::accumulator::pollard::Pollard; - /// let mut pollard = Pollard::new(); + /// let mut pollard = Pollard::::new(); /// let hashes = vec![0, 1, 2, 3, 4, 5, 6, 7] /// .iter() - /// .map(|n| NodeHash::from([*n; 32])) + /// .map(|n| BitcoinNodeHash::from([*n; 32])) /// .collect::>(); /// pollard.modify(&hashes, &[]).unwrap(); /// // We want to prove that the first two hashes are in the accumulator. /// let proof = pollard.prove(&[hashes[1], hashes[0]]).unwrap(); /// //TODO: Verify the proof /// ``` - pub fn prove(&self, targets: &[NodeHash]) -> Result { + pub fn prove(&self, targets: &[Hash]) -> Result, String> { let mut positions = Vec::new(); for target in targets { let node = self.map.get(target).ok_or("Could not find node")?; @@ -290,12 +328,15 @@ impl Pollard { .iter() .map(|pos| self.get_hash(*pos).unwrap()) .collect::>(); - Ok(Proof::new(positions, proof)) + + Ok(Proof::new_with_hash(positions, proof)) } + /// Returns a reference to the roots in this Pollard. - pub fn get_roots(&self) -> &[Rc] { + pub fn get_roots(&self) -> &[Rc>] { &self.roots } + /// Modify is the main API to a [Pollard]. Because order matters, you can only `modify` /// a [Pollard], and internally it'll add and delete, in the correct order. /// @@ -308,7 +349,7 @@ impl Pollard { /// use bitcoin_hashes::sha256::Hash as Data; /// use bitcoin_hashes::Hash; /// use bitcoin_hashes::HashEngine; - /// use rustreexo::accumulator::node_hash::NodeHash; + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; /// use rustreexo::accumulator::pollard::Pollard; /// let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; /// let hashes = values @@ -316,11 +357,11 @@ impl Pollard { /// .map(|val| { /// let mut engine = Data::engine(); /// engine.input(&[val]); - /// NodeHash::from(Data::from_engine(engine).as_byte_array()) + /// BitcoinNodeHash::from(Data::from_engine(engine).as_byte_array()) /// }) /// .collect::>(); /// // Add 8 leaves to the pollard - /// let mut p = Pollard::new(); + /// let mut p = Pollard::::new(); /// p.modify(&hashes, &[]).expect("Pollard should not fail"); /// /// assert_eq!( @@ -328,13 +369,17 @@ impl Pollard { /// String::from("b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42") /// ); /// ``` - pub fn modify(&mut self, add: &[NodeHash], del: &[NodeHash]) -> Result<(), String> { + pub fn modify(&mut self, add: &[Hash], del: &[Hash]) -> Result<(), String> { self.del(del)?; self.add(add); Ok(()) } + #[allow(clippy::type_complexity)] - pub fn grab_node(&self, pos: u64) -> Result<(Rc, Rc, Rc), String> { + pub fn grab_node( + &self, + pos: u64, + ) -> Result<(Rc>, Rc>, Rc>), String> { let (tree, branch_len, bits) = detect_offset(pos, self.leaves); let mut n = Some(self.roots[tree as usize].clone()); let mut sibling = Some(self.roots[tree as usize].clone()); @@ -366,7 +411,8 @@ impl Pollard { } Err(format!("node {} not found", pos)) } - fn del(&mut self, targets: &[NodeHash]) -> Result<(), String> { + + fn del(&mut self, targets: &[Hash]) -> Result<(), String> { let mut pos = targets .iter() .flat_map(|target| self.map.get(target)) @@ -380,7 +426,7 @@ impl Pollard { .collect::>(); pos.sort(); - let (_, targets): (Vec, Vec) = pos.into_iter().unzip(); + let (_, targets): (Vec, Vec) = pos.into_iter().unzip(); for target in targets { match self.map.remove(&target) { Some(target) => { @@ -393,7 +439,8 @@ impl Pollard { } Ok(()) } - pub fn verify(&self, proof: &Proof, del_hashes: &[NodeHash]) -> Result { + + pub fn verify(&self, proof: &Proof, del_hashes: &[Hash]) -> Result { let roots = self .roots .iter() @@ -401,7 +448,8 @@ impl Pollard { .collect::>(); proof.verify(del_hashes, &roots, self.leaves) } - fn get_pos(&self, node: &Weak) -> u64 { + + fn get_pos(&self, node: &Weak>) -> u64 { // This indicates whether the node is a left or right child at each level // When we go down the tree, we can use the indicator to know which // child to take. @@ -458,7 +506,8 @@ impl Pollard { } pos } - fn del_single(&mut self, node: &Node) -> Option<()> { + + fn del_single(&mut self, node: &Node) -> Option<()> { let parent = node.parent.borrow(); // Deleting a root let parent = match *parent { @@ -468,7 +517,7 @@ impl Pollard { self.roots[pos] = Rc::new(Node { ty: NodeType::Branch, parent: RefCell::new(None), - data: Cell::new(NodeHash::default()), + data: Cell::new(Hash::empty()), left: RefCell::new(None), right: RefCell::new(None), }); @@ -506,8 +555,9 @@ impl Pollard { Some(()) } - fn add_single(&mut self, value: NodeHash) { - let mut node: Rc = Rc::new(Node { + + fn add_single(&mut self, value: Hash) { + let mut node: Rc> = Rc::new(Node { ty: NodeType::Leaf, parent: RefCell::new(None), data: Cell::new(value), @@ -518,14 +568,17 @@ impl Pollard { let mut leaves = self.leaves; while leaves & 1 != 0 { let root = self.roots.pop().unwrap(); - if root.get_data() == NodeHash::empty() { + if root.get_data() == AccumulatorHash::empty() { leaves >>= 1; continue; } let new_node = Rc::new(Node { ty: NodeType::Branch, parent: RefCell::new(None), - data: Cell::new(NodeHash::parent_hash(&root.data.get(), &node.data.get())), + data: Cell::new(AccumulatorHash::parent_hash( + &root.data.get(), + &node.data.get(), + )), left: RefCell::new(Some(root.clone())), right: RefCell::new(Some(node.clone())), }); @@ -538,11 +591,13 @@ impl Pollard { self.roots.push(node); self.leaves += 1; } - fn add(&mut self, values: &[NodeHash]) { + + fn add(&mut self, values: &[Hash]) { for value in values { self.add_single(*value); } } + /// to_string returns the full pollard in a string for all forests less than 6 rows. fn string(&self) -> String { if self.leaves == 0 { @@ -615,11 +670,13 @@ impl Debug for Pollard { write!(f, "{}", self.string()) } } + impl Display for Pollard { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { write!(f, "{}", self.string()) } } + #[cfg(test)] mod test { use std::convert::TryFrom; @@ -633,17 +690,19 @@ mod test { use serde::Deserialize; use super::Pollard; - use crate::accumulator::node_hash::NodeHash; + use crate::accumulator::node_hash::AccumulatorHash; + use crate::accumulator::node_hash::BitcoinNodeHash; use crate::accumulator::pollard::Node; use crate::accumulator::proof::Proof; - fn hash_from_u8(value: u8) -> NodeHash { + fn hash_from_u8(value: u8) -> BitcoinNodeHash { let mut engine = Data::engine(); engine.input(&[value]); - NodeHash::from(Data::from_engine(engine).as_byte_array()) + BitcoinNodeHash::from(Data::from_engine(engine).as_byte_array()) } + #[test] fn test_grab_node() { let values = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]; @@ -652,12 +711,14 @@ mod test { let mut p = Pollard::new(); p.modify(&hashes, &[]).expect("Pollard should not fail"); let (found_target, found_sibling, _) = p.grab_node(4).unwrap(); - let target = - NodeHash::try_from("e52d9c508c502347344d8c07ad91cbd6068afc75ff6292f062a09ca381c89e71") - .unwrap(); - let sibling = - NodeHash::try_from("e77b9a9ae9e30b0dbdb6f510a264ef9de781501d7b6b92ae89eb059c5ab743db") - .unwrap(); + let target = BitcoinNodeHash::try_from( + "e52d9c508c502347344d8c07ad91cbd6068afc75ff6292f062a09ca381c89e71", + ) + .unwrap(); + let sibling = BitcoinNodeHash::try_from( + "e77b9a9ae9e30b0dbdb6f510a264ef9de781501d7b6b92ae89eb059c5ab743db", + ) + .unwrap(); assert_eq!(target, found_target.data.get()); assert_eq!(sibling, found_sibling.data.get()); @@ -678,6 +739,7 @@ mod test { node.data.get().to_string() ); } + #[test] fn test_proof_verify() { let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; @@ -688,6 +750,7 @@ mod test { let proof = p.prove(&[hashes[0], hashes[1]]).unwrap(); assert!(p.verify(&proof, &[hashes[0], hashes[1]]).unwrap()); } + #[test] fn test_add() { let values = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]; @@ -713,6 +776,7 @@ mod test { acc.roots[3].data.get().to_string().as_str(), ); } + #[test] fn test_delete_roots_child() { // Assuming the following tree: @@ -722,7 +786,7 @@ mod test { // 00 01 // If I delete `01`, then `00` will become a root, moving it's hash to `02` let values = vec![0, 1]; - let hashes: Vec = values.into_iter().map(hash_from_u8).collect(); + let hashes: Vec = values.into_iter().map(hash_from_u8).collect(); let mut p = Pollard::new(); p.modify(&hashes, &[]).expect("Pollard should not fail"); @@ -743,15 +807,16 @@ mod test { // If I delete `02`, then `02` will become an empty root, it'll point to nothing // and its data will be Data::default() let values = vec![0, 1]; - let hashes: Vec = values.into_iter().map(hash_from_u8).collect(); + let hashes: Vec = values.into_iter().map(hash_from_u8).collect(); let mut p = Pollard::new(); p.modify(&hashes, &[]).expect("Pollard should not fail"); p.del_single(&p.grab_node(2).unwrap().0); assert_eq!(p.get_roots().len(), 1); let root = p.get_roots()[0].clone(); - assert_eq!(root.data.get(), NodeHash::default()); + assert_eq!(root.data.get(), BitcoinNodeHash::default()); } + #[test] fn test_delete_non_root() { // Assuming this tree, if we delete `01`, 00 will move up to 08's position @@ -774,7 +839,7 @@ mod test { // Where 08's data is just 00's let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; - let hashes: Vec = values.into_iter().map(hash_from_u8).collect(); + let hashes: Vec = values.into_iter().map(hash_from_u8).collect(); let mut p = Pollard::new(); p.modify(&hashes, &[]).expect("Pollard should not fail"); @@ -784,12 +849,14 @@ mod test { let (node, _, _) = p.grab_node(8).expect("This tree should have pos 8"); assert_eq!(node.data.get(), hashes[0]); } + #[derive(Debug, Deserialize)] struct TestCase { leaf_preimages: Vec, target_values: Option>, expected_roots: Vec, } + fn run_single_addition_case(case: TestCase) { let hashes = case .leaf_preimages @@ -802,7 +869,7 @@ mod test { let expected_roots = case .expected_roots .iter() - .map(|root| NodeHash::from_str(root).unwrap()) + .map(|root| BitcoinNodeHash::from_str(root).unwrap()) .collect::>(); let roots = p .get_roots() @@ -811,6 +878,7 @@ mod test { .collect::>(); assert_eq!(expected_roots, roots, "Test case failed {:?}", case); } + fn run_case_with_deletion(case: TestCase) { let hashes = case .leaf_preimages @@ -832,7 +900,7 @@ mod test { let expected_roots = case .expected_roots .iter() - .map(|root| NodeHash::from_str(root).unwrap()) + .map(|root| BitcoinNodeHash::from_str(root).unwrap()) .collect::>(); let roots = p .get_roots() @@ -841,6 +909,7 @@ mod test { .collect::>(); assert_eq!(expected_roots, roots, "Test case failed {:?}", case); } + #[test] fn run_tests_from_cases() { #[derive(Deserialize)] @@ -862,6 +931,7 @@ mod test { run_case_with_deletion(i); } } + #[test] fn test_to_string() { let hashes = get_hash_vec_of(&(0..255).collect::>()); @@ -872,6 +942,7 @@ mod test { p.to_string().get(0..30) ); } + #[test] fn test_get_pos() { macro_rules! test_get_pos { @@ -913,6 +984,7 @@ mod test { 25 ); } + #[test] fn test_serialize_one() { let hashes = get_hash_vec_of(&[0, 1, 2, 3, 4, 5, 6, 7]); @@ -922,9 +994,11 @@ mod test { let mut writer = std::io::Cursor::new(Vec::new()); p.get_roots()[0].write_one(&mut writer).unwrap(); let (deserialized, _) = - Node::read_one(&mut std::io::Cursor::new(writer.into_inner())).unwrap(); + Node::::read_one(&mut std::io::Cursor::new(writer.into_inner())) + .unwrap(); assert_eq!(deserialized.get_data(), p.get_roots()[0].get_data()); } + #[test] fn test_serialization() { let hashes = get_hash_vec_of(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]); @@ -934,7 +1008,8 @@ mod test { let mut writer = std::io::Cursor::new(Vec::new()); p.serialize(&mut writer).unwrap(); let deserialized = - Pollard::deserialize(&mut std::io::Cursor::new(writer.into_inner())).unwrap(); + Pollard::::deserialize(&mut std::io::Cursor::new(writer.into_inner())) + .unwrap(); assert_eq!( deserialized.get_roots()[0].get_data(), p.get_roots()[0].get_data() @@ -942,6 +1017,7 @@ mod test { assert_eq!(deserialized.leaves, p.leaves); assert_eq!(deserialized.map.len(), p.map.len()); } + #[test] fn test_proof() { let hashes = get_hash_vec_of(&[0, 1, 2, 3, 4, 5, 6, 7]); @@ -972,7 +1048,8 @@ mod test { assert_eq!(proof, expected_proof); assert!(p.verify(&proof, &del_hashes).unwrap()); } - fn get_hash_vec_of(elements: &[u8]) -> Vec { + + fn get_hash_vec_of(elements: &[u8]) -> Vec { elements.iter().map(|el| hash_from_u8(*el)).collect() } @@ -984,11 +1061,11 @@ mod test { #[test] fn test_serialization_roundtrip() { - let mut p = Pollard::new(); + let mut p = Pollard::::new(); let values = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; - let hashes: Vec = values + let hashes: Vec = values .into_iter() - .map(|i| NodeHash::from([i; 32])) + .map(|i| BitcoinNodeHash::from([i; 32])) .collect(); p.modify(&hashes, &[]).expect("modify should work"); assert_eq!(p.get_roots().len(), 1); @@ -1000,7 +1077,8 @@ mod test { assert_eq!(p.leaves, 16); let mut serialized = Vec::::new(); p.serialize(&mut serialized).expect("serialize should work"); - let deserialized = Pollard::deserialize(&*serialized).expect("deserialize should work"); + let deserialized = + Pollard::::deserialize(&*serialized).expect("deserialize should work"); assert_eq!(deserialized.get_roots().len(), 1); assert!(deserialized.get_roots()[0].get_data().is_empty()); assert_eq!(deserialized.leaves, 16); diff --git a/src/accumulator/proof.rs b/src/accumulator/proof.rs index 86771c3..498dfd1 100644 --- a/src/accumulator/proof.rs +++ b/src/accumulator/proof.rs @@ -9,7 +9,7 @@ //! use bitcoin_hashes::sha256; //! use bitcoin_hashes::Hash; //! use bitcoin_hashes::HashEngine; -//! use rustreexo::accumulator::node_hash::NodeHash; +//! use rustreexo::accumulator::node_hash::BitcoinNodeHash; //! use rustreexo::accumulator::proof::Proof; //! use rustreexo::accumulator::stump::Stump; //! let s = Stump::new(); @@ -22,16 +22,22 @@ //! let mut proof_hashes = Vec::new(); //! //! proof_hashes.push( -//! NodeHash::from_str("4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a") -//! .unwrap(), +//! BitcoinNodeHash::from_str( +//! "4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a", +//! ) +//! .unwrap(), //! ); //! proof_hashes.push( -//! NodeHash::from_str("9576f4ade6e9bc3a6458b506ce3e4e890df29cb14cb5d3d887672aef55647a2b") -//! .unwrap(), +//! BitcoinNodeHash::from_str( +//! "9576f4ade6e9bc3a6458b506ce3e4e890df29cb14cb5d3d887672aef55647a2b", +//! ) +//! .unwrap(), //! ); //! proof_hashes.push( -//! NodeHash::from_str("29590a14c1b09384b94a2c0e94bf821ca75b62eacebc47893397ca88e3bbcbd7") -//! .unwrap(), +//! BitcoinNodeHash::from_str( +//! "29590a14c1b09384b94a2c0e94bf821ca75b62eacebc47893397ca88e3bbcbd7", +//! ) +//! .unwrap(), //! ); //! //! // Hashes of the leaves UTXOs we'll add to the accumulator @@ -57,19 +63,21 @@ use serde::Deserialize; #[cfg(feature = "with-serde")] use serde::Serialize; -use super::node_hash::NodeHash; +use super::node_hash::AccumulatorHash; +use super::node_hash::BitcoinNodeHash; use super::stump::UpdateData; +use super::util; use super::util::get_proof_positions; use super::util::read_u64; use super::util::tree_rows; -use super::util::{self}; -#[derive(Clone, Debug, Default, Eq, PartialEq)] + +#[derive(Clone, Debug, Eq, PartialEq)] #[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))] /// A proof is a collection of hashes and positions. Each target position /// points to a leaf to be proven. Hashes are all /// hashes that can't be calculated from the data itself. /// Proofs are generated elsewhere. -pub struct Proof { +pub struct Proof { /// Targets are the i'th of leaf locations to delete and they are the bottommost leaves. /// With the tree below, the Targets can only consist of one of these: 02, 03, 04. ///```! @@ -92,7 +100,17 @@ pub struct Proof { /// // |---\ |---\ /// // 00 01 02 03 /// ``` - pub hashes: Vec, + pub hashes: Vec, +} + +// the default hash type for a proof is BitcoinNodeHash +impl Default for Proof { + fn default() -> Self { + Proof { + targets: Vec::new(), + hashes: Vec::new(), + } + } } // We often need to return the targets paired with hashes, and the proof position. @@ -101,17 +119,17 @@ pub struct Proof { /// This alias is used when we need to return the nodes and roots for a proof /// if we are not concerned with deleting those elements. -pub(crate) type NodesAndRootsCurrent = (Vec<(u64, NodeHash)>, Vec); +pub(crate) type NodesAndRootsCurrent = (Vec<(u64, Hash)>, Vec); /// This is used when we need to return the nodes and roots for a proof /// if we are concerned with deleting those elements. The difference is that /// we need to retun the old and updatated roots in the accumulator. -pub(crate) type NodesAndRootsOldNew = (Vec<(u64, NodeHash)>, Vec<(NodeHash, NodeHash)>); +pub(crate) type NodesAndRootsOldNew = (Vec<(u64, Hash)>, Vec<(Hash, Hash)>); impl Proof { /// Creates a proof from a vector of target and hashes. /// `targets` are u64s and indicates the position of the leaves we are /// trying to prove. - /// `hashes` are of type `NodeHash` and are all hashes we need for computing the roots. + /// `hashes` are of type `AccumulatorHash` and are all hashes we need for computing the roots. /// /// Assuming a tree with leaf values [0, 1, 2, 3, 4, 5, 6, 7], we should see something like this: ///```! @@ -132,7 +150,6 @@ impl Proof { /// ``` /// use bitcoin_hashes::Hash; /// use bitcoin_hashes::HashEngine; - /// use rustreexo::accumulator::node_hash::NodeHash; /// use rustreexo::accumulator::proof::Proof; /// let targets = vec![0]; /// @@ -142,9 +159,53 @@ impl Proof { /// // Fill `proof_hashes` up with all hashes /// Proof::new(targets, proof_hashes); /// ``` - pub fn new(targets: Vec, hashes: Vec) -> Self { + pub fn new(targets: Vec, hashes: Vec) -> Proof { + Proof { targets, hashes } + } +} + +impl Proof { + /// Creates a proof from a vector of target and hashes, using a different hash type. + /// `targets` are u64s and indicates the position of the leaves we are + /// trying to prove. + /// `hashes` are of type `AccumulatorHash` and are all hashes we need for computing the roots. + /// + /// Different from `new`, this function allows for the proof to be created with a different + /// hash type, as long as it implements `AccumulatorHash`. + /// + /// Assuming a tree with leaf values [0, 1, 2, 3, 4, 5, 6, 7], we should see something like this: + ///```! + /// // 14 + /// // |-----------------\ + /// // 12 13 + /// // |---------\ |--------\ + /// // 08 09 10 11 + /// // |----\ |----\ |----\ |----\ + /// // 00 01 02 03 04 05 06 07 + /// ``` + /// If we are proving `00` (i.e. 00 is our target), then we need 01, + /// 09 and 13's hashes, so we can compute 14 by hashing both siblings + /// in each level (00 and 01, 08 and 09 and 12 and 13). Note that + /// some hashes we can compute by ourselves, and are not present in the + /// proof, in this case 00, 08, 12 and 14. + /// # Example + /// ``` + /// use bitcoin_hashes::Hash; + /// use bitcoin_hashes::HashEngine; + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; + /// use rustreexo::accumulator::proof::Proof; + /// let targets = vec![0]; + /// + /// let mut proof_hashes = Vec::new(); + /// let targets = vec![0]; + /// // For proving 0, we need 01, 09 and 13's hashes. 00, 08, 12 and 14 can be calculated + /// // Fill `proof_hashes` up with all hashes + /// Proof::::new_with_hash(targets, proof_hashes); + /// ``` + pub fn new_with_hash(targets: Vec, hashes: Vec) -> Self { Proof { targets, hashes } } + /// Public interface for verifying proofs. Returns a result with a bool or an Error /// True means the proof is true given the current stump, false means the proof is /// not valid given the current stump. @@ -155,7 +216,7 @@ impl Proof { /// use bitcoin_hashes::sha256; /// use bitcoin_hashes::Hash; /// use bitcoin_hashes::HashEngine; - /// use rustreexo::accumulator::node_hash::NodeHash; + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; /// use rustreexo::accumulator::proof::Proof; /// use rustreexo::accumulator::stump::Stump; /// let s = Stump::new(); @@ -175,16 +236,22 @@ impl Proof { /// // 00 01 02 03 04 05 06 07 /// // For proving 0, we need 01, 09 and 13's hashes. 00, 08, 12 and 14 can be calculated /// proof_hashes.push( - /// NodeHash::from_str("4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a") - /// .unwrap(), + /// BitcoinNodeHash::from_str( + /// "4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a", + /// ) + /// .unwrap(), /// ); /// proof_hashes.push( - /// NodeHash::from_str("9576f4ade6e9bc3a6458b506ce3e4e890df29cb14cb5d3d887672aef55647a2b") - /// .unwrap(), + /// BitcoinNodeHash::from_str( + /// "9576f4ade6e9bc3a6458b506ce3e4e890df29cb14cb5d3d887672aef55647a2b", + /// ) + /// .unwrap(), /// ); /// proof_hashes.push( - /// NodeHash::from_str("29590a14c1b09384b94a2c0e94bf821ca75b62eacebc47893397ca88e3bbcbd7") - /// .unwrap(), + /// BitcoinNodeHash::from_str( + /// "29590a14c1b09384b94a2c0e94bf821ca75b62eacebc47893397ca88e3bbcbd7", + /// ) + /// .unwrap(), /// ); /// /// let mut hashes = Vec::new(); @@ -199,15 +266,15 @@ impl Proof { /// ``` pub fn verify( &self, - del_hashes: &[NodeHash], - roots: &[NodeHash], + del_hashes: &[Hash], + roots: &[Hash], num_leaves: u64, ) -> Result { if self.targets.is_empty() { return Ok(true); } - let mut calculated_roots: std::iter::Peekable> = self + let mut calculated_roots: std::iter::Peekable> = self .calculate_hashes(del_hashes, num_leaves)? .1 .into_iter() @@ -227,8 +294,10 @@ impl Proof { if calculated_roots.len() != number_matched_roots && calculated_roots.len() != 0 { return Ok(false); } + Ok(true) } + /// Returns the elements needed to prove a subset of targets. For example, a tree with /// 8 leaves, if we cache `[0, 2, 6, 7]`, and we need to prove `[2, 7]` only, we have to remove /// elements for 0 and 7. The original proof is `[1, 3, 10]`, and we can compute `[8, 9, 11, 12, 13, 14]`. @@ -244,10 +313,10 @@ impl Proof { /// ``` pub fn get_proof_subset( &self, - del_hashes: &[NodeHash], + del_hashes: &[Hash], new_targets: &[u64], num_leaves: u64, - ) -> Result { + ) -> Result, String> { let forest_rows = tree_rows(num_leaves); let old_proof_positions = get_proof_positions(&self.targets, num_leaves, forest_rows); let needed_positions = get_proof_positions(new_targets, num_leaves, forest_rows); @@ -257,7 +326,7 @@ impl Proof { .iter() .copied() .zip(self.hashes.iter().copied()) - .collect::>(); + .collect::>(); old_proof.extend(intermediate_positions); @@ -271,7 +340,7 @@ impl Proof { } } new_proof.sort(); - let (_, new_proof): (Vec, Vec) = new_proof.into_iter().unzip(); + let (_, new_proof): (Vec, Vec) = new_proof.into_iter().unzip(); Ok(Proof { targets: new_targets.to_vec(), hashes: new_proof, @@ -286,7 +355,7 @@ impl Proof { /// - hashes (32 bytes) /// # Example /// ``` - /// use rustreexo::accumulator::node_hash::NodeHash; + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; /// use rustreexo::accumulator::proof::Proof; /// use rustreexo::accumulator::stump::Stump; /// @@ -309,20 +378,21 @@ impl Proof { writer.write_all(&self.hashes.len().to_le_bytes())?; for hash in &self.hashes { len += 32; - writer.write_all(&**hash)?; + hash.write(&mut writer)?; } Ok(len) } + /// Deserializes a proof from a byte array. /// # Example /// ``` /// use std::io::Cursor; /// - /// use rustreexo::accumulator::node_hash::NodeHash; + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; /// use rustreexo::accumulator::proof::Proof; /// use rustreexo::accumulator::stump::Stump; /// let proof = Cursor::new(vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); - /// let deserialized_proof = Proof::deserialize(proof).unwrap(); + /// let deserialized_proof = Proof::::deserialize(proof).unwrap(); /// // An empty proof is only 16 bytes of zeros, meaning no targets and no hashes /// assert_eq!(Proof::default(), deserialized_proof); /// ``` @@ -336,15 +406,14 @@ impl Proof { let hashes_len = read_u64(&mut buf)? as usize; let mut hashes = Vec::with_capacity(hashes_len); for _ in 0..hashes_len { - let mut hash = [0u8; 32]; - buf.read_exact(&mut hash) - .map_err(|_| "Failed to read hash")?; - hashes.push(hash.into()); + let hash = Hash::read(&mut buf).map_err(|_| "Failed to parse hash")?; + hashes.push(hash); } Ok(Proof { targets, hashes }) } + /// Returns how many targets this proof has - pub fn targets(&self) -> usize { + pub fn n_targets(&self) -> usize { self.targets.len() } @@ -359,15 +428,15 @@ impl Proof { /// If at least one returned element doesn't exist in the accumulator, the proof is invalid. pub(crate) fn calculate_hashes_delete( &self, - del_hashes: &[(NodeHash, NodeHash)], + del_hashes: &[(Hash, Hash)], num_leaves: u64, - ) -> Result { + ) -> Result, String> { // Where all the root hashes that we've calculated will go to. let total_rows = util::tree_rows(num_leaves); // Where all the parent hashes we've calculated in a given row will go to. let mut calculated_root_hashes = - Vec::<(NodeHash, NodeHash)>::with_capacity(util::num_roots(num_leaves)); + Vec::<(Hash, Hash)>::with_capacity(util::num_roots(num_leaves)); // the positions that should be passed as a proof let proof_positions = get_proof_positions(&self.targets, num_leaves, total_rows); @@ -415,14 +484,14 @@ impl Proof { } let parent_hash = match (next_hash_new.is_empty(), sibling_hash_new.is_empty()) { - (true, true) => NodeHash::empty(), + (true, true) => AccumulatorHash::empty(), (true, false) => sibling_hash_new, (false, true) => next_hash_new, - (false, false) => NodeHash::parent_hash(&next_hash_new, &sibling_hash_new), + (false, false) => AccumulatorHash::parent_hash(&next_hash_new, &sibling_hash_new), }; let parent = util::parent(next_pos, total_rows); - let old_parent_hash = NodeHash::parent_hash(&next_hash_old, &sibling_hash_old); + let old_parent_hash = AccumulatorHash::parent_hash(&next_hash_old, &sibling_hash_old); computed.push((parent, (old_parent_hash, parent_hash))); } @@ -445,15 +514,14 @@ impl Proof { /// needed for verification (i.e. the current accumulator). pub(crate) fn calculate_hashes( &self, - del_hashes: &[NodeHash], + del_hashes: &[Hash], num_leaves: u64, - ) -> Result { + ) -> Result, String> { // Where all the root hashes that we've calculated will go to. let total_rows = util::tree_rows(num_leaves); // Where all the parent hashes we've calculated in a given row will go to. - let mut calculated_root_hashes = - Vec::::with_capacity(util::num_roots(num_leaves)); + let mut calculated_root_hashes = Vec::::with_capacity(util::num_roots(num_leaves)); // the positions that should be passed as a proof let proof_positions = get_proof_positions(&self.targets, num_leaves, total_rows); @@ -500,7 +568,7 @@ impl Proof { return Err(format!("Missing sibling for {}", next_pos)); } - let parent_hash = NodeHash::parent_hash(&next_hash, &sibling_hash); + let parent_hash = AccumulatorHash::parent_hash(&next_hash, &sibling_hash); let parent = util::parent(next_pos, total_rows); computed.push((parent, parent_hash)); } @@ -541,18 +609,19 @@ impl Proof { (None, None) => None, } } + /// Uses the data passed in to update a proof, creating a valid proof for a given /// set of targets, after an update. This is useful for caching UTXOs. You grab a proof /// for it once and then keep updating it every block, yielding an always valid proof /// over those UTXOs. pub fn update( self, - cached_hashes: Vec, - add_hashes: Vec, + cached_hashes: Vec, + add_hashes: Vec, block_targets: Vec, remembers: Vec, - update_data: UpdateData, - ) -> Result<(Proof, Vec), String> { + update_data: UpdateData, + ) -> Result<(Proof, Vec), String> { let (proof_after_deletion, cached_hashes) = self.update_proof_remove( block_targets, cached_hashes, @@ -571,17 +640,18 @@ impl Proof { Ok(data_after_addition) } + fn update_proof_add( self, - adds: Vec, - cached_del_hashes: Vec, + adds: Vec, + cached_del_hashes: Vec, remembers: Vec, - new_nodes: Vec<(u64, NodeHash)>, + new_nodes: Vec<(u64, Hash)>, before_num_leaves: u64, to_destroy: Vec, - ) -> Result<(Proof, Vec), String> { + ) -> Result<(Proof, Vec), String> { // Combine the hashes with the targets. - let orig_targets_with_hash: Vec<(u64, NodeHash)> = self + let orig_targets_with_hash: Vec<(u64, Hash)> = self .targets .iter() .copied() @@ -616,7 +686,7 @@ impl Proof { // remembers is an index telling what newly created UTXO should be cached for remember in remembers { - let remember_target: Option<&NodeHash> = adds.get(remember as usize); + let remember_target: Option<&Hash> = adds.get(remember as usize); if let Some(remember_target) = remember_target { let node = new_nodes_iter.find(|(_, hash)| *hash == *remember_target); if let Some((pos, hash)) = node { @@ -655,22 +725,23 @@ impl Proof { } new_proof.sort(); - let (_, hashes): (Vec, Vec) = new_proof.into_iter().unzip(); + let (_, hashes): (Vec, Vec) = new_proof.into_iter().unzip(); Ok(( - Proof { + Proof:: { hashes, targets: new_target_pos, }, target_hashes, )) } + /// maybe_remap remaps the passed in hash and pos if the tree_rows increase after /// adding the new nodes. fn maybe_remap( num_leaves: u64, num_adds: u64, - positions: Vec<(u64, NodeHash)>, - ) -> Vec<(u64, NodeHash)> { + positions: Vec<(u64, Hash)>, + ) -> Vec<(u64, Hash)> { let new_forest_rows = util::tree_rows(num_leaves + num_adds); let old_forest_rows = util::tree_rows(num_leaves); let tree_rows = util::tree_rows(num_leaves); @@ -697,13 +768,13 @@ impl Proof { fn update_proof_remove( self, block_targets: Vec, - cached_hashes: Vec, - updated: Vec<(u64, NodeHash)>, + cached_hashes: Vec, + updated: Vec<(u64, Hash)>, num_leaves: u64, - ) -> Result<(Proof, Vec), String> { + ) -> Result<(Proof, Vec), String> { let total_rows = util::tree_rows(num_leaves); - let targets_with_hash: Vec<(u64, NodeHash)> = self + let targets_with_hash: Vec<(u64, Hash)> = self .targets .iter() .cloned() @@ -767,7 +838,7 @@ impl Proof { proof_elements.sort(); // Grab the hashes for the proof - let (_, hashes): (Vec, Vec) = proof_elements.into_iter().unzip(); + let (_, hashes): (Vec, Vec) = proof_elements.into_iter().unzip(); // Gets all proof targets, but with their new positions after delete let (targets, target_hashes) = Proof::calc_next_positions(&block_targets, &targets_with_hash, num_leaves, true)? @@ -779,10 +850,10 @@ impl Proof { fn calc_next_positions( block_targets: &Vec, - old_positions: &Vec<(u64, NodeHash)>, + old_positions: &Vec<(u64, Hash)>, num_leaves: u64, append_roots: bool, - ) -> Result, String> { + ) -> Result, String> { let total_rows = util::tree_rows(num_leaves); let mut new_positions = vec![]; @@ -825,9 +896,11 @@ mod tests { use serde::Deserialize; use super::Proof; - use crate::accumulator::node_hash::NodeHash; + use crate::accumulator::node_hash::AccumulatorHash; + use crate::accumulator::node_hash::BitcoinNodeHash; use crate::accumulator::stump::Stump; use crate::accumulator::util::hash_from_u8; + #[derive(Deserialize)] struct TestCase { numleaves: usize, @@ -837,12 +910,13 @@ mod tests { proofhashes: Vec, expected: bool, } + /// This test checks whether our update proof works for different scenarios. We start /// with a (valid) cached proof, then we receive `blocks`, like we would in normal Bitcoin /// but for this test, block is just random data. For each block we update our Stump and /// our proof as well, after that, our proof **must** still be valid for the latest Stump. /// - /// Fix-me: Using derive for deserialize, when also using NodeHash leads to an odd + /// Fix-me: Using derive for deserialize, when also using AccumulatorHash leads to an odd /// error that can't be easily fixed. Even bumping version doesn't appear to help. /// Deriving hashes directly reduces the amount of boilerplate code used, and makes everything /// more clearer, hence, it's preferable. @@ -893,19 +967,19 @@ mod tests { .cached_proof .hashes .iter() - .map(|val| NodeHash::from_str(val).unwrap()) + .map(|val| BitcoinNodeHash::from_str(val).unwrap()) .collect(); let cached_hashes: Vec<_> = case_values .cached_hashes .iter() - .map(|val| NodeHash::from_str(val).unwrap()) + .map(|val| BitcoinNodeHash::from_str(val).unwrap()) .collect(); let cached_proof = Proof::new(case_values.cached_proof.targets, proof_hashes); let roots = case_values .initial_roots .into_iter() - .map(|hash| NodeHash::from_str(&hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(&hash).unwrap()) .collect(); let stump = Stump { @@ -923,14 +997,14 @@ mod tests { .update .del_hashes .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect::>(); let block_proof_hashes = case_values .update .proof .hashes .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect::>(); let block_proof = @@ -951,13 +1025,13 @@ mod tests { let expected_roots: Vec<_> = case_values .expected_roots .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect(); let expected_cached_hashes: Vec<_> = case_values .expected_cached_hashes .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect(); assert_eq!(res, Ok(true)); assert_eq!(cached_proof.targets, case_values.expected_targets); @@ -969,29 +1043,54 @@ mod tests { #[test] fn test_get_next() { use super::Proof; - let computed = vec![(1, NodeHash::empty()), (3, NodeHash::empty())]; - let provided = vec![(2, NodeHash::empty()), (4, NodeHash::empty())]; + let computed = vec![(1, BitcoinNodeHash::empty()), (3, BitcoinNodeHash::empty())]; + let provided = vec![(2, BitcoinNodeHash::empty()), (4, BitcoinNodeHash::empty())]; let mut computed_pos = 0; let mut provided_pos = 0; assert_eq!( - Proof::get_next(&computed, &provided, &mut computed_pos, &mut provided_pos), - Some((1, NodeHash::empty())) + Proof::::get_next( + &computed, + &provided, + &mut computed_pos, + &mut provided_pos + ), + Some((1, AccumulatorHash::empty())) ); assert_eq!( - Proof::get_next(&computed, &provided, &mut computed_pos, &mut provided_pos), - Some((2, NodeHash::empty())) + Proof::::get_next( + &computed, + &provided, + &mut computed_pos, + &mut provided_pos + ), + Some((2, AccumulatorHash::empty())) ); assert_eq!( - Proof::get_next(&computed, &provided, &mut computed_pos, &mut provided_pos), - Some((3, NodeHash::empty())) + Proof::::get_next( + &computed, + &provided, + &mut computed_pos, + &mut provided_pos + ), + Some((3, AccumulatorHash::empty())) ); assert_eq!( - Proof::get_next(&computed, &provided, &mut computed_pos, &mut provided_pos), - Some((4, NodeHash::empty())) + Proof::::get_next( + &computed, + &provided, + &mut computed_pos, + &mut provided_pos + ), + Some((4, AccumulatorHash::empty())) ); assert_eq!( - Proof::get_next(&computed, &provided, &mut computed_pos, &mut provided_pos), + Proof::::get_next( + &computed, + &provided, + &mut computed_pos, + &mut provided_pos + ), None ); } @@ -1004,11 +1103,11 @@ mod tests { struct Test { name: &'static str, block_targets: Vec, - old_positions: Vec<(u64, NodeHash)>, + old_positions: Vec<(u64, BitcoinNodeHash)>, num_leaves: u64, num_adds: u64, append_roots: bool, - expected: Vec<(u64, NodeHash)>, + expected: Vec<(u64, BitcoinNodeHash)>, } let tests = vec![Test { @@ -1017,28 +1116,28 @@ mod tests { old_positions: vec![ ( 1, - NodeHash::from_str( + BitcoinNodeHash::from_str( "4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a", ) .unwrap(), ), ( 13, - NodeHash::from_str( + BitcoinNodeHash::from_str( "9d1e0e2d9459d06523ad13e28a4093c2316baafe7aec5b25f30eba2e113599c4", ) .unwrap(), ), ( 17, - NodeHash::from_str( + BitcoinNodeHash::from_str( "9576f4ade6e9bc3a6458b506ce3e4e890df29cb14cb5d3d887672aef55647a2b", ) .unwrap(), ), ( 25, - NodeHash::from_str( + BitcoinNodeHash::from_str( "29590a14c1b09384b94a2c0e94bf821ca75b62eacebc47893397ca88e3bbcbd7", ) .unwrap(), @@ -1050,28 +1149,28 @@ mod tests { expected: (vec![ ( 1, - NodeHash::from_str( + BitcoinNodeHash::from_str( "4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a", ) .unwrap(), ), ( 17, - NodeHash::from_str( + BitcoinNodeHash::from_str( "9576f4ade6e9bc3a6458b506ce3e4e890df29cb14cb5d3d887672aef55647a2b", ) .unwrap(), ), ( 21, - NodeHash::from_str( + BitcoinNodeHash::from_str( "9d1e0e2d9459d06523ad13e28a4093c2316baafe7aec5b25f30eba2e113599c4", ) .unwrap(), ), ( 25, - NodeHash::from_str( + BitcoinNodeHash::from_str( "29590a14c1b09384b94a2c0e94bf821ca75b62eacebc47893397ca88e3bbcbd7", ) .unwrap(), @@ -1091,6 +1190,7 @@ mod tests { assert_eq!(res, test.expected, "testcase: \"{}\" fail", test.name); } } + #[test] fn test_update_proof_delete() { let preimages = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; @@ -1107,7 +1207,7 @@ mod tests { ]; let proof_hashes = proof_hashes .into_iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect(); let cached_proof_hashes = [ @@ -1117,7 +1217,7 @@ mod tests { ]; let cached_proof_hashes = cached_proof_hashes .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect(); let cached_proof = Proof::new(vec![0, 1, 7], cached_proof_hashes); @@ -1142,6 +1242,7 @@ mod tests { let res = stump.verify(&new_proof, &[hash_from_u8(0), hash_from_u8(7)]); assert_eq!(res, Ok(true)); } + #[test] fn test_calculate_hashes() { // Tests if the calculated roots and nodes are correct. @@ -1164,7 +1265,7 @@ mod tests { ]; let proof_hashes = proof .into_iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect(); let p = Proof::new(vec![0, 2, 4, 6], proof_hashes); @@ -1191,12 +1292,12 @@ mod tests { let expected_roots: Vec<_> = expected_roots .iter() - .map(|root| NodeHash::from_str(root).unwrap()) + .map(|root| BitcoinNodeHash::from_str(root).unwrap()) .collect(); let mut expected_computed = expected_hashes .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .zip(&expected_pos); let calculated = p.calculate_hashes(&del_hashes, s.leaves); @@ -1235,22 +1336,24 @@ mod tests { let proof_hashes = proof .into_iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect(); let p = Proof::new(vec![0], proof_hashes); let del_hashes = del_hashes .into_iter() - .map(|hash| (hash, NodeHash::empty())) + .map(|hash| (hash, BitcoinNodeHash::empty())) .collect::>(); let (computed, roots) = p.calculate_hashes_delete(&del_hashes, 8).unwrap(); - let expected_root_old = - NodeHash::from_str("b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42") - .unwrap(); - let expected_root_new = - NodeHash::from_str("726fdd3b432cc59e68487d126e70f0db74a236267f8daeae30b31839a4e7ebed") - .unwrap(); + let expected_root_old = BitcoinNodeHash::from_str( + "b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42", + ) + .unwrap(); + let expected_root_new = BitcoinNodeHash::from_str( + "726fdd3b432cc59e68487d126e70f0db74a236267f8daeae30b31839a4e7ebed", + ) + .unwrap(); let computed_positions = [0_u64, 1, 9, 13, 8, 12, 14].to_vec(); let computed_hashes = [ @@ -1263,7 +1366,7 @@ mod tests { "726fdd3b432cc59e68487d126e70f0db74a236267f8daeae30b31839a4e7ebed", ] .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect::>(); let expected_computed: Vec<_> = computed_positions .into_iter() @@ -1282,6 +1385,7 @@ mod tests { let deserialized = Proof::deserialize(&mut serialized.as_slice()).unwrap(); assert_eq!(p, deserialized); } + #[test] fn test_get_proof_subset() { // Tests if the calculated roots and nodes are correct. @@ -1304,7 +1408,7 @@ mod tests { ]; let proof_hashes = proof .into_iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect(); let p = Proof::new(vec![0, 2, 4, 6], proof_hashes); @@ -1314,6 +1418,7 @@ mod tests { assert_eq!(s.verify(&subset, &[del_hashes[0]]), Ok(true)); assert_eq!(s.verify(&subset, &[del_hashes[2]]), Ok(false)); } + #[test] #[cfg(feature = "with-serde")] fn test_serde_rtt() { @@ -1324,12 +1429,13 @@ mod tests { serde_json::from_str(&serialized).expect("Deserialization failed"); assert_eq!(proof, deserialized); } + fn run_single_case(case: &serde_json::Value) { let case = serde_json::from_value::(case.clone()).expect("Invalid test case"); let roots = case .roots .into_iter() - .map(|root| NodeHash::from_str(root.as_str()).expect("Test case hash is valid")) + .map(|root| BitcoinNodeHash::from_str(root.as_str()).expect("Test case hash is valid")) .collect(); let s = Stump { @@ -1347,7 +1453,7 @@ mod tests { let proof_hashes = case .proofhashes .into_iter() - .map(|hash| NodeHash::from_str(hash.as_str()).expect("Test case hash is valid")) + .map(|hash| BitcoinNodeHash::from_str(hash.as_str()).expect("Test case hash is valid")) .collect(); let p = Proof::new(targets, proof_hashes); @@ -1357,16 +1463,17 @@ mod tests { assert!(Ok(expected) == res); // Test getting proof subset (only if the original proof is valid) if expected { - let (subset, _) = p.targets.split_at(p.targets() / 2); + let (subset, _) = p.targets.split_at(p.n_targets() / 2); let proof = p.get_proof_subset(&del_hashes, subset, s.leaves).unwrap(); let set_hashes = subset .iter() .map(|preimage| hash_from_u8(*preimage as u8)) - .collect::>(); + .collect::>(); assert_eq!(s.verify(&proof, &set_hashes), Ok(true)); } } + #[test] fn test_proof_verify() { let contents = std::fs::read_to_string("test_values/test_cases.json") @@ -1411,6 +1518,7 @@ mod bench { bencher.iter(|| proof.calculate_hashes(&cached_hashes, stump.leaves)) } + #[bench] fn bench_proof_update(bencher: &mut Bencher) { let preimages = [0_u8, 1, 2, 3, 4, 5]; diff --git a/src/accumulator/stump.rs b/src/accumulator/stump.rs index 46613cf..211a2db 100644 --- a/src/accumulator/stump.rs +++ b/src/accumulator/stump.rs @@ -6,13 +6,13 @@ //! ``` //! use std::str::FromStr; //! -//! use rustreexo::accumulator::node_hash::NodeHash; +//! use rustreexo::accumulator::node_hash::BitcoinNodeHash; //! use rustreexo::accumulator::proof::Proof; //! use rustreexo::accumulator::stump::Stump; //! // Create a new empty Stump //! let s = Stump::new(); //! // The newly create outputs -//! let utxos = vec![NodeHash::from_str( +//! let utxos = vec![BitcoinNodeHash::from_str( //! "b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42", //! ) //! .unwrap()]; @@ -21,7 +21,7 @@ //! // Modify the Stump, adding the new outputs and removing the spent ones, notice how //! // it returns a new Stump, instead of modifying the old one. This is due to the fact //! // that modify is a pure function that doesn't modify the old Stump. -//! let s = s.modify(&utxos, &stxos, &Proof::default()); +//! let s = s.modify(&utxos, &stxos, &Proof::::default()); //! assert!(s.is_ok()); //! assert_eq!(s.unwrap().0.roots, utxos); //! ``` @@ -35,28 +35,35 @@ use serde::Deserialize; #[cfg(feature = "with-serde")] use serde::Serialize; -use super::node_hash::NodeHash; +use super::node_hash::AccumulatorHash; +use super::node_hash::BitcoinNodeHash; use super::proof::NodesAndRootsOldNew; use super::proof::Proof; use super::util; -#[derive(Debug, Clone, PartialEq, Default)] +#[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "with-serde", derive(Serialize, Deserialize))] -pub struct Stump { +pub struct Stump { pub leaves: u64, - pub roots: Vec, + pub roots: Vec, +} + +impl Default for Stump { + fn default() -> Self { + Stump::new() + } } #[derive(Debug, Clone, Default)] -pub struct UpdateData { +pub struct UpdateData { /// to_destroy is the positions of the empty roots removed after the add. pub(crate) to_destroy: Vec, /// pre_num_leaves is the numLeaves of the stump before the add. pub(crate) prev_num_leaves: u64, /// new_add are the new hashes for the newly created roots after the addition. - pub(crate) new_add: Vec<(u64, NodeHash)>, + pub(crate) new_add: Vec<(u64, Hash)>, /// new_del are the new hashes after the deletion. - pub(crate) new_del: Vec<(u64, NodeHash)>, + pub(crate) new_del: Vec<(u64, Hash)>, } impl Stump { @@ -72,9 +79,87 @@ impl Stump { roots: Vec::new(), } } - pub fn verify(&self, proof: &Proof, del_hashes: &[NodeHash]) -> Result { + + /// Serialize the Stump into a byte array + /// # Example + /// ``` + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; + /// use rustreexo::accumulator::proof::Proof; + /// use rustreexo::accumulator::stump::Stump; + /// let hashes = [0, 1, 2, 3, 4, 5, 6, 7] + /// .iter() + /// .map(|&el| BitcoinNodeHash::from([el; 32])) + /// .collect::>(); + /// let (stump, _) = Stump::new() + /// .modify(&hashes, &[], &Proof::default()) + /// .unwrap(); + /// let mut writer = Vec::new(); + /// stump.serialize(&mut writer).unwrap(); + /// assert_eq!( + /// writer, + /// vec![ + /// 8, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 150, 124, 244, 241, 98, 69, 217, + /// 222, 235, 97, 61, 137, 135, 76, 197, 134, 232, 173, 253, 8, 28, 17, 124, 123, 16, 4, + /// 66, 30, 63, 113, 246, 74, + /// ] + /// ); + /// ``` + pub fn serialize(&self, mut writer: &mut Dst) -> std::io::Result { + let mut len = 8; + writer.write_all(&self.leaves.to_le_bytes())?; + writer.write_all(&self.roots.len().to_le_bytes())?; + for root in self.roots.iter() { + len += 32; + root.write(&mut writer)?; + } + Ok(len) + } + + /// Rewinds old tree state, this should be used in case of reorgs. + /// Takes the ownership over `old_state`. + ///# Example + /// ``` + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; + /// use rustreexo::accumulator::proof::Proof; + /// use rustreexo::accumulator::stump::Stump; + /// + /// let s_old = Stump::new(); + /// let mut s_new = Stump::new(); + /// + /// let s_old = s_old.modify(&vec![], &vec![], &Proof::default()).unwrap().0; + /// s_new = s_old.clone(); + /// s_new = s_new.modify(&vec![], &vec![], &Proof::default()).unwrap().0; + /// + /// // A reorg happened + /// s_new.undo(s_old); + /// ``` + pub fn undo(&mut self, old_state: Stump) { + self.leaves = old_state.leaves; + self.roots = old_state.roots; + } +} + +impl Stump { + /// Verifies the proof against the Stump. The proof is a list of hashes that are used to + /// recompute the root of the accumulator. The del_hashes are the hashes that are being + /// deleted from the accumulator. + /// // TODO: Add example + pub fn verify(&self, proof: &Proof, del_hashes: &[Hash]) -> Result { proof.verify(del_hashes, &self.roots, self.leaves) } + + /// Creates a new Stump with a custom hash type + /// + /// If you need to use a hash type that's not the [BitcoinNodeHash], you can use this + /// function to create a new Stump with the desired hash type. Use [BitcoinNodeHash::new] + /// to create a new Stump with the default hash type. + pub fn new_with_hash() -> Stump { + Stump { + leaves: 0, + roots: Vec::new(), + } + } + /// Modify is the external API to change the accumulator state. Since order /// matters, you can only modify, providing a list of utxos to be added, /// and txos to be removed, along with it's proof. Either may be @@ -83,12 +168,12 @@ impl Stump { /// ``` /// use std::str::FromStr; /// - /// use rustreexo::accumulator::node_hash::NodeHash; + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; /// use rustreexo::accumulator::proof::Proof; /// use rustreexo::accumulator::stump::Stump; /// /// let s = Stump::new(); - /// let utxos = vec![NodeHash::from_str( + /// let utxos = vec![BitcoinNodeHash::from_str( /// "b151a956139bb821d4effa34ea95c17560e0135d1e4661fc23cedc3af49dac42", /// ) /// .unwrap()]; @@ -99,10 +184,10 @@ impl Stump { /// ``` pub fn modify( &self, - utxos: &[NodeHash], - del_hashes: &[NodeHash], - proof: &Proof, - ) -> Result<(Stump, UpdateData), String> { + utxos: &[Hash], + del_hashes: &[Hash], + proof: &Proof, + ) -> Result<(Stump, UpdateData), String> { let (intermediate, mut computed_roots) = self.remove(del_hashes, proof)?; let mut new_roots = vec![]; @@ -139,101 +224,48 @@ impl Stump { Ok((new_stump, update_data)) } - /// Serialize the Stump into a byte array - /// # Example - /// ``` - /// use rustreexo::accumulator::node_hash::NodeHash; - /// use rustreexo::accumulator::proof::Proof; - /// use rustreexo::accumulator::stump::Stump; - /// let hashes = [0, 1, 2, 3, 4, 5, 6, 7] - /// .iter() - /// .map(|&el| NodeHash::from([el; 32])) - /// .collect::>(); - /// let (stump, _) = Stump::new() - /// .modify(&hashes, &[], &Proof::default()) - /// .unwrap(); - /// let mut writer = Vec::new(); - /// stump.serialize(&mut writer).unwrap(); - /// assert_eq!( - /// writer, - /// vec![ - /// 8, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 150, 124, 244, 241, 98, 69, 217, 222, - /// 235, 97, 61, 137, 135, 76, 197, 134, 232, 173, 253, 8, 28, 17, 124, 123, 16, 4, 66, 30, - /// 63, 113, 246, 74, - /// ] - /// ); - /// ``` - pub fn serialize(&self, writer: &mut Dst) -> std::io::Result { - let mut len = 8; - writer.write_all(&self.leaves.to_le_bytes())?; - writer.write_all(&self.roots.len().to_le_bytes())?; - for root in self.roots.iter() { - len += 32; - writer.write_all(&**root)?; - } - Ok(len) - } + /// Deserialize the Stump from a Reader /// # Example /// ``` - /// use rustreexo::accumulator::node_hash::NodeHash; + /// use rustreexo::accumulator::node_hash::BitcoinNodeHash; /// use rustreexo::accumulator::proof::Proof; /// use rustreexo::accumulator::stump::Stump; /// let buffer = vec![ - /// 8, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 150, 124, 244, 241, 98, 69, 217, 222, 235, - /// 97, 61, 137, 135, 76, 197, 134, 232, 173, 253, 8, 28, 17, 124, 123, 16, 4, 66, 30, 63, 113, - /// 246, 74, + /// 8, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 150, 124, 244, 241, 98, 69, 217, 222, + /// 235, 97, 61, 137, 135, 76, 197, 134, 232, 173, 253, 8, 28, 17, 124, 123, 16, 4, 66, 30, 63, + /// 113, 246, 74, /// ]; /// let mut buffer = std::io::Cursor::new(buffer); /// let hashes = [0, 1, 2, 3, 4, 5, 6, 7] /// .iter() - /// .map(|&el| NodeHash::from([el; 32])) + /// .map(|&el| BitcoinNodeHash::from([el; 32])) /// .collect::>(); /// let (stump, _) = Stump::new() /// .modify(&hashes, &[], &Proof::default()) /// .unwrap(); - /// assert_eq!(stump, Stump::deserialize(buffer).unwrap()); + /// assert_eq!( + /// stump, + /// Stump::::deserialize(buffer).unwrap() + /// ); /// ``` - pub fn deserialize(data: Source) -> Result { - let mut data = data; + pub fn deserialize(mut data: Source) -> Result { let leaves = util::read_u64(&mut data)?; let roots_len = util::read_u64(&mut data)?; let mut roots = vec![]; for _ in 0..roots_len { - let mut root = [0u8; 32]; - data.read_exact(&mut root).map_err(|e| e.to_string())?; - - roots.push(NodeHash::from(root)); + let root = Hash::read(&mut data).map_err(|e| e.to_string())?; + roots.push(root); } Ok(Stump { leaves, roots }) } - /// Rewinds old tree state, this should be used in case of reorgs. - /// Takes the ownership over `old_state`. - ///# Example - /// ``` - /// use rustreexo::accumulator::proof::Proof; - /// use rustreexo::accumulator::stump::Stump; - /// let s_old = Stump::new(); - /// let mut s_new = Stump::new(); - /// - /// let s_old = s_old.modify(&vec![], &vec![], &Proof::default()).unwrap().0; - /// s_new = s_old.clone(); - /// s_new = s_new.modify(&vec![], &vec![], &Proof::default()).unwrap().0; - /// - /// // A reorg happened - /// s_new.undo(s_old); - /// ``` - pub fn undo(&mut self, old_state: Stump) { - self.leaves = old_state.leaves; - self.roots = old_state.roots; - } fn remove( &self, - del_hashes: &[NodeHash], - proof: &Proof, - ) -> Result { + del_hashes: &[Hash], + proof: &Proof, + ) -> Result, String> { if del_hashes.is_empty() { return Ok(( vec![], @@ -243,19 +275,20 @@ impl Stump { let del_hashes = del_hashes .iter() - .map(|hash| (*hash, NodeHash::empty())) + .map(|hash| (*hash, Hash::empty())) .collect::>(); + proof.calculate_hashes_delete(&del_hashes, self.leaves) } /// Adds new leaves into the root fn add( - mut roots: Vec, - utxos: &[NodeHash], + mut roots: Vec, + utxos: &[Hash], mut leaves: u64, - ) -> (Vec, Vec<(u64, NodeHash)>, Vec) { + ) -> (Vec, Vec<(u64, Hash)>, Vec) { let after_rows = util::tree_rows(leaves + (utxos.len() as u64)); - let mut updated_subtree: Vec<(u64, NodeHash)> = vec![]; + let mut updated_subtree: Vec<(u64, Hash)> = vec![]; let all_deleted = util::roots_to_destroy(utxos.len() as u64, leaves, &roots); for (i, add) in utxos.iter().enumerate() { @@ -287,7 +320,7 @@ impl Stump { updated_subtree.push((pos, to_add)); pos = util::parent(pos, after_rows); - to_add = NodeHash::parent_hash(&root, &to_add); + to_add = AccumulatorHash::parent_hash(&root, &to_add); } } h += 1; @@ -306,13 +339,17 @@ impl Stump { #[cfg(test)] mod test { + use std::fmt::Display; + use std::io::Read; + use std::io::Write; use std::str::FromStr; use std::vec; use serde::Deserialize; use super::Stump; - use crate::accumulator::node_hash::NodeHash; + use crate::accumulator::node_hash::AccumulatorHash; + use crate::accumulator::node_hash::BitcoinNodeHash; use crate::accumulator::proof::Proof; use crate::accumulator::util::hash_from_u8; @@ -331,6 +368,75 @@ mod test { assert!(s.leaves == 0); assert!(s.roots.is_empty()); } + + #[test] + fn test_custom_hash_type() { + #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)] + struct CustomHash([u8; 32]); + + impl Display for CustomHash { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.0) + } + } + + impl AccumulatorHash for CustomHash { + fn empty() -> Self { + CustomHash([0; 32]) + } + fn is_empty(&self) -> bool { + self.0.iter().all(|&x| x == 0) + } + fn parent_hash(left: &Self, right: &Self) -> Self { + let mut hash = [0; 32]; + for i in 0..32 { + hash[i] = left.0[i] ^ right.0[i]; + } + CustomHash(hash) + } + fn read(reader: &mut R) -> std::io::Result { + let mut hash = [0; 32]; + reader + .read_exact(&mut hash) + .map_err(|e| e.to_string()) + .unwrap(); + Ok(CustomHash(hash)) + } + fn write(&self, writer: &mut W) -> std::io::Result<()> { + writer + .write_all(&self.0) + .map_err(|e| e.to_string()) + .unwrap(); + Ok(()) + } + fn is_placeholder(&self) -> bool { + false + } + fn placeholder() -> Self { + CustomHash([0; 32]) + } + } + + let s = Stump::::new_with_hash(); + assert!(s.leaves == 0); + assert!(s.roots.is_empty()); + + let hashes = [0, 1, 2, 3, 4, 5, 6, 7] + .iter() + .map(|&el| CustomHash([el; 32])) + .collect::>(); + + let (stump, _) = s + .modify( + &hashes, + &[], + &Proof::::new_with_hash(Vec::new(), Vec::new()), + ) + .unwrap(); + assert_eq!(stump.leaves, 8); + assert_eq!(stump.roots.len(), 1); + } + #[test] fn test_updated_data() { /// This test initializes a Stump, with some utxos. Then, we add a couple more utxos @@ -366,7 +472,7 @@ mod test { let roots = data .roots .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect(); let stump = Stump { leaves: data.leaves, @@ -381,12 +487,12 @@ mod test { let del_hashes = data .del_hashes .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect::>(); let proof_hashes = data .proof_hashes .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect::>(); let proof = Proof::new(data.proof_targets, proof_hashes); let (_, updated) = stump.modify(&utxos, &del_hashes, &proof).unwrap(); @@ -394,7 +500,7 @@ mod test { let new_add_hash: Vec<_> = data .new_add_hash .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect(); let new_add: Vec<_> = data .new_add_pos @@ -405,7 +511,7 @@ mod test { let new_del_hash: Vec<_> = data .new_del_hashes .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect(); let new_del: Vec<_> = data .new_del_pos @@ -421,6 +527,7 @@ mod test { } } } + #[test] fn test_update_data_add() { let preimages = vec![0, 1, 2, 3]; @@ -442,7 +549,7 @@ mod test { "df46b17be5f66f0750a4b3efa26d4679db170a72d41eb56c3e4ff75a58c65386", ] .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect(); let positions: Vec<_> = positions.into_iter().zip(hashes).collect(); @@ -462,6 +569,7 @@ mod test { serde_json::from_str(&serialized).expect("Deserialization failed"); assert_eq!(stump, deserialized); } + fn run_case_with_deletion(case: TestCase) { let leaf_hashes = case .leaf_preimages @@ -481,7 +589,9 @@ mod test { .proofhashes .unwrap_or_default() .into_iter() - .map(|hash| NodeHash::from_str(hash.as_str()).expect("Test case hashes are valid")) + .map(|hash| { + BitcoinNodeHash::from_str(hash.as_str()).expect("Test case hashes are valid") + }) .collect::>(); let proof = Proof::new(case.target_values.unwrap(), proof_hashes); @@ -489,8 +599,10 @@ mod test { let roots = case .expected_roots .into_iter() - .map(|hash| NodeHash::from_str(hash.as_str()).expect("Test case hashes are valid")) - .collect::>(); + .map(|hash| { + BitcoinNodeHash::from_str(hash.as_str()).expect("Test case hashes are valid") + }) + .collect::>(); let (stump, _) = Stump::new() .modify(&leaf_hashes, &[], &Proof::default()) @@ -519,6 +631,7 @@ mod test { assert_eq!(root, s.roots[i].to_string()); } } + #[test] fn test_undo() { let mut hashes = vec![]; @@ -555,7 +668,7 @@ mod test { fn test_serialize() { let hashes = [0, 1, 2, 3, 4, 5, 6, 7] .iter() - .map(|&el| NodeHash::from([el; 32])) + .map(|&el| BitcoinNodeHash::from([el; 32])) .collect::>(); let (stump, _) = Stump::new() .modify(&hashes, &[], &Proof::default()) @@ -567,6 +680,7 @@ mod test { let stump2 = Stump::deserialize(&mut reader).unwrap(); assert_eq!(stump, stump2); } + #[test] fn run_test_cases() { #[derive(Deserialize)] @@ -598,7 +712,7 @@ mod bench { use test::Bencher; use super::Stump; - use crate::accumulator::node_hash::NodeHash; + use crate::accumulator::node_hash::AccumulatorHash; use crate::accumulator::proof::Proof; use crate::accumulator::util::hash_from_u8; @@ -613,13 +727,14 @@ mod bench { "4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459f", ] .iter() - .map(|&hash| NodeHash::try_from(hash).unwrap()) + .map(|&hash| AccumulatorHash::try_from(hash).unwrap()) .collect::>(); let proof = &Proof::default(); bencher.iter(move || { let _ = Stump::new().modify(&hash, &[], &proof); }); } + #[bench] fn bench_add_del(bencher: &mut Bencher) { let leaf_preimages = [0_u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; @@ -639,7 +754,7 @@ mod bench { "c413035120e8c9b0ca3e40c93d06fe60a0d056866138300bb1f1dd172b4923c3", ] .iter() - .map(|&value| NodeHash::try_from(value).unwrap()) + .map(|&value| AccumulatorHash::try_from(value).unwrap()) .collect::>(); let acc = Stump::new() .modify(&leaves, &vec![], &Proof::default()) diff --git a/src/accumulator/util.rs b/src/accumulator/util.rs index b9a90b9..5a62809 100644 --- a/src/accumulator/util.rs +++ b/src/accumulator/util.rs @@ -1,7 +1,7 @@ use std::io::Read; // Rustreexo -use super::node_hash::NodeHash; +use super::node_hash::AccumulatorHash; // isRootPosition checks if the current position is a root given the number of // leaves and the entire rows of the forest. @@ -89,7 +89,11 @@ pub fn left_sibling(position: u64) -> u64 { } // roots_to_destroy returns the empty roots that get written over after num_adds // amount of leaves have been added. -pub fn roots_to_destroy(num_adds: u64, mut num_leaves: u64, orig_roots: &[NodeHash]) -> Vec { +pub fn roots_to_destroy( + num_adds: u64, + mut num_leaves: u64, + orig_roots: &[Hash], +) -> Vec { let mut roots = orig_roots.to_vec(); let mut deleted = vec![]; let mut h = 0; @@ -106,7 +110,7 @@ pub fn roots_to_destroy(num_adds: u64, mut num_leaves: u64, orig_roots: &[NodeHa h += 1; } // Just adding a non-zero value to the slice. - roots.push(NodeHash::placeholder()); + roots.push(AccumulatorHash::placeholder()); num_leaves += 1; } @@ -233,7 +237,7 @@ pub fn parent_many(pos: u64, rise: u8, forest_rows: u8) -> Result { rise, forest_rows )); } - let mask = ((2_u64 << forest_rows) - 1) as u64; + let mask: u64 = (2_u64 << forest_rows) - 1_u64; Ok((pos >> rise | (mask << (forest_rows - (rise - 1)) as u64)) & mask) } @@ -304,24 +308,27 @@ pub fn get_proof_positions(targets: &[u64], num_leaves: u64, forest_rows: u8) -> proof_positions } + #[cfg(any(test, bench))] -pub fn hash_from_u8(value: u8) -> NodeHash { +pub fn hash_from_u8(value: u8) -> super::node_hash::BitcoinNodeHash { use bitcoin_hashes::sha256; use bitcoin_hashes::Hash; use bitcoin_hashes::HashEngine; + let mut engine = bitcoin_hashes::sha256::Hash::engine(); engine.input(&[value]); sha256::Hash::from_engine(engine).into() } + #[cfg(test)] mod tests { use std::str::FromStr; use std::vec; use super::roots_to_destroy; - use crate::accumulator::node_hash::NodeHash; + use crate::accumulator::node_hash::BitcoinNodeHash; use crate::accumulator::util::children; use crate::accumulator::util::tree_rows; @@ -338,6 +345,7 @@ mod tests { super::get_proof_positions(&sorted, num_leaves, num_rows) ); } + #[test] fn test_is_sibling() { assert!(super::is_sibling(0, 1)); @@ -345,6 +353,7 @@ mod tests { assert!(!super::is_sibling(1, 2)); assert!(!super::is_sibling(2, 1)); } + #[test] fn test_root_position() { let pos = super::root_position(5, 2, 3); @@ -353,10 +362,12 @@ mod tests { let pos = super::root_position(5, 0, 3); assert_eq!(pos, 4); } + #[test] fn test_is_right_sibling() { assert!(super::is_right_sibling(0, 1)); } + #[test] fn test_roots_to_destroy() { let roots = [ @@ -367,13 +378,14 @@ mod tests { ]; let roots = roots .iter() - .map(|hash| NodeHash::from_str(hash).unwrap()) + .map(|hash| BitcoinNodeHash::from_str(hash).unwrap()) .collect::>(); let deleted = roots_to_destroy(1, 15, &roots); assert_eq!(deleted, vec![22, 28]) } + #[test] fn test_remove_bit() { // This should remove just one bit from the final number @@ -388,6 +400,7 @@ mod tests { let res = super::remove_bit(14, 1); assert_eq!(res, 6); } + #[test] fn test_detwin() { // 14 @@ -413,12 +426,14 @@ mod tests { assert_eq!(super::tree_rows(12), 4); assert_eq!(super::tree_rows(255), 8); } + fn row_offset(row: u8, forest_rows: u8) -> u64 { // 2 << forestRows is 2 more than the max position // to get the correct offset for a given row, // subtract (2 << `row complement of forestRows`) from (2 << forestRows) (2 << forest_rows) - (2 << (forest_rows - row)) } + #[test] fn test_detect_row() { for forest_rows in 1..63 { @@ -437,8 +452,8 @@ mod tests { } } } - #[test] + #[test] fn test_get_proof_positions() { let targets: Vec = vec![4, 5, 7, 8]; let num_leaves = 8; @@ -447,11 +462,13 @@ mod tests { assert_eq!(vec![6, 9], targets); } + #[test] fn test_is_root_position() { let h = super::is_root_position(14, 8, 3); assert!(h); } + #[test] fn test_children_pos() { assert_eq!(children(4, 2), 0); @@ -459,6 +476,7 @@ mod tests { assert_eq!(children(50, 5), 36); assert_eq!(children(44, 5), 24); } + #[test] fn test_calc_next_pos() { let res = super::calc_next_pos(0, 1, 3);