From bbea9eb52364bc6c60ecd3f6db52f95785811e67 Mon Sep 17 00:00:00 2001 From: Edgar Date: Tue, 6 Feb 2024 17:18:03 +0100 Subject: [PATCH] lower if --- crates/concrete_ir/src/lib.rs | 7 +- crates/concrete_ir/src/lowering.rs | 169 ++++++- simple.con.ir | 741 ----------------------------- 3 files changed, 153 insertions(+), 764 deletions(-) delete mode 100644 simple.con.ir diff --git a/crates/concrete_ir/src/lib.rs b/crates/concrete_ir/src/lib.rs index 41ac6df..0bb726b 100644 --- a/crates/concrete_ir/src/lib.rs +++ b/crates/concrete_ir/src/lib.rs @@ -71,7 +71,7 @@ pub struct Terminator { #[derive(Debug, Clone)] pub enum TerminatorKind { Goto { - target: BasicBlock, + target: BlockIndex, }, Return, Unreachable, @@ -87,10 +87,11 @@ pub enum TerminatorKind { }, } +/// Used for ifs, match #[derive(Debug, Clone)] pub struct SwitchTargets { - pub values: Vec, - pub targets: Vec, // last target is the otherwise block + pub values: Vec, + pub targets: Vec, // last target is the otherwise block (no value matched) } #[derive(Debug, Clone)] diff --git a/crates/concrete_ir/src/lowering.rs b/crates/concrete_ir/src/lowering.rs index 0cc3f6a..8544bbc 100644 --- a/crates/concrete_ir/src/lowering.rs +++ b/crates/concrete_ir/src/lowering.rs @@ -2,7 +2,9 @@ use std::collections::HashMap; use common::{BuildCtx, FnBodyBuilder, IdGenerator}; use concrete_ast::{ - expressions::{ArithOp, BinaryOp, BitwiseOp, CmpOp, Expression, FnCallOp, PathOp, ValueExpr}, + expressions::{ + ArithOp, BinaryOp, BitwiseOp, CmpOp, Expression, FnCallOp, IfExpr, PathOp, ValueExpr, + }, functions::FunctionDef, modules::{Module, ModuleDefItem}, statements::{self, AssignStmt, LetStmt, LetStmtTarget, ReturnStmt}, @@ -13,7 +15,7 @@ use concrete_ast::{ use crate::{ BasicBlock, BinOp, ConstData, ConstKind, ConstValue, DefId, FloatTy, FnBody, IntTy, Local, LocalKind, ModuleBody, Mutability, Operand, Place, PlaceElem, ProgramBody, Rvalue, Statement, - StatementKind, Terminator, TerminatorKind, Ty, TyKind, UintTy, ValueTree, + StatementKind, SwitchTargets, Terminator, TerminatorKind, Ty, TyKind, UintTy, ValueTree, }; pub mod common; @@ -153,24 +155,11 @@ fn lower_func(ctx: ModuleBody, func: &FunctionDef, module_id: DefId) -> ModuleBo } for stmt in &func.body { - match stmt { - statements::Statement::Assign(info) => lower_assign(&mut builder, info), - statements::Statement::Match(_) => todo!(), - statements::Statement::For(_) => todo!(), - statements::Statement::If(_) => todo!(), - statements::Statement::Let(info) => lower_let(&mut builder, info), - statements::Statement::Return(info) => { - lower_return( - &mut builder, - info, - func.decl.ret_type.as_ref().map(|x| lower_type(x).kind), - ); - } - statements::Statement::While(_) => todo!(), - statements::Statement::FnCall(info) => { - lower_fn_call(&mut builder, info); - } - } + lower_statement( + &mut builder, + stmt, + func.decl.ret_type.as_ref().map(|x| lower_type(x).kind), + ); } let (mut ctx, body) = (builder.ctx, builder.body); @@ -179,6 +168,145 @@ fn lower_func(ctx: ModuleBody, func: &FunctionDef, module_id: DefId) -> ModuleBo ctx } +fn lower_statement( + builder: &mut FnBodyBuilder, + info: &concrete_ast::statements::Statement, + ret_type: Option, +) { + match info { + statements::Statement::Assign(info) => lower_assign(builder, info), + statements::Statement::Match(_) => todo!(), + statements::Statement::For(_) => todo!(), + statements::Statement::If(info) => lower_if_statement(builder, info), + statements::Statement::Let(info) => lower_let(builder, info), + statements::Statement::Return(info) => { + lower_return(builder, info, ret_type); + } + statements::Statement::While(_) => todo!(), + statements::Statement::FnCall(info) => { + lower_fn_call(builder, info); + } + } +} + +fn lower_if_statement(builder: &mut FnBodyBuilder, info: &IfExpr) { + let discriminator = lower_expression(builder, &info.value, Some(TyKind::Bool)); + + let local = builder.add_local(Local { + span: None, + ty: Ty { + span: None, + kind: TyKind::Bool, + }, + kind: LocalKind::Temp, + }); + let place = Place { + local, + projection: vec![], + }; + + builder.statements.push(Statement { + span: None, + kind: StatementKind::Assign(place.clone(), discriminator), + }); + + // keep idx to change terminator + let current_block_idx = builder.body.basic_blocks.len(); + + let statements = std::mem::take(&mut builder.statements); + builder.body.basic_blocks.push(BasicBlock { + statements, + terminator: Box::new(Terminator { + span: None, + kind: TerminatorKind::Unreachable, + }), + }); + + // keep idx for switch targets + let first_then_block_idx = builder.body.basic_blocks.len(); + + for stmt in &info.contents { + lower_statement( + builder, + stmt, + Some(builder.body.locals[builder.ret_local].ty.kind.clone()), + ); + } + + // keet idx to change terminator + let last_then_block_idx = builder.body.basic_blocks.len(); + let statements = std::mem::take(&mut builder.statements); + builder.body.basic_blocks.push(BasicBlock { + statements, + terminator: Box::new(Terminator { + span: None, + kind: TerminatorKind::Unreachable, + }), + }); + + let first_else_block_idx = builder.body.basic_blocks.len(); + + if let Some(contents) = &info.r#else { + for stmt in contents { + lower_statement( + builder, + stmt, + Some(builder.body.locals[builder.ret_local].ty.kind.clone()), + ); + } + } + + let last_else_block_idx = builder.body.basic_blocks.len(); + let statements = std::mem::take(&mut builder.statements); + builder.body.basic_blocks.push(BasicBlock { + statements, + terminator: Box::new(Terminator { + span: None, + kind: TerminatorKind::Unreachable, + }), + }); + + // Needed to ease codegen + let otherwise_block_idx = builder.body.basic_blocks.len(); + builder.body.basic_blocks.push(BasicBlock { + statements: vec![], + terminator: Box::new(Terminator { + span: None, + kind: TerminatorKind::Unreachable, + }), + }); + + let targets = SwitchTargets { + values: vec![ + ValueTree::Leaf(ConstValue::Bool(true)), + ValueTree::Leaf(ConstValue::Bool(false)), + ], + targets: vec![ + first_then_block_idx, + first_else_block_idx, + otherwise_block_idx, + ], + }; + + let kind = TerminatorKind::SwitchInt { + discriminator: Operand::Place(place), + targets, + }; + builder.body.basic_blocks[current_block_idx].terminator.kind = kind; + + let next_block_idx = builder.body.basic_blocks.len(); + builder.body.basic_blocks[last_then_block_idx] + .terminator + .kind = TerminatorKind::Goto { + target: next_block_idx, + }; + builder.body.basic_blocks[last_else_block_idx] + .terminator + .kind = TerminatorKind::Goto { + target: next_block_idx, + }; +} + fn lower_let(builder: &mut FnBodyBuilder, info: &LetStmt) { match &info.target { LetStmtTarget::Simple { name, r#type } => { @@ -460,6 +588,7 @@ fn lower_value_expr( } UintTy::U128 => ConstValue::U128(*value), }, + TyKind::Bool => ConstValue::Bool(*value != 0), _ => unreachable!(), })), }, diff --git a/simple.con.ir b/simple.con.ir deleted file mode 100644 index 4c1357d..0000000 --- a/simple.con.ir +++ /dev/null @@ -1,741 +0,0 @@ -ProgramBody { - module_names: { - "Simple": DefId { - program_id: 0, - id: 1, - }, - }, - modules: { - DefId { - program_id: 0, - id: 1, - }: ModuleBody { - id: DefId { - program_id: 0, - id: 1, - }, - parent_id: None, - symbols: SymbolTable { - symbols: {}, - modules: {}, - functions: { - "add_plus_two": DefId { - program_id: 0, - id: 3, - }, - "main": DefId { - program_id: 0, - id: 2, - }, - }, - constants: {}, - structs: {}, - types: {}, - }, - functions: { - DefId { - program_id: 0, - id: 2, - }: FnBody { - id: DefId { - program_id: 0, - id: 2, - }, - basic_blocks: [ - BasicBlock { - statements: [ - Statement { - span: Some( - Span { - from: 48, - to: 49, - }, - ), - kind: StorageLive( - 1, - ), - }, - Statement { - span: Some( - Span { - from: 48, - to: 49, - }, - ), - kind: Assign( - Place { - local: 1, - projection: [], - }, - Use( - Const( - ConstData { - ty: Int( - I32, - ), - data: Value( - Leaf( - I32( - 2, - ), - ), - ), - }, - ), - ), - ), - }, - Statement { - span: Some( - Span { - from: 72, - to: 73, - }, - ), - kind: StorageLive( - 2, - ), - }, - Statement { - span: Some( - Span { - from: 72, - to: 73, - }, - ), - kind: Assign( - Place { - local: 2, - projection: [], - }, - Use( - Const( - ConstData { - ty: Int( - I32, - ), - data: Value( - Leaf( - I32( - 4, - ), - ), - ), - }, - ), - ), - ), - }, - ], - terminator: Terminator { - span: Some( - Span { - from: 99, - to: 111, - }, - ), - kind: Call { - func: DefId { - program_id: 0, - id: 3, - }, - args: [ - Use( - Place( - Place { - local: 1, - projection: [], - }, - ), - ), - Use( - Place( - Place { - local: 2, - projection: [], - }, - ), - ), - ], - destination: Place { - local: 3, - projection: [], - }, - target: Some( - 1, - ), - }, - }, - }, - BasicBlock { - statements: [ - Statement { - span: None, - kind: Assign( - Place { - local: 0, - projection: [], - }, - Use( - Place( - Place { - local: 3, - projection: [], - }, - ), - ), - ), - }, - ], - terminator: Terminator { - span: None, - kind: Return, - }, - }, - ], - locals: [ - Local { - span: None, - ty: Ty { - span: Some( - Span { - from: 30, - to: 33, - }, - ), - kind: Int( - I32, - ), - }, - kind: ReturnPointer, - }, - Local { - span: Some( - Span { - from: 48, - to: 49, - }, - ), - ty: Ty { - span: Some( - Span { - from: 51, - to: 54, - }, - ), - kind: Int( - I32, - ), - }, - kind: Temp, - }, - Local { - span: Some( - Span { - from: 72, - to: 73, - }, - ), - ty: Ty { - span: Some( - Span { - from: 75, - to: 78, - }, - ), - kind: Int( - I32, - ), - }, - kind: Temp, - }, - Local { - span: None, - ty: Ty { - span: Some( - Span { - from: 165, - to: 168, - }, - ), - kind: Int( - I32, - ), - }, - kind: Temp, - }, - ], - }, - DefId { - program_id: 0, - id: 3, - }: FnBody { - id: DefId { - program_id: 0, - id: 3, - }, - basic_blocks: [ - BasicBlock { - statements: [ - Statement { - span: Some( - Span { - from: 187, - to: 188, - }, - ), - kind: StorageLive( - 3, - ), - }, - Statement { - span: Some( - Span { - from: 187, - to: 188, - }, - ), - kind: Assign( - Place { - local: 3, - projection: [], - }, - Use( - Const( - ConstData { - ty: Int( - I32, - ), - data: Value( - Leaf( - I32( - 1, - ), - ), - ), - }, - ), - ), - ), - }, - Statement { - span: None, - kind: StorageLive( - 4, - ), - }, - Statement { - span: None, - kind: Assign( - Place { - local: 4, - projection: [], - }, - Use( - Place( - Place { - local: 3, - projection: [], - }, - ), - ), - ), - }, - Statement { - span: None, - kind: StorageLive( - 5, - ), - }, - Statement { - span: None, - kind: Assign( - Place { - local: 4, - projection: [], - }, - Use( - Const( - ConstData { - ty: Int( - I32, - ), - data: Value( - Leaf( - I32( - 1, - ), - ), - ), - }, - ), - ), - ), - }, - Statement { - span: Some( - Span { - from: 207, - to: 208, - }, - ), - kind: Assign( - Place { - local: 3, - projection: [], - }, - BinaryOp( - Add, - ( - Place( - Place { - local: 4, - projection: [], - }, - ), - Place( - Place { - local: 4, - projection: [], - }, - ), - ), - ), - ), - }, - Statement { - span: None, - kind: StorageLive( - 6, - ), - }, - Statement { - span: None, - kind: Assign( - Place { - local: 6, - projection: [], - }, - Use( - Place( - Place { - local: 1, - projection: [], - }, - ), - ), - ), - }, - Statement { - span: None, - kind: StorageLive( - 7, - ), - }, - Statement { - span: None, - kind: Assign( - Place { - local: 6, - projection: [], - }, - Use( - Place( - Place { - local: 2, - projection: [], - }, - ), - ), - ), - }, - Statement { - span: None, - kind: StorageLive( - 8, - ), - }, - Statement { - span: None, - kind: Assign( - Place { - local: 8, - projection: [], - }, - BinaryOp( - Add, - ( - Place( - Place { - local: 6, - projection: [], - }, - ), - Place( - Place { - local: 6, - projection: [], - }, - ), - ), - ), - ), - }, - Statement { - span: None, - kind: StorageLive( - 9, - ), - }, - Statement { - span: None, - kind: Assign( - Place { - local: 8, - projection: [], - }, - Use( - Place( - Place { - local: 3, - projection: [], - }, - ), - ), - ), - }, - Statement { - span: None, - kind: Assign( - Place { - local: 0, - projection: [], - }, - BinaryOp( - Add, - ( - Place( - Place { - local: 8, - projection: [], - }, - ), - Place( - Place { - local: 8, - projection: [], - }, - ), - ), - ), - ), - }, - ], - terminator: Terminator { - span: None, - kind: Return, - }, - }, - ], - locals: [ - Local { - span: None, - ty: Ty { - span: Some( - Span { - from: 165, - to: 168, - }, - ), - kind: Int( - I32, - ), - }, - kind: ReturnPointer, - }, - Local { - span: Some( - Span { - from: 146, - to: 147, - }, - ), - ty: Ty { - span: Some( - Span { - from: 149, - to: 152, - }, - ), - kind: Int( - I32, - ), - }, - kind: Arg, - }, - Local { - span: Some( - Span { - from: 154, - to: 155, - }, - ), - ty: Ty { - span: Some( - Span { - from: 157, - to: 160, - }, - ), - kind: Int( - I32, - ), - }, - kind: Arg, - }, - Local { - span: Some( - Span { - from: 187, - to: 188, - }, - ), - ty: Ty { - span: Some( - Span { - from: 190, - to: 193, - }, - ), - kind: Int( - I32, - ), - }, - kind: Temp, - }, - Local { - span: None, - ty: Ty { - span: None, - kind: Int( - I32, - ), - }, - kind: Temp, - }, - Local { - span: None, - ty: Ty { - span: None, - kind: Int( - I32, - ), - }, - kind: Temp, - }, - Local { - span: None, - ty: Ty { - span: None, - kind: Int( - I32, - ), - }, - kind: Temp, - }, - Local { - span: None, - ty: Ty { - span: None, - kind: Int( - I32, - ), - }, - kind: Temp, - }, - Local { - span: None, - ty: Ty { - span: None, - kind: Int( - I32, - ), - }, - kind: Temp, - }, - Local { - span: None, - ty: Ty { - span: None, - kind: Int( - I32, - ), - }, - kind: Temp, - }, - ], - }, - }, - function_signatures: { - DefId { - program_id: 0, - id: 3, - }: ( - [ - Ty { - span: Some( - Span { - from: 149, - to: 152, - }, - ), - kind: Int( - I32, - ), - }, - Ty { - span: Some( - Span { - from: 157, - to: 160, - }, - ), - kind: Int( - I32, - ), - }, - ], - Ty { - span: Some( - Span { - from: 165, - to: 168, - }, - ), - kind: Int( - I32, - ), - }, - ), - DefId { - program_id: 0, - id: 2, - }: ( - [], - Ty { - span: Some( - Span { - from: 30, - to: 33, - }, - ), - kind: Int( - I32, - ), - }, - ), - }, - modules: {}, - }, - }, -}