diff --git a/Cargo.toml b/Cargo.toml index 053af461d..8b9522144 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,6 +51,7 @@ scarb = ["build-cli", "dep:scarb-ui", "dep:scarb-metadata", "dep:serde_json"] with-debug-utils = [] with-runtime = ["dep:cairo-native-runtime"] with-serde = ["dep:serde"] +with-cheatcode = [] [dependencies] bumpalo = "3.16.0" diff --git a/examples/starknet.rs b/examples/starknet.rs index 33e67aafa..b47742b89 100644 --- a/examples/starknet.rs +++ b/examples/starknet.rs @@ -10,11 +10,50 @@ use cairo_native::{ utils::find_entry_point_by_idx, }; use starknet_types_core::felt::Felt; -use std::path::Path; +use std::{ + collections::{HashMap, VecDeque}, + path::Path, +}; use tracing_subscriber::{EnvFilter, FmtSubscriber}; -#[derive(Debug)] -struct SyscallHandler; +type Log = (Vec, Vec); +type L2ToL1Message = (Felt, Vec); + +#[derive(Debug, Default)] +struct ContractLogs { + events: VecDeque, + l2_to_l1_messages: VecDeque, +} + +#[derive(Debug, Default)] +struct TestingState { + sequencer_address: Felt, + caller_address: Felt, + contract_address: Felt, + account_contract_address: Felt, + transaction_hash: Felt, + nonce: Felt, + chain_id: Felt, + version: Felt, + max_fee: u64, + block_number: u64, + block_timestamp: u64, + signature: Vec, + logs: HashMap, +} + +#[derive(Debug, Default)] +struct SyscallHandler { + testing_state: TestingState, +} + +impl SyscallHandler { + pub fn new() -> Self { + Self { + testing_state: TestingState::default(), + } + } +} impl StarknetSyscallHandler for SyscallHandler { fn get_block_hash(&mut self, block_number: u64, _gas: &mut u128) -> SyscallResult { @@ -265,6 +304,98 @@ impl StarknetSyscallHandler for SyscallHandler { ) -> SyscallResult<(U256, U256)> { unimplemented!() } + + #[cfg(feature = "with-cheatcode")] + fn cheatcode(&mut self, selector: Felt, input: &[Felt]) -> Vec { + let selector_bytes = selector.to_bytes_be(); + + let selector = match std::str::from_utf8(&selector_bytes) { + Ok(selector) => selector.trim_start_matches('\0'), + Err(_) => return Vec::new(), + }; + + match selector { + "set_sequencer_address" => { + self.testing_state.sequencer_address = input[0]; + vec![] + } + "set_caller_address" => { + self.testing_state.caller_address = input[0]; + vec![] + } + "set_contract_address" => { + self.testing_state.contract_address = input[0]; + vec![] + } + "set_account_contract_address" => { + self.testing_state.account_contract_address = input[0]; + vec![] + } + "set_transaction_hash" => { + self.testing_state.transaction_hash = input[0]; + vec![] + } + "set_nonce" => { + self.testing_state.nonce = input[0]; + vec![] + } + "set_version" => { + self.testing_state.version = input[0]; + vec![] + } + "set_chain_id" => { + self.testing_state.chain_id = input[0]; + vec![] + } + "set_max_fee" => { + let max_fee = input[0].to_biguint().try_into().unwrap(); + self.testing_state.max_fee = max_fee; + vec![] + } + "set_block_number" => { + let block_number = input[0].to_biguint().try_into().unwrap(); + self.testing_state.block_number = block_number; + vec![] + } + "set_block_timestamp" => { + let block_timestamp = input[0].to_biguint().try_into().unwrap(); + self.testing_state.block_timestamp = block_timestamp; + vec![] + } + "set_signature" => { + self.testing_state.signature = input.to_vec(); + vec![] + } + "pop_log" => self + .testing_state + .logs + .get_mut(&input[0]) + .and_then(|logs| logs.events.pop_front()) + .map(|mut log| { + let mut serialized_log = Vec::new(); + serialized_log.push(log.0.len().into()); + serialized_log.append(&mut log.0); + serialized_log.push(log.1.len().into()); + serialized_log.append(&mut log.1); + serialized_log + }) + .unwrap_or_default(), + "pop_l2_to_l1_message" => self + .testing_state + .logs + .get_mut(&input[0]) + .and_then(|logs| logs.l2_to_l1_messages.pop_front()) + .map(|mut log| { + let mut serialized_log = Vec::new(); + serialized_log.push(log.0); + serialized_log.push(log.1.len().into()); + serialized_log.append(&mut log.1); + serialized_log + }) + .unwrap_or_default(), + _ => vec![], + } + } } fn main() { @@ -305,7 +436,7 @@ fn main() { let native_executor = JitNativeExecutor::from_native_module(native_program, Default::default()); let result = native_executor - .invoke_contract_dynamic(fn_id, &[Felt::ONE], Some(u128::MAX), SyscallHandler) + .invoke_contract_dynamic(fn_id, &[Felt::ONE], Some(u128::MAX), SyscallHandler::new()) .expect("failed to execute the given contract"); println!(); diff --git a/src/executor.rs b/src/executor.rs index bd1691b5a..0fb19883b 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -151,7 +151,6 @@ fn invoke_dynamic( mut syscall_handler: Option, ) -> Result { tracing::info!("Invoking function with signature: {function_signature:?}."); - let arena = Bump::new(); let mut invoke_data = ArgumentMapper::new(&arena, registry); @@ -196,6 +195,22 @@ fn invoke_dynamic( None }; + // The Cairo compiler doesn't specify that the cheatcode syscall needs the syscall handler, + // so we must always allocate it in case it needs it, regardless of whether it's passed + // as an argument to the entry point or not. + let mut syscall_handler = syscall_handler + .as_mut() + .map(|syscall_handler| StarknetSyscallHandlerCallbacks::new(syscall_handler)); + // We only care for the previous syscall handler if we actually modify it + #[cfg(feature = "with-cheatcode")] + let previous_syscall_handler = syscall_handler.as_mut().map(|syscall_handler| { + let previous_syscall_handler = crate::starknet::SYSCALL_HANDLER_VTABLE.get(); + let syscall_handler_ptr = std::ptr::addr_of!(*syscall_handler) as *mut (); + crate::starknet::SYSCALL_HANDLER_VTABLE.set(syscall_handler_ptr); + + previous_syscall_handler + }); + // Generate argument list. let mut iter = args.iter(); for type_id in function_signature.param_types.iter().filter(|id| { @@ -209,19 +224,14 @@ fn invoke_dynamic( &[gas as u64, (gas >> 64) as u64], ), CoreTypeConcrete::StarkNet(StarkNetTypeConcrete::System(_)) => { - match syscall_handler.as_mut() { - Some(syscall_handler) => { - let syscall_handler = - arena.alloc(StarknetSyscallHandlerCallbacks::new(syscall_handler)); - invoke_data.push_aligned( - get_integer_layout(64).align(), - &[syscall_handler as *mut _ as u64], - ) - } - None => { - panic!("Syscall handler is required"); - } - } + let syscall_handler = syscall_handler + .as_mut() + .expect("syscall handler is required"); + + invoke_data.push_aligned( + get_integer_layout(64).align(), + &[syscall_handler as *mut _ as u64], + ); } type_info => invoke_data .push( @@ -252,6 +262,13 @@ fn invoke_dynamic( ); } + // If the syscall handler was changed, then reset the previous one. + // It's only necessary to restore the pointer if it's been modified i.e. if previous_syscall_handler is Some(...) + #[cfg(feature = "with-cheatcode")] + if let Some(previous_syscall_handler) = previous_syscall_handler { + crate::starknet::SYSCALL_HANDLER_VTABLE.set(previous_syscall_handler); + } + // Parse final gas. unsafe fn read_value(ptr: &mut NonNull<()>) -> &T { let align_offset = ptr diff --git a/src/libfuncs/starknet.rs b/src/libfuncs/starknet.rs index 12c758071..7aacb39fc 100644 --- a/src/libfuncs/starknet.rs +++ b/src/libfuncs/starknet.rs @@ -15,7 +15,7 @@ use cairo_lang_sierra::{ consts::SignatureAndConstConcreteLibfunc, core::{CoreLibfunc, CoreType}, lib_func::SignatureOnlyConcreteLibfunc, - starknet::StarkNetConcreteLibfunc, + starknet::{testing::TestingConcreteLibfunc, StarkNetConcreteLibfunc}, ConcreteLibfunc, }, program_registry::ProgramRegistry, @@ -40,6 +40,9 @@ use std::alloc::Layout; mod secp256; +#[cfg(feature = "with-cheatcode")] +mod testing; + /// Select and call the correct libfunc builder function from the selector. pub fn build<'ctx, 'this>( context: &'ctx Context, @@ -138,7 +141,14 @@ pub fn build<'ctx, 'this>( StarkNetConcreteLibfunc::Secp256(selector) => self::secp256::build( context, registry, entry, location, helper, metadata, selector, ), - StarkNetConcreteLibfunc::Testing(_) => todo!("implement starknet testing libfunc"), + #[cfg(feature = "with-cheatcode")] + StarkNetConcreteLibfunc::Testing(TestingConcreteLibfunc::Cheatcode(info)) => { + self::testing::build(context, registry, entry, location, helper, metadata, info) + } + #[cfg(not(feature = "with-cheatcode"))] + StarkNetConcreteLibfunc::Testing(TestingConcreteLibfunc::Cheatcode(_)) => { + unimplemented!("feature 'with-cheatcode' is required to compile with cheatcode syscall") + } } } diff --git a/src/libfuncs/starknet/testing.rs b/src/libfuncs/starknet/testing.rs new file mode 100644 index 000000000..cd06788fc --- /dev/null +++ b/src/libfuncs/starknet/testing.rs @@ -0,0 +1,124 @@ +use cairo_lang_sierra::{ + extensions::{ + core::{CoreLibfunc, CoreType}, + starknet::testing::CheatcodeConcreteLibfunc, + ConcreteLibfunc, + }, + program_registry::ProgramRegistry, +}; +use melior::{ + dialect::llvm::{self, alloca, AllocaOptions, LoadStoreOptions}, + ir::{ + attribute::{IntegerAttribute, TypeAttribute}, + r#type::IntegerType, + Block, Location, + }, + Context, +}; + +use crate::{ + block_ext::BlockExt, + error::Result, + libfuncs::LibfuncHelper, + metadata::{runtime_bindings::RuntimeBindingsMeta, MetadataStorage}, + utils::ProgramRegistryExt, +}; + +pub fn build<'ctx, 'this>( + context: &'ctx Context, + registry: &ProgramRegistry, + entry: &'this Block<'ctx>, + location: Location<'ctx>, + helper: &LibfuncHelper<'ctx, 'this>, + metadata: &mut MetadataStorage, + info: &CheatcodeConcreteLibfunc, +) -> Result<()> { + // Calculate the result layout and type, based on the branch signature + let (result_type, result_layout) = registry.build_type_with_layout( + context, + helper, + registry, + metadata, + &info.branch_signatures()[0].vars[0].ty, + )?; + + // Allocate the result pointer with calculated layout and type + let result_ptr = helper + .init_block() + .append_operation(alloca( + context, + helper.init_block().const_int(context, location, 1, 64)?, + llvm::r#type::pointer(context, 0), + location, + AllocaOptions::new() + .align(Some(IntegerAttribute::new( + IntegerType::new(context, 64).into(), + result_layout.align().try_into()?, + ))) + .elem_type(Some(TypeAttribute::new(result_type))), + )) + .result(0)? + .into(); + + // Allocate and store selector. The type contains 256 bits as its interpreted as a [u8;32] from the runtime + let selector = helper + .init_block() + .const_int(context, location, info.selector.clone(), 256)?; + let selector_ptr = helper.init_block().alloca1( + context, + location, + IntegerType::new(context, 256).into(), + None, + )?; + + helper + .init_block() + .store(context, location, selector_ptr, selector, None)?; + + // Allocate and store arguments. The cairo type is a Span (the outer struct), + // which contains an Array (the inner struct) + let span_felt252_type = llvm::r#type::r#struct( + context, + &[llvm::r#type::r#struct( + context, + &[ + llvm::r#type::pointer(context, 0), + IntegerType::new(context, 32).into(), + IntegerType::new(context, 32).into(), + IntegerType::new(context, 32).into(), + ], + false, + )], + false, + ); + let args_ptr = helper + .init_block() + .alloca1(context, location, span_felt252_type, None)?; + entry.store(context, location, args_ptr, entry.argument(0)?.into(), None)?; + + // Call runtime cheatcode syscall wrapper + metadata + .get_mut::() + .expect("Runtime library not available.") + .vtable_cheatcode( + context, + helper, + entry, + location, + result_ptr, + selector_ptr, + args_ptr, + )?; + + // Load result from result ptr and branch + let result = entry.append_op_result(llvm::load( + context, + result_ptr, + result_type, + location, + LoadStoreOptions::new(), + ))?; + entry.append_operation(helper.br(0, &[result], location)); + + Ok(()) +} diff --git a/src/metadata/runtime_bindings.rs b/src/metadata/runtime_bindings.rs index 032cacc6f..e58a57a81 100644 --- a/src/metadata/runtime_bindings.rs +++ b/src/metadata/runtime_bindings.rs @@ -30,6 +30,8 @@ enum RuntimeBinding { DictGasRefund, DictInsert, DictFree, + #[cfg(feature = "with-cheatcode")] + VtableCheatcode, } /// Runtime library bindings metadata. @@ -650,6 +652,60 @@ impl RuntimeBindingsMeta { location, ))) } + + /// Register if necessary, then invoke the `vtable_cheatcode()` runtime function. + /// + /// Calls the cheatcode syscall with the given arguments. + /// + /// The result is stored in `result_ptr`. + #[allow(clippy::too_many_arguments)] + #[cfg(feature = "with-cheatcode")] + pub fn vtable_cheatcode<'c, 'a>( + &mut self, + context: &'c Context, + module: &Module, + block: &'a Block<'c>, + location: Location<'c>, + result_ptr: Value<'c, 'a>, + selector_ptr: Value<'c, 'a>, + args: Value<'c, 'a>, + ) -> Result> + where + 'c: 'a, + { + if self.active_map.insert(RuntimeBinding::VtableCheatcode) { + module.body().append_operation(func::func( + context, + StringAttribute::new(context, "cairo_native__vtable_cheatcode"), + TypeAttribute::new( + FunctionType::new( + context, + &[ + llvm::r#type::pointer(context, 0), + llvm::r#type::pointer(context, 0), + llvm::r#type::pointer(context, 0), + ], + &[], + ) + .into(), + ), + Region::new(), + &[( + Identifier::new(context, "sym_visibility"), + StringAttribute::new(context, "private").into(), + )], + Location::unknown(context), + )); + } + + Ok(block.append_operation(func::call( + context, + FlatSymbolRefAttribute::new(context, "cairo_native__vtable_cheatcode"), + &[result_ptr, selector_ptr, args], + &[], + location, + ))) + } } impl Default for RuntimeBindingsMeta { diff --git a/src/starknet.rs b/src/starknet.rs index f199ea84a..ca513229d 100644 --- a/src/starknet.rs +++ b/src/starknet.rs @@ -242,57 +242,9 @@ pub trait StarknetSyscallHandler { remaining_gas: &mut u128, ) -> SyscallResult<(U256, U256)>; - // Testing syscalls. - fn pop_log(&mut self) { - unimplemented!() - } - - fn set_account_contract_address(&mut self, _contract_address: Felt) { - unimplemented!() - } - - fn set_block_number(&mut self, _block_number: u64) { - unimplemented!() - } - - fn set_block_timestamp(&mut self, _block_timestamp: u64) { - unimplemented!() - } - - fn set_caller_address(&mut self, _address: Felt) { - unimplemented!() - } - - fn set_chain_id(&mut self, _chain_id: Felt) { - unimplemented!() - } - - fn set_contract_address(&mut self, _address: Felt) { - unimplemented!() - } - - fn set_max_fee(&mut self, _max_fee: u128) { - unimplemented!() - } - - fn set_nonce(&mut self, _nonce: Felt) { - unimplemented!() - } - - fn set_sequencer_address(&mut self, _address: Felt) { - unimplemented!() - } - - fn set_signature(&mut self, _signature: &[Felt]) { - unimplemented!() - } - - fn set_transaction_hash(&mut self, _transaction_hash: Felt) { - unimplemented!() - } - - fn set_version(&mut self, _version: Felt) { - unimplemented!() + #[cfg(feature = "with-cheatcode")] + fn cheatcode(&mut self, _selector: Felt, _input: &[Felt]) -> Vec { + unimplemented!(); } } @@ -584,6 +536,13 @@ pub(crate) mod handler { nonce: Felt252Abi, } + /// A C ABI Wrapper around the StarknetSyscallHandler + /// + /// It contains pointers to functions which can be called through MLIR based on the field offset. + /// The functions convert C ABI structures to the Rust equivalent and calls the wrapped implementation. + /// + /// Unlike runtime functions, the callback table is generic to the StarknetSyscallHandler, + /// which allows the user to specify the desired implementation to use during the execution. #[repr(C)] #[derive(Debug)] pub struct StarknetSyscallHandlerCallbacks<'a, T> { @@ -741,6 +700,14 @@ pub(crate) mod handler { gas: &mut u128, p: &Secp256r1Point, ), + // testing syscalls + #[cfg(feature = "with-cheatcode")] + pub cheatcode: extern "C" fn( + result_ptr: &mut ArrayAbi, + ptr: &mut T, + selector: &Felt252Abi, + input: &ArrayAbi, + ), } impl<'a, T> StarknetSyscallHandlerCallbacks<'a, T> @@ -805,6 +772,8 @@ pub(crate) mod handler { secp256r1_mul: Self::wrap_secp256r1_mul, secp256r1_get_point_from_x: Self::wrap_secp256r1_get_point_from_x, secp256r1_get_xy: Self::wrap_secp256r1_get_xy, + #[cfg(feature = "with-cheatcode")] + cheatcode: Self::wrap_cheatcode, } } @@ -866,6 +835,34 @@ pub(crate) mod handler { }; } + #[cfg(feature = "with-cheatcode")] + extern "C" fn wrap_cheatcode( + result_ptr: &mut ArrayAbi, + ptr: &mut T, + selector: &Felt252Abi, + input: &ArrayAbi, + ) { + let input: Vec<_> = unsafe { + let since_offset = input.since as usize; + let until_offset = input.until as usize; + debug_assert!(since_offset <= until_offset); + let len = until_offset - since_offset; + std::slice::from_raw_parts(input.ptr.add(since_offset), len) + } + .iter() + .map(|x| Felt::from_bytes_le(&x.0)) + .collect(); + let selector = Felt::from_bytes_le(&selector.0); + + let result = ptr + .cheatcode(selector, &input) + .into_iter() + .map(|x| Felt252Abi(x.to_bytes_le())) + .collect::>(); + + *result_ptr = unsafe { Self::alloc_mlir_array(&result) }; + } + extern "C" fn wrap_get_execution_info( result_ptr: &mut SyscallResultAbi>, ptr: &mut T, @@ -1649,3 +1646,31 @@ pub(crate) mod handler { } } } + +#[cfg(feature = "with-cheatcode")] +thread_local!(pub static SYSCALL_HANDLER_VTABLE: std::cell::Cell<*mut ()> = const { std::cell::Cell::new(std::ptr::null_mut()) }); + +#[allow(non_snake_case)] +#[cfg(feature = "with-cheatcode")] +/// Runtime function that calls the `cheatcode` syscall +/// +/// The Cairo compiler doesn't specify that the cheatcode syscall needs the syscall handler, +/// so a pointer to `StarknetSyscallHandlerCallbacks` is stored as a `thread::LocalKey` and accesed in runtime by this function. +pub extern "C" fn cairo_native__vtable_cheatcode( + result_ptr: &mut ArrayAbi, + selector: &Felt252Abi, + input: &ArrayAbi, +) { + let ptr = SYSCALL_HANDLER_VTABLE.with(|ptr| ptr.get()); + assert!(!ptr.is_null()); + + let callbacks_ptr = ptr as *mut handler::StarknetSyscallHandlerCallbacks; + let callbacks = unsafe { callbacks_ptr.as_mut().expect("should not be null") }; + + // The `StarknetSyscallHandler` is stored as a reference in the first field of `StarknetSyscalLHandlerCallbacks`, + // so we can interpret `ptr` as a double pointer to the handler. + let handler_ptr_ptr = ptr as *mut *mut DummySyscallHandler; + let handler = unsafe { (*handler_ptr_ptr).as_mut().expect("should not be null") }; + + (callbacks.cheatcode)(result_ptr, handler, selector, input); +} diff --git a/src/utils.rs b/src/utils.rs index 790972882..3366b2fcb 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -340,6 +340,14 @@ pub fn register_runtime_symbols(engine: &ExecutionEngine) { as *const fn(*const std::ffi::c_void, NonNull) -> u64 as *mut (), ); + + #[cfg(feature = "with-cheatcode")] + { + engine.register_symbol( + "cairo_native__vtable_cheatcode", + crate::starknet::cairo_native__vtable_cheatcode as *mut (), + ); + } } } diff --git a/tests/tests/starknet/mod.rs b/tests/tests/starknet/mod.rs index 78fb8af4e..8b00e252c 100644 --- a/tests/tests/starknet/mod.rs +++ b/tests/tests/starknet/mod.rs @@ -1,3 +1,5 @@ mod keccak; mod secp256; + +#[cfg(feature = "with-cheatcode")] mod syscalls; diff --git a/tests/tests/starknet/programs/syscalls.cairo b/tests/tests/starknet/programs/syscalls.cairo index 805dd3754..33284009d 100644 --- a/tests/tests/starknet/programs/syscalls.cairo +++ b/tests/tests/starknet/programs/syscalls.cairo @@ -4,6 +4,7 @@ use core::starknet::{ keccak_syscall, library_call_syscall, replace_class_syscall, send_message_to_l1_syscall, storage_address_try_from_felt252, storage_read_syscall, storage_write_syscall, SyscallResult, + testing::cheatcode, }; use core::starknet::syscalls::get_execution_info_syscall; use core::starknet::syscalls::get_execution_info_v2_syscall; @@ -55,3 +56,59 @@ fn send_message_to_l1() -> SyscallResult<()> { fn keccak() -> SyscallResult { keccak_syscall(array![].span()) } + +fn set_sequencer_address(address: felt252) -> Span { + return cheatcode::<'set_sequencer_address'>(array![address].span()); +} + +fn set_account_contract_address(address: felt252) -> Span { + return cheatcode::<'set_account_contract_address'>(array![address].span()); +} + +fn set_block_number(number: felt252) -> Span { + return cheatcode::<'set_block_number'>(array![number].span()); +} + +fn set_block_timestamp(timestamp: felt252) -> Span { + return cheatcode::<'set_block_timestamp'>(array![timestamp].span()); +} + +fn set_caller_address(address: felt252) -> Span { + return cheatcode::<'set_caller_address'>(array![address].span()); +} + +fn set_chain_id(id: felt252) -> Span { + return cheatcode::<'set_chain_id'>(array![id].span()); +} + +fn set_contract_address(address: felt252) -> Span { + return cheatcode::<'set_contract_address'>(array![address].span()); +} + +fn set_max_fee(fee: felt252) -> Span { + return cheatcode::<'set_max_fee'>(array![fee].span()); +} + +fn set_nonce(nonce: felt252) -> Span { + return cheatcode::<'set_nonce'>(array![nonce].span()); +} + +fn set_signature(signature: Array) -> Span { + return cheatcode::<'set_signature'>(signature.span()); +} + +fn set_transaction_hash(hash: felt252) -> Span { + return cheatcode::<'set_transaction_hash'>(array![hash].span()); +} + +fn set_version(version: felt252) -> Span { + return cheatcode::<'set_version'>(array![version].span()); +} + +fn pop_log(log: felt252) -> Span { + return cheatcode::<'pop_log'>(array![log].span()); +} + +fn pop_l2_to_l1_message(message: felt252) -> Span { + return cheatcode::<'pop_l2_to_l1_message'>(array![message].span()); +} diff --git a/tests/tests/starknet/syscalls.rs b/tests/tests/starknet/syscalls.rs index 8620519dd..91942585c 100644 --- a/tests/tests/starknet/syscalls.rs +++ b/tests/tests/starknet/syscalls.rs @@ -1,3 +1,8 @@ +use std::{ + collections::{HashMap, VecDeque}, + sync::{Arc, Mutex}, +}; + use crate::common::{load_cairo_path, run_native_program}; use cairo_lang_runner::SierraCasmRunner; use cairo_lang_sierra::program::Program; @@ -8,11 +13,55 @@ use cairo_native::{ }, values::JitValue, }; +use itertools::Itertools; use lazy_static::lazy_static; -use pretty_assertions_sorted::assert_eq_sorted; +use pretty_assertions_sorted::{assert_eq, assert_eq_sorted}; use starknet_types_core::felt::Felt; -struct SyscallHandler; +type Log = (Vec, Vec); +type L2ToL1Message = (Felt, Vec); + +#[derive(Debug, Default)] +struct ContractLogs { + events: VecDeque, + l2_to_l1_messages: VecDeque, +} + +#[derive(Debug, Default)] +struct TestingState { + sequencer_address: Felt, + caller_address: Felt, + contract_address: Felt, + account_contract_address: Felt, + transaction_hash: Felt, + nonce: Felt, + chain_id: Felt, + version: Felt, + max_fee: u64, + block_number: u64, + block_timestamp: u64, + signature: Vec, + logs: HashMap, +} + +struct SyscallHandler { + /// Arc Is needed to test that the valures are set correct after the execution + testing_state: Arc>, +} + +impl SyscallHandler { + fn new() -> Self { + Self { + testing_state: Arc::new(Mutex::new(TestingState::default())), + } + } + + fn with(state: Arc>) -> Self { + Self { + testing_state: state, + } + } +} impl StarknetSyscallHandler for SyscallHandler { fn get_block_hash( @@ -350,6 +399,101 @@ impl StarknetSyscallHandler for SyscallHandler { // Tested in `tests/tests/starknet/secp256.rs`. unimplemented!() } + + fn cheatcode(&mut self, selector: Felt, input: &[Felt]) -> Vec { + let selector_bytes = selector.to_bytes_be(); + + let selector = match std::str::from_utf8(&selector_bytes) { + Ok(selector) => selector.trim_start_matches('\0'), + Err(_) => return Vec::new(), + }; + + match selector { + "set_sequencer_address" => { + self.testing_state.lock().unwrap().sequencer_address = input[0]; + vec![] + } + "set_caller_address" => { + self.testing_state.lock().unwrap().caller_address = input[0]; + vec![] + } + "set_contract_address" => { + self.testing_state.lock().unwrap().contract_address = input[0]; + vec![] + } + "set_account_contract_address" => { + self.testing_state.lock().unwrap().account_contract_address = input[0]; + vec![] + } + "set_transaction_hash" => { + self.testing_state.lock().unwrap().transaction_hash = input[0]; + vec![] + } + "set_nonce" => { + self.testing_state.lock().unwrap().nonce = input[0]; + vec![] + } + "set_version" => { + self.testing_state.lock().unwrap().version = input[0]; + vec![] + } + "set_chain_id" => { + self.testing_state.lock().unwrap().chain_id = input[0]; + vec![] + } + "set_max_fee" => { + let max_fee = input[0].to_biguint().try_into().unwrap(); + self.testing_state.lock().unwrap().max_fee = max_fee; + vec![] + } + "set_block_number" => { + let block_number = input[0].to_biguint().try_into().unwrap(); + self.testing_state.lock().unwrap().block_number = block_number; + vec![] + } + "set_block_timestamp" => { + let block_timestamp = input[0].to_biguint().try_into().unwrap(); + self.testing_state.lock().unwrap().block_timestamp = block_timestamp; + vec![] + } + "set_signature" => { + self.testing_state.lock().unwrap().signature = input.to_vec(); + vec![] + } + "pop_log" => self + .testing_state + .lock() + .unwrap() + .logs + .get_mut(&input[0]) + .and_then(|logs| logs.events.pop_front()) + .map(|mut log| { + let mut serialized_log = Vec::new(); + serialized_log.push(log.0.len().into()); + serialized_log.append(&mut log.0); + serialized_log.push(log.1.len().into()); + serialized_log.append(&mut log.1); + serialized_log + }) + .unwrap_or_default(), + "pop_l2_to_l1_message" => self + .testing_state + .lock() + .unwrap() + .logs + .get_mut(&input[0]) + .and_then(|logs| logs.l2_to_l1_messages.pop_front()) + .map(|mut log| { + let mut serialized_log = Vec::new(); + serialized_log.push(log.0); + serialized_log.push(log.1.len().into()); + serialized_log.append(&mut log.1); + serialized_log + }) + .unwrap_or_default(), + _ => vec![], + } + } } lazy_static! { @@ -364,7 +508,7 @@ fn get_block_hash() { "get_block_hash", &[], Some(u128::MAX), - Some(SyscallHandler), + Some(SyscallHandler::new()), ); assert_eq_sorted!( @@ -389,7 +533,7 @@ fn get_execution_info() { "get_execution_info", &[], Some(u128::MAX), - Some(SyscallHandler), + Some(SyscallHandler::new()), ); assert_eq_sorted!( @@ -468,7 +612,7 @@ fn get_execution_info_v2() { "get_execution_info_v2", &[], Some(u128::MAX), - Some(SyscallHandler), + Some(SyscallHandler::new()), ); assert_eq_sorted!( @@ -560,7 +704,7 @@ fn deploy() { "deploy", &[], Some(u128::MAX), - Some(SyscallHandler), + Some(SyscallHandler::new()), ); assert_eq_sorted!( @@ -592,7 +736,7 @@ fn replace_class() { "replace_class", &[], Some(u128::MAX), - Some(SyscallHandler), + Some(SyscallHandler::new()), ); assert_eq_sorted!( @@ -615,7 +759,7 @@ fn library_call() { "library_call", &[], Some(u128::MAX), - Some(SyscallHandler), + Some(SyscallHandler::new()), ); assert_eq_sorted!( @@ -651,7 +795,7 @@ fn call_contract() { "call_contract", &[], Some(u128::MAX), - Some(SyscallHandler), + Some(SyscallHandler::new()), ); assert_eq_sorted!( @@ -687,7 +831,7 @@ fn storage_read() { "storage_read", &[], Some(u128::MAX), - Some(SyscallHandler), + Some(SyscallHandler::new()), ); assert_eq_sorted!( @@ -715,7 +859,7 @@ fn storage_write() { "storage_write", &[], Some(u128::MAX), - Some(SyscallHandler), + Some(SyscallHandler::new()), ); assert_eq_sorted!( @@ -741,7 +885,7 @@ fn emit_event() { "emit_event", &[], Some(u128::MAX), - Some(SyscallHandler), + Some(SyscallHandler::new()), ); assert_eq_sorted!( @@ -764,7 +908,7 @@ fn send_message_to_l1() { "send_message_to_l1", &[], Some(u128::MAX), - Some(SyscallHandler), + Some(SyscallHandler::new()), ); assert_eq_sorted!( @@ -787,7 +931,7 @@ fn keccak() { "keccak", &[], Some(u128::MAX), - Some(SyscallHandler), + Some(SyscallHandler::new()), ); assert_eq_sorted!( @@ -805,3 +949,211 @@ fn keccak() { }, ); } + +#[test] +fn set_sequencer_address() { + let address = Felt::THREE; + + let state = Arc::new(Mutex::new(TestingState::default())); + + let result = run_native_program( + &SYSCALLS_PROGRAM, + "set_sequencer_address", + &[JitValue::Felt252(address)], + Some(u128::MAX), + Some(SyscallHandler::with(state.clone())), + ); + + assert_eq_sorted!( + result.return_value, + JitValue::Struct { + fields: vec![JitValue::Array(vec![])], + debug_name: Some("core::array::Span::".to_string()) + } + ); + + let actual_address = state.lock().unwrap().sequencer_address; + assert_eq!(address, actual_address); +} + +#[test] +fn set_max_fee() { + let max_fee = 3; + + let state = Arc::new(Mutex::new(TestingState::default())); + + let result = run_native_program( + &SYSCALLS_PROGRAM, + "set_max_fee", + &[JitValue::Felt252(Felt::from(max_fee))], + Some(u128::MAX), + Some(SyscallHandler::with(state.clone())), + ); + + assert_eq_sorted!( + result.return_value, + JitValue::Struct { + fields: vec![JitValue::Array(vec![])], + debug_name: Some("core::array::Span::".to_string()) + } + ); + + let actual_max_fee = state.lock().unwrap().max_fee; + assert_eq!(max_fee, actual_max_fee); +} + +#[test] +fn set_signature() { + let signature = vec![Felt::ONE, Felt::TWO, Felt::THREE]; + + let signature_jit = signature + .clone() + .into_iter() + .map(JitValue::Felt252) + .collect_vec(); + + let state = Arc::new(Mutex::new(TestingState::default())); + + let result = run_native_program( + &SYSCALLS_PROGRAM, + "set_signature", + &[JitValue::Array(signature_jit)], + Some(u128::MAX), + Some(SyscallHandler::with(state.clone())), + ); + + assert_eq_sorted!( + result.return_value, + JitValue::Struct { + fields: vec![JitValue::Array(vec![])], + debug_name: Some("core::array::Span::".to_string()) + } + ); + + let actual_signature = state.lock().unwrap().signature.clone(); + assert_eq_sorted!(signature, actual_signature); +} + +#[test] +fn pop_log() { + let log_index = Felt::ONE; + let mut log = (vec![Felt::ONE, Felt::TWO], vec![Felt::THREE]); + + let state = Arc::new(Mutex::new(TestingState::default())); + + let logs = ContractLogs { + l2_to_l1_messages: VecDeque::new(), + events: VecDeque::from(vec![log.clone()]), + }; + + state.lock().unwrap().logs.insert(log_index, logs); + + let result = run_native_program( + &SYSCALLS_PROGRAM, + "pop_log", + &[JitValue::Felt252(log_index)], + Some(u128::MAX), + Some(SyscallHandler::with(state.clone())), + ); + + let mut serialized_log = Vec::new(); + serialized_log.push(log.0.len().into()); + serialized_log.append(&mut log.0); + serialized_log.push(log.1.len().into()); + serialized_log.append(&mut log.1); + + let serialized_log_jit = serialized_log + .into_iter() + .map(JitValue::Felt252) + .collect_vec(); + + assert_eq_sorted!( + result.return_value, + JitValue::Struct { + fields: vec![JitValue::Array(serialized_log_jit)], + debug_name: Some("core::array::Span::".to_string()) + } + ); + + assert!(state + .lock() + .unwrap() + .logs + .get(&log_index) + .unwrap() + .events + .is_empty()); +} + +#[test] +fn pop_log_empty() { + let log_index = Felt::ONE; + + let state = Arc::new(Mutex::new(TestingState::default())); + + let result = run_native_program( + &SYSCALLS_PROGRAM, + "pop_log", + &[JitValue::Felt252(log_index)], + Some(u128::MAX), + Some(SyscallHandler::with(state.clone())), + ); + + assert_eq_sorted!( + result.return_value, + JitValue::Struct { + fields: vec![JitValue::Array(Vec::new())], + debug_name: Some("core::array::Span::".to_string()) + } + ); +} + +#[test] +fn pop_l2_to_l1_message() { + let log_index = Felt::ONE; + let mut message = (Felt::ONE, vec![Felt::TWO, Felt::THREE]); + + let state = Arc::new(Mutex::new(TestingState::default())); + + let logs = ContractLogs { + l2_to_l1_messages: VecDeque::from(vec![message.clone()]), + events: VecDeque::new(), + }; + + state.lock().unwrap().logs.insert(log_index, logs); + + let result = run_native_program( + &SYSCALLS_PROGRAM, + "pop_l2_to_l1_message", + &[JitValue::Felt252(log_index)], + Some(u128::MAX), + Some(SyscallHandler::with(state.clone())), + ); + + let mut serialized_message = Vec::new(); + serialized_message.push(message.0); + serialized_message.push(message.1.len().into()); + serialized_message.append(&mut message.1); + + let serialized_message_jit = serialized_message + .into_iter() + .map(JitValue::Felt252) + .collect_vec(); + + assert_eq_sorted!( + result.return_value, + JitValue::Struct { + fields: vec![JitValue::Array(serialized_message_jit)], + debug_name: Some("core::array::Span::".to_string()) + } + ); + + assert!(state + .lock() + .unwrap() + .logs + .get(&log_index) + .unwrap() + .events + .is_empty()); +}