diff --git a/Cargo.lock b/Cargo.lock index 4523aa7..25c2bd6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -978,9 +978,9 @@ dependencies = [ [[package]] name = "melior" -version = "0.15.1" +version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "758bbd4448db9e994578ab48a6da5210512378f70ac1632cc8c2ae0fbd6c21b5" +checksum = "878012ddccd6fdd099a4d98cebdecbaed9bc5eb325d0778ab9d4f4a52c67c18e" dependencies = [ "dashmap", "melior-macro", diff --git a/crates/concrete_ast/src/expressions.rs b/crates/concrete_ast/src/expressions.rs index 0abd876..12f8752 100644 --- a/crates/concrete_ast/src/expressions.rs +++ b/crates/concrete_ast/src/expressions.rs @@ -1,8 +1,12 @@ -use crate::{common::Ident, statements::Statement, types::TypeSpec}; +use crate::{ + common::Ident, + statements::Statement, + types::{RefType, TypeSpec}, +}; #[derive(Clone, Debug, Eq, PartialEq)] pub enum Expression { - Simple(SimpleExpr), + Value(ValueExpr), FnCall(FnCallOp), Match(MatchExpr), If(IfExpr), @@ -10,15 +14,16 @@ pub enum Expression { BinaryOp(Box, BinaryOp, Box), } -// needed for match variants and array accesses #[derive(Clone, Debug, Eq, PartialEq)] -pub enum SimpleExpr { +pub enum ValueExpr { ConstBool(bool), ConstChar(char), ConstInt(u64), ConstFloat(()), ConstStr(String), Path(PathOp), + Deref(PathOp), + AsRef { path: PathOp, ref_type: RefType }, } #[derive(Clone, Copy, Debug, Eq, PartialEq)] @@ -83,14 +88,14 @@ pub struct IfExpr { #[derive(Clone, Debug, Eq, PartialEq)] pub struct MatchVariant { - pub case: SimpleExpr, + pub case: ValueExpr, pub block: Vec, } #[derive(Clone, Debug, Eq, PartialEq)] pub enum PathSegment { FieldAccess(Ident), - ArrayIndex(SimpleExpr), + ArrayIndex(ValueExpr), } #[derive(Clone, Debug, Eq, PartialEq)] diff --git a/crates/concrete_ast/src/types.rs b/crates/concrete_ast/src/types.rs index 450f8ee..2ae7c87 100644 --- a/crates/concrete_ast/src/types.rs +++ b/crates/concrete_ast/src/types.rs @@ -1,15 +1,40 @@ use crate::common::{DocString, Ident, Span}; +#[derive(Clone, Debug, Eq, Hash, PartialEq, Copy)] +pub enum RefType { + Borrow, + MutBorrow, +} + #[derive(Clone, Debug, Eq, Hash, PartialEq)] pub enum TypeSpec { Simple { name: Ident, + is_ref: Option, + span: Span, }, Generic { name: Ident, + is_ref: Option, type_params: Vec, span: Span, }, + Array { + of_type: Box, + size: Option, + is_ref: Option, + span: Span, + }, +} + +impl TypeSpec { + pub fn is_ref(&self) -> Option { + match self { + TypeSpec::Simple { is_ref, .. } => *is_ref, + TypeSpec::Generic { is_ref, .. } => *is_ref, + TypeSpec::Array { is_ref, .. } => *is_ref, + } + } } #[derive(Clone, Debug, Eq, Hash, PartialEq)] diff --git a/crates/concrete_codegen_mlir/Cargo.toml b/crates/concrete_codegen_mlir/Cargo.toml index 19ce13e..9df9c6f 100644 --- a/crates/concrete_codegen_mlir/Cargo.toml +++ b/crates/concrete_codegen_mlir/Cargo.toml @@ -11,7 +11,7 @@ concrete_ast = { path = "../concrete_ast"} concrete_session = { path = "../concrete_session"} itertools = "0.12.0" llvm-sys = "170.0.1" -melior = { version = "0.15.0", features = ["ods-dialects"] } +melior = { version = "0.15.2", features = ["ods-dialects"] } mlir-sys = "0.2.1" tracing.workspace = true diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index a4f10bd..0db203f 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, error::Error}; use bumpalo::Bump; use concrete_ast::{ expressions::{ - ArithOp, BinaryOp, CmpOp, Expression, FnCallOp, IfExpr, LogicOp, PathOp, SimpleExpr, + ArithOp, BinaryOp, CmpOp, Expression, FnCallOp, IfExpr, LogicOp, PathOp, ValueExpr, }, functions::FunctionDef, modules::{Module, ModuleDefItem}, @@ -144,26 +144,48 @@ impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> { &self, context: &'ctx MeliorContext, spec: &TypeSpec, + ) -> Result, Box> { + match spec.is_ref() { + Some(_) => { + Ok( + MemRefType::new(self.resolve_type_spec_ref(context, spec)?, &[], None, None) + .into(), + ) + } + None => self.resolve_type_spec_ref(context, spec), + } + } + + /// Resolves the type this ref points to. + fn resolve_type_spec_ref( + &self, + context: &'ctx MeliorContext, + spec: &TypeSpec, ) -> Result, Box> { Ok(match spec { - TypeSpec::Simple { name } => self.resolve_type(context, &name.name)?, + TypeSpec::Simple { name, .. } => self.resolve_type(context, &name.name)?, TypeSpec::Generic { name, .. } => self.resolve_type(context, &name.name)?, + TypeSpec::Array { .. } => { + todo!("implement arrays") + } }) } fn is_type_signed(&self, type_info: &TypeSpec) -> bool { let signed = ["i8", "i16", "i32", "i64", "i128"]; match type_info { - TypeSpec::Simple { name } => signed.contains(&name.name.as_str()), + TypeSpec::Simple { name, .. } => signed.contains(&name.name.as_str()), TypeSpec::Generic { name, .. } => signed.contains(&name.name.as_str()), + TypeSpec::Array { .. } => unreachable!(), } } fn is_float(&self, type_info: &TypeSpec) -> bool { let signed = ["f32", "f64"]; match type_info { - TypeSpec::Simple { name } => signed.contains(&name.name.as_str()), + TypeSpec::Simple { name, .. } => signed.contains(&name.name.as_str()), TypeSpec::Generic { name, .. } => signed.contains(&name.name.as_str()), + TypeSpec::Array { .. } => unreachable!(), } } } @@ -508,6 +530,8 @@ fn compile_let_stmt<'ctx, 'parent: 'ctx>( ) -> Result<(), Box> { match &info.target { LetStmtTarget::Simple { name, r#type } => { + let location = get_location(context, session, name.span.from); + let value = compile_expression( session, context, @@ -517,10 +541,7 @@ fn compile_let_stmt<'ctx, 'parent: 'ctx>( &info.value, Some(r#type), )?; - - let location = get_location(context, session, name.span.from); - - let memref_type = MemRefType::new(value.r#type(), &[1], None, None); + let memref_type = MemRefType::new(value.r#type(), &[], None, None); let alloca: Value = block .append_operation(memref::alloca( @@ -533,15 +554,8 @@ fn compile_let_stmt<'ctx, 'parent: 'ctx>( )) .result(0)? .into(); - let k0 = block - .append_operation(arith::constant( - context, - IntegerAttribute::new(0, Type::index(context)).into(), - location, - )) - .result(0)? - .into(); - block.append_operation(memref::store(value, alloca, &[k0], location)); + + block.append_operation(memref::store(value, alloca, &[], location)); scope_ctx .locals @@ -583,15 +597,7 @@ fn compile_assign_stmt<'ctx, 'parent: 'ctx>( Some(&local.type_spec), )?; - let k0 = block - .append_operation(arith::constant( - context, - IntegerAttribute::new(0, Type::index(context)).into(), - location, - )) - .result(0)? - .into(); - block.append_operation(memref::store(value, local.value, &[k0], location)); + block.append_operation(memref::store(value, local.value, &[], location)); Ok(()) } @@ -635,39 +641,9 @@ fn compile_expression<'ctx, 'parent: 'ctx>( ) -> Result, Box> { let location = Location::unknown(context); match info { - Expression::Simple(simple) => match simple { - SimpleExpr::ConstBool(value) => { - let value = - IntegerAttribute::new((*value).into(), IntegerType::new(context, 1).into()); - Ok(block - .append_operation(arith::constant(context, value.into(), location)) - .result(0)? - .into()) - } - SimpleExpr::ConstChar(value) => { - let value = - IntegerAttribute::new((*value) as i64, IntegerType::new(context, 32).into()); - Ok(block - .append_operation(arith::constant(context, value.into(), location)) - .result(0)? - .into()) - } - SimpleExpr::ConstInt(value) => { - let int_type = if let Some(type_info) = type_info { - scope_ctx.resolve_type_spec(context, type_info)? - } else { - IntegerType::new(context, 64).into() - }; - let value = IntegerAttribute::new((*value) as i64, int_type); - Ok(block - .append_operation(arith::constant(context, value.into(), location)) - .result(0)? - .into()) - } - SimpleExpr::ConstFloat(_) => todo!(), - SimpleExpr::ConstStr(_) => todo!(), - SimpleExpr::Path(value) => compile_path_op(session, context, scope_ctx, block, value), - }, + Expression::Value(value) => compile_value_expr( + session, context, scope_ctx, _helper, block, value, type_info, + ), Expression::FnCall(value) => { compile_fn_call(session, context, scope_ctx, _helper, block, value) } @@ -798,6 +774,59 @@ fn compile_expression<'ctx, 'parent: 'ctx>( } } +fn compile_value_expr<'ctx, 'parent: 'ctx>( + session: &Session, + context: &'ctx MeliorContext, + scope_ctx: &mut ScopeContext<'ctx, 'parent>, + _helper: &BlockHelper<'ctx, 'parent>, + block: &'parent Block<'ctx>, + value: &ValueExpr, + type_info: Option<&TypeSpec>, +) -> Result, Box> { + tracing::debug!("compiling value_expr for {:?}", value); + let location = Location::unknown(context); + match value { + ValueExpr::ConstBool(value) => { + let value = IntegerAttribute::new((*value).into(), IntegerType::new(context, 1).into()); + Ok(block + .append_operation(arith::constant(context, value.into(), location)) + .result(0)? + .into()) + } + ValueExpr::ConstChar(value) => { + let value = + IntegerAttribute::new((*value) as i64, IntegerType::new(context, 32).into()); + Ok(block + .append_operation(arith::constant(context, value.into(), location)) + .result(0)? + .into()) + } + ValueExpr::ConstInt(value) => { + let int_type = if let Some(type_info) = type_info { + scope_ctx.resolve_type_spec(context, type_info)? + } else { + IntegerType::new(context, 64).into() + }; + let value = IntegerAttribute::new((*value) as i64, int_type); + Ok(block + .append_operation(arith::constant(context, value.into(), location)) + .result(0)? + .into()) + } + ValueExpr::ConstFloat(_) => todo!(), + ValueExpr::ConstStr(_) => todo!(), + ValueExpr::Path(value) => { + compile_path_op(session, context, scope_ctx, _helper, block, value) + } + ValueExpr::Deref(value) => { + compile_deref(session, context, scope_ctx, _helper, block, value) + } + ValueExpr::AsRef { path, ref_type: _ } => { + compile_asref(session, context, scope_ctx, _helper, block, path) + } + } +} + fn compile_fn_call<'ctx, 'parent: 'ctx>( session: &Session, context: &'ctx MeliorContext, @@ -806,6 +835,7 @@ fn compile_fn_call<'ctx, 'parent: 'ctx>( block: &'parent Block<'ctx>, info: &FnCallOp, ) -> Result, Box> { + tracing::debug!("compiling fncall: {:?}", info); let mut args = Vec::with_capacity(info.args.len()); let location = get_location(context, session, info.target.span.from); @@ -857,34 +887,84 @@ fn compile_path_op<'ctx, 'parent: 'ctx>( session: &Session, context: &'ctx MeliorContext, scope_ctx: &mut ScopeContext<'ctx, 'parent>, + _helper: &BlockHelper<'ctx, 'parent>, block: &'parent Block<'ctx>, path: &PathOp, ) -> Result, Box> { - // For now only simple variables work. + tracing::debug!("compiling pathop {:?}", path); + // For now only simple and array variables work. // TODO: implement properly, this requires having structs implemented. let local = scope_ctx .locals .get(&path.first.name) - .expect("local not found"); + .unwrap_or_else(|| panic!("local {} not found", path.first.name)) + .clone(); let location = get_location(context, session, path.first.span.from); - if local.alloca { - let k0 = block - .append_operation(arith::constant( - context, - IntegerAttribute::new(0, Type::index(context)).into(), - location, - )) + let value = if local.alloca { + block + .append_operation(memref::load(local.value, &[], location)) .result(0)? - .into(); - let value = block - .append_operation(memref::load(local.value, &[k0], location)) + .into() + } else { + local.value + }; + + Ok(value) +} + +fn compile_deref<'ctx, 'parent: 'ctx>( + session: &Session, + context: &'ctx MeliorContext, + scope_ctx: &mut ScopeContext<'ctx, 'parent>, + _helper: &BlockHelper<'ctx, 'parent>, + block: &'parent Block<'ctx>, + path: &PathOp, +) -> Result, Box> { + tracing::debug!("compiling deref for {:?}", path); + let local = scope_ctx + .locals + .get(&path.first.name) + .expect("local not found") + .clone(); + + let location = get_location(context, session, path.first.span.from); + + let mut value = block + .append_operation(memref::load(local.value, &[], location)) + .result(0)? + .into(); + + if local.alloca { + value = block + .append_operation(memref::load(value, &[], location)) .result(0)? .into(); - Ok(value) - } else { - Ok(local.value) } + + Ok(value) +} + +fn compile_asref<'ctx, 'parent: 'ctx>( + _session: &Session, + _context: &'ctx MeliorContext, + scope_ctx: &mut ScopeContext<'ctx, 'parent>, + _helper: &BlockHelper<'ctx, 'parent>, + _block: &'parent Block<'ctx>, + path: &PathOp, +) -> Result, Box> { + tracing::debug!("compiling asref for {:?}", path); + let local = scope_ctx + .locals + .get(&path.first.name) + .expect("local not found") + .clone(); + + if !local.alloca { + panic!("can only take refs to non register values"); + } + + Ok(local.value) } diff --git a/crates/concrete_driver/src/lib.rs b/crates/concrete_driver/src/lib.rs index 8c4735d..e293af0 100644 --- a/crates/concrete_driver/src/lib.rs +++ b/crates/concrete_driver/src/lib.rs @@ -23,6 +23,10 @@ pub struct CompilerArgs { /// Build as a library. #[arg(short, long, default_value_t = false)] library: bool, + + /// Prints the ast. + #[arg(long, default_value_t = false)] + print_ast: bool, } pub fn main() -> Result<(), Box> { @@ -53,6 +57,10 @@ pub fn main() -> Result<(), Box> { } }; + if args.print_ast { + println!("{:#?}", program); + } + let cwd = std::env::current_dir()?; // todo: find a better name, "target" would clash with rust if running in the source tree. let target_dir = cwd.join("build_artifacts/"); diff --git a/crates/concrete_driver/tests/programs.rs b/crates/concrete_driver/tests/programs.rs index 14cee5e..fecc2f3 100644 --- a/crates/concrete_driver/tests/programs.rs +++ b/crates/concrete_driver/tests/programs.rs @@ -130,3 +130,31 @@ fn test_import() { let code = output.status.code().unwrap(); assert_eq!(code, 8); } + +#[test] +fn test_reference() { + let source = r#" + mod Simple { + fn main(argc: i64) -> i64 { + let x: i64 = argc; + return references(x) + dereference(&x); + } + + fn dereference(a: &i64) -> i64 { + return *a; + } + + fn references(a: i64) -> i64 { + let x: i64 = a; + let y: &i64 = &x; + return *y; + } + } + "#; + + let result = compile_program(source, "references", false).expect("failed to compile"); + + let output = run_program(&result.binary_file).expect("failed to run"); + let code = output.status.code().unwrap(); + assert_eq!(code, 2); +} diff --git a/crates/concrete_parser/src/grammar.lalrpop b/crates/concrete_parser/src/grammar.lalrpop index d8b1cbe..7bcda21 100644 --- a/crates/concrete_parser/src/grammar.lalrpop +++ b/crates/concrete_parser/src/grammar.lalrpop @@ -66,7 +66,7 @@ extern { "!" => Token::OperatorNot, "~" => Token::OperatorBitwiseNot, "^" => Token::OperatorBitwiseXor, - "&" => Token::OperatorBitwiseAnd, + "&" => Token::Ampersand, "|" => Token::OperatorBitwiseOr, } } @@ -131,13 +131,27 @@ pub(crate) Ident: ast::common::Ident = { } } +pub(crate) RefType: ast::types::RefType = { + "&" => ast::types::RefType::Borrow, + "&" "mut" => ast::types::RefType::MutBorrow, +} + pub(crate) TypeSpec: ast::types::TypeSpec = { - => ast::types::TypeSpec::Simple { - name + => ast::types::TypeSpec::Simple { + name, + is_ref, + span: Span::new(lo, hi), }, - "<" > ">" => ast::types::TypeSpec::Generic { + "<" > ">" => ast::types::TypeSpec::Generic { name, type_params, + is_ref, + span: Span::new(lo, hi), + }, + "[" )?> "]" => ast::types::TypeSpec::Array { + of_type: Box::new(of_type), + size, + is_ref, span: Span::new(lo, hi), } } @@ -272,7 +286,7 @@ pub(crate) FunctionDef: ast::functions::FunctionDef = { // Expressions pub(crate) Term: ast::expressions::Expression = { - => ast::expressions::Expression::Simple(<>), + => ast::expressions::Expression::Value(<>), => ast::expressions::Expression::FnCall(<>), => ast::expressions::Expression::Match(<>), => ast::expressions::Expression::If(<>), @@ -339,14 +353,18 @@ pub UnaryOp: ast::expressions::UnaryOp = { "~" => ast::expressions::UnaryOp::BitwiseNot, } -pub(crate) SimpleExpr: ast::expressions::SimpleExpr = { - <"integer"> => ast::expressions::SimpleExpr::ConstInt(<>), - <"boolean"> => ast::expressions::SimpleExpr::ConstBool(<>), - <"string"> => ast::expressions::SimpleExpr::ConstStr(<>), - => ast::expressions::SimpleExpr::Path(<>), +pub(crate) ValueExpr: ast::expressions::ValueExpr = { + <"integer"> => ast::expressions::ValueExpr::ConstInt(<>), + <"boolean"> => ast::expressions::ValueExpr::ConstBool(<>), + <"string"> => ast::expressions::ValueExpr::ConstStr(<>), + => ast::expressions::ValueExpr::Path(<>), + "*" => ast::expressions::ValueExpr::Deref(<>), + => ast::expressions::ValueExpr::AsRef { + path, + ref_type + }, } - pub(crate) IfExpr: ast::expressions::IfExpr = { "if" "{" "}" "}")?> => { @@ -369,14 +387,14 @@ pub(crate) MatchExpr: ast::expressions::MatchExpr = { pub(crate) MatchVariant: ast::expressions::MatchVariant = { // 0 -> 1 - "->" => { + "->" => { ast::expressions::MatchVariant { case, block: vec![stmt] } }, // x -> { ... } - "->" "{" "}" => { + "->" "{" "}" => { ast::expressions::MatchVariant { case, block: stmts @@ -393,7 +411,7 @@ pub(crate) PathOp: ast::expressions::PathOp = { pub(crate) PathSegment: ast::expressions::PathSegment = { "." => ast::expressions::PathSegment::FieldAccess(<>), - "[" "]" => ast::expressions::PathSegment::ArrayIndex(e), + "[" "]" => ast::expressions::PathSegment::ArrayIndex(e), } pub PathSegments: Vec = { diff --git a/crates/concrete_parser/src/tokens.rs b/crates/concrete_parser/src/tokens.rs index 9e06b3f..316a74d 100644 --- a/crates/concrete_parser/src/tokens.rs +++ b/crates/concrete_parser/src/tokens.rs @@ -122,7 +122,7 @@ pub enum Token { #[token("^")] OperatorBitwiseXor, #[token("&")] - OperatorBitwiseAnd, + Ampersand, #[token("|")] OperatorBitwiseOr, } diff --git a/examples/borrow.con b/examples/borrow.con new file mode 100644 index 0000000..1f99c91 --- /dev/null +++ b/examples/borrow.con @@ -0,0 +1,16 @@ +mod Simple { + fn main(argc: i64) -> i64 { + let x: i64 = argc; + return references(x) + dereference(&x); + } + + fn dereference(a: &i64) -> i64 { + return *a; + } + + fn references(a: i64) -> i64 { + let x: i64 = a; + let y: &i64 = &x; + return *y; + } +}