diff --git a/crates/concrete_ast/src/statements.rs b/crates/concrete_ast/src/statements.rs index f316fc5..4c0a3cd 100644 --- a/crates/concrete_ast/src/statements.rs +++ b/crates/concrete_ast/src/statements.rs @@ -56,6 +56,7 @@ pub struct ReturnStmt { pub struct AssignStmt { pub target: PathOp, pub value: Expression, + pub is_deref: bool, } #[derive(Clone, Debug, Eq, Hash, PartialEq)] diff --git a/crates/concrete_codegen_mlir/src/codegen.rs b/crates/concrete_codegen_mlir/src/codegen.rs index 2d20d5d..90b2151 100644 --- a/crates/concrete_codegen_mlir/src/codegen.rs +++ b/crates/concrete_codegen_mlir/src/codegen.rs @@ -9,7 +9,7 @@ use concrete_ast::{ functions::FunctionDef, modules::{Module, ModuleDefItem}, statements::{AssignStmt, LetStmt, LetStmtTarget, LetValue, ReturnStmt, Statement, WhileStmt}, - types::TypeSpec, + types::{RefType, TypeSpec}, Program, }; use concrete_session::Session; @@ -60,22 +60,25 @@ pub struct LocalVar<'ctx, 'parent: 'ctx> { // If it's none its on a register, otherwise allocated on the stack. pub alloca: bool, pub value: Value<'ctx, 'parent>, + pub is_mut: bool, } impl<'ctx, 'parent: 'ctx> LocalVar<'ctx, 'parent> { - pub fn param(value: Value<'ctx, 'parent>, type_spec: TypeSpec) -> Self { + pub fn param(value: Value<'ctx, 'parent>, type_spec: TypeSpec, is_mut: bool) -> Self { Self { value, type_spec, alloca: false, + is_mut, } } - pub fn alloca(value: Value<'ctx, 'parent>, type_spec: TypeSpec) -> Self { + pub fn alloca(value: Value<'ctx, 'parent>, type_spec: TypeSpec, is_mut: bool) -> Self { Self { value, type_spec, alloca: true, + is_mut, } } } @@ -221,7 +224,7 @@ fn compile_function_def<'ctx, 'parent: 'ctx>( for (i, param) in info.decl.params.iter().enumerate() { scope_ctx.locals.insert( param.name.name.clone(), - LocalVar::param(fn_block.argument(i)?.into(), param.r#type.clone()), + LocalVar::param(fn_block.argument(i)?.into(), param.r#type.clone(), false), ); } @@ -521,9 +524,10 @@ fn compile_let_stmt<'ctx, 'parent: 'ctx>( block.append_operation(memref::store(value, alloca, &[], location)); - scope_ctx - .locals - .insert(name.name.clone(), LocalVar::alloca(alloca, r#type.clone())); + scope_ctx.locals.insert( + name.name.clone(), + LocalVar::alloca(alloca, r#type.clone(), info.is_mutable), + ); Ok(()) } @@ -547,6 +551,7 @@ fn compile_assign_stmt<'ctx, 'parent: 'ctx>( .expect("local should exist") .clone(); + assert!(local.is_mut, "can only mutate mutable variables"); assert!(local.alloca, "can only mutate local stack variables"); let location = get_location(context, session, info.target.first.span.from); @@ -562,12 +567,17 @@ fn compile_assign_stmt<'ctx, 'parent: 'ctx>( Some(&local.type_spec), )?; - block.append_operation(memref::store(value, local.value, &[], location)); + match local.type_spec.is_ref() { + Some(RefType::MutBorrow) => {} + Some(RefType::Borrow) => {} + None => { + block.append_operation(memref::store(value, local.value, &[], location)); + } + } } else { let mut current_type_spec = &local.type_spec; - // todo: instead of loading, use memref.extract_aligned_pointer_as_index - + // get a ptr to the field let target_ptr = block .append_operation( melior::dialect::ods::memref::extract_aligned_pointer_as_index( diff --git a/crates/concrete_parser/src/grammar.lalrpop b/crates/concrete_parser/src/grammar.lalrpop index ae623e4..e8777b9 100644 --- a/crates/concrete_parser/src/grammar.lalrpop +++ b/crates/concrete_parser/src/grammar.lalrpop @@ -470,7 +470,7 @@ pub(crate) Statement: ast::statements::Statement = { ";" => ast::statements::Statement::Let(<>), ";" => ast::statements::Statement::Assign(<>), ";" => ast::statements::Statement::FnCall(<>), - ";"? => ast::statements::Statement::Return(<>), + ";" => ast::statements::Statement::Return(<>), } pub(crate) LetStmt: ast::statements::LetStmt = { @@ -508,9 +508,10 @@ pub(crate) FieldConstruct: ast::statements::FieldConstruct = { } pub(crate) AssignStmt: ast::statements::AssignStmt = { - "=" => ast::statements::AssignStmt { + "=" => ast::statements::AssignStmt { target, - value + value, + is_deref: is_deref.is_some(), }, } diff --git a/crates/concrete_parser/src/lib.rs b/crates/concrete_parser/src/lib.rs index d0b4e1b..61e1c89 100644 --- a/crates/concrete_parser/src/lib.rs +++ b/crates/concrete_parser/src/lib.rs @@ -59,7 +59,9 @@ mod ModuleName { x = x % 2; match x { - 0 -> return 2, + 0 -> { + return 2; + }, 1 -> { let y: u64 = x * 2; return y * 10; @@ -97,8 +99,12 @@ mod ModuleName { let source = r##"mod FactorialModule { pub fn factorial(x: u64) -> u64 { return match x { - 0 -> return 1, - n -> return n * factorial(n-1), + 0 -> { + return 1; + }, + n -> { + return n * factorial(n-1); + }, }; } }"##;