Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
edg-l committed Feb 13, 2024
1 parent a42d045 commit 36e78fc
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 17 deletions.
2 changes: 1 addition & 1 deletion crates/concrete_codegen_mlir/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use melior::{
ir::{
attribute::{FlatSymbolRefAttribute, FloatAttribute, StringAttribute, TypeAttribute},
r#type::{FunctionType, IntegerType, MemRefType},
Attribute, Block, Location, Module as MeliorModule, Region, Type, Value,
Attribute, Block, Location, Module as MeliorModule, Region, Type, Value, ValueLike,
},
Context as MeliorContext,
};
Expand Down
117 changes: 101 additions & 16 deletions crates/concrete_ir/src/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ fn lower_while(builder: &mut FnBodyBuilder, info: &WhileStmt) {
}

fn lower_if_statement(builder: &mut FnBodyBuilder, info: &IfExpr) {
let discriminator = lower_expression(builder, &info.value, Some(TyKind::Bool));
let disc_type =
find_expression_type(builder, &info.value).expect("failed to find discriminator type");
let discriminator = lower_expression(builder, &info.value, Some(disc_type.clone()));

let local = builder.add_temp_local(TyKind::Bool);
let place = Place {
Expand Down Expand Up @@ -354,7 +356,7 @@ fn lower_if_statement(builder: &mut FnBodyBuilder, info: &IfExpr) {
});

let targets = SwitchTargets {
values: vec![ValueTree::Leaf(ConstValue::Bool(false))],
values: vec![disc_type.get_falsy_value()],
targets: vec![first_else_block_idx, first_then_block_idx],
};

Expand Down Expand Up @@ -414,8 +416,8 @@ fn lower_assign(builder: &mut FnBodyBuilder, info: &AssignStmt) {
})
}

fn lower_return(builder: &mut FnBodyBuilder, info: &ReturnStmt, type_hint: Option<TyKind>) {
let value = lower_expression(builder, &info.value, type_hint);
fn lower_return(builder: &mut FnBodyBuilder, info: &ReturnStmt, ret_type_hint: Option<TyKind>) {
let value = lower_expression(builder, &info.value, ret_type_hint);
builder.statements.push(Statement {
span: None,
kind: StatementKind::Assign(
Expand All @@ -437,6 +439,62 @@ fn lower_return(builder: &mut FnBodyBuilder, info: &ReturnStmt, type_hint: Optio
});
}

fn find_expression_type(builder: &mut FnBodyBuilder, info: &Expression) -> Option<TyKind> {
match info {
Expression::Value(value) => match value {
ValueExpr::ConstBool(_) => Some(TyKind::Bool),
ValueExpr::ConstChar(_) => Some(TyKind::Char),
ValueExpr::ConstInt(_) => None,
ValueExpr::ConstFloat(_) => None,
ValueExpr::ConstStr(_) => Some(TyKind::String),
ValueExpr::Path(path) => {
let local = builder.get_local(&path.first.name).unwrap(); // todo handle segments
Some(local.ty.kind.clone())
}
ValueExpr::Deref(path) => {
let local = builder.get_local(&path.first.name).unwrap(); // todo handle segments
Some(local.ty.kind.clone())
}
ValueExpr::AsRef { path, ref_type } => {
let local = builder.get_local(&path.first.name).unwrap(); // todo handle segments
Some(TyKind::Ref(
Box::new(local.ty.kind.clone()),
match ref_type {
RefType::Borrow => Mutability::Not,
RefType::MutBorrow => Mutability::Mut,
},
))
}
},
Expression::FnCall(info) => {
let fn_id = {
let mod_body = builder.get_module_body();

if let Some(id) = mod_body.symbols.functions.get(&info.target.name) {
*id
} else {
*mod_body
.imports
.get(&info.target.name)
.expect("function call not found")
}
};
let fn_sig = builder.ctx.body.function_signatures.get(&fn_id).unwrap();
Some(fn_sig.1.kind.clone())
}
Expression::Match(_) => None,
Expression::If(_) => None,
Expression::UnaryOp(_, info) => find_expression_type(builder, info),
Expression::BinaryOp(lhs, op, rhs) => {
if matches!(op, BinaryOp::Logic(_)) {
Some(TyKind::Bool)
} else {
find_expression_type(builder, lhs).or(find_expression_type(builder, rhs))
}
}
}
}

fn lower_expression(
builder: &mut FnBodyBuilder,
info: &Expression,
Expand Down Expand Up @@ -532,16 +590,35 @@ fn lower_binary_op(
rhs: &Expression,
type_hint: Option<TyKind>,
) -> Rvalue {
let expr_type = type_hint.clone().expect("type hint needed");
let lhs = lower_expression(builder, lhs, type_hint.clone());
let rhs = lower_expression(builder, rhs, type_hint.clone());
let (lhs, lhs_ty) = if type_hint.is_none() {
let ty = find_expression_type(builder, lhs);
(lower_expression(builder, lhs, ty.clone()), ty)
} else {
(
lower_expression(builder, lhs, type_hint.clone()),
type_hint.clone(),
)
};
let (rhs, rhs_ty) = if type_hint.is_none() {
let ty = find_expression_type(builder, rhs);
(lower_expression(builder, rhs, ty.clone()), ty)
} else {
(
lower_expression(builder, rhs, type_hint.clone()),
type_hint.clone(),
)
};

let local_ty = Ty {
let lhs_ty = lhs_ty.or(rhs_ty.clone()).expect("type not found");
let rhs_ty = rhs_ty.unwrap_or(lhs_ty.clone());
let lhs_local = builder.add_local(Local::temp(Ty {
span: None,
kind: expr_type.clone(),
};
let lhs_local = builder.add_local(Local::temp(local_ty.clone()));
let rhs_local = builder.add_local(Local::temp(local_ty.clone()));
kind: lhs_ty.clone(),
}));
let rhs_local = builder.add_local(Local::temp(Ty {
span: None,
kind: rhs_ty.clone(),
}));
let lhs_place = Place {
local: lhs_local,
projection: vec![],
Expand Down Expand Up @@ -665,10 +742,18 @@ fn lower_value_expr(
})),
ValueExpr::ConstFloat(value) => Rvalue::Use(Operand::Const(match type_hint {
Some(ty) => ConstData {
ty,
data: ConstKind::Value(ValueTree::Leaf(ConstValue::F32(
value.parse().expect("error parsing float"),
))),
ty: ty.clone(),
data: ConstKind::Value(ValueTree::Leaf(match &ty {
TyKind::Float(ty) => match ty {
FloatTy::F32 => {
ConstValue::F32(value.parse().expect("error parsing float"))
}
FloatTy::F64 => {
ConstValue::F64(value.parse().expect("error parsing float"))
}
},
_ => unreachable!(),
})),
},
None => ConstData {
ty: TyKind::Float(FloatTy::F64),
Expand Down

0 comments on commit 36e78fc

Please sign in to comment.