diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 72c0339..21a0d12 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -13,12 +13,12 @@ use concrete_session::Session; use melior::{ dialect::{ arith::{self, CmpiPredicate}, - func, + func, memref, }, ir::{ attribute::{FlatSymbolRefAttribute, IntegerAttribute, StringAttribute, TypeAttribute}, - r#type::{FunctionType, IntegerType}, - Block, Location, Module as MeliorModule, Region, Type, Value, + r#type::{FunctionType, IntegerType, MemRefType}, + Block, Location, Module as MeliorModule, Region, Type, Value, ValueLike, }, Context as MeliorContext, }; @@ -38,6 +38,8 @@ pub fn compile_program( #[derive(Debug, Clone)] pub struct LocalVar<'c, 'op> { pub type_spec: TypeSpec, + // If it's none its on a register, otherwise allocated on the stack. + pub memref_type: Option>, pub value: Value<'c, 'op>, } @@ -152,6 +154,7 @@ fn compile_function_def<'c, 'op>( LocalVar { type_spec: param.r#type.clone(), value: fn_block.argument(i)?.into(), + memref_type: None, }, ); } @@ -216,11 +219,37 @@ fn compile_let_stmt<'c, 'op>( Some(r#type), )?; + let location = get_location(context, session, &name.span); + + let memref_type = MemRefType::new(value.r#type(), &[1], None, None); + + let alloca: Value = block + .append_operation(memref::alloca( + context, + memref_type, + &[], + &[], + None, + location, + )) + .result(0)? + .into(); + let k0 = block + .append_operation(arith::constant( + context, + IntegerAttribute::new(0, Type::index(context)).into(), + location, + )) + .result(0)? + .into(); + block.append_operation(memref::store(value, alloca, &[k0], location)); + compiler_ctx.locals.insert( name.name.clone(), LocalVar { type_spec: r#type.clone(), - value, + memref_type: Some(memref_type), + value: alloca, }, ); @@ -244,6 +273,14 @@ fn compile_assign_stmt<'c, 'op>( .get(&info.target.first.name) .expect("local should exist") .clone(); + + assert!( + local.memref_type.is_some(), + "can only mutate local stack variables" + ); + + let location = get_location(context, session, &info.target.first.span); + let value = compile_expression( session, context, @@ -252,13 +289,16 @@ fn compile_assign_stmt<'c, 'op>( &info.value, Some(&local.type_spec), )?; - compiler_ctx.locals.insert( - info.target.first.name.clone(), - LocalVar { - type_spec: local.type_spec, - value, - }, - ); + + let k0 = block + .append_operation(arith::constant( + context, + IntegerAttribute::new(0, Type::index(context)).into(), + location, + )) + .result(0)? + .into(); + block.append_operation(memref::store(value, local.value, &[k0], location)); Ok(()) } @@ -316,7 +356,9 @@ fn compile_expression<'c, 'op>( } SimpleExpr::ConstFloat(_) => todo!(), SimpleExpr::ConstStr(_) => todo!(), - SimpleExpr::Path(value) => compile_path_op(context, compiler_ctx, block, value), + SimpleExpr::Path(value) => { + compile_path_op(session, context, compiler_ctx, block, value) + } }, Expression::FnCall(value) => { let mut args = Vec::with_capacity(value.args.len()); @@ -464,16 +506,37 @@ fn compile_expression<'c, 'op>( } fn compile_path_op<'c, 'op>( - _context: &MeliorContext, + session: &Session, + context: &'c MeliorContext, compiler_ctx: &mut CompilerContext<'c, 'op>, - _block: &Block<'c>, + block: &'op Block<'c>, path: &PathOp, ) -> Result, Box> { // For now only simple variables work. // TODO: implement properly, this requires having structs implemented. - Ok(compiler_ctx + + let local = compiler_ctx .locals .get(&path.first.name) - .map(|x| x.value) - .expect("local not found")) + .expect("local not found"); + + let location = get_location(context, session, &path.first.span); + + if let Some(_memref_type) = local.memref_type { + let k0 = block + .append_operation(arith::constant( + context, + IntegerAttribute::new(0, Type::index(context)).into(), + location, + )) + .result(0)? + .into(); + let value = block + .append_operation(memref::load(local.value, &[k0], location)) + .result(0)? + .into(); + Ok(value) + } else { + Ok(local.value) + } }