Skip to content

Commit

Permalink
Merge pull request #140 from kenarab/ast-refactor
Browse files Browse the repository at this point in the history
Ast refactor
  • Loading branch information
igaray authored Jun 7, 2024
2 parents 8f9e32e + 727ba89 commit 9575fbe
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 70 deletions.
8 changes: 4 additions & 4 deletions crates/concrete_ast/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,16 @@ pub enum BitwiseOp {

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct MatchExpr {
pub value: Box<Expression>,
pub expr: Box<Expression>,
pub variants: Vec<MatchVariant>,
pub span: Span,
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct IfExpr {
pub value: Box<Expression>,
pub contents: Vec<Statement>,
pub r#else: Option<Vec<Statement>>,
pub cond: Box<Expression>,
pub block_stmts: Vec<Statement>,
pub else_stmts: Option<Vec<Statement>>,
pub span: Span,
}

Expand Down
14 changes: 7 additions & 7 deletions crates/concrete_ast/src/statements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub enum Statement {

#[derive(Clone, Debug, Eq, PartialEq)]
pub enum LetStmtTarget {
Simple { name: Ident, r#type: TypeSpec },
Simple { id: Ident, r#type: TypeSpec },
Destructure(Vec<Binding>),
}

Expand All @@ -38,15 +38,15 @@ pub struct ReturnStmt {

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct AssignStmt {
pub target: PathOp,
pub lvalue: PathOp,
pub derefs: usize,
pub value: Expression,
pub rvalue: Expression,
pub span: Span,
}

#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct Binding {
pub name: Ident,
pub id: Ident,
pub rename: Option<Ident>,
pub r#type: TypeSpec,
}
Expand All @@ -56,12 +56,12 @@ pub struct ForStmt {
pub init: Option<LetStmt>,
pub condition: Option<Expression>,
pub post: Option<AssignStmt>,
pub contents: Vec<Statement>,
pub block_stmts: Vec<Statement>,
pub span: Span,
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct WhileStmt {
pub value: Expression,
pub contents: Vec<Statement>,
pub condition: Expression,
pub block_stmts: Vec<Statement>,
}
42 changes: 22 additions & 20 deletions crates/concrete_check/src/linearity_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,10 @@ impl LinearityChecker {
}
Statement::If(if_stmt) => {
// Process all components of an if expression
let cond_apps = self.count_in_expression(name, &if_stmt.value);
let then_apps = self.count_in_statements(name, &if_stmt.contents);
let cond_apps = self.count_in_expression(name, &if_stmt.cond);
let then_apps = self.count_in_statements(name, &if_stmt.block_stmts);
let else_apps;
let else_statements = &if_stmt.r#else;
let else_statements = &if_stmt.else_stmts;
if let Some(else_statements) = else_statements {
else_apps = self.count_in_statements(name, else_statements);
} else {
Expand All @@ -308,8 +308,8 @@ impl LinearityChecker {
cond_apps.merge(&then_apps).merge(&else_apps)
}
Statement::While(while_expr) => {
let cond = &while_expr.value;
let block = &while_expr.contents;
let cond = &while_expr.condition;
let block = &while_expr.block_stmts;
// Handle while loops
self.count_in_expression(name, cond)
.merge(&self.count_in_statements(name, block))
Expand All @@ -320,7 +320,7 @@ impl LinearityChecker {
let init = &for_expr.init;
let cond = &for_expr.condition;
let post = &for_expr.post;
let block = &for_expr.contents;
let block = &for_expr.block_stmts;
let mut apps = Appearances::zero();
if let Some(init) = init {
if let Some(cond) = cond {
Expand Down Expand Up @@ -364,9 +364,9 @@ impl LinearityChecker {

fn count_in_assign_statement(&self, name: &str, assign_stmt: &AssignStmt) -> Appearances {
let AssignStmt {
target,
lvalue: target,
derefs,
value,
rvalue: value,
span,
} = assign_stmt;
// Handle assignments
Expand Down Expand Up @@ -434,10 +434,10 @@ impl LinearityChecker {
Expression::If(if_expr) => {
// Process all components of an if expression
// TODO review this code. If expressions should be processed counting both branches and comparing them
let cond_apps = self.count_in_expression(name, &if_expr.value);
let then_apps = self.count_in_statements(name, &if_expr.contents);
let cond_apps = self.count_in_expression(name, &if_expr.cond);
let then_apps = self.count_in_statements(name, &if_expr.block_stmts);
cond_apps.merge(&then_apps);
if let Some(else_block) = &if_expr.r#else {
if let Some(else_block) = &if_expr.else_stmts {
let else_apps = self.count_in_statements(name, else_block);
cond_apps.merge(&then_apps).merge(&else_apps);
}
Expand Down Expand Up @@ -497,7 +497,7 @@ impl LinearityChecker {
span,
} = binding;
match target {
LetStmtTarget::Simple { name, r#type } => {
LetStmtTarget::Simple { id: name, r#type } => {
match r#type {
TypeSpec::Simple {
name: variable_type,
Expand Down Expand Up @@ -670,19 +670,20 @@ impl LinearityChecker {
//Statement::If(cond, then_block, else_block) => {
Statement::If(if_stmt) => {
// Handle conditional statements
state_tbl = self.check_expr(state_tbl, depth, &if_stmt.value, context)?;
state_tbl = self.check_stmts(state_tbl, depth + 1, &if_stmt.contents, context)?;
if let Some(else_block) = &if_stmt.r#else {
state_tbl = self.check_expr(state_tbl, depth, &if_stmt.cond, context)?;
state_tbl =
self.check_stmts(state_tbl, depth + 1, &if_stmt.block_stmts, context)?;
if let Some(else_block) = &if_stmt.else_stmts {
state_tbl = self.check_stmts(state_tbl, depth + 1, else_block, context)?;
}
Ok(state_tbl)
}
//Statement::While(cond, block) => {
Statement::While(while_stmt) => {
// Handle while loops
state_tbl = self.check_expr(state_tbl, depth, &while_stmt.value, context)?;
state_tbl = self.check_expr(state_tbl, depth, &while_stmt.condition, context)?;
state_tbl =
self.check_stmts(state_tbl, depth + 1, &while_stmt.contents, context)?;
self.check_stmts(state_tbl, depth + 1, &while_stmt.block_stmts, context)?;
Ok(state_tbl)
}
//Statement::For(init, cond, post, block) => {
Expand All @@ -698,15 +699,16 @@ impl LinearityChecker {
//TODO check assign statement
//self.check_stmt_assign(depth, post)?;
}
state_tbl = self.check_stmts(state_tbl, depth + 1, &for_stmt.contents, context)?;
state_tbl =
self.check_stmts(state_tbl, depth + 1, &for_stmt.block_stmts, context)?;
Ok(state_tbl)
}
Statement::Assign(assign_stmt) => {
// Handle assignments
let AssignStmt {
target,
lvalue: target,
derefs,
value,
rvalue: value,
span,
} = assign_stmt;
tracing::debug!("Checking assignment: {:?}", assign_stmt);
Expand Down
28 changes: 14 additions & 14 deletions crates/concrete_ir/src/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ fn lower_func(
for stmt in &func.body {
if let statements::Statement::Let(info) = stmt {
match &info.target {
LetStmtTarget::Simple { name, r#type } => {
LetStmtTarget::Simple { id: name, r#type } => {
let ty = lower_type(&builder.ctx, r#type, builder.local_module)?;
builder
.name_to_local
Expand All @@ -391,7 +391,7 @@ fn lower_func(
} else if let statements::Statement::For(info) = stmt {
if let Some(info) = &info.init {
match &info.target {
LetStmtTarget::Simple { name, r#type } => {
LetStmtTarget::Simple { id: name, r#type } => {
let ty = lower_type(&builder.ctx, r#type, builder.local_module)?;
builder
.name_to_local
Expand Down Expand Up @@ -520,7 +520,7 @@ fn lower_while(builder: &mut FnBodyBuilder, info: &WhileStmt) -> Result<(), Lowe
});

let (discriminator, discriminator_type, _disc_span) =
lower_expression(builder, &info.value, None)?;
lower_expression(builder, &info.condition, None)?;

let local = builder.add_temp_local(TyKind::Bool);
let place = Place {
Expand Down Expand Up @@ -548,7 +548,7 @@ fn lower_while(builder: &mut FnBodyBuilder, info: &WhileStmt) -> Result<(), Lowe
// keep idx for switch targets
let first_then_block_idx = builder.body.basic_blocks.len();

for stmt in &info.contents {
for stmt in &info.block_stmts {
lower_statement(
builder,
stmt,
Expand Down Expand Up @@ -645,7 +645,7 @@ fn lower_for(builder: &mut FnBodyBuilder, info: &ForStmt) -> Result<(), Lowering
// keep idx for switch targets
let first_then_block_idx = builder.body.basic_blocks.len();

for stmt in &info.contents {
for stmt in &info.block_stmts {
lower_statement(
builder,
stmt,
Expand Down Expand Up @@ -687,7 +687,7 @@ fn lower_for(builder: &mut FnBodyBuilder, info: &ForStmt) -> Result<(), Lowering

fn lower_if_statement(builder: &mut FnBodyBuilder, info: &IfExpr) -> Result<(), LoweringError> {
let (discriminator, discriminator_type, _disc_span) =
lower_expression(builder, &info.value, None)?;
lower_expression(builder, &info.cond, None)?;

let local = builder.add_temp_local(TyKind::Bool);
let place = Place {
Expand Down Expand Up @@ -715,7 +715,7 @@ fn lower_if_statement(builder: &mut FnBodyBuilder, info: &IfExpr) -> Result<(),
// keep idx for switch targets
let first_then_block_idx = builder.body.basic_blocks.len();

for stmt in &info.contents {
for stmt in &info.block_stmts {
lower_statement(
builder,
stmt,
Expand All @@ -739,7 +739,7 @@ fn lower_if_statement(builder: &mut FnBodyBuilder, info: &IfExpr) -> Result<(),

let first_else_block_idx = builder.body.basic_blocks.len();

if let Some(contents) = &info.r#else {
if let Some(contents) = &info.else_stmts {
for stmt in contents {
lower_statement(
builder,
Expand Down Expand Up @@ -783,7 +783,7 @@ fn lower_if_statement(builder: &mut FnBodyBuilder, info: &IfExpr) -> Result<(),

fn lower_let(builder: &mut FnBodyBuilder, info: &LetStmt) -> Result<(), LoweringError> {
match &info.target {
LetStmtTarget::Simple { name, r#type } => {
LetStmtTarget::Simple { id: name, r#type } => {
let ty = lower_type(&builder.ctx, r#type, builder.local_module)?;
let (rvalue, rvalue_ty, rvalue_span) =
lower_expression(builder, &info.value, Some(ty.clone()))?;
Expand Down Expand Up @@ -819,7 +819,7 @@ fn lower_let(builder: &mut FnBodyBuilder, info: &LetStmt) -> Result<(), Lowering
}

fn lower_assign(builder: &mut FnBodyBuilder, info: &AssignStmt) -> Result<(), LoweringError> {
let (mut place, mut ty, _path_span) = lower_path(builder, &info.target)?;
let (mut place, mut ty, _path_span) = lower_path(builder, &info.lvalue)?;

if !builder.body.locals[place.local].is_mutable() {
return Err(LoweringError::NotMutable {
Expand All @@ -834,8 +834,8 @@ fn lower_assign(builder: &mut FnBodyBuilder, info: &AssignStmt) -> Result<(), Lo
TyKind::Ref(inner, is_mut) | TyKind::Ptr(inner, is_mut) => {
if matches!(is_mut, Mutability::Not) {
Err(LoweringError::BorrowNotMutable {
span: info.target.first.span,
name: info.target.first.name.clone(),
span: info.lvalue.first.span,
name: info.lvalue.first.name.clone(),
type_span: ty.span,
program_id: builder.local_module.program_id,
})?;
Expand All @@ -848,7 +848,7 @@ fn lower_assign(builder: &mut FnBodyBuilder, info: &AssignStmt) -> Result<(), Lo
}

let (rvalue, rvalue_ty, rvalue_span) =
lower_expression(builder, &info.value, Some(ty.clone()))?;
lower_expression(builder, &info.rvalue, Some(ty.clone()))?;

if ty.kind != rvalue_ty.kind {
return Err(LoweringError::UnexpectedType {
Expand All @@ -860,7 +860,7 @@ fn lower_assign(builder: &mut FnBodyBuilder, info: &AssignStmt) -> Result<(), Lo
}

builder.statements.push(Statement {
span: Some(info.target.first.span),
span: Some(info.lvalue.first.span),
kind: StatementKind::Assign(place, rvalue),
});

Expand Down
Loading

0 comments on commit 9575fbe

Please sign in to comment.