Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Array Codegen #119

Merged
merged 11 commits into from
May 7, 2024
103 changes: 96 additions & 7 deletions crates/concrete_codegen_mlir/src/codegen.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::collections::HashMap;

use concrete_ir::{
BinOp, DefId, FnBody, LocalKind, ModuleBody, Operand, Place, PlaceElem, ProgramBody, Rvalue,
Span, Ty, TyKind, ValueTree,
BinOp, ConstValue, DefId, FnBody, LocalKind, ModuleBody, Operand, Place, PlaceElem,
ProgramBody, Rvalue, Span, Ty, TyKind, ValueTree,
};
use concrete_session::Session;
use melior::{
Expand Down Expand Up @@ -1012,8 +1012,49 @@ fn compile_store_place<'c: 'b, 'b>(
_ => unreachable!(),
}
}
PlaceElem::Index(_) => todo!(),
PlaceElem::ConstantIndex(_) => todo!(),
PlaceElem::Index(local) => {
local_ty = match local_ty.kind {
TyKind::Array(inner, _) => *inner,
_ => unreachable!(),
};

let place = Place {
local: *local,
projection: vec![],
};

let (index, _) = compile_load_place(ctx, block, &place, locals)?;

ptr = block
.append_operation(llvm::get_element_ptr_dynamic(
ctx.context(),
ptr,
&[index],
compile_type(ctx.module_ctx, &local_ty),
opaque_pointer(ctx.context()),
Location::unknown(ctx.context()),
))
.result(0)?
.into();
}
PlaceElem::ConstantIndex(index) => {
local_ty = match local_ty.kind {
TyKind::Array(inner, _) => *inner,
_ => unreachable!(),
};

ptr = block
.append_operation(llvm::get_element_ptr(
ctx.context(),
ptr,
DenseI32ArrayAttribute::new(ctx.context(), &[(*index).try_into().unwrap()]),
compile_type(ctx.module_ctx, &local_ty),
opaque_pointer(ctx.context()),
Location::unknown(ctx.context()),
))
.result(0)?
.into();
}
}
}

Expand Down Expand Up @@ -1084,8 +1125,48 @@ fn compile_load_place<'c: 'b, 'b>(
_ => unreachable!(),
}
}
PlaceElem::Index(_) => todo!(),
PlaceElem::ConstantIndex(_) => todo!(),
PlaceElem::Index(local) => {
local_ty = match local_ty.kind {
TyKind::Array(inner, _) => *inner,
_ => unreachable!(),
};

let place = Place {
local: *local,
projection: Default::default(),
};

let (index, _) = compile_load_place(ctx, block, &place, locals)?;

ptr = block
.append_operation(llvm::get_element_ptr_dynamic(
ctx.context(),
ptr,
&[index],
compile_type(ctx.module_ctx, &local_ty),
opaque_pointer(ctx.context()),
Location::unknown(ctx.context()),
))
.result(0)?
.into();
}
PlaceElem::ConstantIndex(index) => {
local_ty = match local_ty.kind {
TyKind::Array(inner, _) => *inner,
_ => unreachable!(),
};
ptr = block
.append_operation(llvm::get_element_ptr(
ctx.context(),
ptr,
DenseI32ArrayAttribute::new(ctx.context(), &[(*index).try_into().unwrap()]),
compile_type(ctx.module_ctx, &local_ty),
opaque_pointer(ctx.context()),
Location::unknown(ctx.context()),
))
.result(0)?
.into();
}
}
}

Expand Down Expand Up @@ -1272,7 +1353,15 @@ fn compile_type<'c>(ctx: ModuleCodegenCtx<'c>, ty: &Ty) -> Type<'c> {
concrete_ir::FloatTy::F64 => Type::float64(ctx.ctx.mlir_context),
},
concrete_ir::TyKind::String => todo!(),
concrete_ir::TyKind::Array(_, _) => todo!(),
concrete_ir::TyKind::Array(inner_type, length) => {
let inner_type = compile_type(ctx, inner_type);
let length = match length.data {
concrete_ir::ConstKind::Value(ValueTree::Leaf(ConstValue::U64(length))) => length,
_ => unimplemented!(),
};

melior::dialect::llvm::r#type::array(inner_type, length as u32)
}
concrete_ir::TyKind::Ref(_inner_ty, _) | concrete_ir::TyKind::Ptr(_inner_ty, _) => {
llvm::r#type::opaque_pointer(ctx.ctx.mlir_context)
}
Expand Down
1 change: 1 addition & 0 deletions crates/concrete_driver/tests/examples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mod common;
#[test_case(include_str!("../../../examples/if_if_false.con"), "if_if_false", false, 7 ; "if_if_false.con")]
#[test_case(include_str!("../../../examples/for.con"), "for", false, 10 ; "for.con")]
#[test_case(include_str!("../../../examples/for_while.con"), "for_while", false, 10 ; "for_while.con")]
#[test_case(include_str!("../../../examples/arrays.con"), "arrays", false, 5 ; "arrays.con")]
fn example_tests(source: &str, name: &str, is_library: bool, status_code: i32) {
assert_eq!(
status_code,
Expand Down
7 changes: 5 additions & 2 deletions examples/arrays.con
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
mod Example {
fn main() -> i32 {
let array: [i32; 4] = [1, 2, 3, 4];
let nested_array: [[i32; 2]; 2] = [[1, 2], [3, 4]];
let mut array: [i32; 4] = [1, 9, 3, 4];
let nested_array: [[i32; 2]; 2] = [[1, 2], [9, 9]];

array[1] = 2;
nested_array[1] = [3, 4];

let a: i32 = array[1];
let b: i32 = nested_array[1][0];
Expand Down
Loading