diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 42c07f2..2eb1fff 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -1,4 +1,4 @@ -use std::{cell::Cell, collections::HashMap, error::Error}; +use std::{collections::HashMap, error::Error}; use bumpalo::Bump; use concrete_ast::{ @@ -38,11 +38,29 @@ pub fn compile_program( } #[derive(Debug, Clone)] -pub struct LocalVar<'c, 'op> { +pub struct LocalVar<'ctx, 'parent: 'ctx> { pub type_spec: TypeSpec, // If it's none its on a register, otherwise allocated on the stack. - pub memref_type: Option>, - pub value: Value<'c, 'op>, + pub alloca: bool, + pub value: Value<'ctx, 'parent>, +} + +impl<'ctx, 'parent: 'ctx> LocalVar<'ctx, 'parent> { + pub fn param(value: Value<'ctx, 'parent>, type_spec: TypeSpec) -> Self { + Self { + value, + type_spec, + alloca: false, + } + } + + pub fn alloca(value: Value<'ctx, 'parent>, type_spec: TypeSpec) -> Self { + Self { + value, + type_spec, + alloca: true, + } + } } #[derive(Debug, Clone)] @@ -54,17 +72,13 @@ struct CompilerContext<'c, 'this: 'c> { struct BlockHelper<'ctx, 'this: 'ctx> { region: &'this Region<'ctx>, blocks_arena: &'this Bump, - last_block: Cell<&'this BlockRef<'ctx, 'this>>, } impl<'ctx, 'this> BlockHelper<'ctx, 'this> { pub fn append_block(&self, block: Block<'ctx>) -> &'this BlockRef<'ctx, 'this> { - let block = self - .region - .insert_block_after(*self.last_block.get(), block); + let block = self.region.append_block(block); let block_ref: &'this mut BlockRef<'ctx, 'this> = self.blocks_arena.alloc(block); - self.last_block.set(block_ref); block_ref } @@ -81,6 +95,8 @@ impl<'c, 'this> CompilerContext<'c, 'this> { "u32" | "i32" => IntegerType::new(context, 32).into(), "u16" | "i16" => IntegerType::new(context, 16).into(), "u8" | "i8" => IntegerType::new(context, 8).into(), + "f32" => Type::float32(context), + "f64" => Type::float64(context), "bool" => IntegerType::new(context, 1).into(), _ => todo!("custom type lookup"), }) @@ -155,6 +171,7 @@ fn compile_function_def<'c, 'this: 'c>( compiler_ctx: &mut CompilerContext<'c, 'this>, info: &FunctionDef, ) -> Result, Box> { + tracing::debug!("compiling function {:?}", info.decl.name.name); let location = get_location(context, session, &info.decl.name.span); // Setup function arguments @@ -188,18 +205,13 @@ fn compile_function_def<'c, 'this: 'c>( let helper = BlockHelper { region: ®ion, blocks_arena: &blocks_arena, - last_block: Cell::new(fn_block), }; // Push arguments into locals for (i, param) in info.decl.params.iter().enumerate() { fn_compiler_ctx.locals.insert( param.name.name.clone(), - LocalVar { - type_spec: param.r#type.clone(), - value: fn_block.argument(i)?.into(), - memref_type: None, - }, + LocalVar::param(fn_block.argument(i)?.into(), param.r#type.clone()), ); } @@ -443,14 +455,9 @@ fn compile_let_stmt<'c, 'this: 'c>( .into(); block.append_operation(memref::store(value, alloca, &[k0], location)); - compiler_ctx.locals.insert( - name.name.clone(), - LocalVar { - type_spec: r#type.clone(), - memref_type: Some(memref_type), - value: alloca, - }, - ); + compiler_ctx + .locals + .insert(name.name.clone(), LocalVar::alloca(alloca, r#type.clone())); Ok(()) } @@ -474,10 +481,7 @@ fn compile_assign_stmt<'c, 'this: 'c>( .expect("local should exist") .clone(); - assert!( - local.memref_type.is_some(), - "can only mutate local stack variables" - ); + assert!(local.alloca, "can only mutate local stack variables"); let location = get_location(context, session, &info.target.first.span); @@ -529,7 +533,7 @@ fn compile_expression<'c, 'this: 'c>( session: &Session, context: &'c MeliorContext, compiler_ctx: &mut CompilerContext<'c, 'this>, - helper: &BlockHelper<'c, 'this>, + _helper: &BlockHelper<'c, 'this>, block: &'this Block<'c>, info: &Expression, type_info: Option<&TypeSpec>, @@ -568,7 +572,7 @@ fn compile_expression<'c, 'this: 'c>( SimpleExpr::ConstFloat(_) => todo!(), SimpleExpr::ConstStr(_) => todo!(), SimpleExpr::Path(value) => { - compile_path_op(session, context, compiler_ctx, helper, block, value) + compile_path_op(session, context, compiler_ctx, block, value) } }, Expression::FnCall(value) => { @@ -592,7 +596,7 @@ fn compile_expression<'c, 'this: 'c>( session, context, compiler_ctx, - helper, + _helper, block, arg, Some(&arg_info.r#type), @@ -625,7 +629,7 @@ fn compile_expression<'c, 'this: 'c>( session, context, compiler_ctx, - helper, + _helper, block, lhs, type_info, @@ -634,7 +638,7 @@ fn compile_expression<'c, 'this: 'c>( session, context, compiler_ctx, - helper, + _helper, block, rhs, type_info, @@ -737,7 +741,6 @@ fn compile_path_op<'c, 'this: 'c>( session: &Session, context: &'c MeliorContext, compiler_ctx: &mut CompilerContext<'c, 'this>, - helper: &BlockHelper<'c, 'this>, block: &'this Block<'c>, path: &PathOp, ) -> Result, Box> { @@ -751,7 +754,7 @@ fn compile_path_op<'c, 'this: 'c>( let location = get_location(context, session, &path.first.span); - if let Some(_memref_type) = local.memref_type { + if local.alloca { let k0 = block .append_operation(arith::constant( context, diff --git a/crates/concrete_codegen_mlir/src/context.rs b/crates/concrete_codegen_mlir/src/context.rs index 1f1cabf..70961fd 100644 --- a/crates/concrete_codegen_mlir/src/context.rs +++ b/crates/concrete_codegen_mlir/src/context.rs @@ -1,7 +1,7 @@ use std::error::Error; use concrete_ast::Program; -use concrete_session::{config::DebugInfo, Session}; +use concrete_session::Session; use melior::{ dialect::DialectRegistry, ir::{operation::OperationPrintingFlags, Location, Module as MeliorModule},