Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
edg-l committed Jan 15, 2024
1 parent f42c92f commit 1c2298d
Showing 1 changed file with 107 additions and 151 deletions.
258 changes: 107 additions & 151 deletions crates/concrete_codegen_mlir/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use std::{collections::HashMap, error::Error};
use bumpalo::Bump;
use concrete_ast::{
common::{Ident, Span},
expressions::{ArithOp, BinaryOp, CmpOp, Expression, IfExpr, LogicOp, PathOp, SimpleExpr},
expressions::{
ArithOp, BinaryOp, CmpOp, Expression, FnCallOp, IfExpr, LogicOp, PathOp, SimpleExpr,
},
functions::FunctionDef,
modules::{Module, ModuleDefItem},
statements::{AssignStmt, LetStmt, LetStmtTarget, ReturnStmt, Statement},
Expand Down Expand Up @@ -228,31 +230,8 @@ fn compile_function_def<'ctx, 'parent: 'ctx>(

for stmt in &info.body {
if let Some(block) = fn_block {
match stmt {
Statement::Assign(info) => {
compile_assign_stmt(session, context, &mut scope_ctx, &helper, block, info)?
}
Statement::Match(_) => todo!(),
Statement::For(_) => todo!(),
Statement::If(info) => {
fn_block = compile_if_expr(
session,
context,
&mut scope_ctx,
&helper,
block,
info,
)?;
}
Statement::Let(info) => {
compile_let_stmt(session, context, &mut scope_ctx, &helper, block, info)?
}
Statement::Return(info) => {
compile_return_stmt(session, context, &mut scope_ctx, &helper, block, info)?
}
Statement::While(_) => todo!(),
Statement::FnCall(_) => todo!(),
}
fn_block =
compile_statement(session, context, &mut scope_ctx, &helper, block, stmt)?;
}
}
}
Expand All @@ -267,6 +246,36 @@ fn compile_function_def<'ctx, 'parent: 'ctx>(
))
}

fn compile_statement<'c, 'this: 'c>(
session: &Session,
context: &'c MeliorContext,
scope_ctx: &mut ScopeContext<'c, 'this>,
helper: &BlockHelper<'c, 'this>,
block: &'this BlockRef<'c, 'this>,
info: &Statement,
) -> Result<Option<&'this BlockRef<'c, 'this>>, Box<dyn Error>> {
match info {
Statement::Assign(info) => {
compile_assign_stmt(session, context, scope_ctx, helper, block, info)?
}
Statement::Match(_) => todo!(),
Statement::For(_) => todo!(),
Statement::If(info) => {
return compile_if_expr(session, context, scope_ctx, helper, block, info);
}
Statement::Let(info) => compile_let_stmt(session, context, scope_ctx, helper, block, info)?,
Statement::Return(info) => {
compile_return_stmt(session, context, scope_ctx, helper, block, info)?
}
Statement::While(_) => todo!(),
Statement::FnCall(info) => {
compile_fn_call(session, context, scope_ctx, helper, block, info)?;
}
}

Ok(Some(block))
}

/// Compile a if expression / statement
///
/// This returns a block if any branch doesn't have a function return terminator.
Expand All @@ -285,7 +294,7 @@ fn compile_if_expr<'c, 'this: 'c>(
context: &'c MeliorContext,
scope_ctx: &mut ScopeContext<'c, 'this>,
helper: &BlockHelper<'c, 'this>,
block: &'this Block<'c>,
block: &'this BlockRef<'c, 'this>,
info: &IfExpr,
) -> Result<Option<&'this BlockRef<'c, 'this>>, Box<dyn Error>> {
let condition = compile_expression(
Expand Down Expand Up @@ -323,46 +332,14 @@ fn compile_if_expr<'c, 'this: 'c>(
let mut then_scope_ctx = scope_ctx.clone();
for stmt in &info.contents {
if let Some(then_successor_block) = then_successor {
match stmt {
Statement::Assign(info) => compile_assign_stmt(
session,
context,
&mut then_scope_ctx,
helper,
then_successor_block,
info,
)?,
Statement::Match(_) => todo!(),
Statement::For(_) => todo!(),
Statement::If(info) => {
then_successor = compile_if_expr(
session,
context,
&mut then_scope_ctx,
helper,
then_successor_block,
info,
)?;
}
Statement::Let(info) => compile_let_stmt(
session,
context,
&mut then_scope_ctx,
helper,
then_successor_block,
info,
)?,
Statement::Return(info) => compile_return_stmt(
session,
context,
&mut then_scope_ctx,
helper,
then_successor_block,
info,
)?,
Statement::While(_) => todo!(),
Statement::FnCall(_) => todo!(),
}
then_successor = compile_statement(
session,
context,
&mut then_scope_ctx,
helper,
then_successor_block,
stmt,
)?;
}
}
}
Expand All @@ -371,46 +348,14 @@ fn compile_if_expr<'c, 'this: 'c>(
let mut else_scope_ctx = scope_ctx.clone();
for stmt in else_contents {
if let Some(else_successor_block) = else_successor {
match stmt {
Statement::Assign(info) => compile_assign_stmt(
session,
context,
&mut else_scope_ctx,
helper,
else_successor_block,
info,
)?,
Statement::Match(_) => todo!(),
Statement::For(_) => todo!(),
Statement::If(info) => {
else_successor = compile_if_expr(
session,
context,
&mut else_scope_ctx,
helper,
else_successor_block,
info,
)?;
}
Statement::Let(info) => compile_let_stmt(
session,
context,
&mut else_scope_ctx,
helper,
else_successor_block,
info,
)?,
Statement::Return(info) => compile_return_stmt(
session,
context,
&mut else_scope_ctx,
helper,
else_successor_block,
info,
)?,
Statement::While(_) => todo!(),
Statement::FnCall(_) => todo!(),
}
else_successor = compile_statement(
session,
context,
&mut else_scope_ctx,
helper,
else_successor_block,
stmt,
)?;
}
}
}
Expand Down Expand Up @@ -634,50 +579,7 @@ fn compile_expression<'ctx, 'parent: 'ctx>(
SimpleExpr::Path(value) => compile_path_op(session, context, scope_ctx, block, value),
},
Expression::FnCall(value) => {
let mut args = Vec::with_capacity(value.args.len());
let location = get_location(context, session, value.target.span.from);

let target_fn = scope_ctx
.functions
.get(&value.target.name)
.expect("function not found")
.clone();

assert_eq!(
value.args.len(),
target_fn.decl.params.len(),
"parameter length doesnt match"
);

for (arg, arg_info) in value.args.iter().zip(&target_fn.decl.params) {
let value = compile_expression(
session,
context,
scope_ctx,
_helper,
block,
arg,
Some(&arg_info.r#type),
)?;
args.push(value);
}

let return_type = if let Some(ret_type) = &target_fn.decl.ret_type {
vec![scope_ctx.resolve_type_spec(context, ret_type)?]
} else {
vec![]
};

Ok(block
.append_operation(func::call(
context,
FlatSymbolRefAttribute::new(context, &value.target.name),
&args,
&return_type,
location,
))
.result(0)?
.into())
compile_fn_call(session, context, scope_ctx, _helper, block, value)
}
Expression::Match(_) => todo!(),
Expression::If(_) => todo!(),
Expand Down Expand Up @@ -781,6 +683,60 @@ fn compile_expression<'ctx, 'parent: 'ctx>(
}
}

fn compile_fn_call<'ctx, 'parent: 'ctx>(
session: &Session,
context: &'ctx MeliorContext,
scope_ctx: &mut ScopeContext<'ctx, 'parent>,
_helper: &BlockHelper<'ctx, 'parent>,
block: &'parent Block<'ctx>,
info: &FnCallOp,
) -> Result<Value<'ctx, 'parent>, Box<dyn Error>> {
let mut args = Vec::with_capacity(info.args.len());
let location = get_location(context, session, info.target.span.from);

let target_fn = scope_ctx
.functions
.get(&info.target.name)
.expect("function not found")
.clone();

assert_eq!(
info.args.len(),
target_fn.decl.params.len(),
"parameter length doesnt match"
);

for (arg, arg_info) in info.args.iter().zip(&target_fn.decl.params) {
let value = compile_expression(
session,
context,
scope_ctx,
_helper,
block,
arg,
Some(&arg_info.r#type),
)?;
args.push(value);
}

let return_type = if let Some(ret_type) = &target_fn.decl.ret_type {
vec![scope_ctx.resolve_type_spec(context, ret_type)?]
} else {
vec![]
};

Ok(block
.append_operation(func::call(
context,
FlatSymbolRefAttribute::new(context, &info.target.name),
&args,
&return_type,
location,
))
.result(0)?
.into())
}

fn compile_path_op<'ctx, 'parent: 'ctx>(
session: &Session,
context: &'ctx MeliorContext,
Expand Down

0 comments on commit 1c2298d

Please sign in to comment.