Skip to content

Commit

Permalink
Fix dict snapshot (#724)
Browse files Browse the repository at this point in the history
* wip fix dict clone

* fix dict clone

* remove comment

* fix dict to_jit

* progress

* dict snapshot take with deep cloning works
  • Loading branch information
edg-l authored Aug 2, 2024
1 parent 9cb9b37 commit 3f710d6
Show file tree
Hide file tree
Showing 5 changed files with 417 additions and 32 deletions.
55 changes: 45 additions & 10 deletions runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ pub unsafe extern "C" fn cairo_native__libfunc__hades_permutation(
op2.copy_from_slice(&state[2].to_bytes_be());
}

/// Felt252 type used in cairo native runtime
pub type FeltDict = (HashMap<[u8; 32], NonNull<std::ffi::c_void>>, u64);

/// Allocates a new dictionary. Internally a rust hashmap: `HashMap<[u8; 32], NonNull<()>`
///
/// # Safety
Expand All @@ -157,7 +160,7 @@ pub unsafe extern "C" fn cairo_native__libfunc__hades_permutation(
/// definitely unsafe to use manually.
#[no_mangle]
pub unsafe extern "C" fn cairo_native__alloc_dict() -> *mut std::ffi::c_void {
Box::into_raw(Box::<(HashMap<[u8; 32], NonNull<std::ffi::c_void>>, u64)>::default()) as _
Box::into_raw(Box::<FeltDict>::default()) as _
}

/// Frees the dictionary.
Expand All @@ -167,9 +170,7 @@ pub unsafe extern "C" fn cairo_native__alloc_dict() -> *mut std::ffi::c_void {
/// This function is intended to be called from MLIR, deals with pointers, and is therefore
/// definitely unsafe to use manually.
#[no_mangle]
pub unsafe extern "C" fn cairo_native__dict_free(
ptr: *mut (HashMap<[u8; 32], NonNull<std::ffi::c_void>>, u64),
) {
pub unsafe extern "C" fn cairo_native__dict_free(ptr: *mut FeltDict) {
let mut map = Box::from_raw(ptr);

// Free the entries manually.
Expand All @@ -178,6 +179,42 @@ pub unsafe extern "C" fn cairo_native__dict_free(
}
}

/// Needed for the correct alignment,
/// since the key [u8; 32] in rust has 8 byte alignment but its a felt,
/// so in reality it has 16.
#[repr(C, align(16))]
pub struct DictValuesArrayAbi {
pub key: [u8; 32],
pub value: std::ptr::NonNull<libc::c_void>,
}

/// Returns a array over the values of the dict, used for deep cloning.
///
/// # Safety
///
/// This function is intended to be called from MLIR, deals with pointers, and is therefore
/// definitely unsafe to use manually.
#[no_mangle]
pub unsafe extern "C" fn cairo_native__dict_values(
ptr: *mut FeltDict,
len: *mut u64,
) -> *mut DictValuesArrayAbi {
let dict: &mut FeltDict = &mut *ptr;

let values: Vec<_> = dict
.0
.clone()
.into_iter()
// make it ffi safe for use within MLIR.
.map(|x| DictValuesArrayAbi {
key: x.0,
value: x.1,
})
.collect();
*len = values.len() as u64;
values.leak::<'static>().as_mut_ptr()
}

/// Gets the value for a given key, the returned pointer is null if not found.
/// Increments the access count.
///
Expand All @@ -187,10 +224,10 @@ pub unsafe extern "C" fn cairo_native__dict_free(
/// definitely unsafe to use manually.
#[no_mangle]
pub unsafe extern "C" fn cairo_native__dict_get(
ptr: *mut (HashMap<[u8; 32], NonNull<std::ffi::c_void>>, u64),
ptr: *mut FeltDict,
key: &[u8; 32],
) -> *mut std::ffi::c_void {
let dict: &mut (HashMap<[u8; 32], NonNull<std::ffi::c_void>>, u64) = &mut *ptr;
let dict: &mut FeltDict = &mut *ptr;
let map = &dict.0;
dict.1 += 1;

Expand All @@ -209,7 +246,7 @@ pub unsafe extern "C" fn cairo_native__dict_get(
/// definitely unsafe to use manually.
#[no_mangle]
pub unsafe extern "C" fn cairo_native__dict_insert(
ptr: *mut (HashMap<[u8; 32], NonNull<std::ffi::c_void>>, u64),
ptr: *mut FeltDict,
key: &[u8; 32],
value: NonNull<std::ffi::c_void>,
) -> *mut std::ffi::c_void {
Expand All @@ -230,9 +267,7 @@ pub unsafe extern "C" fn cairo_native__dict_insert(
/// This function is intended to be called from MLIR, deals with pointers, and is therefore
/// definitely unsafe to use manually.
#[no_mangle]
pub unsafe extern "C" fn cairo_native__dict_gas_refund(
ptr: *const (HashMap<[u8; 32], NonNull<std::ffi::c_void>>, u64),
) -> u64 {
pub unsafe extern "C" fn cairo_native__dict_gas_refund(ptr: *const FeltDict) -> u64 {
let dict = &*ptr;
(dict.1 - dict.0.len() as u64) * *DICT_GAS_REFUND_PER_ACCESS
}
Expand Down
62 changes: 56 additions & 6 deletions src/metadata/runtime_bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,20 @@ use std::{collections::HashSet, marker::PhantomData};

#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
enum RuntimeBinding {
DebugPrint,
Pedersen,
HadesPermutation,
EcPointFromXNz,
EcPointTryNewNz,
EcStateAdd,
EcStateAddMul,
EcStateTryFinalizeNz,
EcStateAddMul,
EcStateAdd,
EcPointTryNewNz,
EcPointFromXNz,
DictNew,
DictInsert,
DictGet,
DictGasRefund,
DictInsert,
DictFree,
DictValues,
DebugPrint,
#[cfg(feature = "with-cheatcode")]
VtableCheatcode,
}
Expand Down Expand Up @@ -502,6 +503,55 @@ impl RuntimeBindingsMeta {
)))
}

/// Register if necessary, then invoke the `dict_clone()` function.
///
/// Returns a opaque pointer as the result.
#[allow(clippy::too_many_arguments)]
pub fn dict_values<'c, 'a>(
&mut self,
context: &'c Context,
module: &Module,
ptr: Value<'c, 'a>,
len_ptr: Value<'c, 'a>,
block: &'a Block<'c>,
location: Location<'c>,
) -> Result<OperationRef<'c, 'a>>
where
'c: 'a,
{
if self.active_map.insert(RuntimeBinding::DictValues) {
module.body().append_operation(func::func(
context,
StringAttribute::new(context, "cairo_native__dict_values"),
TypeAttribute::new(
FunctionType::new(
context,
&[
llvm::r#type::pointer(context, 0),
llvm::r#type::pointer(context, 0),
], // ptr to dict, out ptr to length
&[llvm::r#type::pointer(context, 0)], // ptr to array of struct (key, value_ptr)
)
.into(),
),
Region::new(),
&[(
Identifier::new(context, "sym_visibility"),
StringAttribute::new(context, "private").into(),
)],
Location::unknown(context),
));
}

Ok(block.append_operation(func::call(
context,
FlatSymbolRefAttribute::new(context, "cairo_native__dict_values"),
&[ptr, len_ptr],
&[llvm::r#type::pointer(context, 0)],
location,
)))
}

/// Register if necessary, then invoke the `dict_get()` function.
///
/// Gets the value for a given key, the returned pointer is null if not found.
Expand Down
Loading

0 comments on commit 3f710d6

Please sign in to comment.