From 905d6959662edfec4829568f995206fcb8bf64d0 Mon Sep 17 00:00:00 2001 From: Edgar Date: Fri, 5 Jul 2024 15:10:12 +0200 Subject: [PATCH] Fix Nullable snapshot segfault (#718) * make nullable deep copy too * format * try * fix * clippy * add test --- src/types/nullable.rs | 129 +++++++++++++++++++++++++++++++----------- 1 file changed, 95 insertions(+), 34 deletions(-) diff --git a/src/types/nullable.rs b/src/types/nullable.rs index 658ebae89..62955766b 100644 --- a/src/types/nullable.rs +++ b/src/types/nullable.rs @@ -21,15 +21,13 @@ use cairo_lang_sierra::{ }, program_registry::ProgramRegistry, }; +use melior::dialect::cf; use melior::{ dialect::{ llvm::{self, r#type::pointer}, - ods, scf, - }, - ir::{ - attribute::IntegerAttribute, r#type::IntegerType, Block, Location, Module, Region, Type, - Value, + ods, }, + ir::{attribute::IntegerAttribute, r#type::IntegerType, Block, Location, Module, Type, Value}, Context, }; @@ -73,7 +71,13 @@ fn snapshot_take<'ctx, 'this>( metadata.insert(ReallocBindingsMeta::new(context, helper)); } - let elem_layout = registry.get_type(&info.ty)?.layout(registry)?; + let inner_snapshot_take = metadata + .get::() + .and_then(|meta| meta.wrap_invoke(&info.ty)); + + let inner_type = registry.get_type(&info.ty)?; + let inner_layout = inner_type.layout(registry)?; + let inner_ty = inner_type.build(context, helper, registry, metadata, info.self_ty())?; let null_ptr = entry .append_op_result(ods::llvm::mlir_zero(context, pointer(context, 0), location).into())?; @@ -90,46 +94,103 @@ fn snapshot_take<'ctx, 'this>( .into(), )?; - let value = entry - .append_operation(scf::r#if( - is_null, - &[llvm::r#type::pointer(context, 0)], - { - let region = Region::new(); - let block = region.append_block(Block::new(&[])); + let mut block_not_null = helper.append_block(Block::new(&[])); + let block_finish = helper.append_block(Block::new(&[(pointer(context, 0), location)])); - block.append_operation(scf::r#yield(&[null_ptr], location)); - region - }, - { - let region = Region::new(); - let block = region.append_block(Block::new(&[])); + entry.append_operation(cf::cond_br( + context, + is_null, + block_finish, + block_not_null, + &[null_ptr], + &[], + location, + )); + + { + let value_len = + block_not_null.const_int(context, location, inner_layout.pad_to_align().size(), 64)?; + + let dst_ptr = block_not_null.append_op_result(ReallocBindingsMeta::realloc( + context, null_ptr, value_len, location, + ))?; - let alloc_len = block.const_int(context, location, elem_layout.size(), 64)?; + match inner_snapshot_take { + Some(inner_snapshot_take) => { + let value = block_not_null.load(context, location, src_value, inner_ty)?; - let cloned_ptr = block.append_op_result(ReallocBindingsMeta::realloc( - context, null_ptr, alloc_len, location, - ))?; + let (next_block, value) = inner_snapshot_take( + context, + registry, + block_not_null, + location, + helper, + metadata, + value, + )?; + block_not_null = next_block; - block.append_operation( + block_not_null.store(context, location, dst_ptr, value)?; + } + None => { + block_not_null.append_operation( ods::llvm::intr_memcpy( context, - cloned_ptr, + dst_ptr, src_value, - alloc_len, + value_len, IntegerAttribute::new(IntegerType::new(context, 1).into(), 0), location, ) .into(), ); + } + } + block_not_null.append_operation(cf::br(block_finish, &[dst_ptr], location)); + } - block.append_operation(scf::r#yield(&[cloned_ptr], location)); - region - }, - location, - )) - .result(0)? - .into(); + let value = block_finish.argument(0)?.into(); + + Ok((block_finish, value)) +} + +#[cfg(test)] +mod test { + use crate::{ + utils::test::{jit_enum, jit_struct, load_cairo, run_program}, + values::JitValue, + }; + use pretty_assertions_sorted::assert_eq; + + #[test] + fn test_nullable_deep_clone() { + let program = load_cairo! { + use core::array::ArrayTrait; + use core::NullableTrait; + + fn run_test() -> @Nullable> { + let mut x = NullableTrait::new(array![1, 2, 3]); + let x_s = @x; + + let mut y = NullableTrait::deref(x); + y.append(4); - Ok((entry, value)) + x_s + } + + }; + let result = run_program(&program, "run_test", &[]).return_value; + + assert_eq!( + result, + jit_enum!( + 0, + jit_struct!(JitValue::Array(vec![ + JitValue::Felt252(1.into()), + JitValue::Felt252(2.into()), + JitValue::Felt252(3.into()), + ])) + ), + ); + } }