Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize merkle #90

Open
wants to merge 20 commits into
base: stylus
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 56 additions & 17 deletions arbitrator/prover/src/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use sha3::Keccak256;
use smallvec::SmallVec;
use std::{
borrow::Cow,
cmp::min,
convert::{TryFrom, TryInto},
fmt::{self, Display},
fs::File,
Expand Down Expand Up @@ -73,10 +74,46 @@ pub struct Function {
code: Vec<Instruction>,
ty: FunctionType,
#[serde(skip)]
code_merkle: Merkle,
opcode_merkle: Merkle,
#[serde(skip)]
argument_data_merkle: Merkle,
local_types: Vec<ArbValueType>,
}

fn code_to_opcode_hash(code: &Vec<Instruction>, opcode_idx: usize) -> Bytes32 {
let seg = opcode_idx / 16;
let seg_start = seg * 16;
let seg_end = min(seg_start + 16, code.len());
let mut b = [0u8; 32];
for i in 0..(seg_end - seg_start) {
b[i * 2..i * 2 + 2]
.copy_from_slice(code[seg_start + i].opcode.repr().to_be_bytes().as_slice())
}
Bytes32(b)
}

fn code_to_opcode_hashes(code: &Vec<Instruction>) -> Vec<Bytes32> {
#[cfg(feature = "native")]
let iter = (0..(code.len() + 15) / 16).into_par_iter();

#[cfg(not(feature = "native"))]
let iter = (0..(code.len() + 15) / 16).into_iter();

iter.map(|i| code_to_opcode_hash(code, i * 16)).collect()
}

#[cfg(feature = "native")]
fn code_to_argdata_hashes(code: &Vec<Instruction>) -> Vec<Bytes32> {
tsahee marked this conversation as resolved.
Show resolved Hide resolved
code.par_iter()
.map(|i| i.get_proving_argument_data())
.collect()
}

#[cfg(not(feature = "native"))]
fn code_to_argdata_hashes(code: &Vec<Instruction>) -> Vec<Bytes32> {
code.iter().map(|i| i.get_proving_argument_data()).collect()
}

impl Function {
pub fn new<F: FnOnce(&mut Vec<Instruction>) -> Result<()>>(
locals: &[Local],
Expand Down Expand Up @@ -133,24 +170,23 @@ impl Function {
"Function instruction count doesn't fit in a u32",
);

#[cfg(feature = "native")]
let code_hashes = code.par_iter().map(|i| i.hash()).collect();

#[cfg(not(feature = "native"))]
let code_hashes = code.iter().map(|i| i.hash()).collect();
let argument_data_hashes = code_to_argdata_hashes(&code);
let opcode_hashes = code_to_opcode_hashes(&code);

Function {
code,
ty,
code_merkle: Merkle::new(MerkleType::Instruction, code_hashes),
opcode_merkle: Merkle::new(MerkleType::Opcode, opcode_hashes),
argument_data_merkle: Merkle::new(MerkleType::ArgumentData, argument_data_hashes),
local_types,
}
}

fn hash(&self) -> Bytes32 {
let mut h = Keccak256::new();
h.update("Function:");
h.update(self.code_merkle.root());
h.update(self.opcode_merkle.root());
h.update(self.argument_data_merkle.root());
h.finalize().into()
}
}
Expand Down Expand Up @@ -1427,13 +1463,11 @@ impl Machine {
let funcs =
Arc::get_mut(&mut module.funcs).expect("Multiple copies of module functions");
for func in funcs.iter_mut() {
#[cfg(feature = "native")]
let code_hashes = func.code.par_iter().map(|i| i.hash()).collect();
let opcode_hashes = code_to_opcode_hashes(&func.code);
let argdata_hashes = code_to_argdata_hashes(&func.code);

#[cfg(not(feature = "native"))]
let code_hashes = func.code.iter().map(|i| i.hash()).collect();

func.code_merkle = Merkle::new(MerkleType::Instruction, code_hashes);
func.opcode_merkle = Merkle::new(MerkleType::Opcode, opcode_hashes);
func.argument_data_merkle = Merkle::new(MerkleType::ArgumentData, argdata_hashes);
}
module.funcs_merkle = Arc::new(Merkle::new(
MerkleType::Function,
Expand Down Expand Up @@ -2640,11 +2674,16 @@ impl Machine {
// Begin next instruction proof

let func = &module.funcs[self.pc.func()];
out!(func.code[self.pc.inst()].serialize_for_proof());
out!(code_to_opcode_hash(&func.code, self.pc.inst()));
out!(func
.code_merkle
.prove(self.pc.inst())
.opcode_merkle
.prove(self.pc.inst() / 16)
.expect("Failed to prove against code merkle"));
out!(func.code[self.pc.inst()].get_proving_argument_data());
out!(func
.argument_data_merkle
.prove(self.pc.inst())
.expect("Failed to prove against argument data merkle"));
out!(module
.funcs_merkle
.prove(self.pc.func())
Expand Down
16 changes: 8 additions & 8 deletions arbitrator/prover/src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::{
use arbutil::Bytes32;
use digest::Digest;
use eyre::{bail, ErrReport, Result};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use sha3::Keccak256;
use std::{borrow::Cow, convert::TryFrom};
Expand Down Expand Up @@ -59,6 +60,10 @@ fn hash_leaf(bytes: [u8; Memory::LEAF_SIZE]) -> Bytes32 {
h.finalize().into()
}

lazy_static! {
pub static ref EMPTY_MEM_HASH: Bytes32 = hash_leaf([0u8; 32]);
}

fn round_up_to_power_of_two(mut input: usize) -> usize {
if input == 0 {
return 1;
Expand All @@ -84,7 +89,7 @@ impl Memory {
pub const PAGE_SIZE: u64 = 65536;
/// The number of layers in the memory merkle tree
/// 1 + log2(2^32 / LEAF_SIZE) = 1 + log2(2^(32 - log2(LEAF_SIZE))) = 1 + 32 - 5
const MEMORY_LAYERS: usize = 1 + 32 - 5;
pub const MEMORY_LAYERS: usize = 1 + 32 - 5;

pub fn new(size: usize, max_size: u64) -> Memory {
Memory {
Expand Down Expand Up @@ -119,15 +124,10 @@ impl Memory {
})
.collect();
if leaf_hashes.len() < leaves {
let empty_hash = hash_leaf([0u8; 32]);
let empty_hash = *EMPTY_MEM_HASH;
leaf_hashes.resize(leaves, empty_hash);
}
Cow::Owned(Merkle::new_advanced(
MerkleType::Memory,
leaf_hashes,
hash_leaf([0u8; 32]),
Self::MEMORY_LAYERS,
))
Cow::Owned(Merkle::new(MerkleType::Memory, leaf_hashes))
}

pub fn get_leaf_data(&self, leaf_idx: usize) -> [u8; Self::LEAF_SIZE] {
Expand Down
96 changes: 71 additions & 25 deletions arbitrator/prover/src/merkle.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
// Copyright 2021-2023, Offchain Labs, Inc.
// For license information, see https://github.com/nitro/blob/master/LICENSE

use crate::memory::{Memory, EMPTY_MEM_HASH};
use arbutil::Bytes32;
use digest::Digest;
use lazy_static::lazy_static;
use sha3::Keccak256;
use std::convert::TryFrom;
use std::collections::HashMap;
use std::sync::RwLock;

#[cfg(feature = "native")]
use rayon::prelude::*;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub enum MerkleType {
Empty,
Value,
Function,
Instruction,
Opcode,
ArgumentData,
Memory,
Table,
TableElement,
Expand All @@ -33,7 +37,8 @@ impl MerkleType {
MerkleType::Empty => panic!("Attempted to get prefix of empty merkle type"),
MerkleType::Value => "Value merkle tree:",
MerkleType::Function => "Function merkle tree:",
MerkleType::Instruction => "Instruction merkle tree:",
MerkleType::Opcode => "Opcode merkle tree:",
MerkleType::ArgumentData => "Argument data merkle tree:",
MerkleType::Memory => "Memory merkle tree:",
MerkleType::Table => "Table merkle tree:",
MerkleType::TableElement => "Table element merkle tree:",
Expand All @@ -46,7 +51,6 @@ impl MerkleType {
pub struct Merkle {
ty: MerkleType,
layers: Vec<Vec<Bytes32>>,
empty_layers: Vec<Bytes32>,
min_depth: usize,
}

Expand All @@ -58,24 +62,63 @@ fn hash_node(ty: MerkleType, a: Bytes32, b: Bytes32) -> Bytes32 {
h.finalize().into()
}

lazy_static! {
static ref EMPTY_LAYERS: RwLock<HashMap<MerkleType, RwLock<Vec<Bytes32>>>> = Default::default();
}

impl Merkle {
pub fn new(ty: MerkleType, hashes: Vec<Bytes32>) -> Merkle {
Self::new_advanced(ty, hashes, Bytes32::default(), 0)
fn get_empty_readonly(ty: MerkleType, layer: usize) -> Option<Bytes32> {
match EMPTY_LAYERS.read().unwrap().get(&ty) {
None => None,
Some(rwvec) => rwvec.read().unwrap().get(layer).copied(),
}
}

pub fn new_advanced(
ty: MerkleType,
hashes: Vec<Bytes32>,
empty_hash: Bytes32,
min_depth: usize,
) -> Merkle {
pub fn get_empty(ty: MerkleType, layer: usize) -> Bytes32 {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: moving from each markle storing it's own empty_layers vec to all sharing the same one didn't save any time on generating proofs. I still like it - but might be worth considering if added complexity is worth it.

let exists = Self::get_empty_readonly(ty, layer);
if let Some(val_exists) = exists {
return val_exists;
}
let new_val: Bytes32;
if layer == 0 {
new_val = match ty {
MerkleType::Empty => {
panic!("attempted to fetch empty-layer value from empty merkle")
}
MerkleType::Memory => *EMPTY_MEM_HASH,
_ => Bytes32::default(),
}
} else {
let prev_val = Self::get_empty(ty, layer - 1);
new_val = hash_node(ty, prev_val, prev_val);
}
let mut layers = EMPTY_LAYERS.write().unwrap();
let mut typed = layers.entry(ty).or_default().write().unwrap();
if typed.len() > layer {
assert_eq!(typed[layer], new_val);
} else if typed.len() == layer {
typed.push(new_val);
} else {
panic!("trying to compute empty merkle entries out of order")
}
return typed[layer];
}

pub fn new(ty: MerkleType, hashes: Vec<Bytes32>) -> Merkle {
if hashes.is_empty() {
return Merkle::default();
}
let min_depth = match ty {
MerkleType::Empty => panic!("attempted to fetch empty-layer value from empty merkle"),
MerkleType::Memory => Memory::MEMORY_LAYERS,
MerkleType::Opcode => 2,
MerkleType::ArgumentData => 2,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: not sure if it's necessary. I gave opcode and argument data min_depth 2 to avoid short functions producing a "root" that's not hashed at all.

_ => 0,
};
let mut layers = vec![hashes];
let mut empty_layers = vec![empty_hash];
while layers.last().unwrap().len() > 1 || layers.len() < min_depth {
let empty_layer = *empty_layers.last().unwrap();
let empty_layer = Self::get_empty(ty, layers.len() - 1);
let next_empty_layer = Self::get_empty(ty, layers.len());

#[cfg(feature = "native")]
let new_layer = layers.last().unwrap().par_chunks(2);
Expand All @@ -84,15 +127,21 @@ impl Merkle {
let new_layer = layers.last().unwrap().chunks(2);

let new_layer = new_layer
.map(|chunk| hash_node(ty, chunk[0], chunk.get(1).cloned().unwrap_or(empty_layer)))
.map(|chunk| {
let left = chunk[0];
let right = chunk.get(1).cloned().unwrap_or(empty_layer);
if left == empty_layer && right == empty_layer {
next_empty_layer
} else {
hash_node(ty, left, right)
}
})
.collect();
empty_layers.push(hash_node(ty, empty_layer, empty_layer));
layers.push(new_layer);
}
Merkle {
ty,
layers,
empty_layers,
min_depth,
}
}
Expand Down Expand Up @@ -135,7 +184,7 @@ impl Merkle {
layer
.get(counterpart)
.cloned()
.unwrap_or_else(|| self.empty_layers[layer_i]),
.unwrap_or_else(|| Self::get_empty(self.ty, layer_i)),
);
idx >>= 1;
}
Expand All @@ -147,25 +196,22 @@ impl Merkle {
pub fn push_leaf(&mut self, leaf: Bytes32) {
let mut leaves = self.layers.swap_remove(0);
leaves.push(leaf);
let empty = self.empty_layers[0];
*self = Self::new_advanced(self.ty, leaves, empty, self.min_depth);
*self = Self::new(self.ty, leaves);
}

/// Removes the rightmost leaf from the merkle
/// Currently O(n) in the number of leaves (could be log(n))
pub fn pop_leaf(&mut self) {
let mut leaves = self.layers.swap_remove(0);
leaves.pop();
let empty = self.empty_layers[0];
*self = Self::new_advanced(self.ty, leaves, empty, self.min_depth);
*self = Self::new(self.ty, leaves);
}

pub fn set(&mut self, mut idx: usize, hash: Bytes32) {
if self.layers[0][idx] == hash {
return;
}
let mut next_hash = hash;
let empty_layers = &self.empty_layers;
let layers_len = self.layers.len();
for (layer_i, layer) in self.layers.iter_mut().enumerate() {
layer[idx] = next_hash;
Expand All @@ -176,7 +222,7 @@ impl Merkle {
let counterpart = layer
.get(idx ^ 1)
.cloned()
.unwrap_or_else(|| empty_layers[layer_i]);
.unwrap_or_else(|| Self::get_empty(self.ty, layer_i));
if idx % 2 == 0 {
next_hash = hash_node(self.ty, next_hash, counterpart);
} else {
Expand Down
Loading