Skip to content

Commit

Permalink
[refactor] Proposed improved AST variables convention
Browse files Browse the repository at this point in the history
  • Loading branch information
kenarab committed May 31, 2024
1 parent b9568b8 commit 9c63fe8
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 92 deletions.
8 changes: 4 additions & 4 deletions crates/concrete_ast/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,16 @@ pub enum BitwiseOp {

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct MatchExpr {
pub value: Box<Expression>,
pub match_expr: Box<Expression>,
pub variants: Vec<MatchVariant>,
pub span: Span,
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct IfExpr {
pub value: Box<Expression>,
pub contents: Vec<Statement>,
pub r#else: Option<Vec<Statement>>,
pub cond: Box<Expression>,
pub block_stmts: Vec<Statement>,
pub else_stmts: Option<Vec<Statement>>,
pub span: Span,
}

Expand Down
20 changes: 10 additions & 10 deletions crates/concrete_ast/src/statements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ pub enum Statement {

#[derive(Clone, Debug, Eq, PartialEq)]
pub enum LetStmtTarget {
Simple { name: Ident, r#type: TypeSpec },
Simple { id: Ident, r#type: TypeSpec },
Destructure(Vec<Binding>),
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct LetStmt {
pub is_mutable: bool,
pub target: LetStmtTarget,
pub value: Expression,
pub lvalue: LetStmtTarget,
pub rvalue: Expression,
pub span: Span,
}

Expand All @@ -38,30 +38,30 @@ pub struct ReturnStmt {

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct AssignStmt {
pub target: PathOp,
pub lvalue: PathOp,
pub derefs: usize,
pub value: Expression,
pub rvalue: Expression,
pub span: Span,
}

#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct Binding {
pub name: Ident,
pub id: Ident,
pub rename: Option<Ident>,
pub r#type: TypeSpec,
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ForStmt {
pub init: Option<LetStmt>,
pub condition: Option<Expression>,
pub cond: Option<Expression>,
pub post: Option<AssignStmt>,
pub contents: Vec<Statement>,
pub block_stmts: Vec<Statement>,
pub span: Span,
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct WhileStmt {
pub value: Expression,
pub contents: Vec<Statement>,
pub cond: Expression,
pub block_stmts: Vec<Statement>,
}
54 changes: 27 additions & 27 deletions crates/concrete_check/src/linearity_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,14 +324,14 @@ impl LinearityChecker {
match statement {
Statement::Let(binding) => {
// Handle let bindings, possibly involving pattern matching
self.count_in_expression(name, &binding.value)
self.count_in_expression(name, &binding.rvalue)
}
Statement::If(if_stmt) => {
// Process all components of an if expression
let cond_apps = self.count_in_expression(name, &if_stmt.value);
let then_apps = self.count_in_statements(name, &if_stmt.contents);
let cond_apps = self.count_in_expression(name, &if_stmt.cond);
let then_apps = self.count_in_statements(name, &if_stmt.block_stmts);
let else_apps;
let else_statements = &if_stmt.r#else;
let else_statements = &if_stmt.else_stmts;
if let Some(else_statements) = else_statements {
else_apps = self.count_in_statements(name, else_statements);
} else {
Expand All @@ -340,8 +340,8 @@ impl LinearityChecker {
cond_apps.merge(&then_apps).merge(&else_apps)
}
Statement::While(while_expr) => {
let cond = &while_expr.value;
let block = &while_expr.contents;
let cond = &while_expr.cond;
let block = &while_expr.block_stmts;
// Handle while loops
self.count_in_expression(name, cond)
.merge(&self.count_in_statements(name, block))
Expand All @@ -350,9 +350,9 @@ impl LinearityChecker {
// Handle for loops
//init, cond, post, block
let init = &for_expr.init;
let cond = &for_expr.condition;
let cond = &for_expr.cond;
let post = &for_expr.post;
let block = &for_expr.contents;
let block = &for_expr.block_stmts;
let mut apps = Appearances::zero();
if let Some(init) = init {
if let Some(cond) = cond {
Expand Down Expand Up @@ -396,9 +396,9 @@ impl LinearityChecker {

fn count_in_assign_statement(&self, name: &str, assign_stmt: &AssignStmt) -> Appearances {
let AssignStmt {
target,
lvalue: target,
derefs,
value,
rvalue: value,
span,
} = assign_stmt;
// Handle assignments
Expand All @@ -422,8 +422,8 @@ impl LinearityChecker {
fn count_in_let_statements(&self, name: &str, let_stmt: &LetStmt) -> Appearances {
let LetStmt {
is_mutable,
target,
value,
lvalue: target,
rvalue: value,
span,
} = let_stmt;
self.count_in_expression(name, value)
Expand Down Expand Up @@ -466,10 +466,10 @@ impl LinearityChecker {
Expression::If(if_expr) => {
// Process all components of an if expression
// TODO review this code. If expressions should be processed counting both branches and comparing them
let cond_apps = self.count_in_expression(name, &if_expr.value);
let then_apps = self.count_in_statements(name, &if_expr.contents);
let cond_apps = self.count_in_expression(name, &if_expr.cond);
let then_apps = self.count_in_statements(name, &if_expr.block_stmts);
cond_apps.merge(&then_apps);
if let Some(else_block) = &if_expr.r#else {
if let Some(else_block) = &if_expr.else_stmts {
let else_apps = self.count_in_statements(name, else_block);
cond_apps.merge(&then_apps).merge(&else_apps);
}
Expand Down Expand Up @@ -524,12 +524,12 @@ impl LinearityChecker {
// Handle let bindings, possibly involving pattern matching
let LetStmt {
is_mutable,
target,
value,
lvalue: target,
rvalue: value,
span,
} = binding;
match target {
LetStmtTarget::Simple { name, r#type } => {
LetStmtTarget::Simple { id: name, r#type } => {
match r#type {
TypeSpec::Simple {
name: variable_type,
Expand Down Expand Up @@ -702,19 +702,19 @@ impl LinearityChecker {
//Statement::If(cond, then_block, else_block) => {
Statement::If(if_stmt) => {
// Handle conditional statements
state_tbl = self.check_expr(state_tbl, depth, &if_stmt.value, context)?;
state_tbl = self.check_stmts(state_tbl, depth + 1, &if_stmt.contents, context)?;
if let Some(else_block) = &if_stmt.r#else {
state_tbl = self.check_expr(state_tbl, depth, &if_stmt.cond, context)?;
state_tbl = self.check_stmts(state_tbl, depth + 1, &if_stmt.block_stmts, context)?;
if let Some(else_block) = &if_stmt.else_stmts {
state_tbl = self.check_stmts(state_tbl, depth + 1, else_block, context)?;
}
Ok(state_tbl)
}
//Statement::While(cond, block) => {
Statement::While(while_stmt) => {
// Handle while loops
state_tbl = self.check_expr(state_tbl, depth, &while_stmt.value, context)?;
state_tbl = self.check_expr(state_tbl, depth, &while_stmt.cond, context)?;
state_tbl =
self.check_stmts(state_tbl, depth + 1, &while_stmt.contents, context)?;
self.check_stmts(state_tbl, depth + 1, &while_stmt.block_stmts, context)?;
Ok(state_tbl)
}
//Statement::For(init, cond, post, block) => {
Expand All @@ -723,22 +723,22 @@ impl LinearityChecker {
if let Some(init) = &for_stmt.init {
state_tbl = self.check_stmt_let(state_tbl, depth, init, context)?;
}
if let Some(condition) = &for_stmt.condition {
if let Some(condition) = &for_stmt.cond {
state_tbl = self.check_expr(state_tbl, depth, condition, context)?;
}
if let Some(post) = &for_stmt.post {
//TODO check assign statement
//self.check_stmt_assign(depth, post)?;
}
state_tbl = self.check_stmts(state_tbl, depth + 1, &for_stmt.contents, context)?;
state_tbl = self.check_stmts(state_tbl, depth + 1, &for_stmt.block_stmts, context)?;
Ok(state_tbl)
}
Statement::Assign(assign_stmt) => {
// Handle assignments
let AssignStmt {
target,
lvalue: target,
derefs,
value,
rvalue: value,
span,
} = assign_stmt;
tracing::debug!("Checking assignment: {:?}", assign_stmt);
Expand Down
38 changes: 19 additions & 19 deletions crates/concrete_ir/src/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,8 @@ fn lower_func(
// Get all locals
for stmt in &func.body {
if let statements::Statement::Let(info) = stmt {
match &info.target {
LetStmtTarget::Simple { name, r#type } => {
match &info.lvalue {
LetStmtTarget::Simple { id: name, r#type } => {
let ty = lower_type(&builder.ctx, r#type, builder.local_module)?;
builder
.name_to_local
Expand All @@ -305,8 +305,8 @@ fn lower_func(
}
} else if let statements::Statement::For(info) = stmt {
if let Some(info) = &info.init {
match &info.target {
LetStmtTarget::Simple { name, r#type } => {
match &info.lvalue {
LetStmtTarget::Simple { id: name, r#type } => {
let ty = lower_type(&builder.ctx, r#type, builder.local_module)?;
builder
.name_to_local
Expand Down Expand Up @@ -435,7 +435,7 @@ fn lower_while(builder: &mut FnBodyBuilder, info: &WhileStmt) -> Result<(), Lowe
});

let (discriminator, discriminator_type, _disc_span) =
lower_expression(builder, &info.value, None)?;
lower_expression(builder, &info.cond, None)?;

let local = builder.add_temp_local(TyKind::Bool);
let place = Place {
Expand Down Expand Up @@ -463,7 +463,7 @@ fn lower_while(builder: &mut FnBodyBuilder, info: &WhileStmt) -> Result<(), Lowe
// keep idx for switch targets
let first_then_block_idx = builder.body.basic_blocks.len();

for stmt in &info.contents {
for stmt in &info.block_stmts {
lower_statement(
builder,
stmt,
Expand Down Expand Up @@ -515,7 +515,7 @@ fn lower_for(builder: &mut FnBodyBuilder, info: &ForStmt) -> Result<(), Lowering
}),
});

let (discriminator, discriminator_type, _disc_span) = if let Some(condition) = &info.condition {
let (discriminator, discriminator_type, _disc_span) = if let Some(condition) = &info.cond {
let (discriminator, discriminator_type, span) = lower_expression(builder, condition, None)?;

(discriminator, discriminator_type, Some(span))
Expand Down Expand Up @@ -560,7 +560,7 @@ fn lower_for(builder: &mut FnBodyBuilder, info: &ForStmt) -> Result<(), Lowering
// keep idx for switch targets
let first_then_block_idx = builder.body.basic_blocks.len();

for stmt in &info.contents {
for stmt in &info.block_stmts {
lower_statement(
builder,
stmt,
Expand Down Expand Up @@ -602,7 +602,7 @@ fn lower_for(builder: &mut FnBodyBuilder, info: &ForStmt) -> Result<(), Lowering

fn lower_if_statement(builder: &mut FnBodyBuilder, info: &IfExpr) -> Result<(), LoweringError> {
let (discriminator, discriminator_type, _disc_span) =
lower_expression(builder, &info.value, None)?;
lower_expression(builder, &info.cond, None)?;

let local = builder.add_temp_local(TyKind::Bool);
let place = Place {
Expand Down Expand Up @@ -630,7 +630,7 @@ fn lower_if_statement(builder: &mut FnBodyBuilder, info: &IfExpr) -> Result<(),
// keep idx for switch targets
let first_then_block_idx = builder.body.basic_blocks.len();

for stmt in &info.contents {
for stmt in &info.block_stmts {
lower_statement(
builder,
stmt,
Expand All @@ -654,7 +654,7 @@ fn lower_if_statement(builder: &mut FnBodyBuilder, info: &IfExpr) -> Result<(),

let first_else_block_idx = builder.body.basic_blocks.len();

if let Some(contents) = &info.r#else {
if let Some(contents) = &info.else_stmts {
for stmt in contents {
lower_statement(
builder,
Expand Down Expand Up @@ -697,11 +697,11 @@ fn lower_if_statement(builder: &mut FnBodyBuilder, info: &IfExpr) -> Result<(),
}

fn lower_let(builder: &mut FnBodyBuilder, info: &LetStmt) -> Result<(), LoweringError> {
match &info.target {
LetStmtTarget::Simple { name, r#type } => {
match &info.lvalue {
LetStmtTarget::Simple { id: name, r#type } => {
let ty = lower_type(&builder.ctx, r#type, builder.local_module)?;
let (rvalue, rvalue_ty, rvalue_span) =
lower_expression(builder, &info.value, Some(ty.clone()))?;
lower_expression(builder, &info.rvalue, Some(ty.clone()))?;

if ty.kind != rvalue_ty.kind {
return Err(LoweringError::UnexpectedType {
Expand Down Expand Up @@ -734,7 +734,7 @@ fn lower_let(builder: &mut FnBodyBuilder, info: &LetStmt) -> Result<(), Lowering
}

fn lower_assign(builder: &mut FnBodyBuilder, info: &AssignStmt) -> Result<(), LoweringError> {
let (mut place, mut ty, _path_span) = lower_path(builder, &info.target)?;
let (mut place, mut ty, _path_span) = lower_path(builder, &info.lvalue)?;

if !builder.body.locals[place.local].is_mutable() {
return Err(LoweringError::NotMutable {
Expand All @@ -749,8 +749,8 @@ fn lower_assign(builder: &mut FnBodyBuilder, info: &AssignStmt) -> Result<(), Lo
TyKind::Ref(inner, is_mut) | TyKind::Ptr(inner, is_mut) => {
if matches!(is_mut, Mutability::Not) {
Err(LoweringError::BorrowNotMutable {
span: info.target.first.span,
name: info.target.first.name.clone(),
span: info.lvalue.first.span,
name: info.lvalue.first.name.clone(),
type_span: ty.span,
program_id: builder.local_module.program_id,
})?;
Expand All @@ -763,7 +763,7 @@ fn lower_assign(builder: &mut FnBodyBuilder, info: &AssignStmt) -> Result<(), Lo
}

let (rvalue, rvalue_ty, rvalue_span) =
lower_expression(builder, &info.value, Some(ty.clone()))?;
lower_expression(builder, &info.rvalue, Some(ty.clone()))?;

if ty.kind != rvalue_ty.kind {
return Err(LoweringError::UnexpectedType {
Expand All @@ -775,7 +775,7 @@ fn lower_assign(builder: &mut FnBodyBuilder, info: &AssignStmt) -> Result<(), Lo
}

builder.statements.push(Statement {
span: Some(info.target.first.span),
span: Some(info.lvalue.first.span),
kind: StatementKind::Assign(place, rvalue),
});

Expand Down
Loading

0 comments on commit 9c63fe8

Please sign in to comment.