From 36e78fc83449311c818f1452af7a498def264f99 Mon Sep 17 00:00:00 2001 From: Edgar Date: Tue, 13 Feb 2024 14:12:03 +0100 Subject: [PATCH] fixes --- crates/concrete_codegen_mlir/src/codegen.rs | 2 +- crates/concrete_ir/src/lowering.rs | 117 +++++++++++++++++--- 2 files changed, 102 insertions(+), 17 deletions(-) diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index acdbc4b..596a636 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -10,7 +10,7 @@ use melior::{ ir::{ attribute::{FlatSymbolRefAttribute, FloatAttribute, StringAttribute, TypeAttribute}, r#type::{FunctionType, IntegerType, MemRefType}, - Attribute, Block, Location, Module as MeliorModule, Region, Type, Value, + Attribute, Block, Location, Module as MeliorModule, Region, Type, Value, ValueLike, }, Context as MeliorContext, }; diff --git a/crates/concrete_ir/src/lowering.rs b/crates/concrete_ir/src/lowering.rs index a812341..7919cce 100644 --- a/crates/concrete_ir/src/lowering.rs +++ b/crates/concrete_ir/src/lowering.rs @@ -284,7 +284,9 @@ fn lower_while(builder: &mut FnBodyBuilder, info: &WhileStmt) { } fn lower_if_statement(builder: &mut FnBodyBuilder, info: &IfExpr) { - let discriminator = lower_expression(builder, &info.value, Some(TyKind::Bool)); + let disc_type = + find_expression_type(builder, &info.value).expect("failed to find discriminator type"); + let discriminator = lower_expression(builder, &info.value, Some(disc_type.clone())); let local = builder.add_temp_local(TyKind::Bool); let place = Place { @@ -354,7 +356,7 @@ fn lower_if_statement(builder: &mut FnBodyBuilder, info: &IfExpr) { }); let targets = SwitchTargets { - values: vec![ValueTree::Leaf(ConstValue::Bool(false))], + values: vec![disc_type.get_falsy_value()], targets: vec![first_else_block_idx, first_then_block_idx], }; @@ -414,8 +416,8 @@ fn lower_assign(builder: &mut FnBodyBuilder, info: &AssignStmt) { }) } -fn lower_return(builder: &mut FnBodyBuilder, info: &ReturnStmt, type_hint: Option) { - let value = lower_expression(builder, &info.value, type_hint); +fn lower_return(builder: &mut FnBodyBuilder, info: &ReturnStmt, ret_type_hint: Option) { + let value = lower_expression(builder, &info.value, ret_type_hint); builder.statements.push(Statement { span: None, kind: StatementKind::Assign( @@ -437,6 +439,62 @@ fn lower_return(builder: &mut FnBodyBuilder, info: &ReturnStmt, type_hint: Optio }); } +fn find_expression_type(builder: &mut FnBodyBuilder, info: &Expression) -> Option { + match info { + Expression::Value(value) => match value { + ValueExpr::ConstBool(_) => Some(TyKind::Bool), + ValueExpr::ConstChar(_) => Some(TyKind::Char), + ValueExpr::ConstInt(_) => None, + ValueExpr::ConstFloat(_) => None, + ValueExpr::ConstStr(_) => Some(TyKind::String), + ValueExpr::Path(path) => { + let local = builder.get_local(&path.first.name).unwrap(); // todo handle segments + Some(local.ty.kind.clone()) + } + ValueExpr::Deref(path) => { + let local = builder.get_local(&path.first.name).unwrap(); // todo handle segments + Some(local.ty.kind.clone()) + } + ValueExpr::AsRef { path, ref_type } => { + let local = builder.get_local(&path.first.name).unwrap(); // todo handle segments + Some(TyKind::Ref( + Box::new(local.ty.kind.clone()), + match ref_type { + RefType::Borrow => Mutability::Not, + RefType::MutBorrow => Mutability::Mut, + }, + )) + } + }, + Expression::FnCall(info) => { + let fn_id = { + let mod_body = builder.get_module_body(); + + if let Some(id) = mod_body.symbols.functions.get(&info.target.name) { + *id + } else { + *mod_body + .imports + .get(&info.target.name) + .expect("function call not found") + } + }; + let fn_sig = builder.ctx.body.function_signatures.get(&fn_id).unwrap(); + Some(fn_sig.1.kind.clone()) + } + Expression::Match(_) => None, + Expression::If(_) => None, + Expression::UnaryOp(_, info) => find_expression_type(builder, info), + Expression::BinaryOp(lhs, op, rhs) => { + if matches!(op, BinaryOp::Logic(_)) { + Some(TyKind::Bool) + } else { + find_expression_type(builder, lhs).or(find_expression_type(builder, rhs)) + } + } + } +} + fn lower_expression( builder: &mut FnBodyBuilder, info: &Expression, @@ -532,16 +590,35 @@ fn lower_binary_op( rhs: &Expression, type_hint: Option, ) -> Rvalue { - let expr_type = type_hint.clone().expect("type hint needed"); - let lhs = lower_expression(builder, lhs, type_hint.clone()); - let rhs = lower_expression(builder, rhs, type_hint.clone()); + let (lhs, lhs_ty) = if type_hint.is_none() { + let ty = find_expression_type(builder, lhs); + (lower_expression(builder, lhs, ty.clone()), ty) + } else { + ( + lower_expression(builder, lhs, type_hint.clone()), + type_hint.clone(), + ) + }; + let (rhs, rhs_ty) = if type_hint.is_none() { + let ty = find_expression_type(builder, rhs); + (lower_expression(builder, rhs, ty.clone()), ty) + } else { + ( + lower_expression(builder, rhs, type_hint.clone()), + type_hint.clone(), + ) + }; - let local_ty = Ty { + let lhs_ty = lhs_ty.or(rhs_ty.clone()).expect("type not found"); + let rhs_ty = rhs_ty.unwrap_or(lhs_ty.clone()); + let lhs_local = builder.add_local(Local::temp(Ty { span: None, - kind: expr_type.clone(), - }; - let lhs_local = builder.add_local(Local::temp(local_ty.clone())); - let rhs_local = builder.add_local(Local::temp(local_ty.clone())); + kind: lhs_ty.clone(), + })); + let rhs_local = builder.add_local(Local::temp(Ty { + span: None, + kind: rhs_ty.clone(), + })); let lhs_place = Place { local: lhs_local, projection: vec![], @@ -665,10 +742,18 @@ fn lower_value_expr( })), ValueExpr::ConstFloat(value) => Rvalue::Use(Operand::Const(match type_hint { Some(ty) => ConstData { - ty, - data: ConstKind::Value(ValueTree::Leaf(ConstValue::F32( - value.parse().expect("error parsing float"), - ))), + ty: ty.clone(), + data: ConstKind::Value(ValueTree::Leaf(match &ty { + TyKind::Float(ty) => match ty { + FloatTy::F32 => { + ConstValue::F32(value.parse().expect("error parsing float")) + } + FloatTy::F64 => { + ConstValue::F64(value.parse().expect("error parsing float")) + } + }, + _ => unreachable!(), + })), }, None => ConstData { ty: TyKind::Float(FloatTy::F64),