Skip to content

Commit

Permalink
Fix Nullable<T> snapshot segfault (#718)
Browse files Browse the repository at this point in the history
* make nullable deep copy too

* format

* try

* fix

* clippy

* add test
  • Loading branch information
edg-l authored Jul 5, 2024
1 parent 00bb1f3 commit 905d695
Showing 1 changed file with 95 additions and 34 deletions.
129 changes: 95 additions & 34 deletions src/types/nullable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,13 @@ use cairo_lang_sierra::{
},
program_registry::ProgramRegistry,
};
use melior::dialect::cf;
use melior::{
dialect::{
llvm::{self, r#type::pointer},
ods, scf,
},
ir::{
attribute::IntegerAttribute, r#type::IntegerType, Block, Location, Module, Region, Type,
Value,
ods,
},
ir::{attribute::IntegerAttribute, r#type::IntegerType, Block, Location, Module, Type, Value},
Context,
};

Expand Down Expand Up @@ -73,7 +71,13 @@ fn snapshot_take<'ctx, 'this>(
metadata.insert(ReallocBindingsMeta::new(context, helper));
}

let elem_layout = registry.get_type(&info.ty)?.layout(registry)?;
let inner_snapshot_take = metadata
.get::<SnapshotClonesMeta>()
.and_then(|meta| meta.wrap_invoke(&info.ty));

let inner_type = registry.get_type(&info.ty)?;
let inner_layout = inner_type.layout(registry)?;
let inner_ty = inner_type.build(context, helper, registry, metadata, info.self_ty())?;

let null_ptr = entry
.append_op_result(ods::llvm::mlir_zero(context, pointer(context, 0), location).into())?;
Expand All @@ -90,46 +94,103 @@ fn snapshot_take<'ctx, 'this>(
.into(),
)?;

let value = entry
.append_operation(scf::r#if(
is_null,
&[llvm::r#type::pointer(context, 0)],
{
let region = Region::new();
let block = region.append_block(Block::new(&[]));
let mut block_not_null = helper.append_block(Block::new(&[]));
let block_finish = helper.append_block(Block::new(&[(pointer(context, 0), location)]));

block.append_operation(scf::r#yield(&[null_ptr], location));
region
},
{
let region = Region::new();
let block = region.append_block(Block::new(&[]));
entry.append_operation(cf::cond_br(
context,
is_null,
block_finish,
block_not_null,
&[null_ptr],
&[],
location,
));

{
let value_len =
block_not_null.const_int(context, location, inner_layout.pad_to_align().size(), 64)?;

let dst_ptr = block_not_null.append_op_result(ReallocBindingsMeta::realloc(
context, null_ptr, value_len, location,
))?;

let alloc_len = block.const_int(context, location, elem_layout.size(), 64)?;
match inner_snapshot_take {
Some(inner_snapshot_take) => {
let value = block_not_null.load(context, location, src_value, inner_ty)?;

let cloned_ptr = block.append_op_result(ReallocBindingsMeta::realloc(
context, null_ptr, alloc_len, location,
))?;
let (next_block, value) = inner_snapshot_take(
context,
registry,
block_not_null,
location,
helper,
metadata,
value,
)?;
block_not_null = next_block;

block.append_operation(
block_not_null.store(context, location, dst_ptr, value)?;
}
None => {
block_not_null.append_operation(
ods::llvm::intr_memcpy(
context,
cloned_ptr,
dst_ptr,
src_value,
alloc_len,
value_len,
IntegerAttribute::new(IntegerType::new(context, 1).into(), 0),
location,
)
.into(),
);
}
}
block_not_null.append_operation(cf::br(block_finish, &[dst_ptr], location));
}

block.append_operation(scf::r#yield(&[cloned_ptr], location));
region
},
location,
))
.result(0)?
.into();
let value = block_finish.argument(0)?.into();

Ok((block_finish, value))
}

#[cfg(test)]
mod test {
use crate::{
utils::test::{jit_enum, jit_struct, load_cairo, run_program},
values::JitValue,
};
use pretty_assertions_sorted::assert_eq;

#[test]
fn test_nullable_deep_clone() {
let program = load_cairo! {
use core::array::ArrayTrait;
use core::NullableTrait;

fn run_test() -> @Nullable<Array<felt252>> {
let mut x = NullableTrait::new(array![1, 2, 3]);
let x_s = @x;

let mut y = NullableTrait::deref(x);
y.append(4);

Ok((entry, value))
x_s
}

};
let result = run_program(&program, "run_test", &[]).return_value;

assert_eq!(
result,
jit_enum!(
0,
jit_struct!(JitValue::Array(vec![
JitValue::Felt252(1.into()),
JitValue::Felt252(2.into()),
JitValue::Felt252(3.into()),
]))
),
);
}
}

0 comments on commit 905d695

Please sign in to comment.