Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make dicts clone on write. #964

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
272 changes: 156 additions & 116 deletions runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ use std::{
ffi::{c_int, c_void},
fs::File,
io::Write,
mem::ManuallyDrop,
mem::{forget, ManuallyDrop},
os::fd::FromRawFd,
ptr::{self, null, null_mut},
rc::Rc,
};
use std::{ops::Mul, vec::IntoIter};

Expand Down Expand Up @@ -151,31 +152,124 @@ pub struct FeltDict {
pub layout: Layout,
pub elements: *mut (),

pub dup_fn: Option<extern "C" fn(*mut c_void, *mut c_void)>,
pub drop_fn: Option<extern "C" fn(*mut c_void)>,

pub count: u64,
}

impl Clone for FeltDict {
fn clone(&self) -> Self {
let mut new_dict = FeltDict {
mappings: HashMap::with_capacity(self.mappings.len()),

layout: self.layout,
elements: if self.mappings.is_empty() {
null_mut()
} else {
unsafe {
alloc(Layout::from_size_align_unchecked(
self.layout.pad_to_align().size() * self.mappings.len(),
self.layout.align(),
))
.cast()
}
},

dup_fn: self.dup_fn,
drop_fn: self.drop_fn,

// TODO: Check if `0` is fine or otherwise we should copy the value from `old_dict` too.
edg-l marked this conversation as resolved.
Show resolved Hide resolved
count: 0,
};

for (&key, &old_index) in self.mappings.iter() {
let old_value_ptr = unsafe {
self.elements
.byte_add(self.layout.pad_to_align().size() * old_index)
};

let new_index = new_dict.mappings.len();
let new_value_ptr = unsafe {
new_dict
.elements
.byte_add(new_dict.layout.pad_to_align().size() * new_index)
};

new_dict.mappings.insert(key, new_index);
match self.dup_fn {
Some(dup_fn) => dup_fn(old_value_ptr.cast(), new_value_ptr.cast()),
None => unsafe {
ptr::copy_nonoverlapping::<u8>(
old_value_ptr.cast(),
new_value_ptr.cast(),
self.layout.size(),
)
},
}
}

new_dict
}
}

impl Drop for FeltDict {
fn drop(&mut self) {
// Free the entries manually.
if let Some(drop_fn) = self.drop_fn {
for (_, &index) in self.mappings.iter() {
let value_ptr = unsafe {
self.elements
.byte_add(self.layout.pad_to_align().size() * index)
};

drop_fn(value_ptr.cast());
}
}

// Free the value data.
if !self.elements.is_null() {
unsafe {
dealloc(
self.elements.cast(),
Layout::from_size_align_unchecked(
self.layout.pad_to_align().size() * self.mappings.capacity(),
self.layout.align(),
),
)
};
}
}
}

/// Allocate a new dictionary.
///
/// # Safety
///
/// This function is intended to be called from MLIR, deals with pointers, and is therefore
/// definitely unsafe to use manually.
#[no_mangle]
pub unsafe extern "C" fn cairo_native__dict_new(size: u64, align: u64) -> *mut FeltDict {
Box::into_raw(Box::new(FeltDict {
pub unsafe extern "C" fn cairo_native__dict_new(
size: u64,
align: u64,
dup_fn: Option<extern "C" fn(*mut c_void, *mut c_void)>,
drop_fn: Option<extern "C" fn(*mut c_void)>,
) -> *const FeltDict {
Rc::into_raw(Rc::new(FeltDict {
mappings: HashMap::default(),

layout: Layout::from_size_align_unchecked(size as usize, align as usize),
elements: null_mut(),

dup_fn,
drop_fn,

count: 0,
}))
}

/// Free a dictionary using an optional callback to drop each element.
///
/// The `drop_fn` callback is present when the value implements `Drop`.
///
/// # Safety
///
/// This function is intended to be called from MLIR, deals with pointers, and is therefore
Expand All @@ -184,88 +278,23 @@ pub unsafe extern "C" fn cairo_native__dict_new(size: u64, align: u64) -> *mut F
// pointer optimization. Check out
// https://doc.rust-lang.org/nomicon/ffi.html#the-nullable-pointer-optimization for more info.
#[no_mangle]
pub unsafe extern "C" fn cairo_native__dict_drop(
ptr: *mut FeltDict,
drop_fn: Option<extern "C" fn(*mut c_void)>,
) {
let dict = Box::from_raw(ptr);

// Free the entries manually.
if let Some(drop_fn) = drop_fn {
for (_, &index) in dict.mappings.iter() {
let value_ptr = dict
.elements
.byte_add(dict.layout.pad_to_align().size() * index);

drop_fn(value_ptr.cast());
}
}

// Free the value data.
if !dict.elements.is_null() {
dealloc(
dict.elements.cast(),
Layout::from_size_align_unchecked(
dict.layout.pad_to_align().size() * dict.mappings.capacity(),
dict.layout.align(),
),
);
}
pub unsafe extern "C" fn cairo_native__dict_drop(ptr: *const FeltDict) {
drop(Rc::from_raw(ptr));
}

/// Duplicate a dictionary using a provided callback to clone each element.
///
/// The `dup_fn` callback is present when the value is not `Copy`, but `Clone`. The first argument
/// is the original value while the second is the target pointer.
///
/// # Safety
///
/// This function is intended to be called from MLIR, deals with pointers, and is therefore
/// definitely unsafe to use manually.
#[no_mangle]
pub unsafe extern "C" fn cairo_native__dict_dup(
old_dict: &FeltDict,
dup_fn: Option<extern "C" fn(*mut c_void, *mut c_void)>,
) -> *mut FeltDict {
let mut new_dict = Box::new(FeltDict {
mappings: HashMap::with_capacity(old_dict.mappings.len()),

layout: old_dict.layout,
elements: if old_dict.mappings.is_empty() {
null_mut()
} else {
alloc(Layout::from_size_align_unchecked(
old_dict.layout.pad_to_align().size() * old_dict.mappings.len(),
old_dict.layout.align(),
))
.cast()
},

// TODO: Check if `0` is fine or otherwise we should copy the value from `old_dict` too.
count: 0,
});

for (new_index, (&key, &old_index)) in old_dict.mappings.iter().enumerate() {
let old_value_ptr = old_dict
.elements
.byte_add(old_dict.layout.pad_to_align().size() * old_index);

let new_value_ptr = new_dict
.elements
.byte_add(new_dict.layout.pad_to_align().size() * new_index);

new_dict.mappings.insert(key, new_index);
match dup_fn {
Some(dup_fn) => dup_fn(old_value_ptr.cast(), new_value_ptr.cast()),
None => ptr::copy_nonoverlapping::<u8>(
old_value_ptr.cast(),
new_value_ptr.cast(),
old_dict.layout.size(),
),
}
}
pub unsafe extern "C" fn cairo_native__dict_dup(dict_ptr: *const FeltDict) -> *const FeltDict {
let old_dict = Rc::from_raw(dict_ptr);
let new_dict = Rc::clone(&old_dict);

Box::into_raw(new_dict)
forget(old_dict);
Rc::into_raw(new_dict)
}

/// Return a pointer to the entry's value pointer for a given key, inserting a null pointer if not
Expand All @@ -280,45 +309,46 @@ pub unsafe extern "C" fn cairo_native__dict_dup(
/// definitely unsafe to use manually.
#[no_mangle]
pub unsafe extern "C" fn cairo_native__dict_get(
dict: &mut FeltDict,
dict: *const FeltDict,
key: &[u8; 32],
value_ptr: *mut *mut c_void,
) -> c_int {
let mut key = *key;
key[31] &= 0x0F; // Filter out first 4 bits (they're outside an i252).

let old_capacity = dict.mappings.capacity();
let index = dict.mappings.len();
let (index, is_present) = match dict.mappings.entry(key) {
Entry::Occupied(entry) => (*entry.get(), 1),
Entry::Vacant(entry) => {
entry.insert(index);
let mut dict_rc = Rc::from_raw(dict);
let dict = Rc::make_mut(&mut dict_rc);

// Reallocate `mem_data` to match the slab's capacity.
if old_capacity != dict.mappings.capacity() {
dict.elements = realloc(
dict.elements.cast(),
Layout::from_size_align_unchecked(
dict.layout.pad_to_align().size() * old_capacity,
dict.layout.align(),
),
dict.layout.pad_to_align().size() * dict.mappings.capacity(),
)
.cast();
}
let num_mappings = dict.mappings.len();
let has_capacity = num_mappings != dict.mappings.capacity();

(index, 0)
let (is_present, index) = match dict.mappings.entry(*key) {
Entry::Occupied(entry) => (true, *entry.get()),
Entry::Vacant(entry) => {
entry.insert(num_mappings);
(false, num_mappings)
}
};

value_ptr.write(
dict.elements
.byte_add(dict.layout.pad_to_align().size() * index)
.cast(),
);
// Maybe realloc (conditions: !has_capacity && !is_present).
if !has_capacity && !is_present {
dict.elements = realloc(
dict.elements.cast(),
Layout::from_size_align_unchecked(
dict.layout.pad_to_align().size() * dict.mappings.len(),
dict.layout.align(),
),
dict.layout.pad_to_align().size() * dict.mappings.capacity(),
)
.cast();
}

*value_ptr = dict
.elements
.byte_add(dict.layout.pad_to_align().size() * index)
.cast();

dict.count += 1;
forget(dict_rc);

is_present
is_present as c_int
}

/// Compute the total gas refund for the dictionary at squash time.
Expand All @@ -329,8 +359,12 @@ pub unsafe extern "C" fn cairo_native__dict_get(
/// definitely unsafe to use manually.
#[no_mangle]
pub unsafe extern "C" fn cairo_native__dict_gas_refund(ptr: *const FeltDict) -> u64 {
let dict = &*ptr;
(dict.count.saturating_sub(dict.mappings.len() as u64)) * *DICT_GAS_REFUND_PER_ACCESS
let dict = Rc::from_raw(ptr);
let amount =
(dict.count.saturating_sub(dict.mappings.len() as u64)) * *DICT_GAS_REFUND_PER_ACCESS;

forget(dict);
amount
}

/// Compute `ec_point_from_x_nz(x)` and store it.
Expand Down Expand Up @@ -850,21 +884,27 @@ mod tests {

#[test]
fn test_dict() {
let dict =
unsafe { cairo_native__dict_new(size_of::<u64>() as u64, align_of::<u64>() as u64) };
let dict = unsafe {
cairo_native__dict_new(
size_of::<u64>() as u64,
align_of::<u64>() as u64,
None,
None,
)
};

let key = Felt::ONE.to_bytes_le();
let mut ptr = null_mut::<u64>();

assert_eq!(
unsafe { cairo_native__dict_get(&mut *dict, &key, (&raw mut ptr).cast()) },
unsafe { cairo_native__dict_get(dict, &key, (&raw mut ptr).cast()) },
0,
);
assert!(!ptr.is_null());
unsafe { *ptr = 24 };

assert_eq!(
unsafe { cairo_native__dict_get(&mut *dict, &key, (&raw mut ptr).cast()) },
unsafe { cairo_native__dict_get(dict, &key, (&raw mut ptr).cast()) },
1,
);
assert!(!ptr.is_null());
Expand All @@ -874,17 +914,17 @@ mod tests {
let refund = unsafe { cairo_native__dict_gas_refund(dict) };
assert_eq!(refund, 4050);

let cloned_dict = unsafe { cairo_native__dict_dup(&*dict, None) };
unsafe { cairo_native__dict_drop(dict, None) };
let cloned_dict = unsafe { cairo_native__dict_dup(&*dict) };
unsafe { cairo_native__dict_drop(dict) };

assert_eq!(
unsafe { cairo_native__dict_get(&mut *cloned_dict, &key, (&raw mut ptr).cast()) },
unsafe { cairo_native__dict_get(cloned_dict, &key, (&raw mut ptr).cast()) },
1,
);
assert!(!ptr.is_null());
assert_eq!(unsafe { *ptr }, 42);

unsafe { cairo_native__dict_drop(cloned_dict, None) };
unsafe { cairo_native__dict_drop(cloned_dict) };
}

#[test]
Expand Down
Loading
Loading