diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cd63ed6..c48e794 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -69,6 +69,8 @@ jobs: keys-asc: https://apt.llvm.org/llvm-snapshot.gpg.key - name: Install LLVM run: sudo apt-get install llvm-17 llvm-17-dev llvm-17-runtime clang-17 clang-tools-17 lld-17 libpolly-17-dev libmlir-17-dev mlir-17-tools + - name: Install Link deps + run: sudo apt-get install libc-dev build-essential - name: test run: make test @@ -100,6 +102,8 @@ jobs: keys-asc: https://apt.llvm.org/llvm-snapshot.gpg.key - name: Install LLVM run: sudo apt-get install llvm-17 llvm-17-dev llvm-17-runtime clang-17 clang-tools-17 lld-17 libpolly-17-dev libmlir-17-dev mlir-17-tools + - name: Install Link deps + run: sudo apt-get install libc-dev build-essential - name: Install cargo-llvm-cov uses: taiki-e/install-action@cargo-llvm-cov - name: test and generate coverage diff --git a/Cargo.lock b/Cargo.lock index 1a27727..d92ca68 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -37,9 +37,9 @@ checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" [[package]] name = "anstream" -version = "0.6.5" +version = "0.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d664a92ecae85fd0a7392615844904654d1d5f5514837f471ddef4a057aba1b6" +checksum = "4cd2405b3ac1faab2990b74d728624cd9fd115651fcecc7c2d8daf01376275ba" dependencies = [ "anstyle", "anstyle-parse", @@ -95,6 +95,17 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6" +[[package]] +name = "ariadne" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd002a6223f12c7a95cdd4b1cb3a0149d22d37f7a9ecdb2cb691a071fe236c29" +dependencies = [ + "concolor", + "unicode-width", + "yansi", +] + [[package]] name = "ascii-canvas" version = "3.0.0" @@ -112,9 +123,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "base64" -version = "0.21.6" +version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c79fed4cdb43e993fcdadc7e58a09fd0e3e649c4436fa11da71c9f1f3ee7feb9" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" [[package]] name = "beef" @@ -247,9 +258,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.4.14" +version = "4.4.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33e92c5c1a78c62968ec57dbc2440366a2d6e5a23faf829970ff1585dc6b18e2" +checksum = "80932e03c33999b9235edb8655bc9df3204adc9887c2f95b50cb1deb9fd54253" dependencies = [ "clap_builder", "clap_derive", @@ -257,9 +268,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.4.14" +version = "4.4.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4323769dc8a61e2c39ad7dc26f6f2800524691a44d74fe3d1071a5c24db6370" +checksum = "d6c0db58c659eef1c73e444d298c27322a1b52f6927d2ad470c0c0f96fa7b8fa" dependencies = [ "anstream", "anstyle", @@ -312,6 +323,26 @@ dependencies = [ "xdg", ] +[[package]] +name = "concolor" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b946244a988c390a94667ae0e3958411fa40cc46ea496a929b263d883f5f9c3" +dependencies = [ + "bitflags 1.3.2", + "concolor-query", + "is-terminal", +] + +[[package]] +name = "concolor-query" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88d11d52c3d7ca2e6d0040212be9e4dbbcd78b6447f535b6b561f449427944cf" +dependencies = [ + "windows-sys 0.45.0", +] + [[package]] name = "concrete" version = "0.1.0" @@ -330,6 +361,7 @@ dependencies = [ name = "concrete_codegen_mlir" version = "0.1.0" dependencies = [ + "bumpalo", "cc", "concrete_ast", "concrete_session", @@ -343,12 +375,14 @@ dependencies = [ name = "concrete_driver" version = "0.1.0" dependencies = [ + "ariadne", "clap", "concrete_ast", "concrete_codegen_mlir", "concrete_parser", "concrete_session", "salsa-2022", + "tempfile", "tracing", "tracing-subscriber", ] @@ -357,11 +391,12 @@ dependencies = [ name = "concrete_parser" version = "0.1.0" dependencies = [ + "ariadne", "concrete_ast", + "itertools 0.12.0", "lalrpop", "lalrpop-util", "logos", - "owo-colors", "salsa-2022", "tracing", ] @@ -369,6 +404,9 @@ dependencies = [ [[package]] name = "concrete_session" version = "0.1.0" +dependencies = [ + "ariadne", +] [[package]] name = "concrete_type_checker" @@ -632,6 +670,12 @@ dependencies = [ "regex", ] +[[package]] +name = "fastrand" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" + [[package]] name = "fixedbitset" version = "0.4.2" @@ -753,6 +797,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.10" @@ -770,7 +823,7 @@ dependencies = [ "diff", "ena", "is-terminal", - "itertools", + "itertools 0.10.5", "lalrpop-util", "petgraph", "pico-args", @@ -924,9 +977,9 @@ dependencies = [ [[package]] name = "melior" -version = "0.15.0" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "634f33663d2bcac794409829caf83a08967249e5429f34ec20c92230a85c025a" +checksum = "758bbd4448db9e994578ab48a6da5210512378f70ac1632cc8c2ae0fbd6c21b5" dependencies = [ "dashmap", "melior-macro", @@ -1041,12 +1094,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" -[[package]] -name = "owo-colors" -version = "4.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "caff54706df99d2a78a5a4e3455ff45448d81ef1bb63c22cd14052ca0e993a3f" - [[package]] name = "parking_lot" version = "0.12.1" @@ -1266,9 +1313,9 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustix" -version = "0.38.28" +version = "0.38.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72e572a5e8ca657d7366229cdde4bd14c4eb5499a9573d4d366fe1b599daa316" +checksum = "322394588aaf33c24007e8bb3238ee3e4c5c09c084ab32bc73890b99ff326bca" dependencies = [ "bitflags 2.4.1", "errno", @@ -1416,9 +1463,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.11.2" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" +checksum = "2593d31f82ead8df961d8bd23a64c2ccf2eb5dd34b0a34bfb4dd54011c72009e" [[package]] name = "string_cache" @@ -1495,6 +1542,19 @@ dependencies = [ "thiserror", ] +[[package]] +name = "tempfile" +version = "3.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01ce4141aa927a6d1bd34a041795abd0db1cccba5d5f24b009f694bdf3a1f3fa" +dependencies = [ + "cfg-if", + "fastrand", + "redox_syscall", + "rustix", + "windows-sys 0.52.0", +] + [[package]] name = "term" version = "0.7.0" @@ -1663,6 +1723,12 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" +[[package]] +name = "unicode-width" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" + [[package]] name = "unicode-xid" version = "0.2.4" @@ -1717,9 +1783,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.89" +version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ed0d4f68a3015cc185aff4db9506a015f4b96f95303897bfa23f846db54064e" +checksum = "b1223296a201415c7fad14792dbefaace9bd52b62d33453ade1c5b5f07555406" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -1727,9 +1793,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.89" +version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b56f625e64f3a1084ded111c4d5f477df9f8c92df113852fa5a374dbda78826" +checksum = "fcdc935b63408d58a32f8cc9738a0bffd8f05cc7c002086c6ef20b7312ad9dcd" dependencies = [ "bumpalo", "log", @@ -1742,9 +1808,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.89" +version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2" +checksum = "3e4c238561b2d428924c49815533a8b9121c664599558a5d9ec51f8a1740a999" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -1752,9 +1818,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.89" +version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" +checksum = "bae1abb6806dc1ad9e560ed242107c0f6c84335f1749dd4e8ddb012ebd5e25a7" dependencies = [ "proc-macro2", "quote", @@ -1765,9 +1831,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.89" +version = "0.2.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" +checksum = "4d91413b1c31d7539ba5ef2451af3f0b833a005eb27a631cec32bc0635a8602b" [[package]] name = "which" @@ -1812,6 +1878,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", +] + [[package]] name = "windows-sys" version = "0.48.0" @@ -1830,6 +1905,21 @@ dependencies = [ "windows-targets 0.52.0", ] +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + [[package]] name = "windows-targets" version = "0.48.5" @@ -1860,6 +1950,12 @@ dependencies = [ "windows_x86_64_msvc 0.52.0", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" @@ -1872,6 +1968,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -1884,6 +1986,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -1896,6 +2004,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -1908,6 +2022,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -1920,6 +2040,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" @@ -1932,6 +2058,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" @@ -1959,6 +2091,12 @@ dependencies = [ "linked-hash-map", ] +[[package]] +name = "yansi" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09041cd90cf85f7f8b2df60c646f853b7f535ce68f85244eb6731cf89fa498ec" + [[package]] name = "zerocopy" version = "0.7.32" diff --git a/README.md b/README.md index 6791844..122a5c3 100644 --- a/README.md +++ b/README.md @@ -124,22 +124,24 @@ But we want to take a different path with respect to: - No marker traits like Send, Sync for concurrency. The runtime will take care of that. ## Syntax -``` -mod FibonacciModule { - - pub fib(x: u64) -> u64 { - match x { - // we can match literal values - 0 | 1 -> x, - n -> fib(n-1) + fib(n-2) - } - } +```rust +mod Fibonacci { + fn main() -> i64 { + return fib(10); + } + + pub fn fib(n: u64) -> u64 { + if n < 2 { + return n; + } + + return fib(n - 1) + fib(n - 2); + } } ``` -``` +```rust mod Option { - pub enum Option { None, Some(T), diff --git a/crates/concrete_codegen_mlir/Cargo.toml b/crates/concrete_codegen_mlir/Cargo.toml index 39f7262..91fb12d 100644 --- a/crates/concrete_codegen_mlir/Cargo.toml +++ b/crates/concrete_codegen_mlir/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +bumpalo = { version = "3.14.0", features = ["std"] } concrete_ast = { path = "../concrete_ast"} concrete_session = { path = "../concrete_session"} llvm-sys = "170.0.1" diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 21a0d12..c3a9415 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -1,11 +1,13 @@ use std::{collections::HashMap, error::Error}; +use bumpalo::Bump; use concrete_ast::{ - common::Span, - expressions::{ArithOp, BinaryOp, CmpOp, Expression, LogicOp, PathOp, SimpleExpr}, + expressions::{ + ArithOp, BinaryOp, CmpOp, Expression, FnCallOp, IfExpr, LogicOp, PathOp, SimpleExpr, + }, functions::FunctionDef, modules::{Module, ModuleDefItem}, - statements::{AssignStmt, LetStmt, LetStmtTarget, ReturnStmt, Statement}, + statements::{AssignStmt, LetStmt, LetStmtTarget, ReturnStmt, Statement, WhileStmt}, types::TypeSpec, Program, }; @@ -13,12 +15,13 @@ use concrete_session::Session; use melior::{ dialect::{ arith::{self, CmpiPredicate}, - func, memref, + cf, func, memref, }, ir::{ attribute::{FlatSymbolRefAttribute, IntegerAttribute, StringAttribute, TypeAttribute}, r#type::{FunctionType, IntegerType, MemRefType}, - Block, Location, Module as MeliorModule, Region, Type, Value, ValueLike, + Block, BlockRef, Location, Module as MeliorModule, Operation, Region, Type, Value, + ValueLike, }, Context as MeliorContext, }; @@ -36,30 +39,66 @@ pub fn compile_program( } #[derive(Debug, Clone)] -pub struct LocalVar<'c, 'op> { +pub struct LocalVar<'ctx, 'parent: 'ctx> { pub type_spec: TypeSpec, // If it's none its on a register, otherwise allocated on the stack. - pub memref_type: Option>, - pub value: Value<'c, 'op>, + pub alloca: bool, + pub value: Value<'ctx, 'parent>, } -#[derive(Debug, Clone, Default)] -struct CompilerContext<'c, 'op> { - pub locals: HashMap>, +impl<'ctx, 'parent: 'ctx> LocalVar<'ctx, 'parent> { + pub fn param(value: Value<'ctx, 'parent>, type_spec: TypeSpec) -> Self { + Self { + value, + type_spec, + alloca: false, + } + } + + pub fn alloca(value: Value<'ctx, 'parent>, type_spec: TypeSpec) -> Self { + Self { + value, + type_spec, + alloca: true, + } + } +} + +#[derive(Debug, Clone)] +struct ScopeContext<'ctx, 'parent: 'ctx> { + pub locals: HashMap>, pub functions: HashMap, + pub function: Option, +} + +struct BlockHelper<'ctx, 'region: 'ctx> { + region: &'region Region<'ctx>, + blocks_arena: &'region Bump, } -impl<'c, 'op> CompilerContext<'c, 'op> { +impl<'ctx, 'region> BlockHelper<'ctx, 'region> { + pub fn append_block(&self, block: Block<'ctx>) -> &'region BlockRef<'ctx, 'region> { + let block = self.region.append_block(block); + + let block_ref: &'region mut BlockRef<'ctx, 'region> = self.blocks_arena.alloc(block); + + block_ref + } +} + +impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> { fn resolve_type( &self, - context: &'c MeliorContext, + context: &'ctx MeliorContext, name: &str, - ) -> Result, Box> { + ) -> Result, Box> { Ok(match name { "u64" | "i64" => IntegerType::new(context, 64).into(), "u32" | "i32" => IntegerType::new(context, 32).into(), "u16" | "i16" => IntegerType::new(context, 16).into(), "u8" | "i8" => IntegerType::new(context, 8).into(), + "f32" => Type::float32(context), + "f64" => Type::float64(context), "bool" => IntegerType::new(context, 1).into(), _ => todo!("custom type lookup"), }) @@ -67,9 +106,9 @@ impl<'c, 'op> CompilerContext<'c, 'op> { fn resolve_type_spec( &self, - context: &'c MeliorContext, + context: &'ctx MeliorContext, spec: &TypeSpec, - ) -> Result, Box> { + ) -> Result, Box> { Ok(match spec { TypeSpec::Simple { name } => self.resolve_type(context, &name.name)?, TypeSpec::Generic { @@ -78,6 +117,22 @@ impl<'c, 'op> CompilerContext<'c, 'op> { } => self.resolve_type(context, &name.name)?, }) } + + fn is_type_signed(&self, type_info: &TypeSpec) -> bool { + let signed = ["i8", "i16", "i32", "i64", "i128"]; + match type_info { + TypeSpec::Simple { name } => signed.contains(&name.name.as_str()), + TypeSpec::Generic { name, .. } => signed.contains(&name.name.as_str()), + } + } + + fn is_float(&self, type_info: &TypeSpec) -> bool { + let signed = ["f32", "f64"]; + match type_info { + TypeSpec::Simple { name } => signed.contains(&name.name.as_str()), + TypeSpec::Generic { name, .. } => signed.contains(&name.name.as_str()), + } + } } fn compile_module( @@ -88,14 +143,18 @@ fn compile_module( ) -> Result<(), Box> { // todo: handle imports - let mut compiler_ctx: CompilerContext = Default::default(); - let body = mlir_module.body(); + let mut scope_ctx: ScopeContext = ScopeContext { + functions: Default::default(), + locals: Default::default(), + function: None, + }; + // save all function signatures for statement in &module.contents { if let ModuleDefItem::Function(info) = statement { - compiler_ctx + scope_ctx .functions .insert(info.decl.name.name.clone(), info.clone()); } @@ -105,7 +164,10 @@ fn compile_module( match statement { ModuleDefItem::Constant(_) => todo!(), ModuleDefItem::Function(info) => { - compile_function_def(session, context, &mut compiler_ctx, &body, info)?; + let mut scope_ctx = scope_ctx.clone(); + scope_ctx.function = Some(info.clone()); + let op = compile_function_def(session, context, &scope_ctx, info)?; + body.append_operation(op); } ModuleDefItem::Record(_) => todo!(), ModuleDefItem::Type(_) => todo!(), @@ -115,52 +177,49 @@ fn compile_module( Ok(()) } -fn get_location<'c>(context: &'c MeliorContext, session: &Session, span: &Span) -> Location<'c> { - let (line, col) = session.get_line_and_column(span.from); - Location::new(context, &session.file_path.display().to_string(), line, col) +fn get_location<'ctx>( + context: &'ctx MeliorContext, + session: &Session, + offset: usize, +) -> Location<'ctx> { + let (_, line, col) = session.source.get_offset_line(offset).unwrap(); + Location::new( + context, + &session.file_path.display().to_string(), + line + 1, + col + 1, + ) +} + +fn get_named_location<'ctx>(context: &'ctx MeliorContext, name: &str) -> Location<'ctx> { + Location::name(context, name, Location::unknown(context)) } -fn compile_function_def<'c, 'op>( +fn compile_function_def<'ctx, 'parent: 'ctx>( session: &Session, - context: &'c MeliorContext, - compiler_ctx: &mut CompilerContext<'c, 'op>, - block: &'op Block<'c>, + context: &'ctx MeliorContext, + scope_ctx: &ScopeContext<'ctx, 'parent>, info: &FunctionDef, -) -> Result<(), Box> { - let region = Region::new(); - - let location = get_location(context, session, &info.decl.name.span); +) -> Result, Box> { + tracing::debug!("compiling function {:?}", info.decl.name.name); + let location = get_location(context, session, info.decl.name.span.from); // Setup function arguments let mut args = Vec::with_capacity(info.decl.params.len()); let mut fn_args_types = Vec::with_capacity(info.decl.params.len()); for param in &info.decl.params { - let param_type = compiler_ctx.resolve_type_spec(context, ¶m.r#type)?; - let loc = get_location(context, session, ¶m.name.span); + let param_type = scope_ctx.resolve_type_spec(context, ¶m.r#type)?; + let loc = get_location(context, session, param.name.span.from); args.push((param_type, loc)); fn_args_types.push(param_type); } - let fn_block = Block::new(&args); - // Create the function context - let mut fn_compiler_ctx = compiler_ctx.clone(); - - // Push arguments into locals - for (i, param) in info.decl.params.iter().enumerate() { - fn_compiler_ctx.locals.insert( - param.name.name.clone(), - LocalVar { - type_spec: param.r#type.clone(), - value: fn_block.argument(i)?.into(), - memref_type: None, - }, - ); - } + let region = Region::new(); let return_type = if let Some(ret_type) = &info.decl.ret_type { - vec![fn_compiler_ctx.resolve_type_spec(context, ret_type)?] + vec![scope_ctx.resolve_type_spec(context, ret_type)?] } else { vec![] }; @@ -168,44 +227,228 @@ fn compile_function_def<'c, 'op>( let func_type = TypeAttribute::new(FunctionType::new(context, &fn_args_types, &return_type).into()); - for stmt in &info.body { - match stmt { - Statement::Assign(info) => { - compile_assign_stmt(session, context, &mut fn_compiler_ctx, &fn_block, info)? - } - Statement::Match(_) => todo!(), - Statement::For(_) => todo!(), - Statement::If(_) => todo!(), - Statement::Let(info) => { - compile_let_stmt(session, context, &mut fn_compiler_ctx, &fn_block, info)? - } - Statement::Return(info) => { - compile_return_stmt(session, context, &mut fn_compiler_ctx, &fn_block, info)? - } - Statement::While(_) => todo!(), - Statement::FnCall(_) => todo!(), + { + let mut scope_ctx = scope_ctx.clone(); + let mut fn_block = ®ion.append_block(Block::new(&args)); + + let blocks_arena = Bump::new(); + let helper = BlockHelper { + region: ®ion, + blocks_arena: &blocks_arena, + }; + + // Push arguments into locals + for (i, param) in info.decl.params.iter().enumerate() { + scope_ctx.locals.insert( + param.name.name.clone(), + LocalVar::param(fn_block.argument(i)?.into(), param.r#type.clone()), + ); } - } - region.append_block(fn_block); + for stmt in &info.body { + fn_block = + compile_statement(session, context, &mut scope_ctx, &helper, fn_block, stmt)?; + } + } - block.append_operation(func::func( + Ok(func::func( context, StringAttribute::new(context, &info.decl.name.name), func_type, region, &[], location, + )) +} + +fn compile_statement<'c, 'this: 'c>( + session: &Session, + context: &'c MeliorContext, + scope_ctx: &mut ScopeContext<'c, 'this>, + helper: &BlockHelper<'c, 'this>, + mut block: &'this BlockRef<'c, 'this>, + info: &Statement, +) -> Result<&'this BlockRef<'c, 'this>, Box> { + match info { + Statement::Assign(info) => { + compile_assign_stmt(session, context, scope_ctx, helper, block, info)? + } + Statement::Match(_) => todo!(), + Statement::For(_) => todo!(), + Statement::If(info) => { + block = compile_if_expr(session, context, scope_ctx, helper, block, info)?; + } + Statement::Let(info) => compile_let_stmt(session, context, scope_ctx, helper, block, info)?, + Statement::Return(info) => { + compile_return_stmt(session, context, scope_ctx, helper, block, info)? + } + Statement::While(info) => { + block = compile_while(session, context, scope_ctx, helper, block, info)?; + } + Statement::FnCall(info) => { + compile_fn_call(session, context, scope_ctx, helper, block, info)?; + } + } + + Ok(block) +} + +/// Compile a if expression / statement +/// +/// This returns a block if any branch doesn't have a function return terminator. +/// For example, if the if branch has a return and the else branch has a return, +/// it wouldn't make sense to add a merging block and MLIR would give a error saying there is a operation after a terminator. +/// +/// The returned block is the merger block, the one we jump after processing the if branches. +/// +/// ```text +/// - then block - +/// - if (prev block) - < > merge block -- +/// - else block - +/// ``` +fn compile_if_expr<'c, 'this: 'c>( + session: &Session, + context: &'c MeliorContext, + scope_ctx: &mut ScopeContext<'c, 'this>, + helper: &BlockHelper<'c, 'this>, + block: &'this BlockRef<'c, 'this>, + info: &IfExpr, +) -> Result<&'this BlockRef<'c, 'this>, Box> { + let condition = compile_expression( + session, + context, + scope_ctx, + helper, + block, + &info.value, + None, + )?; + + let mut then_successor = helper.append_block(Block::new(&[])); + let mut else_successor = helper.append_block(Block::new(&[])); + + block.append_operation(cf::cond_br( + context, + condition, + then_successor, + else_successor, + &[], + &[], + get_named_location(context, "if"), )); - Ok(()) + { + let mut then_scope_ctx = scope_ctx.clone(); + for stmt in &info.contents { + then_successor = compile_statement( + session, + context, + &mut then_scope_ctx, + helper, + then_successor, + stmt, + )?; + } + } + + if let Some(else_contents) = info.r#else.as_ref() { + let mut else_scope_ctx = scope_ctx.clone(); + for stmt in else_contents { + else_successor = compile_statement( + session, + context, + &mut else_scope_ctx, + helper, + else_successor, + stmt, + )?; + } + } + + // both branches return + if then_successor.terminator().is_some() && else_successor.terminator().is_some() { + return Ok(then_successor); + } + + let merge_block = helper.append_block(Block::new(&[])); + + if then_successor.terminator().is_none() { + then_successor.append_operation(cf::br(merge_block, &[], Location::unknown(context))); + } + + if else_successor.terminator().is_none() { + else_successor.append_operation(cf::br(merge_block, &[], Location::unknown(context))); + } + + Ok(merge_block) } -fn compile_let_stmt<'c, 'op>( +fn compile_while<'c, 'this: 'c>( session: &Session, context: &'c MeliorContext, - compiler_ctx: &mut CompilerContext<'c, 'op>, - block: &'op Block<'c>, + scope_ctx: &mut ScopeContext<'c, 'this>, + helper: &BlockHelper<'c, 'this>, + block: &'this BlockRef<'c, 'this>, + info: &WhileStmt, +) -> Result<&'this BlockRef<'c, 'this>, Box> { + let location = Location::unknown(context); + + let check_block = helper.append_block(Block::new(&[])); + + block.append_operation(cf::br(check_block, &[], location)); + + let body_block = helper.append_block(Block::new(&[])); + let merge_block = helper.append_block(Block::new(&[])); + + let condition = compile_expression( + session, + context, + scope_ctx, + helper, + check_block, + &info.value, + None, + )?; + + check_block.append_operation(cf::cond_br( + context, + condition, + body_block, + merge_block, + &[], + &[], + location, + )); + + let mut body_block = body_block; + + { + let mut body_scope_ctx = scope_ctx.clone(); + for stmt in &info.contents { + body_block = compile_statement( + session, + context, + &mut body_scope_ctx, + helper, + body_block, + stmt, + )?; + } + } + + if body_block.terminator().is_none() { + body_block.append_operation(cf::br(check_block, &[], location)); + } + + Ok(merge_block) +} + +fn compile_let_stmt<'ctx, 'parent: 'ctx>( + session: &Session, + context: &'ctx MeliorContext, + scope_ctx: &mut ScopeContext<'ctx, 'parent>, + helper: &BlockHelper<'ctx, 'parent>, + block: &'parent Block<'ctx>, info: &LetStmt, ) -> Result<(), Box> { match &info.target { @@ -213,13 +456,14 @@ fn compile_let_stmt<'c, 'op>( let value = compile_expression( session, context, - compiler_ctx, + scope_ctx, + helper, block, &info.value, Some(r#type), )?; - let location = get_location(context, session, &name.span); + let location = get_location(context, session, name.span.from); let memref_type = MemRefType::new(value.r#type(), &[1], None, None); @@ -244,14 +488,9 @@ fn compile_let_stmt<'c, 'op>( .into(); block.append_operation(memref::store(value, alloca, &[k0], location)); - compiler_ctx.locals.insert( - name.name.clone(), - LocalVar { - type_spec: r#type.clone(), - memref_type: Some(memref_type), - value: alloca, - }, - ); + scope_ctx + .locals + .insert(name.name.clone(), LocalVar::alloca(alloca, r#type.clone())); Ok(()) } @@ -259,32 +498,31 @@ fn compile_let_stmt<'c, 'op>( } } -fn compile_assign_stmt<'c, 'op>( +fn compile_assign_stmt<'ctx, 'parent: 'ctx>( session: &Session, - context: &'c MeliorContext, - compiler_ctx: &mut CompilerContext<'c, 'op>, - block: &'op Block<'c>, + context: &'ctx MeliorContext, + scope_ctx: &mut ScopeContext<'ctx, 'parent>, + helper: &BlockHelper<'ctx, 'parent>, + block: &'parent Block<'ctx>, info: &AssignStmt, ) -> Result<(), Box> { // todo: implement properly for structs, right now only really works for simple variables. - let local = compiler_ctx + let local = scope_ctx .locals .get(&info.target.first.name) .expect("local should exist") .clone(); - assert!( - local.memref_type.is_some(), - "can only mutate local stack variables" - ); + assert!(local.alloca, "can only mutate local stack variables"); - let location = get_location(context, session, &info.target.first.span); + let location = get_location(context, session, info.target.first.span.from); let value = compile_expression( session, context, - compiler_ctx, + scope_ctx, + helper, block, &info.value, Some(&local.type_spec), @@ -303,26 +541,43 @@ fn compile_assign_stmt<'c, 'op>( Ok(()) } -fn compile_return_stmt<'c, 'op>( +fn compile_return_stmt<'ctx, 'parent: 'ctx>( session: &Session, - context: &'c MeliorContext, - compiler_ctx: &mut CompilerContext<'c, 'op>, - block: &'op Block<'c>, + context: &'ctx MeliorContext, + scope_ctx: &mut ScopeContext<'ctx, 'parent>, + helper: &BlockHelper<'ctx, 'parent>, + block: &'parent Block<'ctx>, info: &ReturnStmt, ) -> Result<(), Box> { - let value = compile_expression(session, context, compiler_ctx, block, &info.value, None)?; + let value = compile_expression( + session, + context, + scope_ctx, + helper, + block, + &info.value, + scope_ctx + .function + .as_ref() + .unwrap() + .decl + .ret_type + .clone() + .as_ref(), + )?; block.append_operation(func::r#return(&[value], Location::unknown(context))); Ok(()) } -fn compile_expression<'c, 'op>( +fn compile_expression<'ctx, 'parent: 'ctx>( session: &Session, - context: &'c MeliorContext, - compiler_ctx: &mut CompilerContext<'c, 'op>, - block: &'op Block<'c>, + context: &'ctx MeliorContext, + scope_ctx: &mut ScopeContext<'ctx, 'parent>, + _helper: &BlockHelper<'ctx, 'parent>, + block: &'parent Block<'ctx>, info: &Expression, type_info: Option<&TypeSpec>, -) -> Result, Box> { +) -> Result, Box> { let location = Location::unknown(context); match info { Expression::Simple(simple) => match simple { @@ -344,7 +599,7 @@ fn compile_expression<'c, 'op>( } SimpleExpr::ConstInt(value) => { let int_type = if let Some(type_info) = type_info { - compiler_ctx.resolve_type_spec(context, type_info)? + scope_ctx.resolve_type_spec(context, type_info)? } else { IntegerType::new(context, 64).into() }; @@ -356,71 +611,54 @@ fn compile_expression<'c, 'op>( } SimpleExpr::ConstFloat(_) => todo!(), SimpleExpr::ConstStr(_) => todo!(), - SimpleExpr::Path(value) => { - compile_path_op(session, context, compiler_ctx, block, value) - } + SimpleExpr::Path(value) => compile_path_op(session, context, scope_ctx, block, value), }, Expression::FnCall(value) => { - let mut args = Vec::with_capacity(value.args.len()); - let location = get_location(context, session, &value.target.span); - - let target_fn = compiler_ctx - .functions - .get(&value.target.name) - .expect("function not found") - .clone(); - - assert_eq!( - value.args.len(), - target_fn.decl.params.len(), - "parameter length doesnt match" - ); - - for (arg, arg_info) in value.args.iter().zip(&target_fn.decl.params) { - let value = compile_expression( - session, - context, - compiler_ctx, - block, - arg, - Some(&arg_info.r#type), - )?; - args.push(value); - } - - let return_type = if let Some(ret_type) = &target_fn.decl.ret_type { - vec![compiler_ctx.resolve_type_spec(context, ret_type)?] - } else { - vec![] - }; - - Ok(block - .append_operation(func::call( - context, - FlatSymbolRefAttribute::new(context, &value.target.name), - &args, - &return_type, - location, - )) - .result(0)? - .into()) + compile_fn_call(session, context, scope_ctx, _helper, block, value) } Expression::Match(_) => todo!(), Expression::If(_) => todo!(), Expression::UnaryOp(_, _) => todo!(), Expression::BinaryOp(lhs, op, rhs) => { - let lhs = compile_expression(session, context, compiler_ctx, block, lhs, type_info)?; - let rhs = compile_expression(session, context, compiler_ctx, block, rhs, type_info)?; + let lhs = + compile_expression(session, context, scope_ctx, _helper, block, lhs, type_info)?; + let rhs = + compile_expression(session, context, scope_ctx, _helper, block, rhs, type_info)?; let op = match op { - // todo: check signedness - BinaryOp::Arith(arith_op) => match arith_op { - ArithOp::Add => arith::addi(lhs, rhs, location), - ArithOp::Sub => arith::subi(lhs, rhs, location), - ArithOp::Mul => arith::muli(lhs, rhs, location), - ArithOp::Div => arith::divsi(lhs, rhs, location), - ArithOp::Mod => arith::remsi(lhs, rhs, location), - }, + BinaryOp::Arith(arith_op) => { + let type_info = type_info.expect("type info missing"); + + if scope_ctx.is_float(type_info) { + match arith_op { + ArithOp::Add => arith::addf(lhs, rhs, location), + ArithOp::Sub => arith::subf(lhs, rhs, location), + ArithOp::Mul => arith::mulf(lhs, rhs, location), + ArithOp::Div => arith::divf(lhs, rhs, location), + ArithOp::Mod => arith::remf(lhs, rhs, location), + } + } else { + match arith_op { + ArithOp::Add => arith::addi(lhs, rhs, location), + ArithOp::Sub => arith::subi(lhs, rhs, location), + ArithOp::Mul => arith::muli(lhs, rhs, location), + ArithOp::Div => { + if scope_ctx.is_type_signed(type_info) { + arith::divsi(lhs, rhs, location) + } else { + arith::divui(lhs, rhs, location) + } + } + ArithOp::Mod => { + if scope_ctx.is_type_signed(type_info) { + arith::remsi(lhs, rhs, location) + } else { + arith::remui(lhs, rhs, location) + } + } + } + } + } BinaryOp::Logic(logic_op) => match logic_op { LogicOp::And => { let const_true = block @@ -505,24 +743,78 @@ fn compile_expression<'c, 'op>( } } -fn compile_path_op<'c, 'op>( +fn compile_fn_call<'ctx, 'parent: 'ctx>( session: &Session, - context: &'c MeliorContext, - compiler_ctx: &mut CompilerContext<'c, 'op>, - block: &'op Block<'c>, + context: &'ctx MeliorContext, + scope_ctx: &mut ScopeContext<'ctx, 'parent>, + _helper: &BlockHelper<'ctx, 'parent>, + block: &'parent Block<'ctx>, + info: &FnCallOp, +) -> Result, Box> { + let mut args = Vec::with_capacity(info.args.len()); + let location = get_location(context, session, info.target.span.from); + + let target_fn = scope_ctx + .functions + .get(&info.target.name) + .expect("function not found") + .clone(); + + assert_eq!( + info.args.len(), + target_fn.decl.params.len(), + "parameter length doesnt match" + ); + + for (arg, arg_info) in info.args.iter().zip(&target_fn.decl.params) { + let value = compile_expression( + session, + context, + scope_ctx, + _helper, + block, + arg, + Some(&arg_info.r#type), + )?; + args.push(value); + } + + let return_type = if let Some(ret_type) = &target_fn.decl.ret_type { + vec![scope_ctx.resolve_type_spec(context, ret_type)?] + } else { + vec![] + }; + + Ok(block + .append_operation(func::call( + context, + FlatSymbolRefAttribute::new(context, &info.target.name), + &args, + &return_type, + location, + )) + .result(0)? + .into()) +} + +fn compile_path_op<'ctx, 'parent: 'ctx>( + session: &Session, + context: &'ctx MeliorContext, + scope_ctx: &mut ScopeContext<'ctx, 'parent>, + block: &'parent Block<'ctx>, path: &PathOp, -) -> Result, Box> { +) -> Result, Box> { // For now only simple variables work. // TODO: implement properly, this requires having structs implemented. - let local = compiler_ctx + let local = scope_ctx .locals .get(&path.first.name) .expect("local not found"); - let location = get_location(context, session, &path.first.span); + let location = get_location(context, session, path.first.span.from); - if let Some(_memref_type) = local.memref_type { + if local.alloca { let k0 = block .append_operation(arith::constant( context, diff --git a/crates/concrete_codegen_mlir/src/context.rs b/crates/concrete_codegen_mlir/src/context.rs index 63e97a4..70961fd 100644 --- a/crates/concrete_codegen_mlir/src/context.rs +++ b/crates/concrete_codegen_mlir/src/context.rs @@ -4,7 +4,7 @@ use concrete_ast::Program; use concrete_session::Session; use melior::{ dialect::DialectRegistry, - ir::{Location, Module as MeliorModule}, + ir::{operation::OperationPrintingFlags, Location, Module as MeliorModule}, utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}, Context as MeliorContext, }; @@ -43,11 +43,16 @@ impl Context { super::codegen::compile_program(session, &self.melior_context, &melior_module, program)?; + let print_flags = OperationPrintingFlags::new().enable_debug_info(true, true); tracing::debug!( - "MLIR Code before passes:\n{:#?}", - melior_module.as_operation() + "MLIR Code before passes:\n{}", + melior_module + .as_operation() + .to_string_with_flags(print_flags)? ); + assert!(melior_module.as_operation().verify()); + // TODO: Add proper error handling. run_pass_manager(&self.melior_context, &mut melior_module).unwrap(); diff --git a/crates/concrete_codegen_mlir/src/lib.rs b/crates/concrete_codegen_mlir/src/lib.rs index e30fa58..959807b 100644 --- a/crates/concrete_codegen_mlir/src/lib.rs +++ b/crates/concrete_codegen_mlir/src/lib.rs @@ -40,6 +40,7 @@ mod pass_manager; pub fn compile(session: &Session, program: &Program) -> Result> { let context = Context::new(); let mlir_module = context.compile(session, program)?; + assert!(mlir_module.melior_module.as_operation().verify()); let object_path = compile_to_object(session, &mlir_module)?; diff --git a/crates/concrete_codegen_mlir/src/linker.rs b/crates/concrete_codegen_mlir/src/linker.rs index 6fb2f78..d65b120 100644 --- a/crates/concrete_codegen_mlir/src/linker.rs +++ b/crates/concrete_codegen_mlir/src/linker.rs @@ -64,6 +64,22 @@ pub fn link_binary(input_path: &Path, output_filename: &Path) -> Result<(), std: } #[cfg(target_os = "linux")] { + let (scrt1, crti, crtn) = { + if file_exists("/usr/lib64/Scrt1.o") { + ( + "/usr/lib64/Scrt1.o", + "/usr/lib64/crti.o", + "/usr/lib64/crtn.o", + ) + } else { + ( + "/lib/x86_64-linux-gnu/Scrt1.o", + "/lib/x86_64-linux-gnu/crti.o", + "/lib/x86_64-linux-gnu/crtn.o", + ) + } + }; + &[ "-pie", "--hash-style=gnu", @@ -72,16 +88,17 @@ pub fn link_binary(input_path: &Path, output_filename: &Path) -> Result<(), std: "/lib64/ld-linux-x86-64.so.2", "-m", "elf_x86_64", - "/usr/lib64/Scrt1.o", - "/usr/lib64/crti.o", + scrt1, + crti, "-o", &output_filename.display().to_string(), "-L/lib64", "-L/usr/lib64", + "-L/lib/x86_64-linux-gnu", "-zrelro", "--no-as-needed", "-lc", - "/usr/lib64/crtn.o", + crtn, &input_path.display().to_string(), ] } @@ -96,3 +113,8 @@ pub fn link_binary(input_path: &Path, output_filename: &Path) -> Result<(), std: proc.wait_with_output()?; Ok(()) } + +#[cfg(target_os = "linux")] +fn file_exists(path: &str) -> bool { + Path::new(path).exists() +} diff --git a/crates/concrete_driver/Cargo.toml b/crates/concrete_driver/Cargo.toml index d329044..35b9084 100644 --- a/crates/concrete_driver/Cargo.toml +++ b/crates/concrete_driver/Cargo.toml @@ -14,3 +14,7 @@ concrete_parser = { path = "../concrete_parser"} concrete_session = { path = "../concrete_session"} concrete_codegen_mlir = { path = "../concrete_codegen_mlir"} salsa = { git = "https://github.com/salsa-rs/salsa.git", package = "salsa-2022" } +ariadne = { version = "0.4.0", features = ["auto-color"] } + +[dev-dependencies] +tempfile = "3.9.0" diff --git a/crates/concrete_driver/src/db.rs b/crates/concrete_driver/src/db.rs index 3ca1f02..8713b82 100644 --- a/crates/concrete_driver/src/db.rs +++ b/crates/concrete_driver/src/db.rs @@ -14,7 +14,7 @@ impl Db for T where T: ?Sized + salsa::DbWithJar + salsa::DbWithJar, } diff --git a/crates/concrete_driver/src/lib.rs b/crates/concrete_driver/src/lib.rs index 437bc91..8c4735d 100644 --- a/crates/concrete_driver/src/lib.rs +++ b/crates/concrete_driver/src/lib.rs @@ -1,3 +1,4 @@ +use ariadne::Source; use clap::Parser; use concrete_codegen_mlir::linker::{link_binary, link_shared_lib}; use concrete_parser::{error::Diagnostics, ProgramSource}; @@ -32,7 +33,11 @@ pub fn main() -> Result<(), Box> { let args = CompilerArgs::parse(); let db = crate::db::Database::default(); - let source = ProgramSource::new(&db, std::fs::read_to_string(args.input.clone())?); + let source = ProgramSource::new( + &db, + std::fs::read_to_string(&args.input)?, + args.input.display().to_string(), + ); tracing::debug!("source code:\n{}", source.input(&db)); let program = match concrete_parser::parse_ast(&db, source) { Some(x) => x, @@ -44,7 +49,7 @@ pub fn main() -> Result<(), Box> { &db, source, ), ); - panic!(); + std::process::exit(1); } }; @@ -70,7 +75,7 @@ pub fn main() -> Result<(), Box> { } else { OptLevel::None }, - source: source.input(&db).to_string(), + source: Source::from(source.input(&db).to_string()), library: args.library, target_dir, output_file, @@ -80,7 +85,12 @@ pub fn main() -> Result<(), Box> { let object_path = concrete_codegen_mlir::compile(&session, &program)?; if session.library { - link_shared_lib(&object_path, &session.output_file.with_extension("so"))?; + link_shared_lib( + &object_path, + &session + .output_file + .with_extension(Session::get_platform_library_ext()), + )?; } else { link_binary(&object_path, &session.output_file.with_extension(""))?; } diff --git a/crates/concrete_driver/tests/common.rs b/crates/concrete_driver/tests/common.rs new file mode 100644 index 0000000..c16c320 --- /dev/null +++ b/crates/concrete_driver/tests/common.rs @@ -0,0 +1,102 @@ +use std::{ + borrow::Cow, + fmt, + path::{Path, PathBuf}, + process::Output, +}; + +use ariadne::Source; +use concrete_codegen_mlir::linker::{link_binary, link_shared_lib}; +use concrete_parser::{error::Diagnostics, ProgramSource}; +use concrete_session::{ + config::{DebugInfo, OptLevel}, + Session, +}; +use tempfile::TempDir; + +#[derive(Debug, Clone)] +struct TestError(Cow<'static, str>); + +impl fmt::Display for TestError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&self.0) + } +} + +impl std::error::Error for TestError {} + +#[derive(Debug)] +pub struct CompileResult { + pub folder: TempDir, + pub object_file: PathBuf, + pub binary_file: PathBuf, +} + +pub fn compile_program( + source: &str, + name: &str, + library: bool, +) -> Result> { + let db = concrete_driver::db::Database::default(); + let source = ProgramSource::new(&db, source.to_string(), name.to_string()); + tracing::debug!("source code:\n{}", source.input(&db)); + let program = match concrete_parser::parse_ast(&db, source) { + Some(x) => x, + None => { + Diagnostics::dump( + &db, + source, + &concrete_parser::parse_ast::accumulated::( + &db, source, + ), + ); + return Err(Box::new(TestError("error compiling".into()))); + } + }; + + let test_dir = tempfile::tempdir()?; + let test_dir_path = test_dir.path(); + // todo: find a better name, "target" would clash with rust if running in the source tree. + let target_dir = test_dir_path.join("build_artifacts/"); + let output_file = target_dir.join(PathBuf::from(name)); + let output_file = if library { + output_file.with_extension(Session::get_platform_library_ext()) + } else { + output_file.with_extension("") + }; + + let session = Session { + file_path: PathBuf::from(name), + debug_info: DebugInfo::Full, + optlevel: OptLevel::None, + source: Source::from(source.input(&db).to_string()), + library, + target_dir, + output_file, + }; + + let object_path = concrete_codegen_mlir::compile(&session, &program)?; + + if library { + link_shared_lib( + &object_path, + &session + .output_file + .with_extension(Session::get_platform_library_ext()), + )?; + } else { + link_binary(&object_path, &session.output_file.with_extension(""))?; + } + + Ok(CompileResult { + folder: test_dir, + object_file: object_path, + binary_file: session.output_file, + }) +} + +pub fn run_program(program: &Path) -> Result { + std::process::Command::new(program) + .spawn()? + .wait_with_output() +} diff --git a/crates/concrete_driver/tests/programs.rs b/crates/concrete_driver/tests/programs.rs new file mode 100644 index 0000000..e10a08a --- /dev/null +++ b/crates/concrete_driver/tests/programs.rs @@ -0,0 +1,107 @@ +use common::{compile_program, run_program}; + +mod common; + +#[test] +fn test_while() { + let source = r#" + mod Simple { + fn main() -> i64 { + return my_func(4); + } + + fn my_func(times: i64) -> i64 { + let mut n: i64 = times; + let mut result: i64 = 1; + + while n > 0 { + result = result + result; + n = n - 1; + } + + return result; + } + } + "#; + + let result = compile_program(source, "while", false).expect("failed to compile"); + + let output = run_program(&result.binary_file).expect("failed to run"); + let code = output.status.code().unwrap(); + assert_eq!(code, 16); +} + +#[test] +fn test_factorial_with_if() { + let source = r#" + mod Simple { + fn main() -> i64 { + return factorial(4); + } + + fn factorial(n: i64) -> i64 { + if n == 0 { + return 1; + } else { + return n * factorial(n - 1); + } + } + } + "#; + + let result = compile_program(source, "factorial", false).expect("failed to compile"); + + let output = run_program(&result.binary_file).expect("failed to run"); + let code = output.status.code().unwrap(); + assert_eq!(code, 24); +} + +#[test] +fn test_fib_with_if() { + let source = r#" + mod Fibonacci { + fn main() -> i64 { + return fib(10); + } + + pub fn fib(n: u64) -> u64 { + if n < 2 { + return n; + } + + return fib(n - 1) + fib(n - 2); + } + } + "#; + + let result = compile_program(source, "fib", false).expect("failed to compile"); + + let output = run_program(&result.binary_file).expect("failed to run"); + let code = output.status.code().unwrap(); + assert_eq!(code, 55); +} + +#[test] +fn test_simple_add() { + let source = r#" + mod Simple { + fn main() -> i32 { + let x: i32 = 2; + let y: i32 = 4; + return add_plus_two(x, y); + } + + fn add_plus_two(x: i32, y: i32) -> i32 { + let mut z: i32 = 1; + z = z + 1; + return x + y + z; + } + } + "#; + + let result = compile_program(source, "simple_add", false).expect("failed to compile"); + + let output = run_program(&result.binary_file).expect("failed to run"); + let code = output.status.code().unwrap(); + assert_eq!(code, 8); +} diff --git a/crates/concrete_parser/Cargo.toml b/crates/concrete_parser/Cargo.toml index 7f03898..239eba6 100644 --- a/crates/concrete_parser/Cargo.toml +++ b/crates/concrete_parser/Cargo.toml @@ -10,8 +10,9 @@ lalrpop-util = { version = "0.20.0", features = ["unicode"] } logos = "0.13.0" tracing = { workspace = true } concrete_ast = { path = "../concrete_ast"} -owo-colors = "4.0.0" salsa = { git = "https://github.com/salsa-rs/salsa.git", package = "salsa-2022" } +ariadne = { version = "0.4.0", features = ["auto-color"] } +itertools = "0.12.0" [build-dependencies] lalrpop = "0.20.0" diff --git a/crates/concrete_parser/src/error.rs b/crates/concrete_parser/src/error.rs index 81c1e95..2d2665e 100644 --- a/crates/concrete_parser/src/error.rs +++ b/crates/concrete_parser/src/error.rs @@ -1,6 +1,12 @@ -use crate::{db::Db, lexer::LexicalError, tokens::Token, ProgramSource}; +use crate::{ + db::Db, + lexer::LexicalError, + tokens::{self, Token}, + ProgramSource, +}; +use ariadne::{ColorGenerator, Label, Report, ReportKind, Source}; +use itertools::Itertools; use lalrpop_util::ParseError; -use owo_colors::OwoColorize; pub type Error = ParseError; @@ -9,36 +15,103 @@ pub struct Diagnostics(Error); impl Diagnostics { pub fn dump(db: &dyn Db, source: ProgramSource, errors: &[Error]) { + let path = source.path(db); let source = source.input(db); - for err in errors { - match &err { - ParseError::InvalidToken { .. } => todo!(), + for error in errors { + let mut colors = ColorGenerator::new(); + let report = match error { + ParseError::InvalidToken { location } => { + let loc = *location; + Report::build(ReportKind::Error, path, loc) + .with_code("P1") + .with_label( + Label::new((path, loc..(loc + 1))) + .with_color(colors.next()) + .with_message("invalid token"), + ) + .with_label( + Label::new((path, (loc.saturating_sub(10))..(loc + 10))) + .with_message("There was a problem parsing part of this code."), + ) + .finish() + } ParseError::UnrecognizedEof { location, expected } => { - let location = *location; - let before = &source[0..location]; - let after = &source[location..]; - - print!("{}", before); - print!("$Got EOF, expected {:?}$", expected.green().bold()); - print!("{}", after); + let loc = *location; + Report::build(ReportKind::Error, path, loc) + .with_code("P2") + .with_label( + Label::new((path, loc..(loc + 1))) + .with_message("unrecognized eof") + .with_color(colors.next()), + ) + .with_note(format!( + "expected one of the following: {}", + expected.iter().join(", ") + )) + .with_label( + Label::new((path, (loc.saturating_sub(10))..(loc + 10))) + .with_message("There was a problem parsing part of this code."), + ) + .finish() } ParseError::UnrecognizedToken { token, expected } => { - let (l, ref tok, r) = *token; - let before = &source[0..l]; - let after = &source[r..]; - - print!("{}", before); - print!( - "$Got {:?}, expected {:?}$", - tok.bold().red(), - expected.green().bold() - ); - print!("{}", after); + Report::build(ReportKind::Error, path, token.0) + .with_code(3) + .with_label( + Label::new((path, token.0..token.2)) + .with_message(format!("unrecognized token {:?}", token.1)) + .with_color(colors.next()), + ) + .with_note(format!( + "expected one of the following: {}", + expected.iter().join(", ") + )) + .with_label( + Label::new((path, (token.0.saturating_sub(10))..(token.2 + 10))) + .with_message("There was a problem parsing part of this code."), + ) + .finish() } - ParseError::ExtraToken { .. } => todo!(), - ParseError::User { .. } => todo!(), - } + ParseError::ExtraToken { token } => Report::build(ReportKind::Error, path, token.0) + .with_code("P3") + .with_message("Extra token") + .with_label( + Label::new((path, token.0..token.2)) + .with_message(format!("unexpected extra token {:?}", token.1)), + ) + .finish(), + ParseError::User { error } => match error { + LexicalError::InvalidToken(err, range) => match err { + tokens::LexingError::NumberParseError => { + Report::build(ReportKind::Error, path, range.start) + .with_code(4) + .with_message("Error parsing literal number") + .with_label( + Label::new((path, range.start..range.end)) + .with_message("error parsing literal number") + .with_color(colors.next()), + ) + .finish() + } + tokens::LexingError::Other => { + Report::build(ReportKind::Error, path, range.start) + .with_code(4) + .with_message("Other error") + .with_label( + Label::new((path, range.start..range.end)) + .with_message("other error") + .with_color(colors.next()), + ) + .finish() + } + }, + }, + }; + + report + .eprint((path, Source::from(source))) + .expect("failed to print to stderr"); } } } diff --git a/crates/concrete_parser/src/lib.rs b/crates/concrete_parser/src/lib.rs index 4c5875a..d0b4e1b 100644 --- a/crates/concrete_parser/src/lib.rs +++ b/crates/concrete_parser/src/lib.rs @@ -19,6 +19,8 @@ pub mod grammar { pub struct ProgramSource { #[return_ref] pub input: String, + #[return_ref] + pub path: String, } // Todo: better error handling diff --git a/crates/concrete_session/Cargo.toml b/crates/concrete_session/Cargo.toml index 80ee2dd..f69684e 100644 --- a/crates/concrete_session/Cargo.toml +++ b/crates/concrete_session/Cargo.toml @@ -6,3 +6,4 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +ariadne = "0.4.0" diff --git a/crates/concrete_session/src/lib.rs b/crates/concrete_session/src/lib.rs index 975a392..9c715bf 100644 --- a/crates/concrete_session/src/lib.rs +++ b/crates/concrete_session/src/lib.rs @@ -1,3 +1,4 @@ +use ariadne::Source; use std::path::PathBuf; use config::{DebugInfo, OptLevel}; @@ -9,7 +10,7 @@ pub struct Session { pub file_path: PathBuf, pub debug_info: DebugInfo, pub optlevel: OptLevel, - pub source: String, // for debugging locations + pub source: Source, // for debugging locations /// True if it should be compiled as a library false for binary. pub library: bool, /// The directory where to store artifacts and intermediate files such as object files. @@ -19,10 +20,13 @@ pub struct Session { } impl Session { - pub fn get_line_and_column(&self, offset: usize) -> (usize, usize) { - let sl = &self.source[0..offset]; - let line_count = sl.lines().count(); - let column = sl.rfind('\n').unwrap_or(0); - (line_count, column) + pub fn get_platform_library_ext() -> &'static str { + if cfg!(target_os = "macos") { + "dylib" + } else if cfg!(target_os = "windows") { + "dll" + } else { + "so" + } } } diff --git a/examples/fib.con b/examples/factorial.con similarity index 100% rename from examples/fib.con rename to examples/factorial.con diff --git a/examples/factorial_if.con b/examples/factorial_if.con new file mode 100644 index 0000000..cf441cf --- /dev/null +++ b/examples/factorial_if.con @@ -0,0 +1,13 @@ +mod Simple { + fn main() -> i64 { + return factorial(4); + } + + fn factorial(n: i64) -> i64 { + if n == 0 { + return 1; + } else { + return n * factorial(n - 1); + } + } +} diff --git a/examples/fib_if.con b/examples/fib_if.con new file mode 100644 index 0000000..4863382 --- /dev/null +++ b/examples/fib_if.con @@ -0,0 +1,13 @@ +mod Fibonacci { + fn main() -> i64 { + return fib(10); + } + + pub fn fib(n: u64) -> u64 { + if n < 2 { + return n; + } + + return fib(n - 1) + fib(n - 2); + } +} diff --git a/examples/while.con b/examples/while.con new file mode 100644 index 0000000..6ab6742 --- /dev/null +++ b/examples/while.con @@ -0,0 +1,17 @@ +mod Simple { + fn main() -> i64 { + return my_func(4); + } + + fn my_func(times: i64) -> i64 { + let mut n: i64 = times; + let mut result: i64 = 1; + + while n > 0 { + result = result + result; + n = n - 1; + } + + return result; + } +}