diff --git a/crates/concrete_ast/src/expressions.rs b/crates/concrete_ast/src/expressions.rs index 7ce2361..12f8752 100644 --- a/crates/concrete_ast/src/expressions.rs +++ b/crates/concrete_ast/src/expressions.rs @@ -1,4 +1,8 @@ -use crate::{common::Ident, statements::Statement, types::TypeSpec}; +use crate::{ + common::Ident, + statements::Statement, + types::{RefType, TypeSpec}, +}; #[derive(Clone, Debug, Eq, PartialEq)] pub enum Expression { @@ -19,6 +23,7 @@ pub enum ValueExpr { ConstStr(String), Path(PathOp), Deref(PathOp), + AsRef { path: PathOp, ref_type: RefType }, } #[derive(Clone, Copy, Debug, Eq, PartialEq)] diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 26ff84f..b4bf8bd 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -1,4 +1,4 @@ -use std::{char::MAX, collections::HashMap, error::Error}; +use std::{collections::HashMap, error::Error}; use bumpalo::Bump; use concrete_ast::{ @@ -16,19 +16,13 @@ use concrete_session::Session; use melior::{ dialect::{ arith::{self, CmpiPredicate}, - cf, func, - llvm::{self, r#type::opaque_pointer, LoadStoreOptions}, - memref, + cf, func, memref, }, ir::{ - attribute::{ - DenseI64ArrayAttribute, FlatSymbolRefAttribute, IntegerAttribute, StringAttribute, - TypeAttribute, - }, - operation::OperationBuilder, + attribute::{FlatSymbolRefAttribute, IntegerAttribute, StringAttribute, TypeAttribute}, r#type::{FunctionType, IntegerType, MemRefType}, - Block, BlockRef, Location, Module as MeliorModule, Operation, Region, Type, TypeLike, - Value, ValueLike, + Block, BlockRef, Location, Module as MeliorModule, Operation, Region, Type, Value, + ValueLike, }, Context as MeliorContext, }; @@ -172,65 +166,27 @@ impl<'ctx, 'parent> ScopeContext<'ctx, 'parent> { Ok(match spec { TypeSpec::Simple { name, .. } => self.resolve_type(context, &name.name)?, TypeSpec::Generic { name, .. } => self.resolve_type(context, &name.name)?, - TypeSpec::Array { - of_type, - span: _, - is_ref: _, - size, - } => match size { - Some(size) => { - let inner_type = self.resolve_type_spec(context, of_type)?; - MemRefType::new(inner_type, &[*size], None, None).into() - /* - llvm::r#type::array( - self.resolve_type_spec(context, &of_type)?, - (*size).try_into().expect("size was above u32"), - ) - */ - } - None => { - // - let inner_type = self.resolve_type_spec(context, of_type)?; - // Type::parse(context, &format!("memref")).unwrap() - - llvm::r#type::r#struct( - context, - &[ - Type::parse(context, &format!("memref")).unwrap(), - IntegerType::new(context, 64).into(), - ], - false, - ) - } - }, + TypeSpec::Array { .. } => { + todo!("implement arrays") + } }) } fn is_type_signed(&self, type_info: &TypeSpec) -> bool { let signed = ["i8", "i16", "i32", "i64", "i128"]; match type_info { - TypeSpec::Simple { name, is_ref, .. } => signed.contains(&name.name.as_str()), - TypeSpec::Generic { name, is_ref, .. } => signed.contains(&name.name.as_str()), - TypeSpec::Array { - of_type, - span, - is_ref, - size, - } => unreachable!(), + TypeSpec::Simple { name, .. } => signed.contains(&name.name.as_str()), + TypeSpec::Generic { name, .. } => signed.contains(&name.name.as_str()), + TypeSpec::Array { .. } => unreachable!(), } } fn is_float(&self, type_info: &TypeSpec) -> bool { let signed = ["f32", "f64"]; match type_info { - TypeSpec::Simple { name, is_ref, .. } => signed.contains(&name.name.as_str()), - TypeSpec::Generic { name, is_ref, .. } => signed.contains(&name.name.as_str()), - TypeSpec::Array { - of_type, - span, - is_ref, - size, - } => unreachable!(), + TypeSpec::Simple { name, .. } => signed.contains(&name.name.as_str()), + TypeSpec::Generic { name, .. } => signed.contains(&name.name.as_str()), + TypeSpec::Array { .. } => unreachable!(), } } } @@ -565,17 +521,6 @@ fn compile_while<'c, 'this: 'c>( Ok(merge_block) } -fn is_local_copy(a: &Expression) -> Option<(&PathOp, bool)> { - match a { - Expression::Value(value) => match value { - ValueExpr::Path(path) => Some((path, false)), - ValueExpr::Deref(path) => Some((path, true)), - _ => None, - }, - _ => None, - } -} - fn compile_let_stmt<'ctx, 'parent: 'ctx>( session: &Session, context: &'ctx MeliorContext, @@ -610,16 +555,8 @@ fn compile_let_stmt<'ctx, 'parent: 'ctx>( )) .result(0)? .into(); - let k0 = block - .append_operation(arith::constant( - context, - IntegerAttribute::new(0, Type::index(context)).into(), - location, - )) - .result(0)? - .into(); - block.append_operation(memref::store(value, alloca, &[k0], location)); + block.append_operation(memref::store(value, alloca, &[], location)); scope_ctx .locals @@ -907,6 +844,7 @@ fn compile_value_expr<'ctx, 'parent: 'ctx>( value: &ValueExpr, type_info: Option<&TypeSpec>, ) -> Result, Box> { + tracing::debug!("compiling value_expr for {:?}", value); let location = Location::unknown(context); match value { ValueExpr::ConstBool(value) => { @@ -944,6 +882,9 @@ fn compile_value_expr<'ctx, 'parent: 'ctx>( ValueExpr::Deref(value) => { compile_deref(session, context, scope_ctx, _helper, block, value) } + ValueExpr::AsRef { path, ref_type: _ } => { + compile_asref(session, context, scope_ctx, _helper, block, path) + } } } @@ -955,6 +896,7 @@ fn compile_fn_call<'ctx, 'parent: 'ctx>( block: &'parent Block<'ctx>, info: &FnCallOp, ) -> Result, Box> { + tracing::debug!("compiling fncall: {:?}", info); let mut args = Vec::with_capacity(info.args.len()); let location = get_location(context, session, info.target.span.from); @@ -1006,10 +948,11 @@ fn compile_path_op<'ctx, 'parent: 'ctx>( session: &Session, context: &'ctx MeliorContext, scope_ctx: &mut ScopeContext<'ctx, 'parent>, - helper: &BlockHelper<'ctx, 'parent>, + _helper: &BlockHelper<'ctx, 'parent>, block: &'parent Block<'ctx>, path: &PathOp, ) -> Result, Box> { + tracing::debug!("compiling pathop {:?}", path); // For now only simple and array variables work. // TODO: implement properly, this requires having structs implemented. @@ -1021,7 +964,7 @@ fn compile_path_op<'ctx, 'parent: 'ctx>( let location = get_location(context, session, path.first.span.from); - let mut value = if local.alloca { + let value = if local.alloca { block .append_operation(memref::load(local.value, &[], location)) .result(0)? @@ -1030,46 +973,6 @@ fn compile_path_op<'ctx, 'parent: 'ctx>( local.value }; - for segment in &path.extra { - match segment { - PathSegment::FieldAccess(_) => todo!(), - PathSegment::ArrayIndex(index) => { - let index = - compile_value_expr(session, context, scope_ctx, helper, block, index, None)?; - let index_ty = Type::index(context); - let index = block - .append_operation(melior::dialect::index::castu(index, index_ty, location)) - .result(0)? - .into(); - - if let TypeSpec::Array { - of_type: _, - size, - is_ref: _, - span, - } = &local.type_spec - { - let location = get_location(context, session, span.from); - #[allow(clippy::if_same_then_else)] - if size.is_some() { - // todo: check inbounds? - value = block - .append_operation(memref::load(value, &[index], location)) - .result(0)? - .into(); - } else { - value = block - .append_operation(memref::load(value, &[index], location)) - .result(0)? - .into(); - } - } else { - panic!("type should be a array when indexing a value"); - } - } - } - } - Ok(value) } @@ -1077,10 +980,11 @@ fn compile_deref<'ctx, 'parent: 'ctx>( session: &Session, context: &'ctx MeliorContext, scope_ctx: &mut ScopeContext<'ctx, 'parent>, - helper: &BlockHelper<'ctx, 'parent>, + _helper: &BlockHelper<'ctx, 'parent>, block: &'parent Block<'ctx>, path: &PathOp, ) -> Result, Box> { + tracing::debug!("compiling deref for {:?}", path); let local = scope_ctx .locals .get(&path.first.name) @@ -1088,51 +992,40 @@ fn compile_deref<'ctx, 'parent: 'ctx>( .clone(); let location = get_location(context, session, path.first.span.from); - let inner_type = scope_ctx.resolve_type_spec_ref(context, &local.type_spec)?; let mut value = block - .append_operation(memref::load(local.value, &[], location)).result(0)?.into(); - - - for segment in &path.extra { - match segment { - PathSegment::FieldAccess(_) => todo!(), - PathSegment::ArrayIndex(index) => { - let index = - compile_value_expr(session, context, scope_ctx, helper, block, index, None)?; - let index_ty = Type::index(context); - let index = block - .append_operation(melior::dialect::index::castu(index, index_ty, location)) - .result(0)? - .into(); - - if let TypeSpec::Array { - of_type: _, - size, - is_ref: _, - span, - } = &local.type_spec - { - let location = get_location(context, session, span.from); - #[allow(clippy::if_same_then_else)] - if size.is_some() { - // todo: check inbounds? - value = block - .append_operation(memref::load(value, &[index], location)) - .result(0)? - .into(); - } else { - value = block - .append_operation(memref::load(value, &[index], location)) - .result(0)? - .into(); - } - } else { - panic!("type should be a array when indexing a value"); - } - } - } + .append_operation(memref::load(local.value, &[], location)) + .result(0)? + .into(); + + if local.alloca { + value = block + .append_operation(memref::load(value, &[], location)) + .result(0)? + .into(); } Ok(value) } + +fn compile_asref<'ctx, 'parent: 'ctx>( + _session: &Session, + _context: &'ctx MeliorContext, + scope_ctx: &mut ScopeContext<'ctx, 'parent>, + _helper: &BlockHelper<'ctx, 'parent>, + _block: &'parent Block<'ctx>, + path: &PathOp, +) -> Result, Box> { + tracing::debug!("compiling asref for {:?}", path); + let local = scope_ctx + .locals + .get(&path.first.name) + .expect("local not found") + .clone(); + + if !local.alloca { + panic!("can only take refs to non register values"); + } + + Ok(local.value) +} diff --git a/crates/concrete_driver/tests/programs.rs b/crates/concrete_driver/tests/programs.rs index 14cee5e..fecc2f3 100644 --- a/crates/concrete_driver/tests/programs.rs +++ b/crates/concrete_driver/tests/programs.rs @@ -130,3 +130,31 @@ fn test_import() { let code = output.status.code().unwrap(); assert_eq!(code, 8); } + +#[test] +fn test_reference() { + let source = r#" + mod Simple { + fn main(argc: i64) -> i64 { + let x: i64 = argc; + return references(x) + dereference(&x); + } + + fn dereference(a: &i64) -> i64 { + return *a; + } + + fn references(a: i64) -> i64 { + let x: i64 = a; + let y: &i64 = &x; + return *y; + } + } + "#; + + let result = compile_program(source, "references", false).expect("failed to compile"); + + let output = run_program(&result.binary_file).expect("failed to run"); + let code = output.status.code().unwrap(); + assert_eq!(code, 2); +} diff --git a/crates/concrete_parser/src/grammar.lalrpop b/crates/concrete_parser/src/grammar.lalrpop index 2aac08e..7bcda21 100644 --- a/crates/concrete_parser/src/grammar.lalrpop +++ b/crates/concrete_parser/src/grammar.lalrpop @@ -359,6 +359,10 @@ pub(crate) ValueExpr: ast::expressions::ValueExpr = { <"string"> => ast::expressions::ValueExpr::ConstStr(<>), => ast::expressions::ValueExpr::Path(<>), "*" => ast::expressions::ValueExpr::Deref(<>), + => ast::expressions::ValueExpr::AsRef { + path, + ref_type + }, } pub(crate) IfExpr: ast::expressions::IfExpr = { diff --git a/examples/borrow.con b/examples/borrow.con index b6dfca8..1f99c91 100644 --- a/examples/borrow.con +++ b/examples/borrow.con @@ -1,9 +1,16 @@ mod Simple { fn main(argc: i64) -> i64 { - return argc; + let x: i64 = argc; + return references(x) + dereference(&x); } fn dereference(a: &i64) -> i64 { return *a; } + + fn references(a: i64) -> i64 { + let x: i64 = a; + let y: &i64 = &x; + return *y; + } }