Skip to content

Commit

Permalink
make code simpler
Browse files Browse the repository at this point in the history
  • Loading branch information
edg-l committed Jan 15, 2024
1 parent 9617498 commit 799df9b
Showing 1 changed file with 55 additions and 107 deletions.
162 changes: 55 additions & 107 deletions crates/concrete_codegen_mlir/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ fn compile_function_def<'ctx, 'parent: 'ctx>(

{
let mut scope_ctx = scope_ctx.clone();
let fn_block = &region.append_block(Block::new(&args));
let mut fn_block = &region.append_block(Block::new(&args));

let blocks_arena = Bump::new();
let helper = BlockHelper {
Expand All @@ -245,13 +245,9 @@ fn compile_function_def<'ctx, 'parent: 'ctx>(
);
}

let mut fn_block = Some(fn_block);

for stmt in &info.body {
if let Some(block) = fn_block {
fn_block =
compile_statement(session, context, &mut scope_ctx, &helper, block, stmt)?;
}
fn_block =
compile_statement(session, context, &mut scope_ctx, &helper, fn_block, stmt)?;
}
}

Expand All @@ -270,31 +266,31 @@ fn compile_statement<'c, 'this: 'c>(
context: &'c MeliorContext,
scope_ctx: &mut ScopeContext<'c, 'this>,
helper: &BlockHelper<'c, 'this>,
block: &'this BlockRef<'c, 'this>,
mut block: &'this BlockRef<'c, 'this>,
info: &Statement,
) -> Result<Option<&'this BlockRef<'c, 'this>>, Box<dyn Error>> {
) -> Result<&'this BlockRef<'c, 'this>, Box<dyn Error>> {
match info {
Statement::Assign(info) => {
compile_assign_stmt(session, context, scope_ctx, helper, block, info)?
}
Statement::Match(_) => todo!(),
Statement::For(_) => todo!(),
Statement::If(info) => {
return compile_if_expr(session, context, scope_ctx, helper, block, info);
block = compile_if_expr(session, context, scope_ctx, helper, block, info)?;
}
Statement::Let(info) => compile_let_stmt(session, context, scope_ctx, helper, block, info)?,
Statement::Return(info) => {
compile_return_stmt(session, context, scope_ctx, helper, block, info)?
}
Statement::While(info) => {
return compile_while(session, context, scope_ctx, helper, block, info);
block = compile_while(session, context, scope_ctx, helper, block, info)?;
}
Statement::FnCall(info) => {
compile_fn_call(session, context, scope_ctx, helper, block, info)?;
}
}

Ok(Some(block))
Ok(block)
}

/// Compile a if expression / statement
Expand All @@ -317,7 +313,7 @@ fn compile_if_expr<'c, 'this: 'c>(
helper: &BlockHelper<'c, 'this>,
block: &'this BlockRef<'c, 'this>,
info: &IfExpr,
) -> Result<Option<&'this BlockRef<'c, 'this>>, Box<dyn Error>> {
) -> Result<&'this BlockRef<'c, 'this>, Box<dyn Error>> {
let condition = compile_expression(
session,
context,
Expand All @@ -328,8 +324,8 @@ fn compile_if_expr<'c, 'this: 'c>(
None,
)?;

let then_successor = helper.append_block(Block::new(&[]));
let else_successor = helper.append_block(Block::new(&[]));
let mut then_successor = helper.append_block(Block::new(&[]));
let mut else_successor = helper.append_block(Block::new(&[]));

block.append_operation(cf::cond_br(
context,
Expand All @@ -341,94 +337,50 @@ fn compile_if_expr<'c, 'this: 'c>(
get_named_location(context, "if"),
));

let mut then_successor = Some(then_successor);
let mut else_successor = Some(else_successor);

{
let mut then_scope_ctx = scope_ctx.clone();
for stmt in &info.contents {
if let Some(then_successor_block) = then_successor {
then_successor = compile_statement(
session,
context,
&mut then_scope_ctx,
helper,
then_successor_block,
stmt,
)?;
}
then_successor = compile_statement(
session,
context,
&mut then_scope_ctx,
helper,
then_successor,
stmt,
)?;
}
}

if let Some(else_contents) = info.r#else.as_ref() {
let mut else_scope_ctx = scope_ctx.clone();
for stmt in else_contents {
if let Some(else_successor_block) = else_successor {
else_successor = compile_statement(
session,
context,
&mut else_scope_ctx,
helper,
else_successor_block,
stmt,
)?;
}
else_successor = compile_statement(
session,
context,
&mut else_scope_ctx,
helper,
else_successor,
stmt,
)?;
}
}

Ok(match (then_successor, else_successor) {
(None, None) => None,
(None, Some(else_successor)) => {
if else_successor.terminator().is_some() {
None
} else {
let final_block = helper.append_block(Block::new(&[]));
else_successor.append_operation(cf::br(
final_block,
&[],
Location::unknown(context),
));
Some(final_block)
}
}
(Some(then_successor), None) => {
if then_successor.terminator().is_some() {
None
} else {
let final_block = helper.append_block(Block::new(&[]));
then_successor.append_operation(cf::br(
final_block,
&[],
Location::unknown(context),
));
Some(final_block)
}
}
(Some(then_successor), Some(else_successor)) => {
if then_successor.terminator().is_some() && else_successor.terminator().is_some() {
None
} else {
let final_block = helper.append_block(Block::new(&[]));
if then_successor.terminator().is_none() {
then_successor.append_operation(cf::br(
final_block,
&[],
Location::unknown(context),
));
}
// both branches return
if then_successor.terminator().is_some() && else_successor.terminator().is_some() {
return Ok(then_successor);
}

if else_successor.terminator().is_none() {
else_successor.append_operation(cf::br(
final_block,
&[],
Location::unknown(context),
));
}
let merge_block = helper.append_block(Block::new(&[]));

Some(final_block)
}
}
})
if then_successor.terminator().is_none() {
then_successor.append_operation(cf::br(merge_block, &[], Location::unknown(context)));
}

if else_successor.terminator().is_none() {
else_successor.append_operation(cf::br(merge_block, &[], Location::unknown(context)));
}

Ok(merge_block)
}

fn compile_while<'c, 'this: 'c>(
Expand All @@ -438,7 +390,7 @@ fn compile_while<'c, 'this: 'c>(
helper: &BlockHelper<'c, 'this>,
block: &'this BlockRef<'c, 'this>,
info: &WhileStmt,
) -> Result<Option<&'this BlockRef<'c, 'this>>, Box<dyn Error>> {
) -> Result<&'this BlockRef<'c, 'this>, Box<dyn Error>> {
let location = Location::unknown(context);

let check_block = helper.append_block(Block::new(&[]));
Expand Down Expand Up @@ -468,31 +420,27 @@ fn compile_while<'c, 'this: 'c>(
location,
));

let mut body_block = Some(body_block);
let mut body_block = body_block;

{
let mut body_scope_ctx = scope_ctx.clone();
for stmt in &info.contents {
if let Some(then_successor_block) = body_block {
body_block = compile_statement(
session,
context,
&mut body_scope_ctx,
helper,
then_successor_block,
stmt,
)?;
}
body_block = compile_statement(
session,
context,
&mut body_scope_ctx,
helper,
body_block,
stmt,
)?;
}
}

match body_block {
Some(body_block) => {
body_block.append_operation(cf::br(check_block, &[], location));
Ok(Some(merge_block))
}
None => Ok(Some(merge_block)),
if body_block.terminator().is_none() {
body_block.append_operation(cf::br(check_block, &[], location));
}

Ok(merge_block)
}

fn compile_let_stmt<'ctx, 'parent: 'ctx>(
Expand Down

0 comments on commit 799df9b

Please sign in to comment.