From f69aef9da9016989b8b5196c37a039151dd90c64 Mon Sep 17 00:00:00 2001 From: Edgar Date: Thu, 11 Jan 2024 18:19:50 +0100 Subject: [PATCH 01/20] progress --- Cargo.lock | 1 + crates/concrete_codegen_mlir/Cargo.toml | 1 + crates/concrete_codegen_mlir/src/codegen.rs | 273 ++++++++++++++------ examples/simple_if.con | 13 + 4 files changed, 215 insertions(+), 73 deletions(-) create mode 100644 examples/simple_if.con diff --git a/Cargo.lock b/Cargo.lock index 1a27727..5defe42 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -330,6 +330,7 @@ dependencies = [ name = "concrete_codegen_mlir" version = "0.1.0" dependencies = [ + "bumpalo", "cc", "concrete_ast", "concrete_session", diff --git a/crates/concrete_codegen_mlir/Cargo.toml b/crates/concrete_codegen_mlir/Cargo.toml index 39f7262..8264ed6 100644 --- a/crates/concrete_codegen_mlir/Cargo.toml +++ b/crates/concrete_codegen_mlir/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +bumpalo = "3.14.0" concrete_ast = { path = "../concrete_ast"} concrete_session = { path = "../concrete_session"} llvm-sys = "170.0.1" diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 21a0d12..3bb0b97 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -1,8 +1,9 @@ -use std::{collections::HashMap, error::Error}; +use std::{cell::Cell, collections::HashMap, error::Error}; +use bumpalo::Bump; use concrete_ast::{ - common::Span, - expressions::{ArithOp, BinaryOp, CmpOp, Expression, LogicOp, PathOp, SimpleExpr}, + common::{Ident, Span}, + expressions::{ArithOp, BinaryOp, CmpOp, Expression, IfExpr, LogicOp, PathOp, SimpleExpr}, functions::FunctionDef, modules::{Module, ModuleDefItem}, statements::{AssignStmt, LetStmt, LetStmtTarget, ReturnStmt, Statement}, @@ -13,12 +14,13 @@ use concrete_session::Session; use melior::{ dialect::{ arith::{self, CmpiPredicate}, - func, memref, + cf, func, memref, }, ir::{ attribute::{FlatSymbolRefAttribute, IntegerAttribute, StringAttribute, TypeAttribute}, r#type::{FunctionType, IntegerType, MemRefType}, - Block, Location, Module as MeliorModule, Region, Type, Value, ValueLike, + Block, BlockRef, Location, Module as MeliorModule, Operation, Region, Type, Value, + ValueLike, }, Context as MeliorContext, }; @@ -43,13 +45,32 @@ pub struct LocalVar<'c, 'op> { pub value: Value<'c, 'op>, } -#[derive(Debug, Clone, Default)] -struct CompilerContext<'c, 'op> { - pub locals: HashMap>, +#[derive(Debug, Clone)] +struct CompilerContext<'c, 'this: 'c> { + pub locals: HashMap>, pub functions: HashMap, } -impl<'c, 'op> CompilerContext<'c, 'op> { +struct BlockHelper<'ctx, 'this: 'ctx> { + region: &'this Region<'ctx>, + blocks_arena: &'this Bump, + last_block: Cell<&'this BlockRef<'ctx, 'this>>, +} + +impl<'ctx, 'this> BlockHelper<'ctx, 'this> { + pub fn append_block(&self, block: Block<'ctx>) -> &'this Block<'ctx> { + let block = self + .region + .insert_block_after(*self.last_block.get(), block); + + let block_ref: &'this mut BlockRef<'ctx, 'this> = self.blocks_arena.alloc(block); + self.last_block.set(block_ref); + + block_ref + } +} + +impl<'c, 'this> CompilerContext<'c, 'this> { fn resolve_type( &self, context: &'c MeliorContext, @@ -88,10 +109,13 @@ fn compile_module( ) -> Result<(), Box> { // todo: handle imports - let mut compiler_ctx: CompilerContext = Default::default(); - let body = mlir_module.body(); + let mut compiler_ctx: CompilerContext = CompilerContext { + functions: Default::default(), + locals: Default::default(), + }; + // save all function signatures for statement in &module.contents { if let ModuleDefItem::Function(info) = statement { @@ -105,7 +129,8 @@ fn compile_module( match statement { ModuleDefItem::Constant(_) => todo!(), ModuleDefItem::Function(info) => { - compile_function_def(session, context, &mut compiler_ctx, &body, info)?; + let op = compile_function_def(session, context, &mut compiler_ctx, info)?; + body.append_operation(op); } ModuleDefItem::Record(_) => todo!(), ModuleDefItem::Type(_) => todo!(), @@ -120,15 +145,12 @@ fn get_location<'c>(context: &'c MeliorContext, session: &Session, span: &Span) Location::new(context, &session.file_path.display().to_string(), line, col) } -fn compile_function_def<'c, 'op>( +fn compile_function_def<'c, 'this: 'c>( session: &Session, context: &'c MeliorContext, - compiler_ctx: &mut CompilerContext<'c, 'op>, - block: &'op Block<'c>, + compiler_ctx: &mut CompilerContext<'c, 'this>, info: &FunctionDef, -) -> Result<(), Box> { - let region = Region::new(); - +) -> Result, Box> { let location = get_location(context, session, &info.decl.name.span); // Setup function arguments @@ -142,25 +164,11 @@ fn compile_function_def<'c, 'op>( fn_args_types.push(param_type); } - let fn_block = Block::new(&args); - // Create the function context - let mut fn_compiler_ctx = compiler_ctx.clone(); - - // Push arguments into locals - for (i, param) in info.decl.params.iter().enumerate() { - fn_compiler_ctx.locals.insert( - param.name.name.clone(), - LocalVar { - type_spec: param.r#type.clone(), - value: fn_block.argument(i)?.into(), - memref_type: None, - }, - ); - } + let region = Region::new(); let return_type = if let Some(ret_type) = &info.decl.ret_type { - vec![fn_compiler_ctx.resolve_type_spec(context, ret_type)?] + vec![compiler_ctx.resolve_type_spec(context, ret_type)?] } else { vec![] }; @@ -168,44 +176,132 @@ fn compile_function_def<'c, 'op>( let func_type = TypeAttribute::new(FunctionType::new(context, &fn_args_types, &return_type).into()); - for stmt in &info.body { - match stmt { - Statement::Assign(info) => { - compile_assign_stmt(session, context, &mut fn_compiler_ctx, &fn_block, info)? - } - Statement::Match(_) => todo!(), - Statement::For(_) => todo!(), - Statement::If(_) => todo!(), - Statement::Let(info) => { - compile_let_stmt(session, context, &mut fn_compiler_ctx, &fn_block, info)? - } - Statement::Return(info) => { - compile_return_stmt(session, context, &mut fn_compiler_ctx, &fn_block, info)? + { + let mut fn_compiler_ctx = compiler_ctx.clone(); + let fn_block = region.append_block(Block::new(&args)); + + let blocks_arena = Bump::new(); + let helper = BlockHelper { + region: ®ion, + blocks_arena: &blocks_arena, + last_block: Cell::new(&fn_block), + }; + + // Push arguments into locals + for (i, param) in info.decl.params.iter().enumerate() { + fn_compiler_ctx.locals.insert( + param.name.name.clone(), + LocalVar { + type_spec: param.r#type.clone(), + value: fn_block.argument(i)?.into(), + memref_type: None, + }, + ); + } + + for stmt in &info.body { + match stmt { + Statement::Assign(info) => compile_assign_stmt( + session, + context, + &mut fn_compiler_ctx, + &helper, + &fn_block, + info, + )?, + Statement::Match(_) => todo!(), + Statement::For(_) => todo!(), + Statement::If(info) => { + compile_if_expr( + session, + context, + &mut fn_compiler_ctx, + &helper, + &fn_block, + info, + )?; + } + Statement::Let(info) => compile_let_stmt( + session, + context, + &mut fn_compiler_ctx, + &helper, + &fn_block, + info, + )?, + Statement::Return(info) => compile_return_stmt( + session, + context, + &mut fn_compiler_ctx, + &helper, + &fn_block, + info, + )?, + Statement::While(_) => todo!(), + Statement::FnCall(_) => todo!(), } - Statement::While(_) => todo!(), - Statement::FnCall(_) => todo!(), } } - region.append_block(fn_block); - - block.append_operation(func::func( + Ok(func::func( context, StringAttribute::new(context, &info.decl.name.name), func_type, region, &[], location, + )) +} + +fn compile_if_expr<'c, 'this: 'c>( + session: &Session, + context: &'c MeliorContext, + compiler_ctx: &mut CompilerContext<'c, 'this>, + helper: &BlockHelper<'c, 'this>, + block: &'this Block<'c>, + info: &IfExpr, +) -> Result<(), Box> { + let condition = compile_expression( + session, + context, + compiler_ctx, + helper, + block, + &info.value, + Some(&TypeSpec::Simple { + name: Ident { + name: "bool".to_string(), + span: Span::new(0, 0), + }, + }), + )?; + + let true_successor = helper.append_block(Block::new(&[])); + let false_successor = helper.append_block(Block::new(&[])); + + let final_block = helper.append_block(Block::new(&[])); + true_successor.append_operation(cf::br(final_block, &[], Location::unknown(context))); + false_successor.append_operation(cf::br(final_block, &[], Location::unknown(context))); + + block.append_operation(cf::cond_br( + context, + condition, + true_successor, + false_successor, + &[], + &[], + Location::unknown(context), )); Ok(()) } -fn compile_let_stmt<'c, 'op>( +fn compile_let_stmt<'c, 'this: 'c>( session: &Session, context: &'c MeliorContext, - compiler_ctx: &mut CompilerContext<'c, 'op>, - block: &'op Block<'c>, + compiler_ctx: &mut CompilerContext<'c, 'this>, + helper: &BlockHelper<'c, 'this>, + block: &'this Block<'c>, info: &LetStmt, ) -> Result<(), Box> { match &info.target { @@ -214,6 +310,7 @@ fn compile_let_stmt<'c, 'op>( session, context, compiler_ctx, + helper, block, &info.value, Some(r#type), @@ -259,11 +356,12 @@ fn compile_let_stmt<'c, 'op>( } } -fn compile_assign_stmt<'c, 'op>( +fn compile_assign_stmt<'c, 'this: 'c>( session: &Session, context: &'c MeliorContext, - compiler_ctx: &mut CompilerContext<'c, 'op>, - block: &'op Block<'c>, + compiler_ctx: &mut CompilerContext<'c, 'this>, + helper: &BlockHelper<'c, 'this>, + block: &'this Block<'c>, info: &AssignStmt, ) -> Result<(), Box> { // todo: implement properly for structs, right now only really works for simple variables. @@ -285,6 +383,7 @@ fn compile_assign_stmt<'c, 'op>( session, context, compiler_ctx, + helper, block, &info.value, Some(&local.type_spec), @@ -303,26 +402,36 @@ fn compile_assign_stmt<'c, 'op>( Ok(()) } -fn compile_return_stmt<'c, 'op>( +fn compile_return_stmt<'c, 'this: 'c>( session: &Session, context: &'c MeliorContext, - compiler_ctx: &mut CompilerContext<'c, 'op>, - block: &'op Block<'c>, + compiler_ctx: &mut CompilerContext<'c, 'this>, + helper: &BlockHelper<'c, 'this>, + block: &'this Block<'c>, info: &ReturnStmt, ) -> Result<(), Box> { - let value = compile_expression(session, context, compiler_ctx, block, &info.value, None)?; + let value = compile_expression( + session, + context, + compiler_ctx, + helper, + block, + &info.value, + None, + )?; block.append_operation(func::r#return(&[value], Location::unknown(context))); Ok(()) } -fn compile_expression<'c, 'op>( +fn compile_expression<'c, 'this: 'c>( session: &Session, context: &'c MeliorContext, - compiler_ctx: &mut CompilerContext<'c, 'op>, - block: &'op Block<'c>, + compiler_ctx: &mut CompilerContext<'c, 'this>, + helper: &BlockHelper<'c, 'this>, + block: &'this Block<'c>, info: &Expression, type_info: Option<&TypeSpec>, -) -> Result, Box> { +) -> Result, Box> { let location = Location::unknown(context); match info { Expression::Simple(simple) => match simple { @@ -357,7 +466,7 @@ fn compile_expression<'c, 'op>( SimpleExpr::ConstFloat(_) => todo!(), SimpleExpr::ConstStr(_) => todo!(), SimpleExpr::Path(value) => { - compile_path_op(session, context, compiler_ctx, block, value) + compile_path_op(session, context, compiler_ctx, helper, block, value) } }, Expression::FnCall(value) => { @@ -381,6 +490,7 @@ fn compile_expression<'c, 'op>( session, context, compiler_ctx, + helper, block, arg, Some(&arg_info.r#type), @@ -409,8 +519,24 @@ fn compile_expression<'c, 'op>( Expression::If(_) => todo!(), Expression::UnaryOp(_, _) => todo!(), Expression::BinaryOp(lhs, op, rhs) => { - let lhs = compile_expression(session, context, compiler_ctx, block, lhs, type_info)?; - let rhs = compile_expression(session, context, compiler_ctx, block, rhs, type_info)?; + let lhs = compile_expression( + session, + context, + compiler_ctx, + helper, + block, + lhs, + type_info, + )?; + let rhs = compile_expression( + session, + context, + compiler_ctx, + helper, + block, + rhs, + type_info, + )?; let op = match op { // todo: check signedness @@ -505,13 +631,14 @@ fn compile_expression<'c, 'op>( } } -fn compile_path_op<'c, 'op>( +fn compile_path_op<'c, 'this: 'c>( session: &Session, context: &'c MeliorContext, - compiler_ctx: &mut CompilerContext<'c, 'op>, - block: &'op Block<'c>, + compiler_ctx: &mut CompilerContext<'c, 'this>, + helper: &BlockHelper<'c, 'this>, + block: &'this Block<'c>, path: &PathOp, -) -> Result, Box> { +) -> Result, Box> { // For now only simple variables work. // TODO: implement properly, this requires having structs implemented. diff --git a/examples/simple_if.con b/examples/simple_if.con new file mode 100644 index 0000000..2506005 --- /dev/null +++ b/examples/simple_if.con @@ -0,0 +1,13 @@ +mod Simple { + fn main() -> bool { + let y: bool = check(4); + return y; + } + + fn check(x: i64) -> bool { + if x == 0 { + return true; + } + return false; + } +} From c0f837db8a6497a20a5084a4c2af3ebcc9249192 Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Sat, 13 Jan 2024 12:18:49 +0100 Subject: [PATCH 02/20] if works --- crates/concrete_codegen_mlir/src/codegen.rs | 140 +++++++++++++++++--- crates/concrete_codegen_mlir/src/context.rs | 13 +- crates/concrete_codegen_mlir/src/lib.rs | 1 + examples/simple_if.con | 5 +- 4 files changed, 134 insertions(+), 25 deletions(-) diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 3bb0b97..42c07f2 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -58,7 +58,7 @@ struct BlockHelper<'ctx, 'this: 'ctx> { } impl<'ctx, 'this> BlockHelper<'ctx, 'this> { - pub fn append_block(&self, block: Block<'ctx>) -> &'this Block<'ctx> { + pub fn append_block(&self, block: Block<'ctx>) -> &'this BlockRef<'ctx, 'this> { let block = self .region .insert_block_after(*self.last_block.get(), block); @@ -145,6 +145,10 @@ fn get_location<'c>(context: &'c MeliorContext, session: &Session, span: &Span) Location::new(context, &session.file_path.display().to_string(), line, col) } +fn get_named_location<'c>(context: &'c MeliorContext, name: &str) -> Location<'c> { + Location::name(context, name, Location::unknown(context)) +} + fn compile_function_def<'c, 'this: 'c>( session: &Session, context: &'c MeliorContext, @@ -178,13 +182,13 @@ fn compile_function_def<'c, 'this: 'c>( { let mut fn_compiler_ctx = compiler_ctx.clone(); - let fn_block = region.append_block(Block::new(&args)); + let mut fn_block = ®ion.append_block(Block::new(&args)); let blocks_arena = Bump::new(); let helper = BlockHelper { region: ®ion, blocks_arena: &blocks_arena, - last_block: Cell::new(&fn_block), + last_block: Cell::new(fn_block), }; // Push arguments into locals @@ -206,18 +210,18 @@ fn compile_function_def<'c, 'this: 'c>( context, &mut fn_compiler_ctx, &helper, - &fn_block, + fn_block, info, )?, Statement::Match(_) => todo!(), Statement::For(_) => todo!(), Statement::If(info) => { - compile_if_expr( + fn_block = compile_if_expr( session, context, &mut fn_compiler_ctx, &helper, - &fn_block, + fn_block, info, )?; } @@ -226,7 +230,7 @@ fn compile_function_def<'c, 'this: 'c>( context, &mut fn_compiler_ctx, &helper, - &fn_block, + fn_block, info, )?, Statement::Return(info) => compile_return_stmt( @@ -234,7 +238,7 @@ fn compile_function_def<'c, 'this: 'c>( context, &mut fn_compiler_ctx, &helper, - &fn_block, + fn_block, info, )?, Statement::While(_) => todo!(), @@ -260,7 +264,7 @@ fn compile_if_expr<'c, 'this: 'c>( helper: &BlockHelper<'c, 'this>, block: &'this Block<'c>, info: &IfExpr, -) -> Result<(), Box> { +) -> Result<&'this BlockRef<'c, 'this>, Box> { let condition = compile_expression( session, context, @@ -276,24 +280,122 @@ fn compile_if_expr<'c, 'this: 'c>( }), )?; - let true_successor = helper.append_block(Block::new(&[])); - let false_successor = helper.append_block(Block::new(&[])); - - let final_block = helper.append_block(Block::new(&[])); - true_successor.append_operation(cf::br(final_block, &[], Location::unknown(context))); - false_successor.append_operation(cf::br(final_block, &[], Location::unknown(context))); + let mut then_successor = helper.append_block(Block::new(&[])); + let mut else_successor = helper.append_block(Block::new(&[])); block.append_operation(cf::cond_br( context, condition, - true_successor, - false_successor, + then_successor, + else_successor, &[], &[], - Location::unknown(context), + get_named_location(context, "if"), )); - Ok(()) + { + let mut true_compiler_ctx = compiler_ctx.clone(); + for stmt in &info.contents { + match stmt { + Statement::Assign(info) => compile_assign_stmt( + session, + context, + &mut true_compiler_ctx, + helper, + then_successor, + info, + )?, + Statement::Match(_) => todo!(), + Statement::For(_) => todo!(), + Statement::If(info) => { + then_successor = compile_if_expr( + session, + context, + &mut true_compiler_ctx, + helper, + then_successor, + info, + )?; + } + Statement::Let(info) => compile_let_stmt( + session, + context, + &mut true_compiler_ctx, + helper, + then_successor, + info, + )?, + Statement::Return(info) => compile_return_stmt( + session, + context, + &mut true_compiler_ctx, + helper, + then_successor, + info, + )?, + Statement::While(_) => todo!(), + Statement::FnCall(_) => todo!(), + } + } + } + + if let Some(else_contents) = info.r#else.as_ref() { + let mut else_compiler_ctx = compiler_ctx.clone(); + for stmt in else_contents { + match stmt { + Statement::Assign(info) => compile_assign_stmt( + session, + context, + &mut else_compiler_ctx, + helper, + else_successor, + info, + )?, + Statement::Match(_) => todo!(), + Statement::For(_) => todo!(), + Statement::If(info) => { + else_successor = compile_if_expr( + session, + context, + &mut else_compiler_ctx, + helper, + else_successor, + info, + )?; + } + Statement::Let(info) => compile_let_stmt( + session, + context, + &mut else_compiler_ctx, + helper, + else_successor, + info, + )?, + Statement::Return(info) => compile_return_stmt( + session, + context, + &mut else_compiler_ctx, + helper, + else_successor, + info, + )?, + Statement::While(_) => todo!(), + Statement::FnCall(_) => todo!(), + } + } + } + + let final_block = helper.append_block(Block::new(&[])); + + if then_successor.terminator().is_none() { + then_successor.append_operation(cf::br(final_block, &[], Location::unknown(context))); + } + + if else_successor.terminator().is_none() { + else_successor.append_operation(cf::br(final_block, &[], Location::unknown(context))); + } + + Ok(final_block) } fn compile_let_stmt<'c, 'this: 'c>( diff --git a/crates/concrete_codegen_mlir/src/context.rs b/crates/concrete_codegen_mlir/src/context.rs index 63e97a4..1f1cabf 100644 --- a/crates/concrete_codegen_mlir/src/context.rs +++ b/crates/concrete_codegen_mlir/src/context.rs @@ -1,10 +1,10 @@ use std::error::Error; use concrete_ast::Program; -use concrete_session::Session; +use concrete_session::{config::DebugInfo, Session}; use melior::{ dialect::DialectRegistry, - ir::{Location, Module as MeliorModule}, + ir::{operation::OperationPrintingFlags, Location, Module as MeliorModule}, utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}, Context as MeliorContext, }; @@ -43,11 +43,16 @@ impl Context { super::codegen::compile_program(session, &self.melior_context, &melior_module, program)?; + let print_flags = OperationPrintingFlags::new().enable_debug_info(true, true); tracing::debug!( - "MLIR Code before passes:\n{:#?}", - melior_module.as_operation() + "MLIR Code before passes:\n{}", + melior_module + .as_operation() + .to_string_with_flags(print_flags)? ); + assert!(melior_module.as_operation().verify()); + // TODO: Add proper error handling. run_pass_manager(&self.melior_context, &mut melior_module).unwrap(); diff --git a/crates/concrete_codegen_mlir/src/lib.rs b/crates/concrete_codegen_mlir/src/lib.rs index e30fa58..959807b 100644 --- a/crates/concrete_codegen_mlir/src/lib.rs +++ b/crates/concrete_codegen_mlir/src/lib.rs @@ -40,6 +40,7 @@ mod pass_manager; pub fn compile(session: &Session, program: &Program) -> Result> { let context = Context::new(); let mlir_module = context.compile(session, program)?; + assert!(mlir_module.melior_module.as_operation().verify()); let object_path = compile_to_object(session, &mlir_module)?; diff --git a/examples/simple_if.con b/examples/simple_if.con index 2506005..2388ec9 100644 --- a/examples/simple_if.con +++ b/examples/simple_if.con @@ -1,11 +1,12 @@ mod Simple { fn main() -> bool { - let y: bool = check(4); + let y: bool = check(0); return y; } fn check(x: i64) -> bool { - if x == 0 { + let y: i64 = 0; + if x == y { return true; } return false; From 8e70167328439994c67c645d694570174e622f4a Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Mon, 15 Jan 2024 11:49:34 +0100 Subject: [PATCH 03/20] cleanup --- crates/concrete_codegen_mlir/src/codegen.rs | 71 +++++++++++---------- crates/concrete_codegen_mlir/src/context.rs | 2 +- 2 files changed, 38 insertions(+), 35 deletions(-) diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 42c07f2..2eb1fff 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -1,4 +1,4 @@ -use std::{cell::Cell, collections::HashMap, error::Error}; +use std::{collections::HashMap, error::Error}; use bumpalo::Bump; use concrete_ast::{ @@ -38,11 +38,29 @@ pub fn compile_program( } #[derive(Debug, Clone)] -pub struct LocalVar<'c, 'op> { +pub struct LocalVar<'ctx, 'parent: 'ctx> { pub type_spec: TypeSpec, // If it's none its on a register, otherwise allocated on the stack. - pub memref_type: Option>, - pub value: Value<'c, 'op>, + pub alloca: bool, + pub value: Value<'ctx, 'parent>, +} + +impl<'ctx, 'parent: 'ctx> LocalVar<'ctx, 'parent> { + pub fn param(value: Value<'ctx, 'parent>, type_spec: TypeSpec) -> Self { + Self { + value, + type_spec, + alloca: false, + } + } + + pub fn alloca(value: Value<'ctx, 'parent>, type_spec: TypeSpec) -> Self { + Self { + value, + type_spec, + alloca: true, + } + } } #[derive(Debug, Clone)] @@ -54,17 +72,13 @@ struct CompilerContext<'c, 'this: 'c> { struct BlockHelper<'ctx, 'this: 'ctx> { region: &'this Region<'ctx>, blocks_arena: &'this Bump, - last_block: Cell<&'this BlockRef<'ctx, 'this>>, } impl<'ctx, 'this> BlockHelper<'ctx, 'this> { pub fn append_block(&self, block: Block<'ctx>) -> &'this BlockRef<'ctx, 'this> { - let block = self - .region - .insert_block_after(*self.last_block.get(), block); + let block = self.region.append_block(block); let block_ref: &'this mut BlockRef<'ctx, 'this> = self.blocks_arena.alloc(block); - self.last_block.set(block_ref); block_ref } @@ -81,6 +95,8 @@ impl<'c, 'this> CompilerContext<'c, 'this> { "u32" | "i32" => IntegerType::new(context, 32).into(), "u16" | "i16" => IntegerType::new(context, 16).into(), "u8" | "i8" => IntegerType::new(context, 8).into(), + "f32" => Type::float32(context), + "f64" => Type::float64(context), "bool" => IntegerType::new(context, 1).into(), _ => todo!("custom type lookup"), }) @@ -155,6 +171,7 @@ fn compile_function_def<'c, 'this: 'c>( compiler_ctx: &mut CompilerContext<'c, 'this>, info: &FunctionDef, ) -> Result, Box> { + tracing::debug!("compiling function {:?}", info.decl.name.name); let location = get_location(context, session, &info.decl.name.span); // Setup function arguments @@ -188,18 +205,13 @@ fn compile_function_def<'c, 'this: 'c>( let helper = BlockHelper { region: ®ion, blocks_arena: &blocks_arena, - last_block: Cell::new(fn_block), }; // Push arguments into locals for (i, param) in info.decl.params.iter().enumerate() { fn_compiler_ctx.locals.insert( param.name.name.clone(), - LocalVar { - type_spec: param.r#type.clone(), - value: fn_block.argument(i)?.into(), - memref_type: None, - }, + LocalVar::param(fn_block.argument(i)?.into(), param.r#type.clone()), ); } @@ -443,14 +455,9 @@ fn compile_let_stmt<'c, 'this: 'c>( .into(); block.append_operation(memref::store(value, alloca, &[k0], location)); - compiler_ctx.locals.insert( - name.name.clone(), - LocalVar { - type_spec: r#type.clone(), - memref_type: Some(memref_type), - value: alloca, - }, - ); + compiler_ctx + .locals + .insert(name.name.clone(), LocalVar::alloca(alloca, r#type.clone())); Ok(()) } @@ -474,10 +481,7 @@ fn compile_assign_stmt<'c, 'this: 'c>( .expect("local should exist") .clone(); - assert!( - local.memref_type.is_some(), - "can only mutate local stack variables" - ); + assert!(local.alloca, "can only mutate local stack variables"); let location = get_location(context, session, &info.target.first.span); @@ -529,7 +533,7 @@ fn compile_expression<'c, 'this: 'c>( session: &Session, context: &'c MeliorContext, compiler_ctx: &mut CompilerContext<'c, 'this>, - helper: &BlockHelper<'c, 'this>, + _helper: &BlockHelper<'c, 'this>, block: &'this Block<'c>, info: &Expression, type_info: Option<&TypeSpec>, @@ -568,7 +572,7 @@ fn compile_expression<'c, 'this: 'c>( SimpleExpr::ConstFloat(_) => todo!(), SimpleExpr::ConstStr(_) => todo!(), SimpleExpr::Path(value) => { - compile_path_op(session, context, compiler_ctx, helper, block, value) + compile_path_op(session, context, compiler_ctx, block, value) } }, Expression::FnCall(value) => { @@ -592,7 +596,7 @@ fn compile_expression<'c, 'this: 'c>( session, context, compiler_ctx, - helper, + _helper, block, arg, Some(&arg_info.r#type), @@ -625,7 +629,7 @@ fn compile_expression<'c, 'this: 'c>( session, context, compiler_ctx, - helper, + _helper, block, lhs, type_info, @@ -634,7 +638,7 @@ fn compile_expression<'c, 'this: 'c>( session, context, compiler_ctx, - helper, + _helper, block, rhs, type_info, @@ -737,7 +741,6 @@ fn compile_path_op<'c, 'this: 'c>( session: &Session, context: &'c MeliorContext, compiler_ctx: &mut CompilerContext<'c, 'this>, - helper: &BlockHelper<'c, 'this>, block: &'this Block<'c>, path: &PathOp, ) -> Result, Box> { @@ -751,7 +754,7 @@ fn compile_path_op<'c, 'this: 'c>( let location = get_location(context, session, &path.first.span); - if let Some(_memref_type) = local.memref_type { + if local.alloca { let k0 = block .append_operation(arith::constant( context, diff --git a/crates/concrete_codegen_mlir/src/context.rs b/crates/concrete_codegen_mlir/src/context.rs index 1f1cabf..70961fd 100644 --- a/crates/concrete_codegen_mlir/src/context.rs +++ b/crates/concrete_codegen_mlir/src/context.rs @@ -1,7 +1,7 @@ use std::error::Error; use concrete_ast::Program; -use concrete_session::{config::DebugInfo, Session}; +use concrete_session::Session; use melior::{ dialect::DialectRegistry, ir::{operation::OperationPrintingFlags, Location, Module as MeliorModule}, From 9ea4627e43c4f0ff11059f8582a706f0e2aaebdb Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Mon, 15 Jan 2024 12:25:30 +0100 Subject: [PATCH 04/20] factorial with ifs --- Cargo.lock | 133 ++++++++- crates/concrete_codegen_mlir/src/codegen.rs | 300 ++++++++++++-------- crates/concrete_driver/Cargo.toml | 1 + crates/concrete_driver/src/lib.rs | 11 +- crates/concrete_parser/Cargo.toml | 3 +- crates/concrete_parser/src/error.rs | 125 ++++++-- crates/concrete_parser/src/lib.rs | 2 + crates/concrete_session/Cargo.toml | 1 + crates/concrete_session/src/lib.rs | 12 +- examples/{fib.con => factorial.con} | 0 examples/factorial_if.con | 14 + examples/simple_if.con | 14 - 12 files changed, 433 insertions(+), 183 deletions(-) rename examples/{fib.con => factorial.con} (100%) create mode 100644 examples/factorial_if.con delete mode 100644 examples/simple_if.con diff --git a/Cargo.lock b/Cargo.lock index 5defe42..d04781c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -95,6 +95,17 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6" +[[package]] +name = "ariadne" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd002a6223f12c7a95cdd4b1cb3a0149d22d37f7a9ecdb2cb691a071fe236c29" +dependencies = [ + "concolor", + "unicode-width", + "yansi", +] + [[package]] name = "ascii-canvas" version = "3.0.0" @@ -312,6 +323,26 @@ dependencies = [ "xdg", ] +[[package]] +name = "concolor" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b946244a988c390a94667ae0e3958411fa40cc46ea496a929b263d883f5f9c3" +dependencies = [ + "bitflags 1.3.2", + "concolor-query", + "is-terminal", +] + +[[package]] +name = "concolor-query" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88d11d52c3d7ca2e6d0040212be9e4dbbcd78b6447f535b6b561f449427944cf" +dependencies = [ + "windows-sys 0.45.0", +] + [[package]] name = "concrete" version = "0.1.0" @@ -344,6 +375,7 @@ dependencies = [ name = "concrete_driver" version = "0.1.0" dependencies = [ + "ariadne", "clap", "concrete_ast", "concrete_codegen_mlir", @@ -358,11 +390,12 @@ dependencies = [ name = "concrete_parser" version = "0.1.0" dependencies = [ + "ariadne", "concrete_ast", + "itertools 0.12.0", "lalrpop", "lalrpop-util", "logos", - "owo-colors", "salsa-2022", "tracing", ] @@ -370,6 +403,9 @@ dependencies = [ [[package]] name = "concrete_session" version = "0.1.0" +dependencies = [ + "ariadne", +] [[package]] name = "concrete_type_checker" @@ -754,6 +790,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.10" @@ -771,7 +816,7 @@ dependencies = [ "diff", "ena", "is-terminal", - "itertools", + "itertools 0.10.5", "lalrpop-util", "petgraph", "pico-args", @@ -1042,12 +1087,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" -[[package]] -name = "owo-colors" -version = "4.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "caff54706df99d2a78a5a4e3455ff45448d81ef1bb63c22cd14052ca0e993a3f" - [[package]] name = "parking_lot" version = "0.12.1" @@ -1664,6 +1703,12 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" +[[package]] +name = "unicode-width" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" + [[package]] name = "unicode-xid" version = "0.2.4" @@ -1813,6 +1858,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", +] + [[package]] name = "windows-sys" version = "0.48.0" @@ -1831,6 +1885,21 @@ dependencies = [ "windows-targets 0.52.0", ] +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + [[package]] name = "windows-targets" version = "0.48.5" @@ -1861,6 +1930,12 @@ dependencies = [ "windows_x86_64_msvc 0.52.0", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" @@ -1873,6 +1948,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -1885,6 +1966,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -1897,6 +1984,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -1909,6 +2002,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -1921,6 +2020,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" @@ -1933,6 +2038,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" @@ -1960,6 +2071,12 @@ dependencies = [ "linked-hash-map", ] +[[package]] +name = "yansi" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09041cd90cf85f7f8b2df60c646f853b7f535ce68f85244eb6731cf89fa498ec" + [[package]] name = "zerocopy" version = "0.7.32" diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 2eb1fff..9fed47a 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -156,9 +156,14 @@ fn compile_module( Ok(()) } -fn get_location<'c>(context: &'c MeliorContext, session: &Session, span: &Span) -> Location<'c> { - let (line, col) = session.get_line_and_column(span.from); - Location::new(context, &session.file_path.display().to_string(), line, col) +fn get_location<'c>(context: &'c MeliorContext, session: &Session, offset: usize) -> Location<'c> { + let (_, line, col) = session.source.get_offset_line(offset).unwrap(); + Location::new( + context, + &session.file_path.display().to_string(), + line + 1, + col + 1, + ) } fn get_named_location<'c>(context: &'c MeliorContext, name: &str) -> Location<'c> { @@ -172,7 +177,7 @@ fn compile_function_def<'c, 'this: 'c>( info: &FunctionDef, ) -> Result, Box> { tracing::debug!("compiling function {:?}", info.decl.name.name); - let location = get_location(context, session, &info.decl.name.span); + let location = get_location(context, session, info.decl.name.span.from); // Setup function arguments let mut args = Vec::with_capacity(info.decl.params.len()); @@ -180,7 +185,7 @@ fn compile_function_def<'c, 'this: 'c>( for param in &info.decl.params { let param_type = compiler_ctx.resolve_type_spec(context, ¶m.r#type)?; - let loc = get_location(context, session, ¶m.name.span); + let loc = get_location(context, session, param.name.span.from); args.push((param_type, loc)); fn_args_types.push(param_type); } @@ -199,7 +204,7 @@ fn compile_function_def<'c, 'this: 'c>( { let mut fn_compiler_ctx = compiler_ctx.clone(); - let mut fn_block = ®ion.append_block(Block::new(&args)); + let fn_block = ®ion.append_block(Block::new(&args)); let blocks_arena = Bump::new(); let helper = BlockHelper { @@ -215,46 +220,50 @@ fn compile_function_def<'c, 'this: 'c>( ); } + let mut fn_block = Some(fn_block); + for stmt in &info.body { - match stmt { - Statement::Assign(info) => compile_assign_stmt( - session, - context, - &mut fn_compiler_ctx, - &helper, - fn_block, - info, - )?, - Statement::Match(_) => todo!(), - Statement::For(_) => todo!(), - Statement::If(info) => { - fn_block = compile_if_expr( + if let Some(block) = fn_block { + match stmt { + Statement::Assign(info) => compile_assign_stmt( + session, + context, + &mut fn_compiler_ctx, + &helper, + block, + info, + )?, + Statement::Match(_) => todo!(), + Statement::For(_) => todo!(), + Statement::If(info) => { + fn_block = compile_if_expr( + session, + context, + &mut fn_compiler_ctx, + &helper, + block, + info, + )?; + } + Statement::Let(info) => compile_let_stmt( session, context, &mut fn_compiler_ctx, &helper, - fn_block, + block, info, - )?; + )?, + Statement::Return(info) => compile_return_stmt( + session, + context, + &mut fn_compiler_ctx, + &helper, + block, + info, + )?, + Statement::While(_) => todo!(), + Statement::FnCall(_) => todo!(), } - Statement::Let(info) => compile_let_stmt( - session, - context, - &mut fn_compiler_ctx, - &helper, - fn_block, - info, - )?, - Statement::Return(info) => compile_return_stmt( - session, - context, - &mut fn_compiler_ctx, - &helper, - fn_block, - info, - )?, - Statement::While(_) => todo!(), - Statement::FnCall(_) => todo!(), } } } @@ -276,7 +285,7 @@ fn compile_if_expr<'c, 'this: 'c>( helper: &BlockHelper<'c, 'this>, block: &'this Block<'c>, info: &IfExpr, -) -> Result<&'this BlockRef<'c, 'this>, Box> { +) -> Result>, Box> { let condition = compile_expression( session, context, @@ -292,8 +301,8 @@ fn compile_if_expr<'c, 'this: 'c>( }), )?; - let mut then_successor = helper.append_block(Block::new(&[])); - let mut else_successor = helper.append_block(Block::new(&[])); + let then_successor = helper.append_block(Block::new(&[])); + let else_successor = helper.append_block(Block::new(&[])); block.append_operation(cf::cond_br( context, @@ -305,48 +314,53 @@ fn compile_if_expr<'c, 'this: 'c>( get_named_location(context, "if"), )); + let mut then_successor = Some(then_successor); + let mut else_successor = Some(else_successor); + { let mut true_compiler_ctx = compiler_ctx.clone(); for stmt in &info.contents { - match stmt { - Statement::Assign(info) => compile_assign_stmt( - session, - context, - &mut true_compiler_ctx, - helper, - then_successor, - info, - )?, - Statement::Match(_) => todo!(), - Statement::For(_) => todo!(), - Statement::If(info) => { - then_successor = compile_if_expr( + if let Some(then_successor_block) = then_successor { + match stmt { + Statement::Assign(info) => compile_assign_stmt( session, context, &mut true_compiler_ctx, helper, - then_successor, + then_successor_block, info, - )?; + )?, + Statement::Match(_) => todo!(), + Statement::For(_) => todo!(), + Statement::If(info) => { + then_successor = compile_if_expr( + session, + context, + &mut true_compiler_ctx, + helper, + then_successor_block, + info, + )?; + } + Statement::Let(info) => compile_let_stmt( + session, + context, + &mut true_compiler_ctx, + helper, + then_successor_block, + info, + )?, + Statement::Return(info) => compile_return_stmt( + session, + context, + &mut true_compiler_ctx, + helper, + then_successor_block, + info, + )?, + Statement::While(_) => todo!(), + Statement::FnCall(_) => todo!(), } - Statement::Let(info) => compile_let_stmt( - session, - context, - &mut true_compiler_ctx, - helper, - then_successor, - info, - )?, - Statement::Return(info) => compile_return_stmt( - session, - context, - &mut true_compiler_ctx, - helper, - then_successor, - info, - )?, - Statement::While(_) => todo!(), - Statement::FnCall(_) => todo!(), } } } @@ -354,60 +368,104 @@ fn compile_if_expr<'c, 'this: 'c>( if let Some(else_contents) = info.r#else.as_ref() { let mut else_compiler_ctx = compiler_ctx.clone(); for stmt in else_contents { - match stmt { - Statement::Assign(info) => compile_assign_stmt( - session, - context, - &mut else_compiler_ctx, - helper, - else_successor, - info, - )?, - Statement::Match(_) => todo!(), - Statement::For(_) => todo!(), - Statement::If(info) => { - else_successor = compile_if_expr( + if let Some(else_successor_block) = else_successor { + match stmt { + Statement::Assign(info) => compile_assign_stmt( + session, + context, + &mut else_compiler_ctx, + helper, + else_successor_block, + info, + )?, + Statement::Match(_) => todo!(), + Statement::For(_) => todo!(), + Statement::If(info) => { + else_successor = compile_if_expr( + session, + context, + &mut else_compiler_ctx, + helper, + else_successor_block, + info, + )?; + } + Statement::Let(info) => compile_let_stmt( session, context, &mut else_compiler_ctx, helper, - else_successor, + else_successor_block, info, - )?; + )?, + Statement::Return(info) => compile_return_stmt( + session, + context, + &mut else_compiler_ctx, + helper, + else_successor_block, + info, + )?, + Statement::While(_) => todo!(), + Statement::FnCall(_) => todo!(), } - Statement::Let(info) => compile_let_stmt( - session, - context, - &mut else_compiler_ctx, - helper, - else_successor, - info, - )?, - Statement::Return(info) => compile_return_stmt( - session, - context, - &mut else_compiler_ctx, - helper, - else_successor, - info, - )?, - Statement::While(_) => todo!(), - Statement::FnCall(_) => todo!(), } } } - let final_block = helper.append_block(Block::new(&[])); - - if then_successor.terminator().is_none() { - then_successor.append_operation(cf::br(final_block, &[], Location::unknown(context))); - } + Ok(match (then_successor, else_successor) { + (None, None) => None, + (None, Some(else_successor)) => { + if else_successor.terminator().is_some() { + None + } else { + let final_block = helper.append_block(Block::new(&[])); + else_successor.append_operation(cf::br( + final_block, + &[], + Location::unknown(context), + )); + Some(final_block) + } + } + (Some(then_successor), None) => { + if then_successor.terminator().is_some() { + None + } else { + let final_block = helper.append_block(Block::new(&[])); + then_successor.append_operation(cf::br( + final_block, + &[], + Location::unknown(context), + )); + Some(final_block) + } + } + (Some(then_successor), Some(else_successor)) => { + if then_successor.terminator().is_some() && else_successor.terminator().is_some() { + None + } else { + let final_block = helper.append_block(Block::new(&[])); + if then_successor.terminator().is_none() { + then_successor.append_operation(cf::br( + final_block, + &[], + Location::unknown(context), + )); + } - if else_successor.terminator().is_none() { - else_successor.append_operation(cf::br(final_block, &[], Location::unknown(context))); - } + if else_successor.terminator().is_none() { + else_successor.append_operation(cf::br( + final_block, + &[], + Location::unknown(context), + )); + } - Ok(final_block) + Some(final_block) + } + } + }) } fn compile_let_stmt<'c, 'this: 'c>( @@ -430,7 +488,7 @@ fn compile_let_stmt<'c, 'this: 'c>( Some(r#type), )?; - let location = get_location(context, session, &name.span); + let location = get_location(context, session, name.span.from); let memref_type = MemRefType::new(value.r#type(), &[1], None, None); @@ -483,7 +541,7 @@ fn compile_assign_stmt<'c, 'this: 'c>( assert!(local.alloca, "can only mutate local stack variables"); - let location = get_location(context, session, &info.target.first.span); + let location = get_location(context, session, info.target.first.span.from); let value = compile_expression( session, @@ -577,7 +635,7 @@ fn compile_expression<'c, 'this: 'c>( }, Expression::FnCall(value) => { let mut args = Vec::with_capacity(value.args.len()); - let location = get_location(context, session, &value.target.span); + let location = get_location(context, session, value.target.span.from); let target_fn = compiler_ctx .functions @@ -752,7 +810,7 @@ fn compile_path_op<'c, 'this: 'c>( .get(&path.first.name) .expect("local not found"); - let location = get_location(context, session, &path.first.span); + let location = get_location(context, session, path.first.span.from); if local.alloca { let k0 = block diff --git a/crates/concrete_driver/Cargo.toml b/crates/concrete_driver/Cargo.toml index d329044..321a170 100644 --- a/crates/concrete_driver/Cargo.toml +++ b/crates/concrete_driver/Cargo.toml @@ -14,3 +14,4 @@ concrete_parser = { path = "../concrete_parser"} concrete_session = { path = "../concrete_session"} concrete_codegen_mlir = { path = "../concrete_codegen_mlir"} salsa = { git = "https://github.com/salsa-rs/salsa.git", package = "salsa-2022" } +ariadne = { version = "0.4.0", features = ["auto-color"] } diff --git a/crates/concrete_driver/src/lib.rs b/crates/concrete_driver/src/lib.rs index 437bc91..927097e 100644 --- a/crates/concrete_driver/src/lib.rs +++ b/crates/concrete_driver/src/lib.rs @@ -1,3 +1,4 @@ +use ariadne::Source; use clap::Parser; use concrete_codegen_mlir::linker::{link_binary, link_shared_lib}; use concrete_parser::{error::Diagnostics, ProgramSource}; @@ -32,7 +33,11 @@ pub fn main() -> Result<(), Box> { let args = CompilerArgs::parse(); let db = crate::db::Database::default(); - let source = ProgramSource::new(&db, std::fs::read_to_string(args.input.clone())?); + let source = ProgramSource::new( + &db, + std::fs::read_to_string(&args.input)?, + args.input.display().to_string(), + ); tracing::debug!("source code:\n{}", source.input(&db)); let program = match concrete_parser::parse_ast(&db, source) { Some(x) => x, @@ -44,7 +49,7 @@ pub fn main() -> Result<(), Box> { &db, source, ), ); - panic!(); + std::process::exit(1); } }; @@ -70,7 +75,7 @@ pub fn main() -> Result<(), Box> { } else { OptLevel::None }, - source: source.input(&db).to_string(), + source: Source::from(source.input(&db).to_string()), library: args.library, target_dir, output_file, diff --git a/crates/concrete_parser/Cargo.toml b/crates/concrete_parser/Cargo.toml index 7f03898..239eba6 100644 --- a/crates/concrete_parser/Cargo.toml +++ b/crates/concrete_parser/Cargo.toml @@ -10,8 +10,9 @@ lalrpop-util = { version = "0.20.0", features = ["unicode"] } logos = "0.13.0" tracing = { workspace = true } concrete_ast = { path = "../concrete_ast"} -owo-colors = "4.0.0" salsa = { git = "https://github.com/salsa-rs/salsa.git", package = "salsa-2022" } +ariadne = { version = "0.4.0", features = ["auto-color"] } +itertools = "0.12.0" [build-dependencies] lalrpop = "0.20.0" diff --git a/crates/concrete_parser/src/error.rs b/crates/concrete_parser/src/error.rs index 81c1e95..2d2665e 100644 --- a/crates/concrete_parser/src/error.rs +++ b/crates/concrete_parser/src/error.rs @@ -1,6 +1,12 @@ -use crate::{db::Db, lexer::LexicalError, tokens::Token, ProgramSource}; +use crate::{ + db::Db, + lexer::LexicalError, + tokens::{self, Token}, + ProgramSource, +}; +use ariadne::{ColorGenerator, Label, Report, ReportKind, Source}; +use itertools::Itertools; use lalrpop_util::ParseError; -use owo_colors::OwoColorize; pub type Error = ParseError; @@ -9,36 +15,103 @@ pub struct Diagnostics(Error); impl Diagnostics { pub fn dump(db: &dyn Db, source: ProgramSource, errors: &[Error]) { + let path = source.path(db); let source = source.input(db); - for err in errors { - match &err { - ParseError::InvalidToken { .. } => todo!(), + for error in errors { + let mut colors = ColorGenerator::new(); + let report = match error { + ParseError::InvalidToken { location } => { + let loc = *location; + Report::build(ReportKind::Error, path, loc) + .with_code("P1") + .with_label( + Label::new((path, loc..(loc + 1))) + .with_color(colors.next()) + .with_message("invalid token"), + ) + .with_label( + Label::new((path, (loc.saturating_sub(10))..(loc + 10))) + .with_message("There was a problem parsing part of this code."), + ) + .finish() + } ParseError::UnrecognizedEof { location, expected } => { - let location = *location; - let before = &source[0..location]; - let after = &source[location..]; - - print!("{}", before); - print!("$Got EOF, expected {:?}$", expected.green().bold()); - print!("{}", after); + let loc = *location; + Report::build(ReportKind::Error, path, loc) + .with_code("P2") + .with_label( + Label::new((path, loc..(loc + 1))) + .with_message("unrecognized eof") + .with_color(colors.next()), + ) + .with_note(format!( + "expected one of the following: {}", + expected.iter().join(", ") + )) + .with_label( + Label::new((path, (loc.saturating_sub(10))..(loc + 10))) + .with_message("There was a problem parsing part of this code."), + ) + .finish() } ParseError::UnrecognizedToken { token, expected } => { - let (l, ref tok, r) = *token; - let before = &source[0..l]; - let after = &source[r..]; - - print!("{}", before); - print!( - "$Got {:?}, expected {:?}$", - tok.bold().red(), - expected.green().bold() - ); - print!("{}", after); + Report::build(ReportKind::Error, path, token.0) + .with_code(3) + .with_label( + Label::new((path, token.0..token.2)) + .with_message(format!("unrecognized token {:?}", token.1)) + .with_color(colors.next()), + ) + .with_note(format!( + "expected one of the following: {}", + expected.iter().join(", ") + )) + .with_label( + Label::new((path, (token.0.saturating_sub(10))..(token.2 + 10))) + .with_message("There was a problem parsing part of this code."), + ) + .finish() } - ParseError::ExtraToken { .. } => todo!(), - ParseError::User { .. } => todo!(), - } + ParseError::ExtraToken { token } => Report::build(ReportKind::Error, path, token.0) + .with_code("P3") + .with_message("Extra token") + .with_label( + Label::new((path, token.0..token.2)) + .with_message(format!("unexpected extra token {:?}", token.1)), + ) + .finish(), + ParseError::User { error } => match error { + LexicalError::InvalidToken(err, range) => match err { + tokens::LexingError::NumberParseError => { + Report::build(ReportKind::Error, path, range.start) + .with_code(4) + .with_message("Error parsing literal number") + .with_label( + Label::new((path, range.start..range.end)) + .with_message("error parsing literal number") + .with_color(colors.next()), + ) + .finish() + } + tokens::LexingError::Other => { + Report::build(ReportKind::Error, path, range.start) + .with_code(4) + .with_message("Other error") + .with_label( + Label::new((path, range.start..range.end)) + .with_message("other error") + .with_color(colors.next()), + ) + .finish() + } + }, + }, + }; + + report + .eprint((path, Source::from(source))) + .expect("failed to print to stderr"); } } } diff --git a/crates/concrete_parser/src/lib.rs b/crates/concrete_parser/src/lib.rs index 4c5875a..d0b4e1b 100644 --- a/crates/concrete_parser/src/lib.rs +++ b/crates/concrete_parser/src/lib.rs @@ -19,6 +19,8 @@ pub mod grammar { pub struct ProgramSource { #[return_ref] pub input: String, + #[return_ref] + pub path: String, } // Todo: better error handling diff --git a/crates/concrete_session/Cargo.toml b/crates/concrete_session/Cargo.toml index 80ee2dd..f69684e 100644 --- a/crates/concrete_session/Cargo.toml +++ b/crates/concrete_session/Cargo.toml @@ -6,3 +6,4 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +ariadne = "0.4.0" diff --git a/crates/concrete_session/src/lib.rs b/crates/concrete_session/src/lib.rs index 975a392..0eccf96 100644 --- a/crates/concrete_session/src/lib.rs +++ b/crates/concrete_session/src/lib.rs @@ -1,3 +1,4 @@ +use ariadne::Source; use std::path::PathBuf; use config::{DebugInfo, OptLevel}; @@ -9,7 +10,7 @@ pub struct Session { pub file_path: PathBuf, pub debug_info: DebugInfo, pub optlevel: OptLevel, - pub source: String, // for debugging locations + pub source: Source, // for debugging locations /// True if it should be compiled as a library false for binary. pub library: bool, /// The directory where to store artifacts and intermediate files such as object files. @@ -17,12 +18,3 @@ pub struct Session { pub output_file: PathBuf, // todo: include target, host, etc } - -impl Session { - pub fn get_line_and_column(&self, offset: usize) -> (usize, usize) { - let sl = &self.source[0..offset]; - let line_count = sl.lines().count(); - let column = sl.rfind('\n').unwrap_or(0); - (line_count, column) - } -} diff --git a/examples/fib.con b/examples/factorial.con similarity index 100% rename from examples/fib.con rename to examples/factorial.con diff --git a/examples/factorial_if.con b/examples/factorial_if.con new file mode 100644 index 0000000..1daf463 --- /dev/null +++ b/examples/factorial_if.con @@ -0,0 +1,14 @@ +mod Simple { + fn main() -> i64 { + return factorial(4); + } + + fn factorial(n: i64) -> i64 { + let zero: i64 = 0; + if n == zero { + return 1; + } else { + return n * factorial(n - 1); + } + } +} diff --git a/examples/simple_if.con b/examples/simple_if.con deleted file mode 100644 index 2388ec9..0000000 --- a/examples/simple_if.con +++ /dev/null @@ -1,14 +0,0 @@ -mod Simple { - fn main() -> bool { - let y: bool = check(0); - return y; - } - - fn check(x: i64) -> bool { - let y: i64 = 0; - if x == y { - return true; - } - return false; - } -} From f9822c9da10d99458b9073a2b5c0fe89607beb9a Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Mon, 15 Jan 2024 12:39:03 +0100 Subject: [PATCH 05/20] std --- crates/concrete_codegen_mlir/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/concrete_codegen_mlir/Cargo.toml b/crates/concrete_codegen_mlir/Cargo.toml index 8264ed6..91fb12d 100644 --- a/crates/concrete_codegen_mlir/Cargo.toml +++ b/crates/concrete_codegen_mlir/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -bumpalo = "3.14.0" +bumpalo = { version = "3.14.0", features = ["std"] } concrete_ast = { path = "../concrete_ast"} concrete_session = { path = "../concrete_session"} llvm-sys = "170.0.1" From f1314fd2dfc71c4f07dcc66425fbebaff971cd29 Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Mon, 15 Jan 2024 13:00:53 +0100 Subject: [PATCH 06/20] better naming --- crates/concrete_codegen_mlir/src/codegen.rs | 213 +++++++++----------- 1 file changed, 93 insertions(+), 120 deletions(-) diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 9fed47a..eedda9c 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -64,32 +64,32 @@ impl<'ctx, 'parent: 'ctx> LocalVar<'ctx, 'parent> { } #[derive(Debug, Clone)] -struct CompilerContext<'c, 'this: 'c> { - pub locals: HashMap>, +struct ScopeContext<'c, 'parent: 'c> { + pub locals: HashMap>, pub functions: HashMap, } -struct BlockHelper<'ctx, 'this: 'ctx> { - region: &'this Region<'ctx>, - blocks_arena: &'this Bump, +struct BlockHelper<'ctx, 'region: 'ctx> { + region: &'region Region<'ctx>, + blocks_arena: &'region Bump, } -impl<'ctx, 'this> BlockHelper<'ctx, 'this> { - pub fn append_block(&self, block: Block<'ctx>) -> &'this BlockRef<'ctx, 'this> { +impl<'ctx, 'region> BlockHelper<'ctx, 'region> { + pub fn append_block(&self, block: Block<'ctx>) -> &'region BlockRef<'ctx, 'region> { let block = self.region.append_block(block); - let block_ref: &'this mut BlockRef<'ctx, 'this> = self.blocks_arena.alloc(block); + let block_ref: &'region mut BlockRef<'ctx, 'region> = self.blocks_arena.alloc(block); block_ref } } -impl<'c, 'this> CompilerContext<'c, 'this> { +impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> { fn resolve_type( &self, - context: &'c MeliorContext, + context: &'ctx MeliorContext, name: &str, - ) -> Result, Box> { + ) -> Result, Box> { Ok(match name { "u64" | "i64" => IntegerType::new(context, 64).into(), "u32" | "i32" => IntegerType::new(context, 32).into(), @@ -104,9 +104,9 @@ impl<'c, 'this> CompilerContext<'c, 'this> { fn resolve_type_spec( &self, - context: &'c MeliorContext, + context: &'ctx MeliorContext, spec: &TypeSpec, - ) -> Result, Box> { + ) -> Result, Box> { Ok(match spec { TypeSpec::Simple { name } => self.resolve_type(context, &name.name)?, TypeSpec::Generic { @@ -127,7 +127,7 @@ fn compile_module( let body = mlir_module.body(); - let mut compiler_ctx: CompilerContext = CompilerContext { + let mut scope_ctx: ScopeContext = ScopeContext { functions: Default::default(), locals: Default::default(), }; @@ -135,7 +135,7 @@ fn compile_module( // save all function signatures for statement in &module.contents { if let ModuleDefItem::Function(info) = statement { - compiler_ctx + scope_ctx .functions .insert(info.decl.name.name.clone(), info.clone()); } @@ -145,7 +145,7 @@ fn compile_module( match statement { ModuleDefItem::Constant(_) => todo!(), ModuleDefItem::Function(info) => { - let op = compile_function_def(session, context, &mut compiler_ctx, info)?; + let op = compile_function_def(session, context, &scope_ctx, info)?; body.append_operation(op); } ModuleDefItem::Record(_) => todo!(), @@ -156,7 +156,11 @@ fn compile_module( Ok(()) } -fn get_location<'c>(context: &'c MeliorContext, session: &Session, offset: usize) -> Location<'c> { +fn get_location<'ctx>( + context: &'ctx MeliorContext, + session: &Session, + offset: usize, +) -> Location<'ctx> { let (_, line, col) = session.source.get_offset_line(offset).unwrap(); Location::new( context, @@ -166,16 +170,16 @@ fn get_location<'c>(context: &'c MeliorContext, session: &Session, offset: usize ) } -fn get_named_location<'c>(context: &'c MeliorContext, name: &str) -> Location<'c> { +fn get_named_location<'ctx>(context: &'ctx MeliorContext, name: &str) -> Location<'ctx> { Location::name(context, name, Location::unknown(context)) } -fn compile_function_def<'c, 'this: 'c>( +fn compile_function_def<'ctx, 'parent: 'ctx>( session: &Session, - context: &'c MeliorContext, - compiler_ctx: &mut CompilerContext<'c, 'this>, + context: &'ctx MeliorContext, + scope_ctx: &ScopeContext<'ctx, 'parent>, info: &FunctionDef, -) -> Result, Box> { +) -> Result, Box> { tracing::debug!("compiling function {:?}", info.decl.name.name); let location = get_location(context, session, info.decl.name.span.from); @@ -184,7 +188,7 @@ fn compile_function_def<'c, 'this: 'c>( let mut fn_args_types = Vec::with_capacity(info.decl.params.len()); for param in &info.decl.params { - let param_type = compiler_ctx.resolve_type_spec(context, ¶m.r#type)?; + let param_type = scope_ctx.resolve_type_spec(context, ¶m.r#type)?; let loc = get_location(context, session, param.name.span.from); args.push((param_type, loc)); fn_args_types.push(param_type); @@ -194,7 +198,7 @@ fn compile_function_def<'c, 'this: 'c>( let region = Region::new(); let return_type = if let Some(ret_type) = &info.decl.ret_type { - vec![compiler_ctx.resolve_type_spec(context, ret_type)?] + vec![scope_ctx.resolve_type_spec(context, ret_type)?] } else { vec![] }; @@ -203,7 +207,7 @@ fn compile_function_def<'c, 'this: 'c>( TypeAttribute::new(FunctionType::new(context, &fn_args_types, &return_type).into()); { - let mut fn_compiler_ctx = compiler_ctx.clone(); + let mut scope_ctx = scope_ctx.clone(); let fn_block = ®ion.append_block(Block::new(&args)); let blocks_arena = Bump::new(); @@ -214,7 +218,7 @@ fn compile_function_def<'c, 'this: 'c>( // Push arguments into locals for (i, param) in info.decl.params.iter().enumerate() { - fn_compiler_ctx.locals.insert( + scope_ctx.locals.insert( param.name.name.clone(), LocalVar::param(fn_block.argument(i)?.into(), param.r#type.clone()), ); @@ -225,42 +229,27 @@ fn compile_function_def<'c, 'this: 'c>( for stmt in &info.body { if let Some(block) = fn_block { match stmt { - Statement::Assign(info) => compile_assign_stmt( - session, - context, - &mut fn_compiler_ctx, - &helper, - block, - info, - )?, + Statement::Assign(info) => { + compile_assign_stmt(session, context, &mut scope_ctx, &helper, block, info)? + } Statement::Match(_) => todo!(), Statement::For(_) => todo!(), Statement::If(info) => { fn_block = compile_if_expr( session, context, - &mut fn_compiler_ctx, + &mut scope_ctx, &helper, block, info, )?; } - Statement::Let(info) => compile_let_stmt( - session, - context, - &mut fn_compiler_ctx, - &helper, - block, - info, - )?, - Statement::Return(info) => compile_return_stmt( - session, - context, - &mut fn_compiler_ctx, - &helper, - block, - info, - )?, + Statement::Let(info) => { + compile_let_stmt(session, context, &mut scope_ctx, &helper, block, info)? + } + Statement::Return(info) => { + compile_return_stmt(session, context, &mut scope_ctx, &helper, block, info)? + } Statement::While(_) => todo!(), Statement::FnCall(_) => todo!(), } @@ -281,7 +270,7 @@ fn compile_function_def<'c, 'this: 'c>( fn compile_if_expr<'c, 'this: 'c>( session: &Session, context: &'c MeliorContext, - compiler_ctx: &mut CompilerContext<'c, 'this>, + scope_ctx: &mut ScopeContext<'c, 'this>, helper: &BlockHelper<'c, 'this>, block: &'this Block<'c>, info: &IfExpr, @@ -289,7 +278,7 @@ fn compile_if_expr<'c, 'this: 'c>( let condition = compile_expression( session, context, - compiler_ctx, + scope_ctx, helper, block, &info.value, @@ -318,14 +307,14 @@ fn compile_if_expr<'c, 'this: 'c>( let mut else_successor = Some(else_successor); { - let mut true_compiler_ctx = compiler_ctx.clone(); + let mut then_scope_ctx = scope_ctx.clone(); for stmt in &info.contents { if let Some(then_successor_block) = then_successor { match stmt { Statement::Assign(info) => compile_assign_stmt( session, context, - &mut true_compiler_ctx, + &mut then_scope_ctx, helper, then_successor_block, info, @@ -336,7 +325,7 @@ fn compile_if_expr<'c, 'this: 'c>( then_successor = compile_if_expr( session, context, - &mut true_compiler_ctx, + &mut then_scope_ctx, helper, then_successor_block, info, @@ -345,7 +334,7 @@ fn compile_if_expr<'c, 'this: 'c>( Statement::Let(info) => compile_let_stmt( session, context, - &mut true_compiler_ctx, + &mut then_scope_ctx, helper, then_successor_block, info, @@ -353,7 +342,7 @@ fn compile_if_expr<'c, 'this: 'c>( Statement::Return(info) => compile_return_stmt( session, context, - &mut true_compiler_ctx, + &mut then_scope_ctx, helper, then_successor_block, info, @@ -366,14 +355,14 @@ fn compile_if_expr<'c, 'this: 'c>( } if let Some(else_contents) = info.r#else.as_ref() { - let mut else_compiler_ctx = compiler_ctx.clone(); + let mut else_scope_ctx = scope_ctx.clone(); for stmt in else_contents { if let Some(else_successor_block) = else_successor { match stmt { Statement::Assign(info) => compile_assign_stmt( session, context, - &mut else_compiler_ctx, + &mut else_scope_ctx, helper, else_successor_block, info, @@ -384,7 +373,7 @@ fn compile_if_expr<'c, 'this: 'c>( else_successor = compile_if_expr( session, context, - &mut else_compiler_ctx, + &mut else_scope_ctx, helper, else_successor_block, info, @@ -393,7 +382,7 @@ fn compile_if_expr<'c, 'this: 'c>( Statement::Let(info) => compile_let_stmt( session, context, - &mut else_compiler_ctx, + &mut else_scope_ctx, helper, else_successor_block, info, @@ -401,7 +390,7 @@ fn compile_if_expr<'c, 'this: 'c>( Statement::Return(info) => compile_return_stmt( session, context, - &mut else_compiler_ctx, + &mut else_scope_ctx, helper, else_successor_block, info, @@ -468,12 +457,12 @@ fn compile_if_expr<'c, 'this: 'c>( }) } -fn compile_let_stmt<'c, 'this: 'c>( +fn compile_let_stmt<'ctx, 'parent: 'ctx>( session: &Session, - context: &'c MeliorContext, - compiler_ctx: &mut CompilerContext<'c, 'this>, - helper: &BlockHelper<'c, 'this>, - block: &'this Block<'c>, + context: &'ctx MeliorContext, + scope_ctx: &mut ScopeContext<'ctx, 'parent>, + helper: &BlockHelper<'ctx, 'parent>, + block: &'parent Block<'ctx>, info: &LetStmt, ) -> Result<(), Box> { match &info.target { @@ -481,7 +470,7 @@ fn compile_let_stmt<'c, 'this: 'c>( let value = compile_expression( session, context, - compiler_ctx, + scope_ctx, helper, block, &info.value, @@ -513,7 +502,7 @@ fn compile_let_stmt<'c, 'this: 'c>( .into(); block.append_operation(memref::store(value, alloca, &[k0], location)); - compiler_ctx + scope_ctx .locals .insert(name.name.clone(), LocalVar::alloca(alloca, r#type.clone())); @@ -523,17 +512,17 @@ fn compile_let_stmt<'c, 'this: 'c>( } } -fn compile_assign_stmt<'c, 'this: 'c>( +fn compile_assign_stmt<'ctx, 'parent: 'ctx>( session: &Session, - context: &'c MeliorContext, - compiler_ctx: &mut CompilerContext<'c, 'this>, - helper: &BlockHelper<'c, 'this>, - block: &'this Block<'c>, + context: &'ctx MeliorContext, + scope_ctx: &mut ScopeContext<'ctx, 'parent>, + helper: &BlockHelper<'ctx, 'parent>, + block: &'parent Block<'ctx>, info: &AssignStmt, ) -> Result<(), Box> { // todo: implement properly for structs, right now only really works for simple variables. - let local = compiler_ctx + let local = scope_ctx .locals .get(&info.target.first.name) .expect("local should exist") @@ -546,7 +535,7 @@ fn compile_assign_stmt<'c, 'this: 'c>( let value = compile_expression( session, context, - compiler_ctx, + scope_ctx, helper, block, &info.value, @@ -566,18 +555,18 @@ fn compile_assign_stmt<'c, 'this: 'c>( Ok(()) } -fn compile_return_stmt<'c, 'this: 'c>( +fn compile_return_stmt<'ctx, 'parent: 'ctx>( session: &Session, - context: &'c MeliorContext, - compiler_ctx: &mut CompilerContext<'c, 'this>, - helper: &BlockHelper<'c, 'this>, - block: &'this Block<'c>, + context: &'ctx MeliorContext, + scope_ctx: &mut ScopeContext<'ctx, 'parent>, + helper: &BlockHelper<'ctx, 'parent>, + block: &'parent Block<'ctx>, info: &ReturnStmt, ) -> Result<(), Box> { let value = compile_expression( session, context, - compiler_ctx, + scope_ctx, helper, block, &info.value, @@ -587,15 +576,15 @@ fn compile_return_stmt<'c, 'this: 'c>( Ok(()) } -fn compile_expression<'c, 'this: 'c>( +fn compile_expression<'ctx, 'parent: 'ctx>( session: &Session, - context: &'c MeliorContext, - compiler_ctx: &mut CompilerContext<'c, 'this>, - _helper: &BlockHelper<'c, 'this>, - block: &'this Block<'c>, + context: &'ctx MeliorContext, + scope_ctx: &mut ScopeContext<'ctx, 'parent>, + _helper: &BlockHelper<'ctx, 'parent>, + block: &'parent Block<'ctx>, info: &Expression, type_info: Option<&TypeSpec>, -) -> Result, Box> { +) -> Result, Box> { let location = Location::unknown(context); match info { Expression::Simple(simple) => match simple { @@ -617,7 +606,7 @@ fn compile_expression<'c, 'this: 'c>( } SimpleExpr::ConstInt(value) => { let int_type = if let Some(type_info) = type_info { - compiler_ctx.resolve_type_spec(context, type_info)? + scope_ctx.resolve_type_spec(context, type_info)? } else { IntegerType::new(context, 64).into() }; @@ -629,15 +618,13 @@ fn compile_expression<'c, 'this: 'c>( } SimpleExpr::ConstFloat(_) => todo!(), SimpleExpr::ConstStr(_) => todo!(), - SimpleExpr::Path(value) => { - compile_path_op(session, context, compiler_ctx, block, value) - } + SimpleExpr::Path(value) => compile_path_op(session, context, scope_ctx, block, value), }, Expression::FnCall(value) => { let mut args = Vec::with_capacity(value.args.len()); let location = get_location(context, session, value.target.span.from); - let target_fn = compiler_ctx + let target_fn = scope_ctx .functions .get(&value.target.name) .expect("function not found") @@ -653,7 +640,7 @@ fn compile_expression<'c, 'this: 'c>( let value = compile_expression( session, context, - compiler_ctx, + scope_ctx, _helper, block, arg, @@ -663,7 +650,7 @@ fn compile_expression<'c, 'this: 'c>( } let return_type = if let Some(ret_type) = &target_fn.decl.ret_type { - vec![compiler_ctx.resolve_type_spec(context, ret_type)?] + vec![scope_ctx.resolve_type_spec(context, ret_type)?] } else { vec![] }; @@ -683,24 +670,10 @@ fn compile_expression<'c, 'this: 'c>( Expression::If(_) => todo!(), Expression::UnaryOp(_, _) => todo!(), Expression::BinaryOp(lhs, op, rhs) => { - let lhs = compile_expression( - session, - context, - compiler_ctx, - _helper, - block, - lhs, - type_info, - )?; - let rhs = compile_expression( - session, - context, - compiler_ctx, - _helper, - block, - rhs, - type_info, - )?; + let lhs = + compile_expression(session, context, scope_ctx, _helper, block, lhs, type_info)?; + let rhs = + compile_expression(session, context, scope_ctx, _helper, block, rhs, type_info)?; let op = match op { // todo: check signedness @@ -795,17 +768,17 @@ fn compile_expression<'c, 'this: 'c>( } } -fn compile_path_op<'c, 'this: 'c>( +fn compile_path_op<'ctx, 'parent: 'ctx>( session: &Session, - context: &'c MeliorContext, - compiler_ctx: &mut CompilerContext<'c, 'this>, - block: &'this Block<'c>, + context: &'ctx MeliorContext, + scope_ctx: &mut ScopeContext<'ctx, 'parent>, + block: &'parent Block<'ctx>, path: &PathOp, -) -> Result, Box> { +) -> Result, Box> { // For now only simple variables work. // TODO: implement properly, this requires having structs implemented. - let local = compiler_ctx + let local = scope_ctx .locals .get(&path.first.name) .expect("local not found"); From 6ca3d1512f7bb3237e62b4d76e3223aba2857494 Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Mon, 15 Jan 2024 13:01:24 +0100 Subject: [PATCH 07/20] forgot this one --- crates/concrete_codegen_mlir/src/codegen.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index eedda9c..648d7ff 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -64,8 +64,8 @@ impl<'ctx, 'parent: 'ctx> LocalVar<'ctx, 'parent> { } #[derive(Debug, Clone)] -struct ScopeContext<'c, 'parent: 'c> { - pub locals: HashMap>, +struct ScopeContext<'ctx, 'parent: 'ctx> { + pub locals: HashMap>, pub functions: HashMap, } From f42c92f0fba0a922a2b4bcf454da127f5ff44666 Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Mon, 15 Jan 2024 13:06:32 +0100 Subject: [PATCH 08/20] add docs on if --- crates/concrete_codegen_mlir/src/codegen.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 648d7ff..bd62647 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -267,6 +267,19 @@ fn compile_function_def<'ctx, 'parent: 'ctx>( )) } +/// Compile a if expression / statement +/// +/// This returns a block if any branch doesn't have a function return terminator. +/// For example, if the if branch has a return and the else branch has a return, +/// it wouldn't make sense to add a merging block and MLIR would give a error saying there is a operation after a terminator. +/// +/// The returned block is the merger block, the one we jump after processing the if branches. +/// +/// ```text +/// - then block - +/// - if (prev block) - < > merge block -- +/// - else block - +/// ``` fn compile_if_expr<'c, 'this: 'c>( session: &Session, context: &'c MeliorContext, From 1c2298d8964c67610d86038c813d2d673ad8aeca Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Mon, 15 Jan 2024 13:55:37 +0100 Subject: [PATCH 09/20] cleanup --- crates/concrete_codegen_mlir/src/codegen.rs | 258 ++++++++------------ 1 file changed, 107 insertions(+), 151 deletions(-) diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index bd62647..1cb27c2 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -3,7 +3,9 @@ use std::{collections::HashMap, error::Error}; use bumpalo::Bump; use concrete_ast::{ common::{Ident, Span}, - expressions::{ArithOp, BinaryOp, CmpOp, Expression, IfExpr, LogicOp, PathOp, SimpleExpr}, + expressions::{ + ArithOp, BinaryOp, CmpOp, Expression, FnCallOp, IfExpr, LogicOp, PathOp, SimpleExpr, + }, functions::FunctionDef, modules::{Module, ModuleDefItem}, statements::{AssignStmt, LetStmt, LetStmtTarget, ReturnStmt, Statement}, @@ -228,31 +230,8 @@ fn compile_function_def<'ctx, 'parent: 'ctx>( for stmt in &info.body { if let Some(block) = fn_block { - match stmt { - Statement::Assign(info) => { - compile_assign_stmt(session, context, &mut scope_ctx, &helper, block, info)? - } - Statement::Match(_) => todo!(), - Statement::For(_) => todo!(), - Statement::If(info) => { - fn_block = compile_if_expr( - session, - context, - &mut scope_ctx, - &helper, - block, - info, - )?; - } - Statement::Let(info) => { - compile_let_stmt(session, context, &mut scope_ctx, &helper, block, info)? - } - Statement::Return(info) => { - compile_return_stmt(session, context, &mut scope_ctx, &helper, block, info)? - } - Statement::While(_) => todo!(), - Statement::FnCall(_) => todo!(), - } + fn_block = + compile_statement(session, context, &mut scope_ctx, &helper, block, stmt)?; } } } @@ -267,6 +246,36 @@ fn compile_function_def<'ctx, 'parent: 'ctx>( )) } +fn compile_statement<'c, 'this: 'c>( + session: &Session, + context: &'c MeliorContext, + scope_ctx: &mut ScopeContext<'c, 'this>, + helper: &BlockHelper<'c, 'this>, + block: &'this BlockRef<'c, 'this>, + info: &Statement, +) -> Result>, Box> { + match info { + Statement::Assign(info) => { + compile_assign_stmt(session, context, scope_ctx, helper, block, info)? + } + Statement::Match(_) => todo!(), + Statement::For(_) => todo!(), + Statement::If(info) => { + return compile_if_expr(session, context, scope_ctx, helper, block, info); + } + Statement::Let(info) => compile_let_stmt(session, context, scope_ctx, helper, block, info)?, + Statement::Return(info) => { + compile_return_stmt(session, context, scope_ctx, helper, block, info)? + } + Statement::While(_) => todo!(), + Statement::FnCall(info) => { + compile_fn_call(session, context, scope_ctx, helper, block, info)?; + } + } + + Ok(Some(block)) +} + /// Compile a if expression / statement /// /// This returns a block if any branch doesn't have a function return terminator. @@ -285,7 +294,7 @@ fn compile_if_expr<'c, 'this: 'c>( context: &'c MeliorContext, scope_ctx: &mut ScopeContext<'c, 'this>, helper: &BlockHelper<'c, 'this>, - block: &'this Block<'c>, + block: &'this BlockRef<'c, 'this>, info: &IfExpr, ) -> Result>, Box> { let condition = compile_expression( @@ -323,46 +332,14 @@ fn compile_if_expr<'c, 'this: 'c>( let mut then_scope_ctx = scope_ctx.clone(); for stmt in &info.contents { if let Some(then_successor_block) = then_successor { - match stmt { - Statement::Assign(info) => compile_assign_stmt( - session, - context, - &mut then_scope_ctx, - helper, - then_successor_block, - info, - )?, - Statement::Match(_) => todo!(), - Statement::For(_) => todo!(), - Statement::If(info) => { - then_successor = compile_if_expr( - session, - context, - &mut then_scope_ctx, - helper, - then_successor_block, - info, - )?; - } - Statement::Let(info) => compile_let_stmt( - session, - context, - &mut then_scope_ctx, - helper, - then_successor_block, - info, - )?, - Statement::Return(info) => compile_return_stmt( - session, - context, - &mut then_scope_ctx, - helper, - then_successor_block, - info, - )?, - Statement::While(_) => todo!(), - Statement::FnCall(_) => todo!(), - } + then_successor = compile_statement( + session, + context, + &mut then_scope_ctx, + helper, + then_successor_block, + stmt, + )?; } } } @@ -371,46 +348,14 @@ fn compile_if_expr<'c, 'this: 'c>( let mut else_scope_ctx = scope_ctx.clone(); for stmt in else_contents { if let Some(else_successor_block) = else_successor { - match stmt { - Statement::Assign(info) => compile_assign_stmt( - session, - context, - &mut else_scope_ctx, - helper, - else_successor_block, - info, - )?, - Statement::Match(_) => todo!(), - Statement::For(_) => todo!(), - Statement::If(info) => { - else_successor = compile_if_expr( - session, - context, - &mut else_scope_ctx, - helper, - else_successor_block, - info, - )?; - } - Statement::Let(info) => compile_let_stmt( - session, - context, - &mut else_scope_ctx, - helper, - else_successor_block, - info, - )?, - Statement::Return(info) => compile_return_stmt( - session, - context, - &mut else_scope_ctx, - helper, - else_successor_block, - info, - )?, - Statement::While(_) => todo!(), - Statement::FnCall(_) => todo!(), - } + else_successor = compile_statement( + session, + context, + &mut else_scope_ctx, + helper, + else_successor_block, + stmt, + )?; } } } @@ -634,50 +579,7 @@ fn compile_expression<'ctx, 'parent: 'ctx>( SimpleExpr::Path(value) => compile_path_op(session, context, scope_ctx, block, value), }, Expression::FnCall(value) => { - let mut args = Vec::with_capacity(value.args.len()); - let location = get_location(context, session, value.target.span.from); - - let target_fn = scope_ctx - .functions - .get(&value.target.name) - .expect("function not found") - .clone(); - - assert_eq!( - value.args.len(), - target_fn.decl.params.len(), - "parameter length doesnt match" - ); - - for (arg, arg_info) in value.args.iter().zip(&target_fn.decl.params) { - let value = compile_expression( - session, - context, - scope_ctx, - _helper, - block, - arg, - Some(&arg_info.r#type), - )?; - args.push(value); - } - - let return_type = if let Some(ret_type) = &target_fn.decl.ret_type { - vec![scope_ctx.resolve_type_spec(context, ret_type)?] - } else { - vec![] - }; - - Ok(block - .append_operation(func::call( - context, - FlatSymbolRefAttribute::new(context, &value.target.name), - &args, - &return_type, - location, - )) - .result(0)? - .into()) + compile_fn_call(session, context, scope_ctx, _helper, block, value) } Expression::Match(_) => todo!(), Expression::If(_) => todo!(), @@ -781,6 +683,60 @@ fn compile_expression<'ctx, 'parent: 'ctx>( } } +fn compile_fn_call<'ctx, 'parent: 'ctx>( + session: &Session, + context: &'ctx MeliorContext, + scope_ctx: &mut ScopeContext<'ctx, 'parent>, + _helper: &BlockHelper<'ctx, 'parent>, + block: &'parent Block<'ctx>, + info: &FnCallOp, +) -> Result, Box> { + let mut args = Vec::with_capacity(info.args.len()); + let location = get_location(context, session, info.target.span.from); + + let target_fn = scope_ctx + .functions + .get(&info.target.name) + .expect("function not found") + .clone(); + + assert_eq!( + info.args.len(), + target_fn.decl.params.len(), + "parameter length doesnt match" + ); + + for (arg, arg_info) in info.args.iter().zip(&target_fn.decl.params) { + let value = compile_expression( + session, + context, + scope_ctx, + _helper, + block, + arg, + Some(&arg_info.r#type), + )?; + args.push(value); + } + + let return_type = if let Some(ret_type) = &target_fn.decl.ret_type { + vec![scope_ctx.resolve_type_spec(context, ret_type)?] + } else { + vec![] + }; + + Ok(block + .append_operation(func::call( + context, + FlatSymbolRefAttribute::new(context, &info.target.name), + &args, + &return_type, + location, + )) + .result(0)? + .into()) +} + fn compile_path_op<'ctx, 'parent: 'ctx>( session: &Session, context: &'ctx MeliorContext, From d6ff3ab3ef7d686db5285a20a6de8b501778bca2 Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Mon, 15 Jan 2024 14:34:42 +0100 Subject: [PATCH 10/20] implement while too --- crates/concrete_codegen_mlir/src/codegen.rs | 77 ++++++++++++++++++--- examples/while.con | 17 +++++ 2 files changed, 86 insertions(+), 8 deletions(-) create mode 100644 examples/while.con diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 1cb27c2..81716e9 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -8,7 +8,7 @@ use concrete_ast::{ }, functions::FunctionDef, modules::{Module, ModuleDefItem}, - statements::{AssignStmt, LetStmt, LetStmtTarget, ReturnStmt, Statement}, + statements::{AssignStmt, LetStmt, LetStmtTarget, ReturnStmt, Statement, WhileStmt}, types::TypeSpec, Program, }; @@ -267,7 +267,9 @@ fn compile_statement<'c, 'this: 'c>( Statement::Return(info) => { compile_return_stmt(session, context, scope_ctx, helper, block, info)? } - Statement::While(_) => todo!(), + Statement::While(info) => { + return compile_while(session, context, scope_ctx, helper, block, info); + } Statement::FnCall(info) => { compile_fn_call(session, context, scope_ctx, helper, block, info)?; } @@ -304,12 +306,7 @@ fn compile_if_expr<'c, 'this: 'c>( helper, block, &info.value, - Some(&TypeSpec::Simple { - name: Ident { - name: "bool".to_string(), - span: Span::new(0, 0), - }, - }), + None, )?; let then_successor = helper.append_block(Block::new(&[])); @@ -415,6 +412,70 @@ fn compile_if_expr<'c, 'this: 'c>( }) } +fn compile_while<'c, 'this: 'c>( + session: &Session, + context: &'c MeliorContext, + scope_ctx: &mut ScopeContext<'c, 'this>, + helper: &BlockHelper<'c, 'this>, + block: &'this BlockRef<'c, 'this>, + info: &WhileStmt, +) -> Result>, Box> { + let location = Location::unknown(context); + + let check_block = helper.append_block(Block::new(&[])); + + block.append_operation(cf::br(check_block, &[], location)); + + let body_block = helper.append_block(Block::new(&[])); + let merge_block = helper.append_block(Block::new(&[])); + + let condition = compile_expression( + session, + context, + scope_ctx, + helper, + check_block, + &info.value, + None, + )?; + + check_block.append_operation(cf::cond_br( + context, + condition, + body_block, + merge_block, + &[], + &[], + location, + )); + + let mut body_block = Some(body_block); + + { + let mut body_scope_ctx = scope_ctx.clone(); + for stmt in &info.contents { + if let Some(then_successor_block) = body_block { + body_block = compile_statement( + session, + context, + &mut body_scope_ctx, + helper, + then_successor_block, + stmt, + )?; + } + } + } + + match body_block { + Some(body_block) => { + body_block.append_operation(cf::br(check_block, &[], location)); + Ok(Some(merge_block)) + } + None => Ok(Some(merge_block)), + } +} + fn compile_let_stmt<'ctx, 'parent: 'ctx>( session: &Session, context: &'ctx MeliorContext, diff --git a/examples/while.con b/examples/while.con new file mode 100644 index 0000000..6ab6742 --- /dev/null +++ b/examples/while.con @@ -0,0 +1,17 @@ +mod Simple { + fn main() -> i64 { + return my_func(4); + } + + fn my_func(times: i64) -> i64 { + let mut n: i64 = times; + let mut result: i64 = 1; + + while n > 0 { + result = result + result; + n = n - 1; + } + + return result; + } +} From e95fadfc1a2353da8aaa73f0280c8fb6b0083098 Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Mon, 15 Jan 2024 15:05:43 +0100 Subject: [PATCH 11/20] add integration tests --- Cargo.lock | 20 ++++ crates/concrete_codegen_mlir/src/codegen.rs | 1 - crates/concrete_driver/Cargo.toml | 3 + crates/concrete_driver/src/db.rs | 2 +- crates/concrete_driver/src/lib.rs | 7 +- crates/concrete_driver/tests/common.rs | 102 ++++++++++++++++++++ crates/concrete_driver/tests/programs.rs | 83 ++++++++++++++++ crates/concrete_session/src/lib.rs | 12 +++ 8 files changed, 227 insertions(+), 3 deletions(-) create mode 100644 crates/concrete_driver/tests/common.rs create mode 100644 crates/concrete_driver/tests/programs.rs diff --git a/Cargo.lock b/Cargo.lock index d04781c..5216d43 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -382,6 +382,7 @@ dependencies = [ "concrete_parser", "concrete_session", "salsa-2022", + "tempfile", "tracing", "tracing-subscriber", ] @@ -669,6 +670,12 @@ dependencies = [ "regex", ] +[[package]] +name = "fastrand" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" + [[package]] name = "fixedbitset" version = "0.4.2" @@ -1535,6 +1542,19 @@ dependencies = [ "thiserror", ] +[[package]] +name = "tempfile" +version = "3.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01ce4141aa927a6d1bd34a041795abd0db1cccba5d5f24b009f694bdf3a1f3fa" +dependencies = [ + "cfg-if", + "fastrand", + "redox_syscall", + "rustix", + "windows-sys 0.52.0", +] + [[package]] name = "term" version = "0.7.0" diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 81716e9..dc0fb29 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -2,7 +2,6 @@ use std::{collections::HashMap, error::Error}; use bumpalo::Bump; use concrete_ast::{ - common::{Ident, Span}, expressions::{ ArithOp, BinaryOp, CmpOp, Expression, FnCallOp, IfExpr, LogicOp, PathOp, SimpleExpr, }, diff --git a/crates/concrete_driver/Cargo.toml b/crates/concrete_driver/Cargo.toml index 321a170..35b9084 100644 --- a/crates/concrete_driver/Cargo.toml +++ b/crates/concrete_driver/Cargo.toml @@ -15,3 +15,6 @@ concrete_session = { path = "../concrete_session"} concrete_codegen_mlir = { path = "../concrete_codegen_mlir"} salsa = { git = "https://github.com/salsa-rs/salsa.git", package = "salsa-2022" } ariadne = { version = "0.4.0", features = ["auto-color"] } + +[dev-dependencies] +tempfile = "3.9.0" diff --git a/crates/concrete_driver/src/db.rs b/crates/concrete_driver/src/db.rs index 3ca1f02..8713b82 100644 --- a/crates/concrete_driver/src/db.rs +++ b/crates/concrete_driver/src/db.rs @@ -14,7 +14,7 @@ impl Db for T where T: ?Sized + salsa::DbWithJar + salsa::DbWithJar, } diff --git a/crates/concrete_driver/src/lib.rs b/crates/concrete_driver/src/lib.rs index 927097e..8c4735d 100644 --- a/crates/concrete_driver/src/lib.rs +++ b/crates/concrete_driver/src/lib.rs @@ -85,7 +85,12 @@ pub fn main() -> Result<(), Box> { let object_path = concrete_codegen_mlir::compile(&session, &program)?; if session.library { - link_shared_lib(&object_path, &session.output_file.with_extension("so"))?; + link_shared_lib( + &object_path, + &session + .output_file + .with_extension(Session::get_platform_library_ext()), + )?; } else { link_binary(&object_path, &session.output_file.with_extension(""))?; } diff --git a/crates/concrete_driver/tests/common.rs b/crates/concrete_driver/tests/common.rs new file mode 100644 index 0000000..c16c320 --- /dev/null +++ b/crates/concrete_driver/tests/common.rs @@ -0,0 +1,102 @@ +use std::{ + borrow::Cow, + fmt, + path::{Path, PathBuf}, + process::Output, +}; + +use ariadne::Source; +use concrete_codegen_mlir::linker::{link_binary, link_shared_lib}; +use concrete_parser::{error::Diagnostics, ProgramSource}; +use concrete_session::{ + config::{DebugInfo, OptLevel}, + Session, +}; +use tempfile::TempDir; + +#[derive(Debug, Clone)] +struct TestError(Cow<'static, str>); + +impl fmt::Display for TestError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&self.0) + } +} + +impl std::error::Error for TestError {} + +#[derive(Debug)] +pub struct CompileResult { + pub folder: TempDir, + pub object_file: PathBuf, + pub binary_file: PathBuf, +} + +pub fn compile_program( + source: &str, + name: &str, + library: bool, +) -> Result> { + let db = concrete_driver::db::Database::default(); + let source = ProgramSource::new(&db, source.to_string(), name.to_string()); + tracing::debug!("source code:\n{}", source.input(&db)); + let program = match concrete_parser::parse_ast(&db, source) { + Some(x) => x, + None => { + Diagnostics::dump( + &db, + source, + &concrete_parser::parse_ast::accumulated::( + &db, source, + ), + ); + return Err(Box::new(TestError("error compiling".into()))); + } + }; + + let test_dir = tempfile::tempdir()?; + let test_dir_path = test_dir.path(); + // todo: find a better name, "target" would clash with rust if running in the source tree. + let target_dir = test_dir_path.join("build_artifacts/"); + let output_file = target_dir.join(PathBuf::from(name)); + let output_file = if library { + output_file.with_extension(Session::get_platform_library_ext()) + } else { + output_file.with_extension("") + }; + + let session = Session { + file_path: PathBuf::from(name), + debug_info: DebugInfo::Full, + optlevel: OptLevel::None, + source: Source::from(source.input(&db).to_string()), + library, + target_dir, + output_file, + }; + + let object_path = concrete_codegen_mlir::compile(&session, &program)?; + + if library { + link_shared_lib( + &object_path, + &session + .output_file + .with_extension(Session::get_platform_library_ext()), + )?; + } else { + link_binary(&object_path, &session.output_file.with_extension(""))?; + } + + Ok(CompileResult { + folder: test_dir, + object_file: object_path, + binary_file: session.output_file, + }) +} + +pub fn run_program(program: &Path) -> Result { + std::process::Command::new(program) + .spawn()? + .wait_with_output() +} diff --git a/crates/concrete_driver/tests/programs.rs b/crates/concrete_driver/tests/programs.rs new file mode 100644 index 0000000..2f25fa8 --- /dev/null +++ b/crates/concrete_driver/tests/programs.rs @@ -0,0 +1,83 @@ +use common::{compile_program, run_program}; + +mod common; + +#[test] +fn test_while() { + let source = r#" + mod Simple { + fn main() -> i64 { + return my_func(4); + } + + fn my_func(times: i64) -> i64 { + let mut n: i64 = times; + let mut result: i64 = 1; + + while n > 0 { + result = result + result; + n = n - 1; + } + + return result; + } + } + "#; + + let result = compile_program(source, "while", false).expect("failed to compile"); + + let output = run_program(&result.binary_file).expect("failed to run"); + let code = output.status.code().unwrap(); + assert_eq!(code, 16); +} + +#[test] +fn test_factorial_with_if() { + let source = r#" + mod Simple { + fn main() -> i64 { + return factorial(4); + } + + fn factorial(n: i64) -> i64 { + let zero: i64 = 0; + if n == zero { + return 1; + } else { + return n * factorial(n - 1); + } + } + } + "#; + + let result = compile_program(source, "factorial", false).expect("failed to compile"); + + let output = run_program(&result.binary_file).expect("failed to run"); + let code = output.status.code().unwrap(); + assert_eq!(code, 24); +} + +#[test] +fn test_simple_add() { + let source = r#" + mod Simple { + fn main() -> i32 { + let x: i32 = 2; + let y: i32 = 4; + return add_plus_two(x, y); + } + + fn add_plus_two(x: i32, y: i32) -> i32 { + let mut z: i32 = 1; + z = z + 1; + return x + y + z; + } + } + "#; + + let result = compile_program(source, "simple_add", false).expect("failed to compile"); + + let output = run_program(&result.binary_file).expect("failed to run"); + let code = output.status.code().unwrap(); + assert_eq!(code, 8); +} diff --git a/crates/concrete_session/src/lib.rs b/crates/concrete_session/src/lib.rs index 0eccf96..9c715bf 100644 --- a/crates/concrete_session/src/lib.rs +++ b/crates/concrete_session/src/lib.rs @@ -18,3 +18,15 @@ pub struct Session { pub output_file: PathBuf, // todo: include target, host, etc } + +impl Session { + pub fn get_platform_library_ext() -> &'static str { + if cfg!(target_os = "macos") { + "dylib" + } else if cfg!(target_os = "windows") { + "dll" + } else { + "so" + } + } +} From 555317c9d8ee4010a5f283c7f2c659bc3068b6f7 Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Mon, 15 Jan 2024 15:25:00 +0100 Subject: [PATCH 12/20] try --- .github/workflows/ci.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cd63ed6..086d866 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -69,6 +69,8 @@ jobs: keys-asc: https://apt.llvm.org/llvm-snapshot.gpg.key - name: Install LLVM run: sudo apt-get install llvm-17 llvm-17-dev llvm-17-runtime clang-17 clang-tools-17 lld-17 libpolly-17-dev libmlir-17-dev mlir-17-tools + - name: Install Link deps + run: sudo apt-get install libc-dev build-essentials - name: test run: make test @@ -100,6 +102,8 @@ jobs: keys-asc: https://apt.llvm.org/llvm-snapshot.gpg.key - name: Install LLVM run: sudo apt-get install llvm-17 llvm-17-dev llvm-17-runtime clang-17 clang-tools-17 lld-17 libpolly-17-dev libmlir-17-dev mlir-17-tools + - name: Install Link deps + run: sudo apt-get install libc-dev build-essentials - name: Install cargo-llvm-cov uses: taiki-e/install-action@cargo-llvm-cov - name: test and generate coverage From ce64bed3378ea15f596074295a9cdd5abc2c140e Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Mon, 15 Jan 2024 15:28:53 +0100 Subject: [PATCH 13/20] try to fix ci --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 086d866..b4d8db7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -70,7 +70,7 @@ jobs: - name: Install LLVM run: sudo apt-get install llvm-17 llvm-17-dev llvm-17-runtime clang-17 clang-tools-17 lld-17 libpolly-17-dev libmlir-17-dev mlir-17-tools - name: Install Link deps - run: sudo apt-get install libc-dev build-essentials + run: sudo apt-get install libc-dev - name: test run: make test @@ -103,7 +103,7 @@ jobs: - name: Install LLVM run: sudo apt-get install llvm-17 llvm-17-dev llvm-17-runtime clang-17 clang-tools-17 lld-17 libpolly-17-dev libmlir-17-dev mlir-17-tools - name: Install Link deps - run: sudo apt-get install libc-dev build-essentials + run: sudo apt-get install libc-dev - name: Install cargo-llvm-cov uses: taiki-e/install-action@cargo-llvm-cov - name: test and generate coverage From 0a368c12ec0ff74a25073a06676e7fb0f968da7e Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Mon, 15 Jan 2024 16:26:48 +0100 Subject: [PATCH 14/20] try to fix ci --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b4d8db7..c48e794 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -70,7 +70,7 @@ jobs: - name: Install LLVM run: sudo apt-get install llvm-17 llvm-17-dev llvm-17-runtime clang-17 clang-tools-17 lld-17 libpolly-17-dev libmlir-17-dev mlir-17-tools - name: Install Link deps - run: sudo apt-get install libc-dev + run: sudo apt-get install libc-dev build-essential - name: test run: make test @@ -103,7 +103,7 @@ jobs: - name: Install LLVM run: sudo apt-get install llvm-17 llvm-17-dev llvm-17-runtime clang-17 clang-tools-17 lld-17 libpolly-17-dev libmlir-17-dev mlir-17-tools - name: Install Link deps - run: sudo apt-get install libc-dev + run: sudo apt-get install libc-dev build-essential - name: Install cargo-llvm-cov uses: taiki-e/install-action@cargo-llvm-cov - name: test and generate coverage From 1b8fc6379cb13cb2e728e2b575f35d1e0b9feaa4 Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Mon, 15 Jan 2024 16:36:29 +0100 Subject: [PATCH 15/20] ubuntu is special --- crates/concrete_codegen_mlir/src/linker.rs | 27 +++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/crates/concrete_codegen_mlir/src/linker.rs b/crates/concrete_codegen_mlir/src/linker.rs index 6fb2f78..f4ade60 100644 --- a/crates/concrete_codegen_mlir/src/linker.rs +++ b/crates/concrete_codegen_mlir/src/linker.rs @@ -64,6 +64,22 @@ pub fn link_binary(input_path: &Path, output_filename: &Path) -> Result<(), std: } #[cfg(target_os = "linux")] { + let (scrt1, crti, crtn) = { + if file_exists("/usr/lib64/Scrt1.o") { + ( + "/usr/lib64/Scrt1.o", + "/usr/lib64/crti.o", + "/usr/lib64/crtn.o", + ) + } else { + ( + "/lib/x86_64-linux-gnu/Scrt1.o", + "/lib/x86_64-linux-gnu/crti.o", + "/lib/x86_64-linux-gnu/crtn.o", + ) + } + }; + &[ "-pie", "--hash-style=gnu", @@ -72,16 +88,17 @@ pub fn link_binary(input_path: &Path, output_filename: &Path) -> Result<(), std: "/lib64/ld-linux-x86-64.so.2", "-m", "elf_x86_64", - "/usr/lib64/Scrt1.o", - "/usr/lib64/crti.o", + scrt1, + crti, "-o", &output_filename.display().to_string(), "-L/lib64", "-L/usr/lib64", + "-L/lib/x86_64-linux-gnu", "-zrelro", "--no-as-needed", "-lc", - "/usr/lib64/crtn.o", + crtn, &input_path.display().to_string(), ] } @@ -96,3 +113,7 @@ pub fn link_binary(input_path: &Path, output_filename: &Path) -> Result<(), std: proc.wait_with_output()?; Ok(()) } + +fn file_exists(path: &str) -> bool { + Path::new(path).exists() +} From f79bcb4671502583e3a9c96dc9eba5521fa72dfd Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Mon, 15 Jan 2024 16:50:44 +0100 Subject: [PATCH 16/20] check if its signed or float in arith ops --- crates/concrete_codegen_mlir/src/codegen.rs | 70 ++++++++++++++++++--- 1 file changed, 61 insertions(+), 9 deletions(-) diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index dc0fb29..53662cc 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -68,6 +68,7 @@ impl<'ctx, 'parent: 'ctx> LocalVar<'ctx, 'parent> { struct ScopeContext<'ctx, 'parent: 'ctx> { pub locals: HashMap>, pub functions: HashMap, + pub function: Option, } struct BlockHelper<'ctx, 'region: 'ctx> { @@ -116,6 +117,22 @@ impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> { } => self.resolve_type(context, &name.name)?, }) } + + fn is_type_signed(&self, type_info: &TypeSpec) -> bool { + let signed = ["i8", "i16", "i32", "i64", "i128"]; + match type_info { + TypeSpec::Simple { name } => signed.contains(&name.name.as_str()), + TypeSpec::Generic { name, .. } => signed.contains(&name.name.as_str()), + } + } + + fn is_float(&self, type_info: &TypeSpec) -> bool { + let signed = ["f32", "f64"]; + match type_info { + TypeSpec::Simple { name } => signed.contains(&name.name.as_str()), + TypeSpec::Generic { name, .. } => signed.contains(&name.name.as_str()), + } + } } fn compile_module( @@ -131,6 +148,7 @@ fn compile_module( let mut scope_ctx: ScopeContext = ScopeContext { functions: Default::default(), locals: Default::default(), + function: None, }; // save all function signatures @@ -146,6 +164,8 @@ fn compile_module( match statement { ModuleDefItem::Constant(_) => todo!(), ModuleDefItem::Function(info) => { + let mut scope_ctx = scope_ctx.clone(); + scope_ctx.function = Some(info.clone()); let op = compile_function_def(session, context, &scope_ctx, info)?; body.append_operation(op); } @@ -588,7 +608,14 @@ fn compile_return_stmt<'ctx, 'parent: 'ctx>( helper, block, &info.value, - None, + scope_ctx + .function + .as_ref() + .unwrap() + .decl + .ret_type + .clone() + .as_ref(), )?; block.append_operation(func::r#return(&[value], Location::unknown(context))); Ok(()) @@ -651,14 +678,39 @@ fn compile_expression<'ctx, 'parent: 'ctx>( compile_expression(session, context, scope_ctx, _helper, block, rhs, type_info)?; let op = match op { - // todo: check signedness - BinaryOp::Arith(arith_op) => match arith_op { - ArithOp::Add => arith::addi(lhs, rhs, location), - ArithOp::Sub => arith::subi(lhs, rhs, location), - ArithOp::Mul => arith::muli(lhs, rhs, location), - ArithOp::Div => arith::divsi(lhs, rhs, location), - ArithOp::Mod => arith::remsi(lhs, rhs, location), - }, + BinaryOp::Arith(arith_op) => { + let type_info = type_info.expect("type info missing"); + + if scope_ctx.is_float(type_info) { + match arith_op { + ArithOp::Add => arith::addf(lhs, rhs, location), + ArithOp::Sub => arith::subf(lhs, rhs, location), + ArithOp::Mul => arith::mulf(lhs, rhs, location), + ArithOp::Div => arith::divf(lhs, rhs, location), + ArithOp::Mod => arith::remf(lhs, rhs, location), + } + } else { + match arith_op { + ArithOp::Add => arith::addi(lhs, rhs, location), + ArithOp::Sub => arith::subi(lhs, rhs, location), + ArithOp::Mul => arith::muli(lhs, rhs, location), + ArithOp::Div => { + if scope_ctx.is_type_signed(type_info) { + arith::divsi(lhs, rhs, location) + } else { + arith::divui(lhs, rhs, location) + } + } + ArithOp::Mod => { + if scope_ctx.is_type_signed(type_info) { + arith::remsi(lhs, rhs, location) + } else { + arith::remui(lhs, rhs, location) + } + } + } + } + } BinaryOp::Logic(logic_op) => match logic_op { LogicOp::And => { let const_true = block From fffd75f88f9a5f46ef6d4d2162762f262e8a2bc8 Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Mon, 15 Jan 2024 17:03:43 +0100 Subject: [PATCH 17/20] better --- crates/concrete_driver/tests/programs.rs | 3 +-- examples/factorial_if.con | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/crates/concrete_driver/tests/programs.rs b/crates/concrete_driver/tests/programs.rs index 2f25fa8..442b01f 100644 --- a/crates/concrete_driver/tests/programs.rs +++ b/crates/concrete_driver/tests/programs.rs @@ -40,8 +40,7 @@ fn test_factorial_with_if() { } fn factorial(n: i64) -> i64 { - let zero: i64 = 0; - if n == zero { + if n == 0 { return 1; } else { return n * factorial(n - 1); diff --git a/examples/factorial_if.con b/examples/factorial_if.con index 1daf463..cf441cf 100644 --- a/examples/factorial_if.con +++ b/examples/factorial_if.con @@ -4,8 +4,7 @@ mod Simple { } fn factorial(n: i64) -> i64 { - let zero: i64 = 0; - if n == zero { + if n == 0 { return 1; } else { return n * factorial(n - 1); From 9617498442e7961c4a0b62b296527be16782ec1b Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Mon, 15 Jan 2024 17:08:01 +0100 Subject: [PATCH 18/20] add fib and update readme --- README.md | 26 +++++++++++++----------- crates/concrete_driver/tests/programs.rs | 25 +++++++++++++++++++++++ examples/fib_if.con | 13 ++++++++++++ 3 files changed, 52 insertions(+), 12 deletions(-) create mode 100644 examples/fib_if.con diff --git a/README.md b/README.md index 6791844..122a5c3 100644 --- a/README.md +++ b/README.md @@ -124,22 +124,24 @@ But we want to take a different path with respect to: - No marker traits like Send, Sync for concurrency. The runtime will take care of that. ## Syntax -``` -mod FibonacciModule { - - pub fib(x: u64) -> u64 { - match x { - // we can match literal values - 0 | 1 -> x, - n -> fib(n-1) + fib(n-2) - } - } +```rust +mod Fibonacci { + fn main() -> i64 { + return fib(10); + } + + pub fn fib(n: u64) -> u64 { + if n < 2 { + return n; + } + + return fib(n - 1) + fib(n - 2); + } } ``` -``` +```rust mod Option { - pub enum Option { None, Some(T), diff --git a/crates/concrete_driver/tests/programs.rs b/crates/concrete_driver/tests/programs.rs index 442b01f..e10a08a 100644 --- a/crates/concrete_driver/tests/programs.rs +++ b/crates/concrete_driver/tests/programs.rs @@ -56,6 +56,31 @@ fn test_factorial_with_if() { assert_eq!(code, 24); } +#[test] +fn test_fib_with_if() { + let source = r#" + mod Fibonacci { + fn main() -> i64 { + return fib(10); + } + + pub fn fib(n: u64) -> u64 { + if n < 2 { + return n; + } + + return fib(n - 1) + fib(n - 2); + } + } + "#; + + let result = compile_program(source, "fib", false).expect("failed to compile"); + + let output = run_program(&result.binary_file).expect("failed to run"); + let code = output.status.code().unwrap(); + assert_eq!(code, 55); +} + #[test] fn test_simple_add() { let source = r#" diff --git a/examples/fib_if.con b/examples/fib_if.con new file mode 100644 index 0000000..4863382 --- /dev/null +++ b/examples/fib_if.con @@ -0,0 +1,13 @@ +mod Fibonacci { + fn main() -> i64 { + return fib(10); + } + + pub fn fib(n: u64) -> u64 { + if n < 2 { + return n; + } + + return fib(n - 1) + fib(n - 2); + } +} From 799df9bbad9bba8728c95f24f6d8e6145dcfd652 Mon Sep 17 00:00:00 2001 From: Edgar Luque Date: Mon, 15 Jan 2024 21:03:39 +0100 Subject: [PATCH 19/20] make code simpler --- crates/concrete_codegen_mlir/src/codegen.rs | 162 +++++++------------- 1 file changed, 55 insertions(+), 107 deletions(-) diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 53662cc..c3a9415 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -229,7 +229,7 @@ fn compile_function_def<'ctx, 'parent: 'ctx>( { let mut scope_ctx = scope_ctx.clone(); - let fn_block = ®ion.append_block(Block::new(&args)); + let mut fn_block = ®ion.append_block(Block::new(&args)); let blocks_arena = Bump::new(); let helper = BlockHelper { @@ -245,13 +245,9 @@ fn compile_function_def<'ctx, 'parent: 'ctx>( ); } - let mut fn_block = Some(fn_block); - for stmt in &info.body { - if let Some(block) = fn_block { - fn_block = - compile_statement(session, context, &mut scope_ctx, &helper, block, stmt)?; - } + fn_block = + compile_statement(session, context, &mut scope_ctx, &helper, fn_block, stmt)?; } } @@ -270,9 +266,9 @@ fn compile_statement<'c, 'this: 'c>( context: &'c MeliorContext, scope_ctx: &mut ScopeContext<'c, 'this>, helper: &BlockHelper<'c, 'this>, - block: &'this BlockRef<'c, 'this>, + mut block: &'this BlockRef<'c, 'this>, info: &Statement, -) -> Result>, Box> { +) -> Result<&'this BlockRef<'c, 'this>, Box> { match info { Statement::Assign(info) => { compile_assign_stmt(session, context, scope_ctx, helper, block, info)? @@ -280,21 +276,21 @@ fn compile_statement<'c, 'this: 'c>( Statement::Match(_) => todo!(), Statement::For(_) => todo!(), Statement::If(info) => { - return compile_if_expr(session, context, scope_ctx, helper, block, info); + block = compile_if_expr(session, context, scope_ctx, helper, block, info)?; } Statement::Let(info) => compile_let_stmt(session, context, scope_ctx, helper, block, info)?, Statement::Return(info) => { compile_return_stmt(session, context, scope_ctx, helper, block, info)? } Statement::While(info) => { - return compile_while(session, context, scope_ctx, helper, block, info); + block = compile_while(session, context, scope_ctx, helper, block, info)?; } Statement::FnCall(info) => { compile_fn_call(session, context, scope_ctx, helper, block, info)?; } } - Ok(Some(block)) + Ok(block) } /// Compile a if expression / statement @@ -317,7 +313,7 @@ fn compile_if_expr<'c, 'this: 'c>( helper: &BlockHelper<'c, 'this>, block: &'this BlockRef<'c, 'this>, info: &IfExpr, -) -> Result>, Box> { +) -> Result<&'this BlockRef<'c, 'this>, Box> { let condition = compile_expression( session, context, @@ -328,8 +324,8 @@ fn compile_if_expr<'c, 'this: 'c>( None, )?; - let then_successor = helper.append_block(Block::new(&[])); - let else_successor = helper.append_block(Block::new(&[])); + let mut then_successor = helper.append_block(Block::new(&[])); + let mut else_successor = helper.append_block(Block::new(&[])); block.append_operation(cf::cond_br( context, @@ -341,94 +337,50 @@ fn compile_if_expr<'c, 'this: 'c>( get_named_location(context, "if"), )); - let mut then_successor = Some(then_successor); - let mut else_successor = Some(else_successor); - { let mut then_scope_ctx = scope_ctx.clone(); for stmt in &info.contents { - if let Some(then_successor_block) = then_successor { - then_successor = compile_statement( - session, - context, - &mut then_scope_ctx, - helper, - then_successor_block, - stmt, - )?; - } + then_successor = compile_statement( + session, + context, + &mut then_scope_ctx, + helper, + then_successor, + stmt, + )?; } } if let Some(else_contents) = info.r#else.as_ref() { let mut else_scope_ctx = scope_ctx.clone(); for stmt in else_contents { - if let Some(else_successor_block) = else_successor { - else_successor = compile_statement( - session, - context, - &mut else_scope_ctx, - helper, - else_successor_block, - stmt, - )?; - } + else_successor = compile_statement( + session, + context, + &mut else_scope_ctx, + helper, + else_successor, + stmt, + )?; } } - Ok(match (then_successor, else_successor) { - (None, None) => None, - (None, Some(else_successor)) => { - if else_successor.terminator().is_some() { - None - } else { - let final_block = helper.append_block(Block::new(&[])); - else_successor.append_operation(cf::br( - final_block, - &[], - Location::unknown(context), - )); - Some(final_block) - } - } - (Some(then_successor), None) => { - if then_successor.terminator().is_some() { - None - } else { - let final_block = helper.append_block(Block::new(&[])); - then_successor.append_operation(cf::br( - final_block, - &[], - Location::unknown(context), - )); - Some(final_block) - } - } - (Some(then_successor), Some(else_successor)) => { - if then_successor.terminator().is_some() && else_successor.terminator().is_some() { - None - } else { - let final_block = helper.append_block(Block::new(&[])); - if then_successor.terminator().is_none() { - then_successor.append_operation(cf::br( - final_block, - &[], - Location::unknown(context), - )); - } + // both branches return + if then_successor.terminator().is_some() && else_successor.terminator().is_some() { + return Ok(then_successor); + } - if else_successor.terminator().is_none() { - else_successor.append_operation(cf::br( - final_block, - &[], - Location::unknown(context), - )); - } + let merge_block = helper.append_block(Block::new(&[])); - Some(final_block) - } - } - }) + if then_successor.terminator().is_none() { + then_successor.append_operation(cf::br(merge_block, &[], Location::unknown(context))); + } + + if else_successor.terminator().is_none() { + else_successor.append_operation(cf::br(merge_block, &[], Location::unknown(context))); + } + + Ok(merge_block) } fn compile_while<'c, 'this: 'c>( @@ -438,7 +390,7 @@ fn compile_while<'c, 'this: 'c>( helper: &BlockHelper<'c, 'this>, block: &'this BlockRef<'c, 'this>, info: &WhileStmt, -) -> Result>, Box> { +) -> Result<&'this BlockRef<'c, 'this>, Box> { let location = Location::unknown(context); let check_block = helper.append_block(Block::new(&[])); @@ -468,31 +420,27 @@ fn compile_while<'c, 'this: 'c>( location, )); - let mut body_block = Some(body_block); + let mut body_block = body_block; { let mut body_scope_ctx = scope_ctx.clone(); for stmt in &info.contents { - if let Some(then_successor_block) = body_block { - body_block = compile_statement( - session, - context, - &mut body_scope_ctx, - helper, - then_successor_block, - stmt, - )?; - } + body_block = compile_statement( + session, + context, + &mut body_scope_ctx, + helper, + body_block, + stmt, + )?; } } - match body_block { - Some(body_block) => { - body_block.append_operation(cf::br(check_block, &[], location)); - Ok(Some(merge_block)) - } - None => Ok(Some(merge_block)), + if body_block.terminator().is_none() { + body_block.append_operation(cf::br(check_block, &[], location)); } + + Ok(merge_block) } fn compile_let_stmt<'ctx, 'parent: 'ctx>( From e67c3086b46470c0a04c41d009dda5a7b79542c0 Mon Sep 17 00:00:00 2001 From: Edgar Date: Tue, 16 Jan 2024 11:40:09 +0100 Subject: [PATCH 20/20] check on macos --- Cargo.lock | 48 +++++++++++----------- crates/concrete_codegen_mlir/src/linker.rs | 1 + 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5216d43..d92ca68 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -37,9 +37,9 @@ checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" [[package]] name = "anstream" -version = "0.6.5" +version = "0.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d664a92ecae85fd0a7392615844904654d1d5f5514837f471ddef4a057aba1b6" +checksum = "4cd2405b3ac1faab2990b74d728624cd9fd115651fcecc7c2d8daf01376275ba" dependencies = [ "anstyle", "anstyle-parse", @@ -123,9 +123,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "base64" -version = "0.21.6" +version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c79fed4cdb43e993fcdadc7e58a09fd0e3e649c4436fa11da71c9f1f3ee7feb9" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" [[package]] name = "beef" @@ -258,9 +258,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.4.14" +version = "4.4.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33e92c5c1a78c62968ec57dbc2440366a2d6e5a23faf829970ff1585dc6b18e2" +checksum = "80932e03c33999b9235edb8655bc9df3204adc9887c2f95b50cb1deb9fd54253" dependencies = [ "clap_builder", "clap_derive", @@ -268,9 +268,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.4.14" +version = "4.4.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4323769dc8a61e2c39ad7dc26f6f2800524691a44d74fe3d1071a5c24db6370" +checksum = "d6c0db58c659eef1c73e444d298c27322a1b52f6927d2ad470c0c0f96fa7b8fa" dependencies = [ "anstream", "anstyle", @@ -977,9 +977,9 @@ dependencies = [ [[package]] name = "melior" -version = "0.15.0" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "634f33663d2bcac794409829caf83a08967249e5429f34ec20c92230a85c025a" +checksum = "758bbd4448db9e994578ab48a6da5210512378f70ac1632cc8c2ae0fbd6c21b5" dependencies = [ "dashmap", "melior-macro", @@ -1313,9 +1313,9 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustix" -version = "0.38.28" +version = "0.38.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72e572a5e8ca657d7366229cdde4bd14c4eb5499a9573d4d366fe1b599daa316" +checksum = "322394588aaf33c24007e8bb3238ee3e4c5c09c084ab32bc73890b99ff326bca" dependencies = [ "bitflags 2.4.1", "errno", @@ -1463,9 +1463,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.11.2" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" +checksum = "2593d31f82ead8df961d8bd23a64c2ccf2eb5dd34b0a34bfb4dd54011c72009e" [[package]] name = "string_cache" @@ -1783,9 +1783,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.89" +version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ed0d4f68a3015cc185aff4db9506a015f4b96f95303897bfa23f846db54064e" +checksum = "b1223296a201415c7fad14792dbefaace9bd52b62d33453ade1c5b5f07555406" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -1793,9 +1793,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.89" +version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b56f625e64f3a1084ded111c4d5f477df9f8c92df113852fa5a374dbda78826" +checksum = "fcdc935b63408d58a32f8cc9738a0bffd8f05cc7c002086c6ef20b7312ad9dcd" dependencies = [ "bumpalo", "log", @@ -1808,9 +1808,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.89" +version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2" +checksum = "3e4c238561b2d428924c49815533a8b9121c664599558a5d9ec51f8a1740a999" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -1818,9 +1818,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.89" +version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" +checksum = "bae1abb6806dc1ad9e560ed242107c0f6c84335f1749dd4e8ddb012ebd5e25a7" dependencies = [ "proc-macro2", "quote", @@ -1831,9 +1831,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.89" +version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" +checksum = "4d91413b1c31d7539ba5ef2451af3f0b833a005eb27a631cec32bc0635a8602b" [[package]] name = "which" diff --git a/crates/concrete_codegen_mlir/src/linker.rs b/crates/concrete_codegen_mlir/src/linker.rs index f4ade60..d65b120 100644 --- a/crates/concrete_codegen_mlir/src/linker.rs +++ b/crates/concrete_codegen_mlir/src/linker.rs @@ -114,6 +114,7 @@ pub fn link_binary(input_path: &Path, output_filename: &Path) -> Result<(), std: Ok(()) } +#[cfg(target_os = "linux")] fn file_exists(path: &str) -> bool { Path::new(path).exists() }