Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize merkle #90

Open
wants to merge 20 commits into
base: stylus
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions arbitrator/prover/src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ pub fn get_impl(module: &str, name: &str) -> Result<(Function, bool)> {
};

let debug = module == "console";
Function::new(&[], append, hostio.ty(), &[]).map(|x| (x, debug))
Function::new(&[], append, hostio.ty()).map(|x| (x, debug))
}

/// Adds internal functions to a module.
Expand Down Expand Up @@ -458,7 +458,6 @@ lazy_static! {
0, // impls don't use other internals
),
ty.clone(),
&[] // impls don't make calls
);
func.expect("failed to create bulk memory func")
})
Expand Down
152 changes: 104 additions & 48 deletions arbitrator/prover/src/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use sha3::Keccak256;
use smallvec::SmallVec;
use std::{
borrow::Cow,
cmp::min,
convert::{TryFrom, TryInto},
fmt::{self, Display},
fs::File,
Expand All @@ -46,14 +47,6 @@ use wasmparser::{DataKind, ElementItem, ElementKind, Operator, TableType};
#[cfg(feature = "native")]
use rayon::prelude::*;

fn hash_call_indirect_data(table: u32, ty: &FunctionType) -> Bytes32 {
let mut h = Keccak256::new();
h.update("Call indirect:");
h.update((table as u64).to_be_bytes());
h.update(ty.hash());
h.finalize().into()
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum InboxIdentifier {
Sequencer = 0,
Expand All @@ -73,53 +66,83 @@ pub struct Function {
code: Vec<Instruction>,
ty: FunctionType,
#[serde(skip)]
code_merkle: Merkle,
opcode_merkle: Merkle,
#[serde(skip)]
argument_data_merkle: Merkle,
local_types: Vec<ArbValueType>,
#[serde(skip)]
empty_locals_hash: Bytes32,
}

fn code_to_opcode_hash(code: &Vec<Instruction>, opcode_idx: usize) -> Bytes32 {
let seg = opcode_idx / 16;
let seg_start = seg * 16;
let seg_end = min(seg_start + 16, code.len());
let mut b = [0u8; 32];
for i in 0..(seg_end - seg_start) {
b[i * 2..i * 2 + 2]
.copy_from_slice(code[seg_start + i].opcode.repr().to_be_bytes().as_slice())
}
Bytes32(b)
}

fn code_to_opcode_hashes(code: &Vec<Instruction>) -> Vec<Bytes32> {
#[cfg(feature = "native")]
let iter = (0..(code.len() + 15) / 16).into_par_iter();

#[cfg(not(feature = "native"))]
let iter = (0..(code.len() + 15) / 16).into_iter();

iter.map(|i| code_to_opcode_hash(code, i * 16)).collect()
}

fn code_to_argdata_hash(code: &Vec<Instruction>, opcode_idx: usize) -> Bytes32 {
let seg = opcode_idx / 4;
let seg_start = seg * 4;
let seg_end = min(seg_start + 4, code.len());
let mut b = [0u8; 32];
for i in 0..(seg_end - seg_start) {
b[i * 8..i * 8 + 8]
.copy_from_slice(code[seg_start + i].argument_data.to_be_bytes().as_slice())
}
Bytes32(b)
}

fn code_to_argdata_hashes(code: &Vec<Instruction>) -> Vec<Bytes32> {
#[cfg(feature = "native")]
let iter = (0..(code.len() + 3) / 4).into_par_iter();

#[cfg(not(feature = "native"))]
let iter = (0..(code.len() + 3) / 4).into_iter();

iter.map(|i| code_to_argdata_hash(code, i * 4)).collect()
}

impl Function {
pub fn new<F: FnOnce(&mut Vec<Instruction>) -> Result<()>>(
locals: &[Local],
add_body: F,
func_ty: FunctionType,
module_types: &[FunctionType],
) -> Result<Function> {
let mut locals_with_params = func_ty.inputs.clone();
locals_with_params.extend(locals.iter().map(|x| x.value));

let mut insts = Vec::new();
let empty_local_hashes = locals_with_params
.iter()
.cloned()
.map(Value::default_of_type)
.map(Value::hash)
.collect::<Vec<_>>();
insts.push(Instruction {
opcode: Opcode::InitFrame,
argument_data: 0,
proving_argument_data: Some(Merkle::new(MerkleType::Value, empty_local_hashes).root()),
});
// Fill in parameters
for i in (0..func_ty.inputs.len()).rev() {
insts.push(Instruction {
opcode: Opcode::LocalSet,
argument_data: i as u64,
proving_argument_data: None,
});
}

add_body(&mut insts)?;
insts.push(Instruction::simple(Opcode::Return));

// Insert missing proving argument data
for inst in insts.iter_mut() {
if inst.opcode == Opcode::CallIndirect {
let (table, ty) = crate::wavm::unpack_call_indirect(inst.argument_data);
let ty = &module_types[usize::try_from(ty).unwrap()];
inst.proving_argument_data = Some(hash_call_indirect_data(table, ty));
}
}

Ok(Function::new_from_wavm(insts, func_ty, locals_with_params))
}

Expand All @@ -133,24 +156,35 @@ impl Function {
"Function instruction count doesn't fit in a u32",
);

#[cfg(feature = "native")]
let code_hashes = code.par_iter().map(|i| i.hash()).collect();

#[cfg(not(feature = "native"))]
let code_hashes = code.iter().map(|i| i.hash()).collect();
let argument_data_hashes = code_to_argdata_hashes(&code);
let opcode_hashes = code_to_opcode_hashes(&code);

Function {
code,
empty_locals_hash: Self::calc_empty_locals_hash(&local_types),
ty,
code_merkle: Merkle::new(MerkleType::Instruction, code_hashes),
opcode_merkle: Merkle::new(MerkleType::Opcode, opcode_hashes),
argument_data_merkle: Merkle::new(MerkleType::ArgumentData, argument_data_hashes),
local_types,
}
}

fn calc_empty_locals_hash(locals_with_params: &Vec<ArbValueType>) -> Bytes32 {
let empty_local_hashes = locals_with_params
.iter()
.cloned()
.map(Value::default_of_type)
.map(Value::hash)
.collect::<Vec<_>>();
Merkle::new(MerkleType::Value, empty_local_hashes).root()
}

fn hash(&self) -> Bytes32 {
let mut h = Keccak256::new();
h.update("Function:");
h.update(self.code_merkle.root());
h.update(self.opcode_merkle.root());
h.update(self.argument_data_merkle.root());
h.update(self.empty_locals_hash);
h.finalize().into()
}
}
Expand Down Expand Up @@ -273,6 +307,8 @@ struct Module {
#[serde(skip)]
funcs_merkle: Arc<Merkle>,
types: Arc<Vec<FunctionType>>,
#[serde(skip)]
types_merkle: Arc<Merkle>,
internals_offset: u32,
names: Arc<NameCustomSection>,
host_call_hooks: Arc<Vec<Option<(String, String)>>>,
Expand Down Expand Up @@ -387,7 +423,6 @@ impl Module {
)
},
func_ty.clone(),
&types,
)?);
}
code.extend(internals);
Expand Down Expand Up @@ -531,6 +566,10 @@ impl Module {
code.iter().map(|f| f.hash()).collect(),
)),
funcs: Arc::new(code),
types_merkle: Arc::new(Merkle::new(
MerkleType::FunctionType,
types.iter().map(FunctionType::hash).collect(),
)),
types: Arc::new(types.to_owned()),
internals_offset,
names: Arc::new(bin.names.to_owned()),
Expand Down Expand Up @@ -566,6 +605,7 @@ impl Module {
h.update(self.memory.hash());
h.update(self.tables_merkle.root());
h.update(self.funcs_merkle.root());
h.update(self.types_merkle.root());
h.update(self.internals_offset.to_be_bytes());
h.finalize().into()
}
Expand All @@ -587,6 +627,7 @@ impl Module {

data.extend(self.tables_merkle.root());
data.extend(self.funcs_merkle.root());
data.extend(self.types_merkle.root());

data.extend(self.internals_offset.to_be_bytes());

Expand Down Expand Up @@ -1325,7 +1366,6 @@ impl Machine {
Ok(())
},
FunctionType::default(),
&entrypoint_types,
)?];
let entrypoint = Module {
globals: Vec::new(),
Expand All @@ -1337,6 +1377,10 @@ impl Machine {
entrypoint_funcs.iter().map(Function::hash).collect(),
)),
funcs: Arc::new(entrypoint_funcs),
types_merkle: Arc::new(Merkle::new(
MerkleType::FunctionType,
entrypoint_types.iter().map(FunctionType::hash).collect(),
)),
types: Arc::new(entrypoint_types),
names: Arc::new(entrypoint_names),
internals_offset: 0,
Expand Down Expand Up @@ -1427,18 +1471,21 @@ impl Machine {
let funcs =
Arc::get_mut(&mut module.funcs).expect("Multiple copies of module functions");
for func in funcs.iter_mut() {
#[cfg(feature = "native")]
let code_hashes = func.code.par_iter().map(|i| i.hash()).collect();

#[cfg(not(feature = "native"))]
let code_hashes = func.code.iter().map(|i| i.hash()).collect();
let opcode_hashes = code_to_opcode_hashes(&func.code);
let argdata_hashes = code_to_argdata_hashes(&func.code);

func.code_merkle = Merkle::new(MerkleType::Instruction, code_hashes);
func.opcode_merkle = Merkle::new(MerkleType::Opcode, opcode_hashes);
func.argument_data_merkle = Merkle::new(MerkleType::ArgumentData, argdata_hashes);
func.empty_locals_hash = Function::calc_empty_locals_hash(&func.local_types)
}
module.funcs_merkle = Arc::new(Merkle::new(
MerkleType::Function,
module.funcs.iter().map(Function::hash).collect(),
));
module.types_merkle = Arc::new(Merkle::new(
MerkleType::FunctionType,
module.types.iter().map(FunctionType::hash).collect(),
))
}
let mut mach = Machine {
status: MachineStatus::Running,
Expand Down Expand Up @@ -2640,11 +2687,17 @@ impl Machine {
// Begin next instruction proof

let func = &module.funcs[self.pc.func()];
out!(func.code[self.pc.inst()].serialize_for_proof());
out!(code_to_opcode_hash(&func.code, self.pc.inst()));
out!(func
.code_merkle
.prove(self.pc.inst())
.opcode_merkle
.prove(self.pc.inst() / 16)
.expect("Failed to prove against code merkle"));
out!(code_to_argdata_hash(&func.code, self.pc.inst()));
out!(func
.argument_data_merkle
.prove(self.pc.inst() / 4)
.expect("Failed to prove against argument data merkle"));
out!(func.empty_locals_hash);
out!(module
.funcs_merkle
.prove(self.pc.func())
Expand Down Expand Up @@ -2729,11 +2782,14 @@ impl Machine {
Some(Value::I32(i)) => *i,
x => fail!("top of stack before call_indirect is {x:?}"),
};
let ty = &module.types[usize::try_from(ty).unwrap()];
out!((table as u64).to_be_bytes());
out!(ty.hash());
let table_usize = usize::try_from(table).unwrap();
let type_usize = usize::try_from(ty).unwrap();
let table = &module.tables[table_usize];
out!(module.types[type_usize].hash());
out!(module
.types_merkle
.prove(type_usize)
.expect("failed to prove types merkle"));
out!(table
.serialize_for_proof()
.expect("failed to serialize table"));
Expand Down
16 changes: 8 additions & 8 deletions arbitrator/prover/src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::{
use arbutil::Bytes32;
use digest::Digest;
use eyre::{bail, ErrReport, Result};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use sha3::Keccak256;
use std::{borrow::Cow, convert::TryFrom};
Expand Down Expand Up @@ -59,6 +60,10 @@ fn hash_leaf(bytes: [u8; Memory::LEAF_SIZE]) -> Bytes32 {
h.finalize().into()
}

lazy_static! {
pub static ref EMPTY_MEM_HASH: Bytes32 = hash_leaf([0u8; 32]);
}

fn round_up_to_power_of_two(mut input: usize) -> usize {
if input == 0 {
return 1;
Expand All @@ -84,7 +89,7 @@ impl Memory {
pub const PAGE_SIZE: u64 = 65536;
/// The number of layers in the memory merkle tree
/// 1 + log2(2^32 / LEAF_SIZE) = 1 + log2(2^(32 - log2(LEAF_SIZE))) = 1 + 32 - 5
const MEMORY_LAYERS: usize = 1 + 32 - 5;
pub const MEMORY_LAYERS: usize = 1 + 32 - 5;

pub fn new(size: usize, max_size: u64) -> Memory {
Memory {
Expand Down Expand Up @@ -119,15 +124,10 @@ impl Memory {
})
.collect();
if leaf_hashes.len() < leaves {
let empty_hash = hash_leaf([0u8; 32]);
let empty_hash = *EMPTY_MEM_HASH;
leaf_hashes.resize(leaves, empty_hash);
}
Cow::Owned(Merkle::new_advanced(
MerkleType::Memory,
leaf_hashes,
hash_leaf([0u8; 32]),
Self::MEMORY_LAYERS,
))
Cow::Owned(Merkle::new(MerkleType::Memory, leaf_hashes))
}

pub fn get_leaf_data(&self, leaf_idx: usize) -> [u8; Self::LEAF_SIZE] {
Expand Down
Loading