diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 8e5f8c0..cb1e1d4 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -1,8 +1,8 @@ use std::collections::HashMap; use concrete_ir::{ - BinOp, DefId, FnBody, LocalKind, ModuleBody, Operand, Place, PlaceElem, ProgramBody, Rvalue, - Span, Ty, TyKind, ValueTree, + BinOp, ConstValue, DefId, FnBody, LocalKind, ModuleBody, Operand, Place, PlaceElem, + ProgramBody, Rvalue, Span, Ty, TyKind, ValueTree, }; use concrete_session::Session; use melior::{ @@ -1012,8 +1012,49 @@ fn compile_store_place<'c: 'b, 'b>( _ => unreachable!(), } } - PlaceElem::Index(_) => todo!(), - PlaceElem::ConstantIndex(_) => todo!(), + PlaceElem::Index(local) => { + local_ty = match local_ty.kind { + TyKind::Array(inner, _) => *inner, + _ => unreachable!(), + }; + + let place = Place { + local: *local, + projection: vec![], + }; + + let (index, _) = compile_load_place(ctx, block, &place, locals)?; + + ptr = block + .append_operation(llvm::get_element_ptr_dynamic( + ctx.context(), + ptr, + &[index], + compile_type(ctx.module_ctx, &local_ty), + opaque_pointer(ctx.context()), + Location::unknown(ctx.context()), + )) + .result(0)? + .into(); + } + PlaceElem::ConstantIndex(index) => { + local_ty = match local_ty.kind { + TyKind::Array(inner, _) => *inner, + _ => unreachable!(), + }; + + ptr = block + .append_operation(llvm::get_element_ptr( + ctx.context(), + ptr, + DenseI32ArrayAttribute::new(ctx.context(), &[(*index).try_into().unwrap()]), + compile_type(ctx.module_ctx, &local_ty), + opaque_pointer(ctx.context()), + Location::unknown(ctx.context()), + )) + .result(0)? + .into(); + } } } @@ -1084,8 +1125,48 @@ fn compile_load_place<'c: 'b, 'b>( _ => unreachable!(), } } - PlaceElem::Index(_) => todo!(), - PlaceElem::ConstantIndex(_) => todo!(), + PlaceElem::Index(local) => { + local_ty = match local_ty.kind { + TyKind::Array(inner, _) => *inner, + _ => unreachable!(), + }; + + let place = Place { + local: *local, + projection: Default::default(), + }; + + let (index, _) = compile_load_place(ctx, block, &place, locals)?; + + ptr = block + .append_operation(llvm::get_element_ptr_dynamic( + ctx.context(), + ptr, + &[index], + compile_type(ctx.module_ctx, &local_ty), + opaque_pointer(ctx.context()), + Location::unknown(ctx.context()), + )) + .result(0)? + .into(); + } + PlaceElem::ConstantIndex(index) => { + local_ty = match local_ty.kind { + TyKind::Array(inner, _) => *inner, + _ => unreachable!(), + }; + ptr = block + .append_operation(llvm::get_element_ptr( + ctx.context(), + ptr, + DenseI32ArrayAttribute::new(ctx.context(), &[(*index).try_into().unwrap()]), + compile_type(ctx.module_ctx, &local_ty), + opaque_pointer(ctx.context()), + Location::unknown(ctx.context()), + )) + .result(0)? + .into(); + } } } @@ -1272,7 +1353,15 @@ fn compile_type<'c>(ctx: ModuleCodegenCtx<'c>, ty: &Ty) -> Type<'c> { concrete_ir::FloatTy::F64 => Type::float64(ctx.ctx.mlir_context), }, concrete_ir::TyKind::String => todo!(), - concrete_ir::TyKind::Array(_, _) => todo!(), + concrete_ir::TyKind::Array(inner_type, length) => { + let inner_type = compile_type(ctx, inner_type); + let length = match length.data { + concrete_ir::ConstKind::Value(ValueTree::Leaf(ConstValue::U64(length))) => length, + _ => unimplemented!(), + }; + + melior::dialect::llvm::r#type::array(inner_type, length as u32) + } concrete_ir::TyKind::Ref(_inner_ty, _) | concrete_ir::TyKind::Ptr(_inner_ty, _) => { llvm::r#type::opaque_pointer(ctx.ctx.mlir_context) } diff --git a/crates/concrete_driver/tests/examples.rs b/crates/concrete_driver/tests/examples.rs index 8ca3e6c..455f58e 100644 --- a/crates/concrete_driver/tests/examples.rs +++ b/crates/concrete_driver/tests/examples.rs @@ -20,6 +20,7 @@ mod common; #[test_case(include_str!("../../../examples/if_if_false.con"), "if_if_false", false, 7 ; "if_if_false.con")] #[test_case(include_str!("../../../examples/for.con"), "for", false, 10 ; "for.con")] #[test_case(include_str!("../../../examples/for_while.con"), "for_while", false, 10 ; "for_while.con")] +#[test_case(include_str!("../../../examples/arrays.con"), "arrays", false, 5 ; "arrays.con")] fn example_tests(source: &str, name: &str, is_library: bool, status_code: i32) { assert_eq!( status_code, diff --git a/examples/arrays.con b/examples/arrays.con index 9721fd7..52ade15 100644 --- a/examples/arrays.con +++ b/examples/arrays.con @@ -1,7 +1,10 @@ mod Example { fn main() -> i32 { - let array: [i32; 4] = [1, 2, 3, 4]; - let nested_array: [[i32; 2]; 2] = [[1, 2], [3, 4]]; + let mut array: [i32; 4] = [1, 9, 3, 4]; + let nested_array: [[i32; 2]; 2] = [[1, 2], [9, 9]]; + + array[1] = 2; + nested_array[1] = [3, 4]; let a: i32 = array[1]; let b: i32 = nested_array[1][0];