Skip to content

Commit

Permalink
Fix snapshot_take for arrays. (#716)
Browse files Browse the repository at this point in the history
* Fix `snapshot_take` for arrays.

* Add test.

* Fix clippy issues.
  • Loading branch information
azteca1998 authored Jul 4, 2024
1 parent 218131f commit 00bb1f3
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 15 deletions.
13 changes: 9 additions & 4 deletions src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,6 @@ fn parse_result(
((ret_registers[1] as u128) << 64) | ret_registers[0] as u128,
)),
},
CoreTypeConcrete::Uint128MulGuarantee(_) => todo!(),
CoreTypeConcrete::Sint8(_) => match return_ptr {
Some(return_ptr) => Ok(JitValue::Sint8(unsafe { *return_ptr.cast().as_ref() })),
None => Ok(JitValue::Sint8(ret_registers[0] as i8)),
Expand Down Expand Up @@ -845,6 +844,10 @@ fn parse_result(
Ok(value)
},

CoreTypeConcrete::Snapshot(info) => {
parse_result(&info.ty, registry, return_ptr, ret_registers)
}

// Builtins are handled before the call to parse_result
// and should not be reached here.
CoreTypeConcrete::Bitwise(_)
Expand All @@ -855,14 +858,16 @@ fn parse_result(
| CoreTypeConcrete::RangeCheck(_)
| CoreTypeConcrete::Pedersen(_)
| CoreTypeConcrete::Poseidon(_)
| CoreTypeConcrete::SegmentArena(_) => unreachable!(),
| CoreTypeConcrete::SegmentArena(_)
| CoreTypeConcrete::StarkNet(StarkNetTypeConcrete::System(_)) => unreachable!(),

CoreTypeConcrete::Felt252DictEntry(_)
| CoreTypeConcrete::Span(_)
| CoreTypeConcrete::Snapshot(_)
| CoreTypeConcrete::BoundedInt(_)
| CoreTypeConcrete::Uninitialized(_)
| CoreTypeConcrete::Coupon(_)
| CoreTypeConcrete::StarkNet(_) => todo!(),
| CoreTypeConcrete::StarkNet(_)
| CoreTypeConcrete::Uint128MulGuarantee(_) => todo!(),
}
}

Expand Down
117 changes: 106 additions & 11 deletions src/types/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ use cairo_lang_sierra::{
},
program_registry::ProgramRegistry,
};
use melior::dialect::scf;
use melior::ir::Region;
use melior::{
dialect::{
arith, cf,
Expand All @@ -46,6 +48,7 @@ use melior::{
},
Context,
};
use std::cell::Cell;

/// Build the MLIR type.
///
Expand Down Expand Up @@ -199,20 +202,73 @@ fn snapshot_take<'ctx, 'this>(

match elem_snapshot_take {
Some(elem_snapshot_take) => {
let value = block_realloc.load(context, location, src_ptr, elem_ty)?;
let k0 = block_realloc.const_int(context, location, 0, 64)?;
block_realloc.append_operation(scf::r#for(
k0,
dst_len_bytes,
elem_stride,
{
let region = Region::new();
let block = region.append_block(Block::new(&[(
IntegerType::new(context, 64).into(),
location,
)]));

let (block_relloc, value) = elem_snapshot_take(
context,
registry,
block_realloc,
let i = block.argument(0)?.into();
block.append_operation(scf::execute_region(
&[],
{
let region = Region::new();
let block = region.append_block(Block::new(&[]));

let src_ptr =
block.append_op_result(llvm::get_element_ptr_dynamic(
context,
src_ptr,
&[i],
IntegerType::new(context, 8).into(),
llvm::r#type::pointer(context, 0),
location,
))?;
let dst_ptr =
block.append_op_result(llvm::get_element_ptr_dynamic(
context,
dst_ptr,
&[i],
IntegerType::new(context, 8).into(),
llvm::r#type::pointer(context, 0),
location,
))?;

let helper = LibfuncHelper {
module: helper.module,
init_block: helper.init_block,
region: &region,
blocks_arena: helper.blocks_arena,
last_block: Cell::new(&block),
branches: Vec::new(),
results: Vec::new(),
};

let value = block.load(context, location, src_ptr, elem_ty)?;
let (block, value) = elem_snapshot_take(
context, registry, &block, location, &helper, metadata, value,
)?;
block.store(context, location, dst_ptr, value)?;

block.append_operation(scf::r#yield(&[], location));
region
},
location,
));

block.append_operation(scf::r#yield(&[], location));
region
},
location,
helper,
metadata,
value,
)?;
));

block_relloc.store(context, location, dst_ptr, value)?;
block_relloc.append_operation(cf::br(block_finish, &[dst_ptr], location));
block_realloc.append_operation(cf::br(block_finish, &[dst_ptr], location));
}
None => {
block_realloc.append_operation(
Expand Down Expand Up @@ -242,3 +298,42 @@ fn snapshot_take<'ctx, 'this>(

Ok((block_finish, dst_value))
}

#[cfg(test)]
mod test {
use crate::{
utils::test::{load_cairo, run_program},
values::JitValue,
};
use pretty_assertions_sorted::assert_eq;

#[test]
fn test_array_snapshot_deep_clone() {
let program = load_cairo! {
fn run_test() -> @Array<Array<felt252>> {
let mut inputs: Array<Array<felt252>> = ArrayTrait::new();
inputs.append(array![1, 2, 3]);
inputs.append(array![4, 5, 6]);

@inputs
}
};
let result = run_program(&program, "run_test", &[]).return_value;

assert_eq!(
result,
JitValue::Array(vec![
JitValue::Array(vec![
JitValue::Felt252(1.into()),
JitValue::Felt252(2.into()),
JitValue::Felt252(3.into()),
]),
JitValue::Array(vec![
JitValue::Felt252(4.into()),
JitValue::Felt252(5.into()),
JitValue::Felt252(6.into()),
]),
]),
);
}
}

0 comments on commit 00bb1f3

Please sign in to comment.