Skip to content

Commit

Permalink
feat: make Merkle backends stateless (#707)
Browse files Browse the repository at this point in the history
Change the associated methods for `IsMerkleBackend` to associated
functions. Make the `Data` and `Node` associated types `Sync + Send`.
This makes it easier to later add parallelism to Merkle tree
construction.
  • Loading branch information
Oppen authored Dec 5, 2023
1 parent 96f4600 commit 3724c1c
Show file tree
Hide file tree
Showing 10 changed files with 53 additions and 53 deletions.
10 changes: 5 additions & 5 deletions crypto/src/merkle_tree/backends/field_element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,19 @@ impl<F, D: Digest, const NUM_BYTES: usize> IsMerkleTreeBackend
for FieldElementBackend<F, D, NUM_BYTES>
where
F: IsField,
FieldElement<F>: Serializable,
FieldElement<F>: Serializable + Sync + Send,
[u8; NUM_BYTES]: From<GenericArray<u8, <D as OutputSizeUser>::OutputSize>>,
{
type Node = [u8; NUM_BYTES];
type Data = FieldElement<F>;

fn hash_data(&self, input: &FieldElement<F>) -> [u8; NUM_BYTES] {
fn hash_data(input: &FieldElement<F>) -> [u8; NUM_BYTES] {
let mut hasher = D::new();
hasher.update(input.serialize());
hasher.finalize().into()
}

fn hash_new_parent(&self, left: &[u8; NUM_BYTES], right: &[u8; NUM_BYTES]) -> [u8; NUM_BYTES] {
fn hash_new_parent(left: &[u8; NUM_BYTES], right: &[u8; NUM_BYTES]) -> [u8; NUM_BYTES] {
let mut hasher = D::new();
hasher.update(left);
hasher.update(right);
Expand All @@ -57,16 +57,16 @@ pub struct TreePoseidon<P: Poseidon + Default> {
impl<P> IsMerkleTreeBackend for TreePoseidon<P>
where
P: Poseidon + Default,
FieldElement<P::F>: Sync + Send,
{
type Node = FieldElement<P::F>;
type Data = FieldElement<P::F>;

fn hash_data(&self, input: &FieldElement<P::F>) -> FieldElement<P::F> {
fn hash_data(input: &FieldElement<P::F>) -> FieldElement<P::F> {
P::hash_single(input)
}

fn hash_new_parent(
&self,
left: &FieldElement<P::F>,
right: &FieldElement<P::F>,
) -> FieldElement<P::F> {
Expand Down
10 changes: 6 additions & 4 deletions crypto/src/merkle_tree/backends/field_element_vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ where
F: IsField,
FieldElement<F>: Serializable,
[u8; NUM_BYTES]: From<GenericArray<u8, <D as OutputSizeUser>::OutputSize>>,
Vec<FieldElement<F>>: Sync + Send,
{
type Node = [u8; NUM_BYTES];
type Data = Vec<FieldElement<F>>;

fn hash_data(&self, input: &Vec<FieldElement<F>>) -> [u8; NUM_BYTES] {
fn hash_data(input: &Vec<FieldElement<F>>) -> [u8; NUM_BYTES] {
let mut hasher = D::new();
for element in input.iter() {
hasher.update(element.serialize());
Expand All @@ -46,7 +47,7 @@ where
result_hash
}

fn hash_new_parent(&self, left: &[u8; NUM_BYTES], right: &[u8; NUM_BYTES]) -> [u8; NUM_BYTES] {
fn hash_new_parent(left: &[u8; NUM_BYTES], right: &[u8; NUM_BYTES]) -> [u8; NUM_BYTES] {
let mut hasher = D::new();
hasher.update(left);
hasher.update(right);
Expand All @@ -64,16 +65,17 @@ pub struct BatchPoseidonTree<P: Poseidon + Default> {
impl<P> IsMerkleTreeBackend for BatchPoseidonTree<P>
where
P: Poseidon + Default,
Vec<FieldElement<P::F>>: Sync + Send,
FieldElement<P::F>: Sync + Send,
{
type Node = FieldElement<P::F>;
type Data = Vec<FieldElement<P::F>>;

fn hash_data(&self, input: &Vec<FieldElement<P::F>>) -> FieldElement<P::F> {
fn hash_data(input: &Vec<FieldElement<P::F>>) -> FieldElement<P::F> {
P::hash_many(input)
}

fn hash_new_parent(
&self,
left: &FieldElement<P::F>,
right: &FieldElement<P::F>,
) -> FieldElement<P::F> {
Expand Down
5 changes: 2 additions & 3 deletions crypto/src/merkle_tree/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ where
B: IsMerkleTreeBackend,
{
pub fn build(unhashed_leaves: &[B::Data]) -> Self {
let hasher = B::default();
let mut hashed_leaves: Vec<B::Node> = hasher.hash_leaves(unhashed_leaves);
let mut hashed_leaves: Vec<B::Node> = B::hash_leaves(unhashed_leaves);

//The leaf must be a power of 2 set
hashed_leaves = complete_until_power_of_two(&mut hashed_leaves);
Expand All @@ -26,7 +25,7 @@ where
inner_nodes.extend(hashed_leaves);

//Build the inner nodes of the tree
build(&mut inner_nodes, ROOT, &hasher);
build::<B>(&mut inner_nodes, ROOT);

MerkleTree {
root: inner_nodes[ROOT].clone(),
Expand Down
7 changes: 3 additions & 4 deletions crypto/src/merkle_tree/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@ impl<T: PartialEq + Eq> Proof<T> {
where
B: IsMerkleTreeBackend<Node = T>,
{
let hasher = B::default();
let mut hashed_value = hasher.hash_data(value);
let mut hashed_value = B::hash_data(value);

for sibling_node in self.merkle_path.iter() {
if index % 2 == 0 {
hashed_value = hasher.hash_new_parent(&hashed_value, sibling_node);
hashed_value = B::hash_new_parent(&hashed_value, sibling_node);
} else {
hashed_value = hasher.hash_new_parent(sibling_node, &hashed_value);
hashed_value = B::hash_new_parent(sibling_node, &hashed_value);
}

index >>= 1;
Expand Down
9 changes: 6 additions & 3 deletions crypto/src/merkle_tree/test_merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,18 @@ impl<F: IsField> Default for TestBackend<F> {
}
}

impl<F: IsField> IsMerkleTreeBackend for TestBackend<F> {
impl<F: IsField> IsMerkleTreeBackend for TestBackend<F>
where
FieldElement<F>: Sync + Send,
{
type Node = FieldElement<F>;
type Data = FieldElement<F>;

fn hash_data(&self, input: &Self::Data) -> Self::Node {
fn hash_data(input: &Self::Data) -> Self::Node {
input + input
}

fn hash_new_parent(&self, left: &Self::Node, right: &Self::Node) -> Self::Node {
fn hash_new_parent(left: &Self::Node, right: &Self::Node) -> Self::Node {
left + right
}
}
12 changes: 6 additions & 6 deletions crypto/src/merkle_tree/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@
/// tree is built from. It also defines the `Node` type and the hash function
/// used to build parent nodes from children nodes.
pub trait IsMerkleTreeBackend: Default {
type Node: PartialEq + Eq + Clone;
type Data;
type Node: PartialEq + Eq + Clone + Sync + Send;
type Data: Sync + Send;

/// This function takes a single variable `Data` and converts it to a node.
fn hash_data(&self, leaf: &Self::Data) -> Self::Node;
fn hash_data(leaf: &Self::Data) -> Self::Node;

/// This function takes the list of data from which the Merkle
/// tree will be built from and converts it to a list of leaf nodes.
fn hash_leaves(&self, unhashed_leaves: &[Self::Data]) -> Vec<Self::Node> {
fn hash_leaves(unhashed_leaves: &[Self::Data]) -> Vec<Self::Node> {
unhashed_leaves
.iter()
.map(|leaf| self.hash_data(leaf))
.map(|leaf| Self::hash_data(leaf))
.collect()
}

/// This function takes to children nodes and builds a new parent node.
/// It will be used in the construction of the Merkle tree.
fn hash_new_parent(&self, child_1: &Self::Node, child_2: &Self::Node) -> Self::Node;
fn hash_new_parent(child_1: &Self::Node, child_2: &Self::Node) -> Self::Node;
}
15 changes: 6 additions & 9 deletions crypto/src/merkle_tree/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub fn is_power_of_two(x: usize) -> bool {
(x != 0) && ((x & (x - 1)) == 0)
}

pub fn build<B: IsMerkleTreeBackend>(nodes: &mut Vec<B::Node>, parent_index: usize, hasher: &B)
pub fn build<B: IsMerkleTreeBackend>(nodes: &mut Vec<B::Node>, parent_index: usize)
where
B::Node: Clone,
{
Expand All @@ -39,11 +39,10 @@ where
let left_child_index = left_child_index(parent_index);
let right_child_index = right_child_index(parent_index);

build(nodes, left_child_index, hasher);
build(nodes, right_child_index, hasher);
build::<B>(nodes, left_child_index);
build::<B>(nodes, right_child_index);

nodes[parent_index] =
hasher.hash_new_parent(&nodes[left_child_index], &nodes[right_child_index]);
nodes[parent_index] = B::hash_new_parent(&nodes[left_child_index], &nodes[right_child_index]);
}

pub fn is_leaf(lenght: usize, node_index: usize) -> bool {
Expand Down Expand Up @@ -73,8 +72,7 @@ mod tests {
// expected |2|4|6|8|
fn hash_leaves_from_a_list_of_field_elemnts() {
let values: Vec<FE> = (1..5).map(FE::new).collect();
let hasher = TestBackend::default();
let hashed_leaves = hasher.hash_leaves(&values);
let hashed_leaves = TestBackend::hash_leaves(&values);
let list_of_nodes = &[FE::new(2), FE::new(4), FE::new(6), FE::new(8)];
for (leaf, expected_leaf) in hashed_leaves.iter().zip(list_of_nodes) {
assert_eq!(leaf, expected_leaf);
Expand Down Expand Up @@ -105,8 +103,7 @@ mod tests {
let mut nodes = vec![FE::zero(); leaves.len() - 1];
nodes.extend(leaves);

let hasher = TestBackend::default();
build::<TestBackend<U64PF>>(&mut nodes, ROOT, &hasher);
build::<TestBackend<U64PF>>(&mut nodes, ROOT);
assert_eq!(nodes[ROOT], FE::new(10));
}
}
6 changes: 3 additions & 3 deletions provers/stark/src/fri/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub fn commit_phase<F: IsFFTField>(
Vec<FriLayer<F, BatchedMerkleTreeBackend<F>>>,
)
where
FieldElement<F>: Serializable,
FieldElement<F>: Serializable + Sync + Send,
{
let mut domain_size = domain_size;

Expand Down Expand Up @@ -77,7 +77,7 @@ pub fn query_phase<F: IsFFTField>(
iotas: &[usize],
) -> Vec<FriDecommitment<F>>
where
FieldElement<F>: Serializable,
FieldElement<F>: Serializable + Sync + Send,
{
if !fri_layers.is_empty() {
let query_list = iotas
Expand Down Expand Up @@ -117,7 +117,7 @@ pub fn new_fri_layer<F: IsFFTField>(
) -> crate::fri::fri_commitment::FriLayer<F, BatchedMerkleTreeBackend<F>>
where
F: IsFFTField,
FieldElement<F>: Serializable,
FieldElement<F>: Serializable + Sync + Send,
{
let mut evaluation = poly
.evaluate_offset_fft(1, Some(domain_size), coset_offset)
Expand Down
16 changes: 8 additions & 8 deletions provers/stark/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub struct Round1<F, A>
where
F: IsFFTField,
A: AIR<Field = F>,
FieldElement<F>: Serializable,
FieldElement<F>: Serializable + Sync + Send,
{
pub(crate) trace_polys: Vec<Polynomial<FieldElement<F>>>,
pub(crate) lde_trace: TraceTable<F>,
Expand All @@ -59,7 +59,7 @@ where
pub struct Round2<F>
where
F: IsFFTField,
FieldElement<F>: Serializable,
FieldElement<F>: Serializable + Sync + Send,
{
pub(crate) composition_poly_parts: Vec<Polynomial<FieldElement<F>>>,
pub(crate) lde_composition_poly_evaluations: Vec<Vec<FieldElement<F>>>,
Expand Down Expand Up @@ -105,7 +105,7 @@ pub trait IsStarkProver {
vectors: &[Vec<FieldElement<Self::Field>>],
) -> (BatchedMerkleTree<Self::Field>, Commitment)
where
FieldElement<Self::Field>: Serializable,
FieldElement<Self::Field>: Serializable + Sync + Send,
{
let tree = BatchedMerkleTree::<Self::Field>::build(vectors);
let commitment = tree.root;
Expand Down Expand Up @@ -222,7 +222,7 @@ pub trait IsStarkProver {
lde_composition_poly_parts_evaluations: &[Vec<FieldElement<Self::Field>>],
) -> (BatchedMerkleTree<Self::Field>, Commitment)
where
FieldElement<Self::Field>: Serializable,
FieldElement<Self::Field>: Serializable + Sync + Send,
{
// TODO: Remove clones
let mut lde_composition_poly_evaluations = Vec::new();
Expand Down Expand Up @@ -310,7 +310,7 @@ pub trait IsStarkProver {
z: &FieldElement<Self::Field>,
) -> Round3<Self::Field>
where
FieldElement<Self::Field>: Serializable,
FieldElement<Self::Field>: Serializable + Sync + Send,
{
let z_power = z.pow(round_2_result.composition_poly_parts.len());

Expand Down Expand Up @@ -565,7 +565,7 @@ pub trait IsStarkProver {
index: usize,
) -> (Proof<Commitment>, Vec<FieldElement<Self::Field>>)
where
FieldElement<Self::Field>: Serializable,
FieldElement<Self::Field>: Serializable + Sync + Send,
{
let proof = composition_poly_merkle_tree
.get_proof_by_pos(index)
Expand All @@ -591,7 +591,7 @@ pub trait IsStarkProver {
index: usize,
) -> (Vec<Proof<Commitment>>, Vec<FieldElement<Self::Field>>)
where
FieldElement<Self::Field>: Serializable,
FieldElement<Self::Field>: Serializable + Sync + Send,
{
let domain_size = domain.lde_roots_of_unity_coset.len();
let lde_trace_evaluations = lde_trace
Expand Down Expand Up @@ -623,7 +623,7 @@ pub trait IsStarkProver {
DeepPolynomialOpenings<Self::Field>,
)
where
FieldElement<Self::Field>: Serializable,
FieldElement<Self::Field>: Serializable + Sync + Send,
{
let mut openings = Vec::new();
let mut openings_symmetric = Vec::new();
Expand Down
16 changes: 8 additions & 8 deletions provers/stark/src/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ pub trait IsStarkVerifier {
challenges: &Challenges<Self::Field, A>,
) -> bool
where
FieldElement<Self::Field>: Serializable,
FieldElement<Self::Field>: Serializable + Sync + Send,
A: AIR<Field = Self::Field>,
{
let (deep_poly_evaluations, deep_poly_evaluations_sym) =
Expand Down Expand Up @@ -367,7 +367,7 @@ pub trait IsStarkVerifier {
value: &[FieldElement<Self::Field>],
) -> bool
where
FieldElement<Self::Field>: Serializable,
FieldElement<Self::Field>: Serializable + Sync + Send,
{
proof.verify::<BatchedMerkleTreeBackend<Self::Field>>(root, index, &value.to_owned())
}
Expand All @@ -382,7 +382,7 @@ pub trait IsStarkVerifier {
iota: usize,
) -> bool
where
FieldElement<Self::Field>: Serializable,
FieldElement<Self::Field>: Serializable + Sync + Send,
{
let lde_trace_evaluations = vec![
deep_poly_openings.lde_trace_evaluations[..num_main_columns].to_vec(),
Expand Down Expand Up @@ -425,7 +425,7 @@ pub trait IsStarkVerifier {
iota: &usize,
) -> bool
where
FieldElement<Self::Field>: Serializable,
FieldElement<Self::Field>: Serializable + Sync + Send,
{
let mut value = deep_poly_openings
.lde_composition_poly_parts_evaluation
Expand All @@ -447,7 +447,7 @@ pub trait IsStarkVerifier {
challenges: &Challenges<F, A>,
) -> bool
where
FieldElement<Self::Field>: Serializable,
FieldElement<Self::Field>: Serializable + Sync + Send,
{
challenges
.iotas
Expand Down Expand Up @@ -486,7 +486,7 @@ pub trait IsStarkVerifier {
iota: usize,
) -> bool
where
FieldElement<Self::Field>: Serializable,
FieldElement<Self::Field>: Serializable + Sync + Send,
{
let evaluations = if iota % 2 == 1 {
vec![evaluation_sym.clone(), evaluation.clone()]
Expand Down Expand Up @@ -519,7 +519,7 @@ pub trait IsStarkVerifier {
deep_composition_evaluation_sym: &FieldElement<Self::Field>,
) -> bool
where
FieldElement<Self::Field>: Serializable,
FieldElement<Self::Field>: Serializable + Sync + Send,
{
let fri_layers_merkle_roots = &proof.fri_layers_merkle_roots;
let evaluation_point_vec: Vec<FieldElement<Self::Field>> =
Expand Down Expand Up @@ -670,7 +670,7 @@ pub trait IsStarkVerifier {
) -> bool
where
A: AIR<Field = Self::Field>,
FieldElement<Self::Field>: Serializable,
FieldElement<Self::Field>: Serializable + Sync + Send,
{
// Verify there are enough queries
if proof.query_list.len() < proof_options.fri_number_of_queries {
Expand Down

0 comments on commit 3724c1c

Please sign in to comment.