diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 1cb27c2..81716e9 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -8,7 +8,7 @@ use concrete_ast::{ }, functions::FunctionDef, modules::{Module, ModuleDefItem}, - statements::{AssignStmt, LetStmt, LetStmtTarget, ReturnStmt, Statement}, + statements::{AssignStmt, LetStmt, LetStmtTarget, ReturnStmt, Statement, WhileStmt}, types::TypeSpec, Program, }; @@ -267,7 +267,9 @@ fn compile_statement<'c, 'this: 'c>( Statement::Return(info) => { compile_return_stmt(session, context, scope_ctx, helper, block, info)? } - Statement::While(_) => todo!(), + Statement::While(info) => { + return compile_while(session, context, scope_ctx, helper, block, info); + } Statement::FnCall(info) => { compile_fn_call(session, context, scope_ctx, helper, block, info)?; } @@ -304,12 +306,7 @@ fn compile_if_expr<'c, 'this: 'c>( helper, block, &info.value, - Some(&TypeSpec::Simple { - name: Ident { - name: "bool".to_string(), - span: Span::new(0, 0), - }, - }), + None, )?; let then_successor = helper.append_block(Block::new(&[])); @@ -415,6 +412,70 @@ fn compile_if_expr<'c, 'this: 'c>( }) } +fn compile_while<'c, 'this: 'c>( + session: &Session, + context: &'c MeliorContext, + scope_ctx: &mut ScopeContext<'c, 'this>, + helper: &BlockHelper<'c, 'this>, + block: &'this BlockRef<'c, 'this>, + info: &WhileStmt, +) -> Result>, Box> { + let location = Location::unknown(context); + + let check_block = helper.append_block(Block::new(&[])); + + block.append_operation(cf::br(check_block, &[], location)); + + let body_block = helper.append_block(Block::new(&[])); + let merge_block = helper.append_block(Block::new(&[])); + + let condition = compile_expression( + session, + context, + scope_ctx, + helper, + check_block, + &info.value, + None, + )?; + + check_block.append_operation(cf::cond_br( + context, + condition, + body_block, + merge_block, + &[], + &[], + location, + )); + + let mut body_block = Some(body_block); + + { + let mut body_scope_ctx = scope_ctx.clone(); + for stmt in &info.contents { + if let Some(then_successor_block) = body_block { + body_block = compile_statement( + session, + context, + &mut body_scope_ctx, + helper, + then_successor_block, + stmt, + )?; + } + } + } + + match body_block { + Some(body_block) => { + body_block.append_operation(cf::br(check_block, &[], location)); + Ok(Some(merge_block)) + } + None => Ok(Some(merge_block)), + } +} + fn compile_let_stmt<'ctx, 'parent: 'ctx>( session: &Session, context: &'ctx MeliorContext, diff --git a/examples/while.con b/examples/while.con new file mode 100644 index 0000000..6ab6742 --- /dev/null +++ b/examples/while.con @@ -0,0 +1,17 @@ +mod Simple { + fn main() -> i64 { + return my_func(4); + } + + fn my_func(times: i64) -> i64 { + let mut n: i64 = times; + let mut result: i64 = 1; + + while n > 0 { + result = result + result; + n = n - 1; + } + + return result; + } +}