From 799df9bbad9bba8728c95f24f6d8e6145dcfd652 Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Mon, 15 Jan 2024 21:03:39 +0100 Subject: [PATCH] make code simpler --- crates/concrete_codegen_mlir/src/codegen.rs | 162 +++++++------------- 1 file changed, 55 insertions(+), 107 deletions(-) diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 53662cc..c3a9415 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -229,7 +229,7 @@ fn compile_function_def<'ctx, 'parent: 'ctx>( { let mut scope_ctx = scope_ctx.clone(); - let fn_block = ®ion.append_block(Block::new(&args)); + let mut fn_block = ®ion.append_block(Block::new(&args)); let blocks_arena = Bump::new(); let helper = BlockHelper { @@ -245,13 +245,9 @@ fn compile_function_def<'ctx, 'parent: 'ctx>( ); } - let mut fn_block = Some(fn_block); - for stmt in &info.body { - if let Some(block) = fn_block { - fn_block = - compile_statement(session, context, &mut scope_ctx, &helper, block, stmt)?; - } + fn_block = + compile_statement(session, context, &mut scope_ctx, &helper, fn_block, stmt)?; } } @@ -270,9 +266,9 @@ fn compile_statement<'c, 'this: 'c>( context: &'c MeliorContext, scope_ctx: &mut ScopeContext<'c, 'this>, helper: &BlockHelper<'c, 'this>, - block: &'this BlockRef<'c, 'this>, + mut block: &'this BlockRef<'c, 'this>, info: &Statement, -) -> Result>, Box> { +) -> Result<&'this BlockRef<'c, 'this>, Box> { match info { Statement::Assign(info) => { compile_assign_stmt(session, context, scope_ctx, helper, block, info)? @@ -280,21 +276,21 @@ fn compile_statement<'c, 'this: 'c>( Statement::Match(_) => todo!(), Statement::For(_) => todo!(), Statement::If(info) => { - return compile_if_expr(session, context, scope_ctx, helper, block, info); + block = compile_if_expr(session, context, scope_ctx, helper, block, info)?; } Statement::Let(info) => compile_let_stmt(session, context, scope_ctx, helper, block, info)?, Statement::Return(info) => { compile_return_stmt(session, context, scope_ctx, helper, block, info)? } Statement::While(info) => { - return compile_while(session, context, scope_ctx, helper, block, info); + block = compile_while(session, context, scope_ctx, helper, block, info)?; } Statement::FnCall(info) => { compile_fn_call(session, context, scope_ctx, helper, block, info)?; } } - Ok(Some(block)) + Ok(block) } /// Compile a if expression / statement @@ -317,7 +313,7 @@ fn compile_if_expr<'c, 'this: 'c>( helper: &BlockHelper<'c, 'this>, block: &'this BlockRef<'c, 'this>, info: &IfExpr, -) -> Result>, Box> { +) -> Result<&'this BlockRef<'c, 'this>, Box> { let condition = compile_expression( session, context, @@ -328,8 +324,8 @@ fn compile_if_expr<'c, 'this: 'c>( None, )?; - let then_successor = helper.append_block(Block::new(&[])); - let else_successor = helper.append_block(Block::new(&[])); + let mut then_successor = helper.append_block(Block::new(&[])); + let mut else_successor = helper.append_block(Block::new(&[])); block.append_operation(cf::cond_br( context, @@ -341,94 +337,50 @@ fn compile_if_expr<'c, 'this: 'c>( get_named_location(context, "if"), )); - let mut then_successor = Some(then_successor); - let mut else_successor = Some(else_successor); - { let mut then_scope_ctx = scope_ctx.clone(); for stmt in &info.contents { - if let Some(then_successor_block) = then_successor { - then_successor = compile_statement( - session, - context, - &mut then_scope_ctx, - helper, - then_successor_block, - stmt, - )?; - } + then_successor = compile_statement( + session, + context, + &mut then_scope_ctx, + helper, + then_successor, + stmt, + )?; } } if let Some(else_contents) = info.r#else.as_ref() { let mut else_scope_ctx = scope_ctx.clone(); for stmt in else_contents { - if let Some(else_successor_block) = else_successor { - else_successor = compile_statement( - session, - context, - &mut else_scope_ctx, - helper, - else_successor_block, - stmt, - )?; - } + else_successor = compile_statement( + session, + context, + &mut else_scope_ctx, + helper, + else_successor, + stmt, + )?; } } - Ok(match (then_successor, else_successor) { - (None, None) => None, - (None, Some(else_successor)) => { - if else_successor.terminator().is_some() { - None - } else { - let final_block = helper.append_block(Block::new(&[])); - else_successor.append_operation(cf::br( - final_block, - &[], - Location::unknown(context), - )); - Some(final_block) - } - } - (Some(then_successor), None) => { - if then_successor.terminator().is_some() { - None - } else { - let final_block = helper.append_block(Block::new(&[])); - then_successor.append_operation(cf::br( - final_block, - &[], - Location::unknown(context), - )); - Some(final_block) - } - } - (Some(then_successor), Some(else_successor)) => { - if then_successor.terminator().is_some() && else_successor.terminator().is_some() { - None - } else { - let final_block = helper.append_block(Block::new(&[])); - if then_successor.terminator().is_none() { - then_successor.append_operation(cf::br( - final_block, - &[], - Location::unknown(context), - )); - } + // both branches return + if then_successor.terminator().is_some() && else_successor.terminator().is_some() { + return Ok(then_successor); + } - if else_successor.terminator().is_none() { - else_successor.append_operation(cf::br( - final_block, - &[], - Location::unknown(context), - )); - } + let merge_block = helper.append_block(Block::new(&[])); - Some(final_block) - } - } - }) + if then_successor.terminator().is_none() { + then_successor.append_operation(cf::br(merge_block, &[], Location::unknown(context))); + } + + if else_successor.terminator().is_none() { + else_successor.append_operation(cf::br(merge_block, &[], Location::unknown(context))); + } + + Ok(merge_block) } fn compile_while<'c, 'this: 'c>( @@ -438,7 +390,7 @@ fn compile_while<'c, 'this: 'c>( helper: &BlockHelper<'c, 'this>, block: &'this BlockRef<'c, 'this>, info: &WhileStmt, -) -> Result>, Box> { +) -> Result<&'this BlockRef<'c, 'this>, Box> { let location = Location::unknown(context); let check_block = helper.append_block(Block::new(&[])); @@ -468,31 +420,27 @@ fn compile_while<'c, 'this: 'c>( location, )); - let mut body_block = Some(body_block); + let mut body_block = body_block; { let mut body_scope_ctx = scope_ctx.clone(); for stmt in &info.contents { - if let Some(then_successor_block) = body_block { - body_block = compile_statement( - session, - context, - &mut body_scope_ctx, - helper, - then_successor_block, - stmt, - )?; - } + body_block = compile_statement( + session, + context, + &mut body_scope_ctx, + helper, + body_block, + stmt, + )?; } } - match body_block { - Some(body_block) => { - body_block.append_operation(cf::br(check_block, &[], location)); - Ok(Some(merge_block)) - } - None => Ok(Some(merge_block)), + if body_block.terminator().is_none() { + body_block.append_operation(cf::br(check_block, &[], location)); } + + Ok(merge_block) } fn compile_let_stmt<'ctx, 'parent: 'ctx>(