diff --git a/crates/concrete_ast/src/expressions.rs b/crates/concrete_ast/src/expressions.rs index 818c666..92818e9 100644 --- a/crates/concrete_ast/src/expressions.rs +++ b/crates/concrete_ast/src/expressions.rs @@ -99,16 +99,16 @@ pub enum BitwiseOp { #[derive(Clone, Debug, Eq, PartialEq)] pub struct MatchExpr { - pub value: Box, + pub expr: Box, pub variants: Vec, pub span: Span, } #[derive(Clone, Debug, Eq, PartialEq)] pub struct IfExpr { - pub value: Box, - pub contents: Vec, - pub r#else: Option>, + pub cond: Box, + pub block_stmts: Vec, + pub else_stmts: Option>, pub span: Span, } diff --git a/crates/concrete_ast/src/statements.rs b/crates/concrete_ast/src/statements.rs index f836281..0d2f348 100644 --- a/crates/concrete_ast/src/statements.rs +++ b/crates/concrete_ast/src/statements.rs @@ -18,7 +18,7 @@ pub enum Statement { #[derive(Clone, Debug, Eq, PartialEq)] pub enum LetStmtTarget { - Simple { name: Ident, r#type: TypeSpec }, + Simple { id: Ident, r#type: TypeSpec }, Destructure(Vec), } @@ -38,15 +38,15 @@ pub struct ReturnStmt { #[derive(Clone, Debug, Eq, PartialEq)] pub struct AssignStmt { - pub target: PathOp, + pub lvalue: PathOp, pub derefs: usize, - pub value: Expression, + pub rvalue: Expression, pub span: Span, } #[derive(Clone, Debug, Eq, Hash, PartialEq)] pub struct Binding { - pub name: Ident, + pub id: Ident, pub rename: Option, pub r#type: TypeSpec, } @@ -56,12 +56,12 @@ pub struct ForStmt { pub init: Option, pub condition: Option, pub post: Option, - pub contents: Vec, + pub block_stmts: Vec, pub span: Span, } #[derive(Clone, Debug, Eq, PartialEq)] pub struct WhileStmt { - pub value: Expression, - pub contents: Vec, + pub condition: Expression, + pub block_stmts: Vec, } diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index 11bf651..fbe228a 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -296,10 +296,10 @@ impl LinearityChecker { } Statement::If(if_stmt) => { // Process all components of an if expression - let cond_apps = self.count_in_expression(name, &if_stmt.value); - let then_apps = self.count_in_statements(name, &if_stmt.contents); + let cond_apps = self.count_in_expression(name, &if_stmt.cond); + let then_apps = self.count_in_statements(name, &if_stmt.block_stmts); let else_apps; - let else_statements = &if_stmt.r#else; + let else_statements = &if_stmt.else_stmts; if let Some(else_statements) = else_statements { else_apps = self.count_in_statements(name, else_statements); } else { @@ -308,8 +308,8 @@ impl LinearityChecker { cond_apps.merge(&then_apps).merge(&else_apps) } Statement::While(while_expr) => { - let cond = &while_expr.value; - let block = &while_expr.contents; + let cond = &while_expr.condition; + let block = &while_expr.block_stmts; // Handle while loops self.count_in_expression(name, cond) .merge(&self.count_in_statements(name, block)) @@ -320,7 +320,7 @@ impl LinearityChecker { let init = &for_expr.init; let cond = &for_expr.condition; let post = &for_expr.post; - let block = &for_expr.contents; + let block = &for_expr.block_stmts; let mut apps = Appearances::zero(); if let Some(init) = init { if let Some(cond) = cond { @@ -364,9 +364,9 @@ impl LinearityChecker { fn count_in_assign_statement(&self, name: &str, assign_stmt: &AssignStmt) -> Appearances { let AssignStmt { - target, + lvalue: target, derefs, - value, + rvalue: value, span, } = assign_stmt; // Handle assignments @@ -434,10 +434,10 @@ impl LinearityChecker { Expression::If(if_expr) => { // Process all components of an if expression // TODO review this code. If expressions should be processed counting both branches and comparing them - let cond_apps = self.count_in_expression(name, &if_expr.value); - let then_apps = self.count_in_statements(name, &if_expr.contents); + let cond_apps = self.count_in_expression(name, &if_expr.cond); + let then_apps = self.count_in_statements(name, &if_expr.block_stmts); cond_apps.merge(&then_apps); - if let Some(else_block) = &if_expr.r#else { + if let Some(else_block) = &if_expr.else_stmts { let else_apps = self.count_in_statements(name, else_block); cond_apps.merge(&then_apps).merge(&else_apps); } @@ -497,7 +497,7 @@ impl LinearityChecker { span, } = binding; match target { - LetStmtTarget::Simple { name, r#type } => { + LetStmtTarget::Simple { id: name, r#type } => { match r#type { TypeSpec::Simple { name: variable_type, @@ -670,9 +670,10 @@ impl LinearityChecker { //Statement::If(cond, then_block, else_block) => { Statement::If(if_stmt) => { // Handle conditional statements - state_tbl = self.check_expr(state_tbl, depth, &if_stmt.value, context)?; - state_tbl = self.check_stmts(state_tbl, depth + 1, &if_stmt.contents, context)?; - if let Some(else_block) = &if_stmt.r#else { + state_tbl = self.check_expr(state_tbl, depth, &if_stmt.cond, context)?; + state_tbl = + self.check_stmts(state_tbl, depth + 1, &if_stmt.block_stmts, context)?; + if let Some(else_block) = &if_stmt.else_stmts { state_tbl = self.check_stmts(state_tbl, depth + 1, else_block, context)?; } Ok(state_tbl) @@ -680,9 +681,9 @@ impl LinearityChecker { //Statement::While(cond, block) => { Statement::While(while_stmt) => { // Handle while loops - state_tbl = self.check_expr(state_tbl, depth, &while_stmt.value, context)?; + state_tbl = self.check_expr(state_tbl, depth, &while_stmt.condition, context)?; state_tbl = - self.check_stmts(state_tbl, depth + 1, &while_stmt.contents, context)?; + self.check_stmts(state_tbl, depth + 1, &while_stmt.block_stmts, context)?; Ok(state_tbl) } //Statement::For(init, cond, post, block) => { @@ -698,15 +699,16 @@ impl LinearityChecker { //TODO check assign statement //self.check_stmt_assign(depth, post)?; } - state_tbl = self.check_stmts(state_tbl, depth + 1, &for_stmt.contents, context)?; + state_tbl = + self.check_stmts(state_tbl, depth + 1, &for_stmt.block_stmts, context)?; Ok(state_tbl) } Statement::Assign(assign_stmt) => { // Handle assignments let AssignStmt { - target, + lvalue: target, derefs, - value, + rvalue: value, span, } = assign_stmt; tracing::debug!("Checking assignment: {:?}", assign_stmt); diff --git a/crates/concrete_ir/src/lowering.rs b/crates/concrete_ir/src/lowering.rs index 6dabc96..db16ba0 100644 --- a/crates/concrete_ir/src/lowering.rs +++ b/crates/concrete_ir/src/lowering.rs @@ -373,7 +373,7 @@ fn lower_func( for stmt in &func.body { if let statements::Statement::Let(info) = stmt { match &info.target { - LetStmtTarget::Simple { name, r#type } => { + LetStmtTarget::Simple { id: name, r#type } => { let ty = lower_type(&builder.ctx, r#type, builder.local_module)?; builder .name_to_local @@ -391,7 +391,7 @@ fn lower_func( } else if let statements::Statement::For(info) = stmt { if let Some(info) = &info.init { match &info.target { - LetStmtTarget::Simple { name, r#type } => { + LetStmtTarget::Simple { id: name, r#type } => { let ty = lower_type(&builder.ctx, r#type, builder.local_module)?; builder .name_to_local @@ -520,7 +520,7 @@ fn lower_while(builder: &mut FnBodyBuilder, info: &WhileStmt) -> Result<(), Lowe }); let (discriminator, discriminator_type, _disc_span) = - lower_expression(builder, &info.value, None)?; + lower_expression(builder, &info.condition, None)?; let local = builder.add_temp_local(TyKind::Bool); let place = Place { @@ -548,7 +548,7 @@ fn lower_while(builder: &mut FnBodyBuilder, info: &WhileStmt) -> Result<(), Lowe // keep idx for switch targets let first_then_block_idx = builder.body.basic_blocks.len(); - for stmt in &info.contents { + for stmt in &info.block_stmts { lower_statement( builder, stmt, @@ -645,7 +645,7 @@ fn lower_for(builder: &mut FnBodyBuilder, info: &ForStmt) -> Result<(), Lowering // keep idx for switch targets let first_then_block_idx = builder.body.basic_blocks.len(); - for stmt in &info.contents { + for stmt in &info.block_stmts { lower_statement( builder, stmt, @@ -687,7 +687,7 @@ fn lower_for(builder: &mut FnBodyBuilder, info: &ForStmt) -> Result<(), Lowering fn lower_if_statement(builder: &mut FnBodyBuilder, info: &IfExpr) -> Result<(), LoweringError> { let (discriminator, discriminator_type, _disc_span) = - lower_expression(builder, &info.value, None)?; + lower_expression(builder, &info.cond, None)?; let local = builder.add_temp_local(TyKind::Bool); let place = Place { @@ -715,7 +715,7 @@ fn lower_if_statement(builder: &mut FnBodyBuilder, info: &IfExpr) -> Result<(), // keep idx for switch targets let first_then_block_idx = builder.body.basic_blocks.len(); - for stmt in &info.contents { + for stmt in &info.block_stmts { lower_statement( builder, stmt, @@ -739,7 +739,7 @@ fn lower_if_statement(builder: &mut FnBodyBuilder, info: &IfExpr) -> Result<(), let first_else_block_idx = builder.body.basic_blocks.len(); - if let Some(contents) = &info.r#else { + if let Some(contents) = &info.else_stmts { for stmt in contents { lower_statement( builder, @@ -783,7 +783,7 @@ fn lower_if_statement(builder: &mut FnBodyBuilder, info: &IfExpr) -> Result<(), fn lower_let(builder: &mut FnBodyBuilder, info: &LetStmt) -> Result<(), LoweringError> { match &info.target { - LetStmtTarget::Simple { name, r#type } => { + LetStmtTarget::Simple { id: name, r#type } => { let ty = lower_type(&builder.ctx, r#type, builder.local_module)?; let (rvalue, rvalue_ty, rvalue_span) = lower_expression(builder, &info.value, Some(ty.clone()))?; @@ -819,7 +819,7 @@ fn lower_let(builder: &mut FnBodyBuilder, info: &LetStmt) -> Result<(), Lowering } fn lower_assign(builder: &mut FnBodyBuilder, info: &AssignStmt) -> Result<(), LoweringError> { - let (mut place, mut ty, _path_span) = lower_path(builder, &info.target)?; + let (mut place, mut ty, _path_span) = lower_path(builder, &info.lvalue)?; if !builder.body.locals[place.local].is_mutable() { return Err(LoweringError::NotMutable { @@ -834,8 +834,8 @@ fn lower_assign(builder: &mut FnBodyBuilder, info: &AssignStmt) -> Result<(), Lo TyKind::Ref(inner, is_mut) | TyKind::Ptr(inner, is_mut) => { if matches!(is_mut, Mutability::Not) { Err(LoweringError::BorrowNotMutable { - span: info.target.first.span, - name: info.target.first.name.clone(), + span: info.lvalue.first.span, + name: info.lvalue.first.name.clone(), type_span: ty.span, program_id: builder.local_module.program_id, })?; @@ -848,7 +848,7 @@ fn lower_assign(builder: &mut FnBodyBuilder, info: &AssignStmt) -> Result<(), Lo } let (rvalue, rvalue_ty, rvalue_span) = - lower_expression(builder, &info.value, Some(ty.clone()))?; + lower_expression(builder, &info.rvalue, Some(ty.clone()))?; if ty.kind != rvalue_ty.kind { return Err(LoweringError::UnexpectedType { @@ -860,7 +860,7 @@ fn lower_assign(builder: &mut FnBodyBuilder, info: &AssignStmt) -> Result<(), Lo } builder.statements.push(Statement { - span: Some(info.target.first.span), + span: Some(info.lvalue.first.span), kind: StatementKind::Assign(place, rvalue), }); diff --git a/crates/concrete_parser/src/grammar.lalrpop b/crates/concrete_parser/src/grammar.lalrpop index 821d76d..e7a3ed6 100644 --- a/crates/concrete_parser/src/grammar.lalrpop +++ b/crates/concrete_parser/src/grammar.lalrpop @@ -503,21 +503,21 @@ pub(crate) ValueExpr: ast::expressions::ValueExpr = { } pub(crate) IfExpr: ast::expressions::IfExpr = { - "if" "{" "}" + "if" "{" "}" "}")?> => { ast::expressions::IfExpr { - value: Box::new(value), - contents, - r#else: else_stmts, + cond: Box::new(cond), + block_stmts, + else_stmts, span: Span::new(lo, hi) } } } pub(crate) MatchExpr: ast::expressions::MatchExpr = { - "match" "{" > "}" => { + "match" "{" > "}" => { ast::expressions::MatchExpr { - value: Box::new(value), + expr: Box::new(expr), variants, span: Span::new(lo, hi) } @@ -594,19 +594,19 @@ pub(crate) Statement: ast::statements::Statement = { } pub(crate) LetStmt: ast::statements::LetStmt = { - "let" ":" "=" => ast::statements::LetStmt { + "let" ":" "=" => ast::statements::LetStmt { is_mutable: is_mutable.is_some(), target: ast::statements::LetStmtTarget::Simple { - name, + id, r#type: target_type }, value, span: Span::new(lo, hi), }, - "let" ":" "=" => ast::statements::LetStmt { + "let" ":" "=" => ast::statements::LetStmt { is_mutable: is_mutable.is_some(), target: ast::statements::LetStmtTarget::Simple { - name, + id, r#type: target_type }, value: ast::expressions::Expression::StructInit(value), @@ -615,15 +615,15 @@ pub(crate) LetStmt: ast::statements::LetStmt = { } pub(crate) AssignStmt: ast::statements::AssignStmt = { - "=" => ast::statements::AssignStmt { - target, - value, + "=" => ast::statements::AssignStmt { + lvalue, + rvalue, derefs: derefs.len(), span: Span::new(lo, hi), }, - "=" => ast::statements::AssignStmt { - target, - value: ast::expressions::Expression::StructInit(value), + "=" => ast::statements::AssignStmt { + lvalue, + rvalue: ast::expressions::Expression::StructInit(rvalue), derefs: derefs.len(), span: Span::new(lo, hi), }, @@ -637,40 +637,40 @@ pub(crate) ReturnStmt: ast::statements::ReturnStmt = { } pub(crate) WhileStmt: ast::statements::WhileStmt = { - "while" "{" "}" => { + "while" "{" "}" => { ast::statements::WhileStmt { - value, - contents, + condition, + block_stmts, } } } pub(crate) ForStmt: ast::statements::ForStmt = { - "for" "(" ";" ";" ")" "{" "}" => { + "for" "(" ";" ";" ")" "{" "}" => { ast::statements::ForStmt { init, condition, post, - contents, + block_stmts, span: Span::new(lo, hi) } }, - "for" "(" ")" "{" "}" => { + "for" "(" ")" "{" "}" => { ast::statements::ForStmt { init: None, condition: Some(condition), post: None, - contents, + block_stmts, span: Span::new(lo, hi) } }, - "for" "{" "}" => { + "for" "{" "}" => { ast::statements::ForStmt { init: None, condition: None, post: None, - contents, + block_stmts, span: Span::new(lo, hi) } }