From 3f710d641deeed34223fa51231587a837526e3a3 Mon Sep 17 00:00:00 2001 From: Edgar Date: Fri, 2 Aug 2024 12:58:55 +0200 Subject: [PATCH] Fix dict snapshot (#724) * wip fix dict clone * fix dict clone * remove comment * fix dict to_jit * progress * dict snapshot take with deep cloning works --- runtime/src/lib.rs | 55 ++++-- src/metadata/runtime_bindings.rs | 62 ++++++- src/types/felt252_dict.rs | 299 ++++++++++++++++++++++++++++++- src/utils.rs | 23 ++- src/values.rs | 10 +- 5 files changed, 417 insertions(+), 32 deletions(-) diff --git a/runtime/src/lib.rs b/runtime/src/lib.rs index a129545f4..cc41b324f 100644 --- a/runtime/src/lib.rs +++ b/runtime/src/lib.rs @@ -149,6 +149,9 @@ pub unsafe extern "C" fn cairo_native__libfunc__hades_permutation( op2.copy_from_slice(&state[2].to_bytes_be()); } +/// Felt252 type used in cairo native runtime +pub type FeltDict = (HashMap<[u8; 32], NonNull>, u64); + /// Allocates a new dictionary. Internally a rust hashmap: `HashMap<[u8; 32], NonNull<()>` /// /// # Safety @@ -157,7 +160,7 @@ pub unsafe extern "C" fn cairo_native__libfunc__hades_permutation( /// definitely unsafe to use manually. #[no_mangle] pub unsafe extern "C" fn cairo_native__alloc_dict() -> *mut std::ffi::c_void { - Box::into_raw(Box::<(HashMap<[u8; 32], NonNull>, u64)>::default()) as _ + Box::into_raw(Box::::default()) as _ } /// Frees the dictionary. @@ -167,9 +170,7 @@ pub unsafe extern "C" fn cairo_native__alloc_dict() -> *mut std::ffi::c_void { /// This function is intended to be called from MLIR, deals with pointers, and is therefore /// definitely unsafe to use manually. #[no_mangle] -pub unsafe extern "C" fn cairo_native__dict_free( - ptr: *mut (HashMap<[u8; 32], NonNull>, u64), -) { +pub unsafe extern "C" fn cairo_native__dict_free(ptr: *mut FeltDict) { let mut map = Box::from_raw(ptr); // Free the entries manually. @@ -178,6 +179,42 @@ pub unsafe extern "C" fn cairo_native__dict_free( } } +/// Needed for the correct alignment, +/// since the key [u8; 32] in rust has 8 byte alignment but its a felt, +/// so in reality it has 16. +#[repr(C, align(16))] +pub struct DictValuesArrayAbi { + pub key: [u8; 32], + pub value: std::ptr::NonNull, +} + +/// Returns a array over the values of the dict, used for deep cloning. +/// +/// # Safety +/// +/// This function is intended to be called from MLIR, deals with pointers, and is therefore +/// definitely unsafe to use manually. +#[no_mangle] +pub unsafe extern "C" fn cairo_native__dict_values( + ptr: *mut FeltDict, + len: *mut u64, +) -> *mut DictValuesArrayAbi { + let dict: &mut FeltDict = &mut *ptr; + + let values: Vec<_> = dict + .0 + .clone() + .into_iter() + // make it ffi safe for use within MLIR. + .map(|x| DictValuesArrayAbi { + key: x.0, + value: x.1, + }) + .collect(); + *len = values.len() as u64; + values.leak::<'static>().as_mut_ptr() +} + /// Gets the value for a given key, the returned pointer is null if not found. /// Increments the access count. /// @@ -187,10 +224,10 @@ pub unsafe extern "C" fn cairo_native__dict_free( /// definitely unsafe to use manually. #[no_mangle] pub unsafe extern "C" fn cairo_native__dict_get( - ptr: *mut (HashMap<[u8; 32], NonNull>, u64), + ptr: *mut FeltDict, key: &[u8; 32], ) -> *mut std::ffi::c_void { - let dict: &mut (HashMap<[u8; 32], NonNull>, u64) = &mut *ptr; + let dict: &mut FeltDict = &mut *ptr; let map = &dict.0; dict.1 += 1; @@ -209,7 +246,7 @@ pub unsafe extern "C" fn cairo_native__dict_get( /// definitely unsafe to use manually. #[no_mangle] pub unsafe extern "C" fn cairo_native__dict_insert( - ptr: *mut (HashMap<[u8; 32], NonNull>, u64), + ptr: *mut FeltDict, key: &[u8; 32], value: NonNull, ) -> *mut std::ffi::c_void { @@ -230,9 +267,7 @@ pub unsafe extern "C" fn cairo_native__dict_insert( /// This function is intended to be called from MLIR, deals with pointers, and is therefore /// definitely unsafe to use manually. #[no_mangle] -pub unsafe extern "C" fn cairo_native__dict_gas_refund( - ptr: *const (HashMap<[u8; 32], NonNull>, u64), -) -> u64 { +pub unsafe extern "C" fn cairo_native__dict_gas_refund(ptr: *const FeltDict) -> u64 { let dict = &*ptr; (dict.1 - dict.0.len() as u64) * *DICT_GAS_REFUND_PER_ACCESS } diff --git a/src/metadata/runtime_bindings.rs b/src/metadata/runtime_bindings.rs index e58a57a81..4a487f22f 100644 --- a/src/metadata/runtime_bindings.rs +++ b/src/metadata/runtime_bindings.rs @@ -17,19 +17,20 @@ use std::{collections::HashSet, marker::PhantomData}; #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)] enum RuntimeBinding { - DebugPrint, Pedersen, HadesPermutation, - EcPointFromXNz, - EcPointTryNewNz, - EcStateAdd, - EcStateAddMul, EcStateTryFinalizeNz, + EcStateAddMul, + EcStateAdd, + EcPointTryNewNz, + EcPointFromXNz, DictNew, + DictInsert, DictGet, DictGasRefund, - DictInsert, DictFree, + DictValues, + DebugPrint, #[cfg(feature = "with-cheatcode")] VtableCheatcode, } @@ -502,6 +503,55 @@ impl RuntimeBindingsMeta { ))) } + /// Register if necessary, then invoke the `dict_clone()` function. + /// + /// Returns a opaque pointer as the result. + #[allow(clippy::too_many_arguments)] + pub fn dict_values<'c, 'a>( + &mut self, + context: &'c Context, + module: &Module, + ptr: Value<'c, 'a>, + len_ptr: Value<'c, 'a>, + block: &'a Block<'c>, + location: Location<'c>, + ) -> Result> + where + 'c: 'a, + { + if self.active_map.insert(RuntimeBinding::DictValues) { + module.body().append_operation(func::func( + context, + StringAttribute::new(context, "cairo_native__dict_values"), + TypeAttribute::new( + FunctionType::new( + context, + &[ + llvm::r#type::pointer(context, 0), + llvm::r#type::pointer(context, 0), + ], // ptr to dict, out ptr to length + &[llvm::r#type::pointer(context, 0)], // ptr to array of struct (key, value_ptr) + ) + .into(), + ), + Region::new(), + &[( + Identifier::new(context, "sym_visibility"), + StringAttribute::new(context, "private").into(), + )], + Location::unknown(context), + )); + } + + Ok(block.append_operation(func::call( + context, + FlatSymbolRefAttribute::new(context, "cairo_native__dict_values"), + &[ptr, len_ptr], + &[llvm::r#type::pointer(context, 0)], + location, + ))) + } + /// Register if necessary, then invoke the `dict_get()` function. /// /// Gets the value for a given key, the returned pointer is null if not found. diff --git a/src/types/felt252_dict.rs b/src/types/felt252_dict.rs index a6d2c201d..e5661d0a1 100644 --- a/src/types/felt252_dict.rs +++ b/src/types/felt252_dict.rs @@ -6,8 +6,19 @@ //! used to count accesses to the dictionary. The type is interacted through the runtime functions to //! insert, get elements and increment the access counter. +use std::cell::Cell; + use super::WithSelf; -use crate::{error::Result, metadata::MetadataStorage}; +use crate::{ + block_ext::BlockExt, + error::Result, + libfuncs::LibfuncHelper, + metadata::{ + realloc_bindings::ReallocBindingsMeta, runtime_bindings::RuntimeBindingsMeta, + snapshot_clones::SnapshotClonesMeta, MetadataStorage, + }, + types::TypeBuilder, +}; use cairo_lang_sierra::{ extensions::{ core::{CoreLibfunc, CoreType}, @@ -16,8 +27,14 @@ use cairo_lang_sierra::{ program_registry::ProgramRegistry, }; use melior::{ - dialect::llvm, - ir::{Module, Type}, + dialect::{ + llvm::{self, r#type::pointer}, + ods, scf, + }, + ir::{ + attribute::IntegerAttribute, r#type::IntegerType, Block, Location, Module, Region, Type, + Value, + }, Context, }; @@ -28,22 +45,292 @@ pub fn build<'ctx>( context: &'ctx Context, _module: &Module<'ctx>, _registry: &ProgramRegistry, - _metadata: &mut MetadataStorage, - _info: WithSelf, + metadata: &mut MetadataStorage, + info: WithSelf, ) -> Result> { + metadata + .get_or_insert_with::(SnapshotClonesMeta::default) + .register( + info.self_ty().clone(), + snapshot_take, + InfoAndTypeConcreteType { + info: info.info.clone(), + ty: info.ty.clone(), + }, + ); + Ok(llvm::r#type::pointer(context, 0)) } +#[allow(clippy::too_many_arguments)] +fn snapshot_take<'ctx, 'this>( + context: &'ctx Context, + registry: &ProgramRegistry, + entry: &'this Block<'ctx>, + location: Location<'ctx>, + helper: &LibfuncHelper<'ctx, 'this>, + metadata: &mut MetadataStorage, + info: WithSelf, + src_value: Value<'ctx, 'this>, +) -> Result<(&'this Block<'ctx>, Value<'ctx, 'this>)> { + if metadata.get::().is_none() { + metadata.insert(ReallocBindingsMeta::new(context, helper)); + } + + let elem_snapshot_take = metadata + .get::() + .and_then(|meta| meta.wrap_invoke(&info.ty)); + + let elem_ty = registry.get_type(&info.ty)?; + let elem_layout = elem_ty.layout(registry)?; + let elem_ty = elem_ty.build(context, helper, registry, metadata, &info.ty)?; + + let location = Location::name(context, "dict_snapshot_clone", location); + + let runtime_bindings = metadata + .get_mut::() + .expect("Runtime library not available."); + + let len_ptr = helper.init_block().alloca_int(context, location, 64)?; + let u64_ty = IntegerType::new(context, 64).into(); + + let entry_values_type = llvm::r#type::r#struct( + context, + &[IntegerType::new(context, 252).into(), pointer(context, 0)], // key, value ptr + false, + ); + + // ptr to array of entry_values_type + let entries_ptr = runtime_bindings + .dict_values(context, helper, src_value, len_ptr, entry, location)? + .result(0)? + .into(); + + let array_len = entry.load(context, location, len_ptr, u64_ty)?; + + let k0 = entry.const_int(context, location, 0, 64)?; + let k1 = entry.const_int(context, location, 1, 64)?; + let elem_stride_bytes = + entry.const_int(context, location, elem_layout.pad_to_align().size(), 64)?; + let nullptr = entry.append_op_result(llvm::zero(pointer(context, 0), location))?; + + let cloned_dict_ptr = runtime_bindings + .dict_alloc_new(context, helper, entry, location)? + .result(0)? + .into(); + + entry.append_operation(scf::r#for( + k0, + array_len, + k1, + { + let region = Region::new(); + let block = region.append_block(Block::new(&[( + IntegerType::new(context, 64).into(), + location, + )])); + + let i = block.argument(0)?.into(); + block.append_operation(scf::execute_region( + &[], + { + let region = Region::new(); + let block = region.append_block(Block::new(&[])); + + let entry_ptr = block.append_op_result(llvm::get_element_ptr_dynamic( + context, + entries_ptr, + &[i], + entry_values_type, + llvm::r#type::pointer(context, 0), + location, + ))?; + + let helper = LibfuncHelper { + module: helper.module, + init_block: helper.init_block, + region: ®ion, + blocks_arena: helper.blocks_arena, + last_block: Cell::new(&block), + branches: Vec::new(), + results: Vec::new(), + }; + + let entry_value = + block.load(context, location, entry_ptr, entry_values_type)?; + + let key = block.extract_value( + context, + location, + entry_value, + IntegerType::new(context, 252).into(), + 0, + )?; + let key_ptr = helper.init_block().alloca_int(context, location, 252)?; + block.store(context, location, key_ptr, key)?; + let value_ptr = block.extract_value( + context, + location, + entry_value, + pointer(context, 0), + 1, + )?; + + match elem_snapshot_take { + Some(elem_snapshot_take) => { + let value = block.load(context, location, value_ptr, elem_ty)?; + let (block, cloned_value) = elem_snapshot_take( + context, registry, &block, location, &helper, metadata, value, + )?; + + let cloned_value_ptr = + block.append_op_result(ReallocBindingsMeta::realloc( + context, + nullptr, + elem_stride_bytes, + location, + ))?; + + block.store(context, location, cloned_value_ptr, cloned_value)?; + + // needed due to mut borrow + let runtime_bindings = metadata + .get_mut::() + .expect("Runtime library not available."); + runtime_bindings.dict_insert( + context, + &helper, + block, + cloned_dict_ptr, + key_ptr, + cloned_value_ptr, + location, + )?; + block.append_operation(scf::r#yield(&[], location)); + } + None => { + let cloned_value_ptr = + block.append_op_result(ReallocBindingsMeta::realloc( + context, + nullptr, + elem_stride_bytes, + location, + ))?; + block.append_operation( + ods::llvm::intr_memcpy( + context, + cloned_value_ptr, + value_ptr, + elem_stride_bytes, + IntegerAttribute::new(IntegerType::new(context, 1).into(), 0), + location, + ) + .into(), + ); + runtime_bindings.dict_insert( + context, + &helper, + &block, + cloned_dict_ptr, + key_ptr, + cloned_value_ptr, + location, + )?; + block.append_operation(scf::r#yield(&[], location)); + } + } + + region + }, + location, + )); + + block.append_operation(scf::r#yield(&[], location)); + region + }, + location, + )); + + Ok((entry, cloned_dict_ptr)) +} + #[cfg(test)] mod test { use crate::{ - utils::test::{load_cairo, run_program}, + utils::test::{jit_dict, load_cairo, run_program}, values::JitValue, }; use pretty_assertions_sorted::assert_eq; use starknet_types_core::felt::Felt; use std::collections::HashMap; + #[test] + fn dict_snapshot_take() { + let program = load_cairo! { + fn run_test() -> @Felt252Dict { + let mut dict: Felt252Dict = Default::default(); + dict.insert(2, 1_u32); + + @dict + } + }; + let result = run_program(&program, "run_test", &[]).return_value; + + assert_eq!( + result, + jit_dict!( + 2 => 1u32 + ), + ); + } + + #[test] + fn dict_snapshot_take_complex() { + let program = load_cairo! { + fn run_test() -> @Felt252Dict>> { + let mut dict: Felt252Dict>> = Default::default(); + dict.insert(2, NullableTrait::new(array![3, 4])); + + @dict + } + + }; + let result = run_program(&program, "run_test", &[]).return_value; + + assert_eq!( + result, + jit_dict!( + 2 => JitValue::Array(vec![3u32.into(), 4u32.into()]) + ), + ); + } + + #[test] + fn dict_snapshot_take_compare() { + let program = load_cairo! { + fn run_test() -> @Felt252Dict>> { + let mut dict: Felt252Dict>> = Default::default(); + dict.insert(2, NullableTrait::new(array![3, 4])); + + @dict + } + + }; + let program2 = load_cairo! { + fn run_test() -> Felt252Dict>> { + let mut dict: Felt252Dict>> = Default::default(); + dict.insert(2, NullableTrait::new(array![3, 4])); + + dict + } + + }; + let result1 = run_program(&program, "run_test", &[]).return_value; + let result2 = run_program(&program2, "run_test", &[]).return_value; + + assert_eq!(result1, result2); + } + /// Ensure that a dictionary of booleans compiles. #[test] fn dict_type_bool() { diff --git a/src/utils.rs b/src/utils.rs index 79f78cc04..0622e650f 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -253,6 +253,8 @@ pub fn run_pass_manager(context: &Context, module: &mut Module) -> Result<(), Er #[cfg(feature = "with-runtime")] pub fn register_runtime_symbols(engine: &ExecutionEngine) { + use cairo_native_runtime::FeltDict; + unsafe { engine.register_symbol( "cairo_native__libfunc__debug__print", @@ -307,20 +309,29 @@ pub fn register_runtime_symbols(engine: &ExecutionEngine) { engine.register_symbol( "cairo_native__alloc_dict", - cairo_native_runtime::cairo_native__alloc_dict as *const fn() -> *mut std::ffi::c_void + cairo_native_runtime::cairo_native__alloc_dict as *const fn() -> *mut FeltDict as *mut (), ); engine.register_symbol( "cairo_native__dict_free", - cairo_native_runtime::cairo_native__dict_free as *const fn(*mut std::ffi::c_void) -> () + cairo_native_runtime::cairo_native__dict_free as *const fn(*mut FeltDict) -> () as *mut (), ); + engine.register_symbol( + "cairo_native__dict_values", + cairo_native_runtime::cairo_native__dict_values + as *const fn( + *mut FeltDict, + *mut u64, + ) -> *mut ([u8; 32], std::ptr::NonNull) as *mut (), + ); + engine.register_symbol( "cairo_native__dict_get", cairo_native_runtime::cairo_native__dict_get - as *const fn(*mut std::ffi::c_void, &[u8; 32]) -> *mut std::ffi::c_void + as *const fn(*mut FeltDict, &[u8; 32]) -> *mut std::ffi::c_void as *mut (), ); @@ -328,16 +339,16 @@ pub fn register_runtime_symbols(engine: &ExecutionEngine) { "cairo_native__dict_insert", cairo_native_runtime::cairo_native__dict_insert as *const fn( - *mut std::ffi::c_void, + *mut FeltDict, &[u8; 32], NonNull, + usize, ) -> *mut std::ffi::c_void as *mut (), ); engine.register_symbol( "cairo_native__dict_gas_refund", - cairo_native_runtime::cairo_native__dict_gas_refund - as *const fn(*const std::ffi::c_void, NonNull) -> u64 + cairo_native_runtime::cairo_native__dict_gas_refund as *const fn(*const FeltDict) -> u64 as *mut (), ); diff --git a/src/values.rs b/src/values.rs index e254c8c25..edadb2015 100644 --- a/src/values.rs +++ b/src/values.rs @@ -18,6 +18,7 @@ use cairo_lang_sierra::{ ids::ConcreteTypeId, program_registry::ProgramRegistry, }; +use cairo_native_runtime::FeltDict; use educe::Educe; use num_bigint::{BigInt, Sign, ToBigInt}; use num_traits::Euclid; @@ -386,7 +387,7 @@ impl JitValue { let elem_ty = registry.get_type(&info.ty).unwrap(); let elem_layout = elem_ty.layout(registry).unwrap().pad_to_align(); - let mut value_map = HashMap::<[u8; 32], NonNull>::new(); + let mut value_map = Box::::default(); // next key must be called before next_value @@ -402,7 +403,7 @@ impl JitValue { elem_layout.size(), ); - value_map.insert( + value_map.0.insert( key, NonNull::new(value_malloc_ptr) .expect("allocation failure") @@ -410,7 +411,7 @@ impl JitValue { ); } - NonNull::new_unchecked(Box::into_raw(Box::new(value_map))).cast() + NonNull::new_unchecked(Box::into_raw(value_map)).cast() } else { Err(Error::UnexpectedValue(format!( "expected value of type {:?} but got a felt dict", @@ -705,7 +706,7 @@ impl JitValue { let (map, _) = *Box::from_raw( ptr.cast::>() .as_ref() - .cast::<(HashMap<[u8; 32], NonNull>, u64)>() + .cast::() .as_ptr(), ); @@ -714,6 +715,7 @@ impl JitValue { for (key, val_ptr) in map.iter() { let key = Felt::from_bytes_le(key); output_map.insert(key, Self::from_jit(val_ptr.cast(), &info.ty, registry)); + libc::free(val_ptr.as_ptr()); } JitValue::Felt252Dict {