From 36a88792d5bf7342322f3f4b55de42f8080b135d Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Wed, 8 May 2024 16:56:21 -0300 Subject: [PATCH 01/31] linear types example for devloping linear type checker --- examples/linearExample01.con | 21 +++++++++++++++++++++ examples/linearExample02.con | 26 ++++++++++++++++++++++++++ examples/linearExample03if.con | 31 +++++++++++++++++++++++++++++++ 3 files changed, 78 insertions(+) create mode 100644 examples/linearExample01.con create mode 100644 examples/linearExample02.con create mode 100644 examples/linearExample03if.con diff --git a/examples/linearExample01.con b/examples/linearExample01.con new file mode 100644 index 0000000..a992d25 --- /dev/null +++ b/examples/linearExample01.con @@ -0,0 +1,21 @@ +mod LinearExampleStub { + + struct Linear { + x: i32, + y: i32, + } + + fn main() -> i32 { + let xy: Linear = + Linear { + x: 0, + y: 1, + }; + // FIXME Prefered initialization but not yet implemented + // [1, 0]; + let sum: i32 = xy.x + xy.y; + // linear value is written/consumed + xy.x = xy.x + 1; + return xy.x; + } +} \ No newline at end of file diff --git a/examples/linearExample02.con b/examples/linearExample02.con new file mode 100644 index 0000000..876e6a3 --- /dev/null +++ b/examples/linearExample02.con @@ -0,0 +1,26 @@ +mod LinearExampleStub { + + struct Linear { + x: i32, + y: i32, + } + + fn main() -> i32 { + let xy: Linear = + Linear { + x: 0, + y: 1, + }; + // FIXME Prefered initialization but not yet implemented + // [1, 0]; + let sum: i32 = xy.x + xy.y; + // linear value is written/consumed + consume_x(&xy); + return xy.x; + } + + fn consume_x(value: &Linear) { + value.x = value.x + 1; + } + +} \ No newline at end of file diff --git a/examples/linearExample03if.con b/examples/linearExample03if.con new file mode 100644 index 0000000..7801275 --- /dev/null +++ b/examples/linearExample03if.con @@ -0,0 +1,31 @@ +mod LinearExampleIfStub { + + struct Linear { + x: i32, + y: i32, + } + + fn main() -> i32 { + let xy: Linear = + Linear { + x: 0, + y: 1, + }; + // FIXME Prefered initialization but not yet implemented + // [1, 0]; + let sum: i32 = xy.x + xy.y; + let mut consumed: i32 = 0; + if xy.x > xy.y{ + consume_x(&xy, 1); + } + else { + consume_x(&xy, 0-1); + } + return xy.x; + } + + fn consume_x(value: &Linear, i: i32) { + value.x = value.x + i; + } + +} \ No newline at end of file From c863194e6e6144840ba4542667f60109de2b54ce Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Fri, 10 May 2024 14:48:22 -0300 Subject: [PATCH 02/31] linearExamples with mutable types --- crates/concrete_check/src/linearity_check.rs | 48 ++++++++++++++++++++ examples/linearExample01.con | 3 +- examples/linearExample02.con | 5 +- examples/linearExample03if.con | 3 +- 4 files changed, 52 insertions(+), 7 deletions(-) create mode 100644 crates/concrete_check/src/linearity_check.rs diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs new file mode 100644 index 0000000..ff063c5 --- /dev/null +++ b/crates/concrete_check/src/linearity_check.rs @@ -0,0 +1,48 @@ +use std::collections::HashMap; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum VarState { + Available, + Consumed, + Borrowed, + BorrowedMut, +} + +#[derive(Debug, Clone)] +struct StateTbl { + vars: HashMap, +} + +impl StateTbl { + // Initialize with an empty state table + fn new() -> Self { + Self { + vars: HashMap::new(), + } + } + + // Example of updating the state table + fn update_state(&mut self, var: String, state: VarState) { + self.vars.insert(var, state); + } + + // Remove a variable from the state table + fn remove_entry(&mut self, var: &str) { + self.vars.remove(var); + } + + // Retrieve a variable's state + fn get_state(&self, var: &str) -> Option<&VarState> { + self.vars.get(var) + } +} + +// Placeholder function signatures (implementation required) +fn check_expr(expr: &str, state_tbl: &mut StateTbl) { + // Implementation needed based on OCaml logic +} + +fn count(vars: &[String], state_tbl: &StateTbl) -> usize { + // Implementation needed based on OCaml logic + 0 +} diff --git a/examples/linearExample01.con b/examples/linearExample01.con index a992d25..62fce18 100644 --- a/examples/linearExample01.con +++ b/examples/linearExample01.con @@ -6,14 +6,13 @@ mod LinearExampleStub { } fn main() -> i32 { - let xy: Linear = + let mut xy: Linear = Linear { x: 0, y: 1, }; // FIXME Prefered initialization but not yet implemented // [1, 0]; - let sum: i32 = xy.x + xy.y; // linear value is written/consumed xy.x = xy.x + 1; return xy.x; diff --git a/examples/linearExample02.con b/examples/linearExample02.con index 876e6a3..24705be 100644 --- a/examples/linearExample02.con +++ b/examples/linearExample02.con @@ -6,20 +6,19 @@ mod LinearExampleStub { } fn main() -> i32 { - let xy: Linear = + let mut xy: Linear = Linear { x: 0, y: 1, }; // FIXME Prefered initialization but not yet implemented // [1, 0]; - let sum: i32 = xy.x + xy.y; // linear value is written/consumed consume_x(&xy); return xy.x; } - fn consume_x(value: &Linear) { + fn consume_x(value: & mut Linear) { value.x = value.x + 1; } diff --git a/examples/linearExample03if.con b/examples/linearExample03if.con index 7801275..144b2f7 100644 --- a/examples/linearExample03if.con +++ b/examples/linearExample03if.con @@ -13,7 +13,6 @@ mod LinearExampleIfStub { }; // FIXME Prefered initialization but not yet implemented // [1, 0]; - let sum: i32 = xy.x + xy.y; let mut consumed: i32 = 0; if xy.x > xy.y{ consume_x(&xy, 1); @@ -24,7 +23,7 @@ mod LinearExampleIfStub { return xy.x; } - fn consume_x(value: &Linear, i: i32) { + fn consume_x(value: & mut Linear, i: i32) { value.x = value.x + i; } From a70127e02125bf1404470834cd6fc5ba919a29d3 Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Fri, 10 May 2024 16:14:57 -0300 Subject: [PATCH 03/31] linearExamples with mutable types compiles ok --- examples/linearExample02.con | 2 +- examples/linearExample03if.con | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/linearExample02.con b/examples/linearExample02.con index 24705be..f839385 100644 --- a/examples/linearExample02.con +++ b/examples/linearExample02.con @@ -14,7 +14,7 @@ mod LinearExampleStub { // FIXME Prefered initialization but not yet implemented // [1, 0]; // linear value is written/consumed - consume_x(&xy); + consume_x(&mut xy); return xy.x; } diff --git a/examples/linearExample03if.con b/examples/linearExample03if.con index 144b2f7..a6aed8e 100644 --- a/examples/linearExample03if.con +++ b/examples/linearExample03if.con @@ -6,19 +6,18 @@ mod LinearExampleIfStub { } fn main() -> i32 { - let xy: Linear = + let mut xy: Linear = Linear { x: 0, y: 1, }; // FIXME Prefered initialization but not yet implemented // [1, 0]; - let mut consumed: i32 = 0; if xy.x > xy.y{ - consume_x(&xy, 1); + consume_x(&mut xy, 1); } else { - consume_x(&xy, 0-1); + consume_x(&mut xy, 0-1); } return xy.x; } From ed2b6bb02750fbf2ca5cd5ce21071335dce653a5 Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Fri, 10 May 2024 19:07:01 -0300 Subject: [PATCH 04/31] Connected a stub for linearityCheck --- crates/concrete_check/src/lib.rs | 3 ++ crates/concrete_check/src/linearity_check.rs | 28 +++++++++++++++++++ .../src/linearity_check/errors.rs | 14 ++++++++++ crates/concrete_driver/src/lib.rs | 10 +++++++ 4 files changed, 55 insertions(+) create mode 100644 crates/concrete_check/src/linearity_check/errors.rs diff --git a/crates/concrete_check/src/lib.rs b/crates/concrete_check/src/lib.rs index 25a3d44..11062e9 100644 --- a/crates/concrete_check/src/lib.rs +++ b/crates/concrete_check/src/lib.rs @@ -4,6 +4,9 @@ use ariadne::{ColorGenerator, Label, Report, ReportKind}; use concrete_ir::lowering::errors::LoweringError; use concrete_session::Session; +pub mod linearity_check; + + /// Creates a report from a lowering error. pub fn lowering_error_to_report( error: LoweringError, diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index ff063c5..f789c38 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -1,5 +1,15 @@ use std::collections::HashMap; +use concrete_session::Session; +//pub mod linearityError; + + +use self::errors::LinearityError; +pub mod errors; + + +use concrete_ir::ProgramBody; + #[derive(Debug, Clone, PartialEq, Eq, Hash)] enum VarState { Available, @@ -37,6 +47,7 @@ impl StateTbl { } } +/* // Placeholder function signatures (implementation required) fn check_expr(expr: &str, state_tbl: &mut StateTbl) { // Implementation needed based on OCaml logic @@ -46,3 +57,20 @@ fn count(vars: &[String], state_tbl: &StateTbl) -> usize { // Implementation needed based on OCaml logic 0 } + +*/ + +// Do nothing implementation of linearity check +#[allow(unused_variables)] +pub fn linearity_check_program(program_ir: &ProgramBody, session: &Session) -> Result { + let mut linearity_table = StateTbl::new(); + linearity_table.update_state("x".to_string(), VarState::Available); + linearity_table.update_state("y".to_string(), VarState::Consumed); + linearity_table.update_state("z".to_string(), VarState::Borrowed); + linearity_table.update_state("w".to_string(), VarState::BorrowedMut); + + linearity_table.remove_entry("x"); + let state = linearity_table.get_state("y"); + Ok("OK".to_string()) +} + diff --git a/crates/concrete_check/src/linearity_check/errors.rs b/crates/concrete_check/src/linearity_check/errors.rs new file mode 100644 index 0000000..317d126 --- /dev/null +++ b/crates/concrete_check/src/linearity_check/errors.rs @@ -0,0 +1,14 @@ +use concrete_ir::Span; + +use thiserror::Error; + +#[derive(Debug, Error, Clone)] +pub enum LinearityError { + #[error("Variable {variable} not consumed at module {module:?}")] + LinearNotConsumed { + span: Span, + module: String, + program_id: usize, + variable: String + }, +} \ No newline at end of file diff --git a/crates/concrete_driver/src/lib.rs b/crates/concrete_driver/src/lib.rs index 48284ff..d185da6 100644 --- a/crates/concrete_driver/src/lib.rs +++ b/crates/concrete_driver/src/lib.rs @@ -592,6 +592,16 @@ pub fn compile(args: &CompilerArgs) -> Result { } }; + #[allow(unused_variables)] + let linearity_result = match concrete_check::linearity_check::linearity_check_program(&program_ir, &session) { + Ok(ir) => ir, + Err(error) => { + println!("TODO error message when linearity fails"); + std::process::exit(1); + } + }; + + if args.ir { std::fs::write( session.output_file.with_extension("ir"), From 354840a044c406d325b0074a3c59558cb647c31b Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Fri, 10 May 2024 19:07:54 -0300 Subject: [PATCH 05/31] Connected a stub for linearityCheck --- crates/concrete_check/src/linearity_check.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index f789c38..9304f70 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -1,7 +1,6 @@ use std::collections::HashMap; use concrete_session::Session; -//pub mod linearityError; use self::errors::LinearityError; From 211fd0526e011aebac3033cd0bf6a4dd7a5977d4 Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Fri, 10 May 2024 19:17:45 -0300 Subject: [PATCH 06/31] Included linearChecker first testcases --- crates/concrete_driver/tests/examples.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/crates/concrete_driver/tests/examples.rs b/crates/concrete_driver/tests/examples.rs index 4215fef..1c097cb 100644 --- a/crates/concrete_driver/tests/examples.rs +++ b/crates/concrete_driver/tests/examples.rs @@ -21,6 +21,9 @@ mod common; #[test_case(include_str!("../../../examples/for.con"), "for", false, 10 ; "for.con")] #[test_case(include_str!("../../../examples/for_while.con"), "for_while", false, 10 ; "for_while.con")] #[test_case(include_str!("../../../examples/arrays.con"), "arrays", false, 5 ; "arrays.con")] +#[test_case(include_str!("../../../examples/linearExample01.con"), "linearity", false, 1 ; "linearExample01.con")] +#[test_case(include_str!("../../../examples/linearExample02.con"), "linearity", false, 1 ; "linearExample02.con")] +#[test_case(include_str!("../../../examples/linearExample03if.con"), "linearity", false, 255 ; "linearExample03if.con")] fn example_tests(source: &str, name: &str, is_library: bool, status_code: i32) { assert_eq!( status_code, From a144cf30cde3e34b116b805b746a9aeae0974ab3 Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Fri, 10 May 2024 19:32:51 -0300 Subject: [PATCH 07/31] Updated linearityCheck testcases --- crates/concrete_driver/tests/examples.rs | 6 +++--- examples/linearExample01.con | 4 ++-- examples/linearExample02.con | 4 ++-- examples/linearExample03if.con | 6 +++--- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/crates/concrete_driver/tests/examples.rs b/crates/concrete_driver/tests/examples.rs index 1c097cb..6a4ed1b 100644 --- a/crates/concrete_driver/tests/examples.rs +++ b/crates/concrete_driver/tests/examples.rs @@ -21,9 +21,9 @@ mod common; #[test_case(include_str!("../../../examples/for.con"), "for", false, 10 ; "for.con")] #[test_case(include_str!("../../../examples/for_while.con"), "for_while", false, 10 ; "for_while.con")] #[test_case(include_str!("../../../examples/arrays.con"), "arrays", false, 5 ; "arrays.con")] -#[test_case(include_str!("../../../examples/linearExample01.con"), "linearity", false, 1 ; "linearExample01.con")] -#[test_case(include_str!("../../../examples/linearExample02.con"), "linearity", false, 1 ; "linearExample02.con")] -#[test_case(include_str!("../../../examples/linearExample03if.con"), "linearity", false, 255 ; "linearExample03if.con")] +#[test_case(include_str!("../../../examples/linearExample01.con"), "linearity", false, 2 ; "linearExample01.con")] +#[test_case(include_str!("../../../examples/linearExample02.con"), "linearity", false, 2 ; "linearExample02.con")] +#[test_case(include_str!("../../../examples/linearExample03if.con"), "linearity", false, 0 ; "linearExample03if.con")] fn example_tests(source: &str, name: &str, is_library: bool, status_code: i32) { assert_eq!( status_code, diff --git a/examples/linearExample01.con b/examples/linearExample01.con index 62fce18..3c226e4 100644 --- a/examples/linearExample01.con +++ b/examples/linearExample01.con @@ -8,8 +8,8 @@ mod LinearExampleStub { fn main() -> i32 { let mut xy: Linear = Linear { - x: 0, - y: 1, + x: 1, + y: 0, }; // FIXME Prefered initialization but not yet implemented // [1, 0]; diff --git a/examples/linearExample02.con b/examples/linearExample02.con index f839385..81d6fce 100644 --- a/examples/linearExample02.con +++ b/examples/linearExample02.con @@ -8,8 +8,8 @@ mod LinearExampleStub { fn main() -> i32 { let mut xy: Linear = Linear { - x: 0, - y: 1, + x: 1, + y: 0, }; // FIXME Prefered initialization but not yet implemented // [1, 0]; diff --git a/examples/linearExample03if.con b/examples/linearExample03if.con index a6aed8e..25ba01f 100644 --- a/examples/linearExample03if.con +++ b/examples/linearExample03if.con @@ -8,12 +8,12 @@ mod LinearExampleIfStub { fn main() -> i32 { let mut xy: Linear = Linear { - x: 0, - y: 1, + x: 1, + y: 0, }; // FIXME Prefered initialization but not yet implemented // [1, 0]; - if xy.x > xy.y{ + if xy.x < xy.y{ consume_x(&mut xy, 1); } else { From fa577cab540d1934cbaf2ed6bd333c492eee195b Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Fri, 10 May 2024 19:49:28 -0300 Subject: [PATCH 08/31] Implemented check_var_in_expr --- crates/concrete_check/src/linearity_check.rs | 80 +++++++++++++++++++- 1 file changed, 76 insertions(+), 4 deletions(-) diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index 9304f70..40a7307 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -9,14 +9,42 @@ pub mod errors; use concrete_ir::ProgramBody; +#[derive(Debug, Clone, Copy)] +struct Appearances { + consumed: u32, + write: u32, + read: u32, + path: u32, +} + + #[derive(Debug, Clone, PartialEq, Eq, Hash)] enum VarState { - Available, + Unconsumed, Consumed, Borrowed, BorrowedMut, } +#[derive(Debug, Clone, PartialEq)] +enum CountResult { + Zero, + One, + MoreThanOne, +} + +#[allow(dead_code)] +impl Appearances { + fn partition(count: u32) -> CountResult { + match count { + 0 => CountResult::Zero, + 1 => CountResult::One, + _ => CountResult::MoreThanOne, + } + } +} + + #[derive(Debug, Clone)] struct StateTbl { vars: HashMap, @@ -49,21 +77,65 @@ impl StateTbl { /* // Placeholder function signatures (implementation required) fn check_expr(expr: &str, state_tbl: &mut StateTbl) { - // Implementation needed based on OCaml logic + // Implementation needed } fn count(vars: &[String], state_tbl: &StateTbl) -> usize { - // Implementation needed based on OCaml logic + // Implementation needed 0 } */ + +#[allow(dead_code)] +#[allow(unused_variables)] +fn count(name: &str, expr: &str) -> Appearances { + // TODO implement + Appearances { consumed: 0, write: 0, read: 0, path: 0 } +} + + +#[allow(dead_code)] +#[allow(unreachable_patterns)] +fn check_var_in_expr(state_tbl: &mut StateTbl, depth: u32, name: &str, expr: &str) -> Result { + let apps = count(name, expr); // Assume count function implementation + let Appearances { consumed, write, read, path } = apps; + + let state = state_tbl.get_state(name).unwrap_or(&VarState::Unconsumed); // Assume default state + + match (state, Appearances::partition(consumed), Appearances::partition(write), Appearances::partition(read), Appearances::partition(path)) { + (VarState::Unconsumed, CountResult::Zero, CountResult::Zero, _, _) => Ok(state_tbl.clone()), + (VarState::Unconsumed, CountResult::Zero, CountResult::One, CountResult::Zero, CountResult::Zero) => Ok(state_tbl.clone()), + (VarState::Unconsumed, CountResult::Zero, CountResult::One, _, _) => Err("Error: Borrowed mutably and used".to_string()), + (VarState::Unconsumed, CountResult::Zero, CountResult::MoreThanOne, _, _) => Err("Error: Borrowed mutably more than once".to_string()), + (VarState::Unconsumed, CountResult::One, CountResult::Zero, CountResult::Zero, CountResult::Zero) => consume_once(state_tbl, depth, name), + (VarState::Unconsumed, CountResult::One, _, _, _) => Err("Error: Consumed and something else".to_string()), + (VarState::Unconsumed, CountResult::MoreThanOne, _, _, _) => Err("Error: Consumed more than once".to_string()), + (VarState::Borrowed, CountResult::Zero, CountResult::Zero, CountResult::Zero, _) => Ok(state_tbl.clone()), + (VarState::Borrowed, _, _, _, _) => Err("Error: Read borrowed and something else".to_string()), + (VarState::BorrowedMut, CountResult::Zero, CountResult::Zero, CountResult::Zero, CountResult::Zero) => Ok(state_tbl.clone()), + (VarState::BorrowedMut, _, _, _, _) => Err("Error: Write borrowed and used".to_string()), + (VarState::Consumed, CountResult::Zero, CountResult::Zero, CountResult::Zero, CountResult::Zero) => Ok(state_tbl.clone()), + (VarState::Consumed, _, _, _, _) => Err("Error: Already consumed".to_string()), + _ => Err("Unhandled state or appearance count".to_string()), + } +} + + +#[allow(unused_variables)] +fn consume_once(state_tbl: &mut StateTbl, depth: u32, name: &str) -> Result { + // TODO Implement the logic to consume a variable once, updating the state table and handling depth + state_tbl.update_state(name.to_string(), VarState::Consumed); + Ok(state_tbl.clone()) +} + + // Do nothing implementation of linearity check #[allow(unused_variables)] pub fn linearity_check_program(program_ir: &ProgramBody, session: &Session) -> Result { let mut linearity_table = StateTbl::new(); - linearity_table.update_state("x".to_string(), VarState::Available); + linearity_table.update_state("x".to_string(), VarState::Unconsumed); linearity_table.update_state("y".to_string(), VarState::Consumed); linearity_table.update_state("z".to_string(), VarState::Borrowed); linearity_table.update_state("w".to_string(), VarState::BorrowedMut); From 53a79f7ead48296fd46cf3ac6d8b2459b90b8246 Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Mon, 13 May 2024 11:30:33 -0300 Subject: [PATCH 09/31] LinearityErrors coded --- crates/concrete_check/src/linearity_check.rs | 48 +++++++++-------- .../src/linearity_check/errors.rs | 52 +++++++++++++++---- 2 files changed, 70 insertions(+), 30 deletions(-) diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index 40a7307..fc96ec2 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -59,8 +59,8 @@ impl StateTbl { } // Example of updating the state table - fn update_state(&mut self, var: String, state: VarState) { - self.vars.insert(var, state); + fn update_state(&mut self, var: & str, state: VarState) { + self.vars.insert(var.to_string(), state); } // Remove a variable from the state table @@ -96,37 +96,43 @@ fn count(name: &str, expr: &str) -> Appearances { } + #[allow(dead_code)] #[allow(unreachable_patterns)] -fn check_var_in_expr(state_tbl: &mut StateTbl, depth: u32, name: &str, expr: &str) -> Result { +//fn check_var_in_expr(state_tbl: &mut StateTbl, depth: u32, name: &str, expr: &str) -> Result { +fn check_var_in_expr(state_tbl: &mut StateTbl, _depth: u32, name: &str, expr: &str) -> Result { let apps = count(name, expr); // Assume count function implementation let Appearances { consumed, write, read, path } = apps; let state = state_tbl.get_state(name).unwrap_or(&VarState::Unconsumed); // Assume default state match (state, Appearances::partition(consumed), Appearances::partition(write), Appearances::partition(read), Appearances::partition(path)) { - (VarState::Unconsumed, CountResult::Zero, CountResult::Zero, _, _) => Ok(state_tbl.clone()), + /*( State Consumed WBorrow RBorrow Path ) + (* ------------------|-------------------|-----------------|------------------|----------------)*/ + (VarState::Unconsumed, CountResult::Zero, CountResult::Zero, _, _) => Ok(state_tbl.clone()), (VarState::Unconsumed, CountResult::Zero, CountResult::One, CountResult::Zero, CountResult::Zero) => Ok(state_tbl.clone()), - (VarState::Unconsumed, CountResult::Zero, CountResult::One, _, _) => Err("Error: Borrowed mutably and used".to_string()), - (VarState::Unconsumed, CountResult::Zero, CountResult::MoreThanOne, _, _) => Err("Error: Borrowed mutably more than once".to_string()), - (VarState::Unconsumed, CountResult::One, CountResult::Zero, CountResult::Zero, CountResult::Zero) => consume_once(state_tbl, depth, name), - (VarState::Unconsumed, CountResult::One, _, _, _) => Err("Error: Consumed and something else".to_string()), - (VarState::Unconsumed, CountResult::MoreThanOne, _, _, _) => Err("Error: Consumed more than once".to_string()), - (VarState::Borrowed, CountResult::Zero, CountResult::Zero, CountResult::Zero, _) => Ok(state_tbl.clone()), - (VarState::Borrowed, _, _, _, _) => Err("Error: Read borrowed and something else".to_string()), - (VarState::BorrowedMut, CountResult::Zero, CountResult::Zero, CountResult::Zero, CountResult::Zero) => Ok(state_tbl.clone()), - (VarState::BorrowedMut, _, _, _, _) => Err("Error: Write borrowed and used".to_string()), - (VarState::Consumed, CountResult::Zero, CountResult::Zero, CountResult::Zero, CountResult::Zero) => Ok(state_tbl.clone()), - (VarState::Consumed, _, _, _, _) => Err("Error: Already consumed".to_string()), - _ => Err("Unhandled state or appearance count".to_string()), + (VarState::Unconsumed, CountResult::Zero, CountResult::MoreThanOne, _, _) => + Err(LinearityError::BorrowedMutMoreThanOnce { variable: name.to_string() }), + (VarState::Unconsumed, CountResult::One, _, _, _) => + Err(LinearityError::ConsumedAndUsed { variable: name.to_string() }), + (VarState::Unconsumed, CountResult::MoreThanOne, _, _, _) => + Err(LinearityError::ConsumedMoreThanOnce { variable: name.to_string() }), + (VarState::Borrowed, _, _, _, _) => + Err(LinearityError::ReadBorrowedAndUsed { variable: name.to_string() }), + (VarState::BorrowedMut, _, _, _, _) => + Err(LinearityError::WriteBorrowedAndUsed { variable: name.to_string() }), + (VarState::Consumed, _, _, _, _) => + Err(LinearityError::AlreadyConsumedAndUsed { variable: name.to_string() }), + _ => Err(LinearityError::UnhandledStateOrCount { variable: name.to_string() }), } } +#[allow(dead_code)] #[allow(unused_variables)] fn consume_once(state_tbl: &mut StateTbl, depth: u32, name: &str) -> Result { // TODO Implement the logic to consume a variable once, updating the state table and handling depth - state_tbl.update_state(name.to_string(), VarState::Consumed); + state_tbl.update_state(name, VarState::Consumed); Ok(state_tbl.clone()) } @@ -135,10 +141,10 @@ fn consume_once(state_tbl: &mut StateTbl, depth: u32, name: &str) -> Result Result { let mut linearity_table = StateTbl::new(); - linearity_table.update_state("x".to_string(), VarState::Unconsumed); - linearity_table.update_state("y".to_string(), VarState::Consumed); - linearity_table.update_state("z".to_string(), VarState::Borrowed); - linearity_table.update_state("w".to_string(), VarState::BorrowedMut); + linearity_table.update_state("x", VarState::Unconsumed); + linearity_table.update_state("y", VarState::Consumed); + linearity_table.update_state("z", VarState::Borrowed); + linearity_table.update_state("w", VarState::BorrowedMut); linearity_table.remove_entry("x"); let state = linearity_table.get_state("y"); diff --git a/crates/concrete_check/src/linearity_check/errors.rs b/crates/concrete_check/src/linearity_check/errors.rs index 317d126..e2de2d7 100644 --- a/crates/concrete_check/src/linearity_check/errors.rs +++ b/crates/concrete_check/src/linearity_check/errors.rs @@ -1,14 +1,48 @@ -use concrete_ir::Span; +//use concrete_ir::Span; use thiserror::Error; #[derive(Debug, Error, Clone)] pub enum LinearityError { - #[error("Variable {variable} not consumed at module {module:?}")] - LinearNotConsumed { - span: Span, - module: String, - program_id: usize, - variable: String - }, -} \ No newline at end of file + #[error("Variable {variable} not consumed")] + NotConsumed { + variable: String, + }, + #[error("Borrowed mutably and used for Variable {variable}")] + BorrowedMutUsed { + variable: String, + }, + #[error("Variable {variable} borrowed mutably more than once")] + BorrowedMutMoreThanOnce { + variable: String, + }, + #[error("Variable {variable} consumed once and then used again")] + ConsumedAndUsed { + variable: String, + }, + #[error("Variable {variable} consumed more than once")] + ConsumedMoreThanOnce { + variable: String, + }, + #[error("Variable {variable} read borrowed and used in other ways")] + ReadBorrowedAndUsed { + variable: String, + }, + #[error("Variable {variable} write borrowed and used")] + WriteBorrowedAndUsed { + variable: String, + }, + #[error("Variable {variable} already consumed and used again")] + AlreadyConsumedAndUsed { + variable: String, + }, + #[error("Unhandled state or appearance count for Variable {variable}")] + UnhandledStateOrCount { + variable: String, + }, + #[error("Linearity error. Variable {variable} generated {message}")] + Unspecified { + variable: String, + message: String, + }, +} From f0596a977b40eb23fbe7f9b5066d6198b9ec1322 Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Mon, 13 May 2024 11:48:30 -0300 Subject: [PATCH 10/31] Appareances and count function --- crates/concrete_check/src/linearity_check.rs | 180 +++++++++++++++++-- 1 file changed, 169 insertions(+), 11 deletions(-) diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index fc96ec2..c67e23f 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -9,13 +9,6 @@ pub mod errors; use concrete_ir::ProgramBody; -#[derive(Debug, Clone, Copy)] -struct Appearances { - consumed: u32, - write: u32, - read: u32, - path: u32, -} #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -33,8 +26,57 @@ enum CountResult { MoreThanOne, } +#[derive(Debug, Clone, Copy)] +struct Appearances { + consumed: u32, + write: u32, + read: u32, + path: u32, +} + +#[allow(dead_code)] +enum Expr { + NilConstant, + BoolConstant(bool), + IntConstant(i32), + FloatConstant(f64), + StringConstant(String), + ConstVar, + ParamVar(String), + LocalVar(String), + FunVar, + Funcall(Box, Vec), + MethodCall(Box, Vec), + VarMethodCall(Vec), + FptrCall(Box, Vec), + Cast(Box, String), + Comparison(Box, Box), + Conjunction(Box, Box), + Disjunction(Box, Box), + Negation(Box), + IfExpression(Box, Box, Box), + RecordConstructor(Vec), + UnionConstructor(Vec), + Path { head: Box, elems: Vec }, + Embed(Vec), + Deref(Box), + SizeOf, + BorrowExpr(BorrowMode, String), + ArrayIndex(Box), +} + +#[allow(dead_code)] +enum BorrowMode { + ReadBorrow, + WriteBorrow, +} + #[allow(dead_code)] impl Appearances { + fn new(consumed: u32, write: u32, read: u32, path: u32) -> Self { + Appearances { consumed, write, read, path } + } + fn partition(count: u32) -> CountResult { match count { 0 => CountResult::Zero, @@ -42,6 +84,39 @@ impl Appearances { _ => CountResult::MoreThanOne, } } + + fn zero() -> Self { + Self::new(0, 0, 0, 0) + } + + fn consumed_once() -> Self { + Self::new(1, 0, 0, 0) + } + + fn read_once() -> Self { + Self::new(0, 0, 1, 0) + } + + fn write_once() -> Self { + Self::new(0, 1, 0, 0) + } + + fn path_once() -> Self { + Self::new(0, 0, 0, 1) + } + + fn merge(&self, other: &Appearances) -> Self { + Appearances { + consumed: self.consumed + other.consumed, + write: self.write + other.write, + read: self.read + other.read, + path: self.path + other.path, + } + } + + fn merge_list(appearances: Vec) -> Self { + appearances.into_iter().fold(Self::zero(), |acc, x| acc.merge(&x)) + } } @@ -88,19 +163,102 @@ fn count(vars: &[String], state_tbl: &StateTbl) -> usize { */ + #[allow(dead_code)] #[allow(unused_variables)] -fn count(name: &str, expr: &str) -> Appearances { - // TODO implement - Appearances { consumed: 0, write: 0, read: 0, path: 0 } +/* +fn count(name: &str, expr: &Expr) -> Appearances { + match expr { + Expr::NilConstant | Expr::BoolConstant(_) | Expr::IntConstant(_) | Expr::FloatConstant(_) | Expr::StringConstant(_) | Expr::ConstVar | Expr::FunVar | Expr::SizeOf => + Appearances::zero(), + Expr::ParamVar(var_name) | Expr::LocalVar(var_name) => + if var_name == name { Appearances::consumed_once() } else { Appearances::zero() }, + Expr::Funcall(func, args) | Expr::MethodCall(func, args) | Expr::VarMethodCall(args) | Expr::FptrCall(func, args) | Expr::Embed(args) => + args.iter().map(|arg| count(name, arg)).collect::>().into_iter().fold(Appearances::zero(), |acc, x| acc.merge(&x)), + Expr::Cast(e, _) | Expr::Negation(e) | Expr::Deref(e) => + count(name, e), + Expr::Comparison(lhs, rhs) | Expr::Conjunction(lhs, rhs) | Expr::Disjunction(lhs, rhs) => + count(name, lhs).merge(&count(name, rhs)), + Expr::IfExpression(cond, then_expr, else_expr) => + count(name, cond).merge(&count(name, then_expr)).merge(&count(name, else_expr)), + Expr::RecordConstructor(args) | Expr::UnionConstructor(args) => + args.iter().map(|arg| count(name, arg)).collect::>().into_iter().fold(Appearances::zero(), |acc, x| acc.merge(&x)), + Expr::Path { head, elems } => { + let head_apps = count(name, head); + let elems_apps = elems.iter().map(|elem| count(name, elem)).collect::>().into_iter().fold(Appearances::zero(), |acc, x| acc.merge(&x)); + head_apps.merge(&elems_apps) + }, + Expr::BorrowExpr(mode, var_name) => + if var_name == name { + match mode { + BorrowMode::ReadBorrow => Appearances::read_once(), + BorrowMode::WriteBorrow => Appearances::write_once(), + } + } else { + Appearances::zero() + }, + Expr::ArrayIndex(e) => + count(name, e), + } } +*/ +fn count(name: &str, expr: &Expr) -> Appearances { + match expr { + Expr::NilConstant | Expr::BoolConstant(_) | Expr::IntConstant(_) | Expr::FloatConstant(_) | Expr::StringConstant(_) | Expr::ConstVar | Expr::FunVar | Expr::SizeOf => + Appearances::zero(), + + Expr::ParamVar(var_name) | Expr::LocalVar(var_name) => + if var_name == name { Appearances::consumed_once() } else { Appearances::zero() }, + + Expr::Funcall(func, args) | Expr::MethodCall(func, args) | Expr::FptrCall(func, args) => + args.iter().map(|arg| count(name, arg)).fold(Appearances::zero(), |acc, x| acc.merge(&x)), + + Expr::VarMethodCall(args) | Expr::Embed(args) => + args.iter().map(|arg| count(name, arg)).fold(Appearances::zero(), |acc, x| acc.merge(&x)), + + Expr::Cast(e, _) | Expr::Negation(e) | Expr::Deref(e) => + count(name, e), + + Expr::Comparison(lhs, rhs) | Expr::Conjunction(lhs, rhs) | Expr::Disjunction(lhs, rhs) => + count(name, lhs).merge(&count(name, rhs)), + + Expr::IfExpression(cond, then_expr, else_expr) => + count(name, cond).merge(&count(name, then_expr)).merge(&count(name, else_expr)), + + Expr::RecordConstructor(args) | Expr::UnionConstructor(args) => + args.iter().map(|arg| count(name, arg)).fold(Appearances::zero(), |acc, x| acc.merge(&x)), + + Expr::Path { head, elems } => { + let head_apps = count(name, head); + let elems_apps = elems.iter().map(|elem| count(name, elem)).fold(Appearances::zero(), |acc, x| acc.merge(&x)); + head_apps.merge(&elems_apps) + }, + + Expr::BorrowExpr(mode, var_name) => + if var_name == name { + match mode { + BorrowMode::ReadBorrow => Appearances::read_once(), + BorrowMode::WriteBorrow => Appearances::write_once(), + } + } else { + Appearances::zero() + }, + + Expr::ArrayIndex(e) => + count(name, e), + } +} + + + + #[allow(dead_code)] #[allow(unreachable_patterns)] //fn check_var_in_expr(state_tbl: &mut StateTbl, depth: u32, name: &str, expr: &str) -> Result { -fn check_var_in_expr(state_tbl: &mut StateTbl, _depth: u32, name: &str, expr: &str) -> Result { +fn check_var_in_expr(state_tbl: &mut StateTbl, _depth: u32, name: &str, expr: &Expr) -> Result { let apps = count(name, expr); // Assume count function implementation let Appearances { consumed, write, read, path } = apps; From d26836a383d5d5775764efed9880524c634c300b Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Mon, 13 May 2024 13:33:48 -0300 Subject: [PATCH 11/31] struct LinearityChecker --- crates/concrete_check/Cargo.toml | 4 + crates/concrete_check/src/linearity_check.rs | 166 ++++++++++--------- 2 files changed, 90 insertions(+), 80 deletions(-) diff --git a/crates/concrete_check/Cargo.toml b/crates/concrete_check/Cargo.toml index 8affcac..bcc5651 100644 --- a/crates/concrete_check/Cargo.toml +++ b/crates/concrete_check/Cargo.toml @@ -12,3 +12,7 @@ concrete_ir = { version = "0.1.0", path = "../concrete_ir" } concrete_session = { version = "0.1.0", path = "../concrete_session" } itertools = "0.12.0" thiserror = "1.0.56" + + +[features] +linearity = [] \ No newline at end of file diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index c67e23f..cb96c02 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -7,9 +7,8 @@ use self::errors::LinearityError; pub mod errors; -use concrete_ir::ProgramBody; - - +use concrete_ir::{ProgramBody, FnBody, Statement}; +use concrete_ast::expressions::Expression; #[derive(Debug, Clone, PartialEq, Eq, Hash)] enum VarState { @@ -149,59 +148,11 @@ impl StateTbl { } } -/* -// Placeholder function signatures (implementation required) -fn check_expr(expr: &str, state_tbl: &mut StateTbl) { - // Implementation needed -} - -fn count(vars: &[String], state_tbl: &StateTbl) -> usize { - // Implementation needed - 0 -} - -*/ #[allow(dead_code)] #[allow(unused_variables)] -/* -fn count(name: &str, expr: &Expr) -> Appearances { - match expr { - Expr::NilConstant | Expr::BoolConstant(_) | Expr::IntConstant(_) | Expr::FloatConstant(_) | Expr::StringConstant(_) | Expr::ConstVar | Expr::FunVar | Expr::SizeOf => - Appearances::zero(), - Expr::ParamVar(var_name) | Expr::LocalVar(var_name) => - if var_name == name { Appearances::consumed_once() } else { Appearances::zero() }, - Expr::Funcall(func, args) | Expr::MethodCall(func, args) | Expr::VarMethodCall(args) | Expr::FptrCall(func, args) | Expr::Embed(args) => - args.iter().map(|arg| count(name, arg)).collect::>().into_iter().fold(Appearances::zero(), |acc, x| acc.merge(&x)), - Expr::Cast(e, _) | Expr::Negation(e) | Expr::Deref(e) => - count(name, e), - Expr::Comparison(lhs, rhs) | Expr::Conjunction(lhs, rhs) | Expr::Disjunction(lhs, rhs) => - count(name, lhs).merge(&count(name, rhs)), - Expr::IfExpression(cond, then_expr, else_expr) => - count(name, cond).merge(&count(name, then_expr)).merge(&count(name, else_expr)), - Expr::RecordConstructor(args) | Expr::UnionConstructor(args) => - args.iter().map(|arg| count(name, arg)).collect::>().into_iter().fold(Appearances::zero(), |acc, x| acc.merge(&x)), - Expr::Path { head, elems } => { - let head_apps = count(name, head); - let elems_apps = elems.iter().map(|elem| count(name, elem)).collect::>().into_iter().fold(Appearances::zero(), |acc, x| acc.merge(&x)); - head_apps.merge(&elems_apps) - }, - Expr::BorrowExpr(mode, var_name) => - if var_name == name { - match mode { - BorrowMode::ReadBorrow => Appearances::read_once(), - BorrowMode::WriteBorrow => Appearances::write_once(), - } - } else { - Appearances::zero() - }, - Expr::ArrayIndex(e) => - count(name, e), - } -} -*/ fn count(name: &str, expr: &Expr) -> Appearances { match expr { Expr::NilConstant | Expr::BoolConstant(_) | Expr::IntConstant(_) | Expr::FloatConstant(_) | Expr::StringConstant(_) | Expr::ConstVar | Expr::FunVar | Expr::SizeOf => @@ -255,35 +206,6 @@ fn count(name: &str, expr: &Expr) -> Appearances { -#[allow(dead_code)] -#[allow(unreachable_patterns)] -//fn check_var_in_expr(state_tbl: &mut StateTbl, depth: u32, name: &str, expr: &str) -> Result { -fn check_var_in_expr(state_tbl: &mut StateTbl, _depth: u32, name: &str, expr: &Expr) -> Result { - let apps = count(name, expr); // Assume count function implementation - let Appearances { consumed, write, read, path } = apps; - - let state = state_tbl.get_state(name).unwrap_or(&VarState::Unconsumed); // Assume default state - - match (state, Appearances::partition(consumed), Appearances::partition(write), Appearances::partition(read), Appearances::partition(path)) { - /*( State Consumed WBorrow RBorrow Path ) - (* ------------------|-------------------|-----------------|------------------|----------------)*/ - (VarState::Unconsumed, CountResult::Zero, CountResult::Zero, _, _) => Ok(state_tbl.clone()), - (VarState::Unconsumed, CountResult::Zero, CountResult::One, CountResult::Zero, CountResult::Zero) => Ok(state_tbl.clone()), - (VarState::Unconsumed, CountResult::Zero, CountResult::MoreThanOne, _, _) => - Err(LinearityError::BorrowedMutMoreThanOnce { variable: name.to_string() }), - (VarState::Unconsumed, CountResult::One, _, _, _) => - Err(LinearityError::ConsumedAndUsed { variable: name.to_string() }), - (VarState::Unconsumed, CountResult::MoreThanOne, _, _, _) => - Err(LinearityError::ConsumedMoreThanOnce { variable: name.to_string() }), - (VarState::Borrowed, _, _, _, _) => - Err(LinearityError::ReadBorrowedAndUsed { variable: name.to_string() }), - (VarState::BorrowedMut, _, _, _, _) => - Err(LinearityError::WriteBorrowedAndUsed { variable: name.to_string() }), - (VarState::Consumed, _, _, _, _) => - Err(LinearityError::AlreadyConsumedAndUsed { variable: name.to_string() }), - _ => Err(LinearityError::UnhandledStateOrCount { variable: name.to_string() }), - } -} #[allow(dead_code)] @@ -295,7 +217,91 @@ fn consume_once(state_tbl: &mut StateTbl, depth: u32, name: &str) -> Result Self { + LinearityChecker { + state_tbl: StateTbl::new(), + } + } + + fn linearity_check(&mut self, program: &ProgramBody) -> Result<(), LinearityError> { + // Assume Program is a struct that represents the entire program. + for function in &program.functions { + self.check_function(&function.1)?; + } + Ok(()) + } + + fn check_function(&mut self, function: &FnBody) -> Result<(), LinearityError> { + // Logic to check linearity within a function + // This may involve iterating over statements and expressions, similar to OCaml's recursion. + for basic_block in &function.basic_blocks{ + for statement in &basic_block.statements { + self.check_statement(&statement)?; + } + + } + Ok(()) + } + + fn check_statement(&mut self, statement: &Statement) -> Result<(), LinearityError> { + // TODO here we have to decide a unique Expression enum like declared above (translated from OCAML) for treating code + // + /* + match statement { + Statement::Expression(expr) => self.check_expr(expr), + statement. + // Add more statement types as needed + }*/ + Ok(()) + } + + fn check_expr(&self, expr: &Expression) -> Result<(), LinearityError> { + // Expression checking logic here + Ok(()) + } + + fn check_var_in_expr(&mut self, _depth: u32, name: &str, expr: &Expr) -> Result<(), LinearityError> { + let apps = count(name, expr); // Assume count function implementation + let Appearances { consumed, write, read, path } = apps; + + let state = self.state_tbl.get_state(name).unwrap_or(&VarState::Unconsumed); // Assume default state + + match (state, Appearances::partition(consumed), Appearances::partition(write), Appearances::partition(read), Appearances::partition(path)) { + /*( State Consumed WBorrow RBorrow Path ) + (* ------------------|-------------------|-----------------|------------------|----------------)*/ + //(VarState::Unconsumed, CountResult::Zero, CountResult::Zero, _, _) => Ok(state_tbl.clone()), + //(VarState::Unconsumed, CountResult::Zero, CountResult::One, CountResult::Zero, CountResult::Zero) => Ok(state_tbl.clone()), + (VarState::Unconsumed, CountResult::Zero, CountResult::Zero, _, _) => Ok(()), + (VarState::Unconsumed, CountResult::Zero, CountResult::One, CountResult::Zero, CountResult::Zero) => Ok(()), + (VarState::Unconsumed, CountResult::Zero, CountResult::MoreThanOne, _, _) => + Err(LinearityError::BorrowedMutMoreThanOnce { variable: name.to_string() }), + (VarState::Unconsumed, CountResult::One, _, _, _) => + Err(LinearityError::ConsumedAndUsed { variable: name.to_string() }), + (VarState::Unconsumed, CountResult::MoreThanOne, _, _, _) => + Err(LinearityError::ConsumedMoreThanOnce { variable: name.to_string() }), + (VarState::Borrowed, _, _, _, _) => + Err(LinearityError::ReadBorrowedAndUsed { variable: name.to_string() }), + (VarState::BorrowedMut, _, _, _, _) => + Err(LinearityError::WriteBorrowedAndUsed { variable: name.to_string() }), + (VarState::Consumed, _, _, _, _) => + Err(LinearityError::AlreadyConsumedAndUsed { variable: name.to_string() }), + _ => Err(LinearityError::UnhandledStateOrCount { variable: name.to_string() }), + } + } + +} + + + // Do nothing implementation of linearity check +//#[cfg(feature = "linearity")] #[allow(unused_variables)] pub fn linearity_check_program(program_ir: &ProgramBody, session: &Session) -> Result { let mut linearity_table = StateTbl::new(); From 8687bd06c368404e40fb8fee4e3ba3e7d891e134 Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Mon, 13 May 2024 21:05:37 -0300 Subject: [PATCH 12/31] Huge refactoring in progress. Does not build yet --- crates/concrete_check/src/linearity_check.rs | 351 ++++++++++++++---- .../src/linearity_check/errors.rs | 4 + 2 files changed, 285 insertions(+), 70 deletions(-) diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index cb96c02..d5b6010 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -7,8 +7,12 @@ use self::errors::LinearityError; pub mod errors; -use concrete_ir::{ProgramBody, FnBody, Statement}; + +//use concrete_ir::{ProgramBody, FnBody}; +//use concrete_ast::Program{ ProgramBody, FnBody }; +use concrete_ast::functions::FunctionDef; use concrete_ast::expressions::Expression; +use concrete_ast::statements::{Statement, AssignStmt, LetStmt, WhileStmt, ForStmt, LetStmtTarget}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] enum VarState { @@ -151,54 +155,6 @@ impl StateTbl { -#[allow(dead_code)] -#[allow(unused_variables)] -fn count(name: &str, expr: &Expr) -> Appearances { - match expr { - Expr::NilConstant | Expr::BoolConstant(_) | Expr::IntConstant(_) | Expr::FloatConstant(_) | Expr::StringConstant(_) | Expr::ConstVar | Expr::FunVar | Expr::SizeOf => - Appearances::zero(), - - Expr::ParamVar(var_name) | Expr::LocalVar(var_name) => - if var_name == name { Appearances::consumed_once() } else { Appearances::zero() }, - - Expr::Funcall(func, args) | Expr::MethodCall(func, args) | Expr::FptrCall(func, args) => - args.iter().map(|arg| count(name, arg)).fold(Appearances::zero(), |acc, x| acc.merge(&x)), - - Expr::VarMethodCall(args) | Expr::Embed(args) => - args.iter().map(|arg| count(name, arg)).fold(Appearances::zero(), |acc, x| acc.merge(&x)), - - Expr::Cast(e, _) | Expr::Negation(e) | Expr::Deref(e) => - count(name, e), - - Expr::Comparison(lhs, rhs) | Expr::Conjunction(lhs, rhs) | Expr::Disjunction(lhs, rhs) => - count(name, lhs).merge(&count(name, rhs)), - - Expr::IfExpression(cond, then_expr, else_expr) => - count(name, cond).merge(&count(name, then_expr)).merge(&count(name, else_expr)), - - Expr::RecordConstructor(args) | Expr::UnionConstructor(args) => - args.iter().map(|arg| count(name, arg)).fold(Appearances::zero(), |acc, x| acc.merge(&x)), - - Expr::Path { head, elems } => { - let head_apps = count(name, head); - let elems_apps = elems.iter().map(|elem| count(name, elem)).fold(Appearances::zero(), |acc, x| acc.merge(&x)); - head_apps.merge(&elems_apps) - }, - - Expr::BorrowExpr(mode, var_name) => - if var_name == name { - match mode { - BorrowMode::ReadBorrow => Appearances::read_once(), - BorrowMode::WriteBorrow => Appearances::write_once(), - } - } else { - Appearances::zero() - }, - - Expr::ArrayIndex(e) => - count(name, e), - } -} @@ -230,45 +186,296 @@ impl LinearityChecker { } } - fn linearity_check(&mut self, program: &ProgramBody) -> Result<(), LinearityError> { + fn linearity_check(&mut self, program: &FunctionDef) -> Result<(), LinearityError> { // Assume Program is a struct that represents the entire program. - for function in &program.functions { - self.check_function(&function.1)?; + for statement in &program.body { + self.check_stmt(0, &statement); } Ok(()) } - + /* fn check_function(&mut self, function: &FnBody) -> Result<(), LinearityError> { // Logic to check linearity within a function // This may involve iterating over statements and expressions, similar to OCaml's recursion. for basic_block in &function.basic_blocks{ for statement in &basic_block.statements { - self.check_statement(&statement)?; + self.check_stmt(0, &statement)?; } } Ok(()) + }*/ + + fn check_expr(&mut self, depth: u32, expr: &Expression) -> Result<(), LinearityError> { + // Assuming you have a method to get all variable names and types + //let vars = &mut self.state_tbl.vars; + let vars = self.state_tbl.vars.clone(); + for (name, ty) in vars.iter() { + self.check_var_in_expr(depth, &name, &ty, expr)?; + } + Ok(()) } - fn check_statement(&mut self, statement: &Statement) -> Result<(), LinearityError> { - // TODO here we have to decide a unique Expression enum like declared above (translated from OCAML) for treating code - // - /* + #[allow(dead_code)] + #[allow(unused_variables)] + /* + fn count(name: &str, expr: &Expr) -> Appearances { + match expr { + Expr::NilConstant | Expr::BoolConstant(_) | Expr::IntConstant(_) | Expr::FloatConstant(_) | Expr::StringConstant(_) | Expr::ConstVar | Expr::FunVar | Expr::SizeOf => + Appearances::zero(), + + Expr::ParamVar(var_name) | Expr::LocalVar(var_name) => + if var_name == name { Appearances::consumed_once() } else { Appearances::zero() }, + + Expr::Funcall(func, args) | Expr::MethodCall(func, args) | Expr::FptrCall(func, args) => + args.iter().map(|arg| count(name, arg)).fold(Appearances::zero(), |acc, x| acc.merge(&x)), + + Expr::VarMethodCall(args) | Expr::Embed(args) => + args.iter().map(|arg| count(name, arg)).fold(Appearances::zero(), |acc, x| acc.merge(&x)), + + Expr::Cast(e, _) | Expr::Negation(e) | Expr::Deref(e) => + count(name, e), + + Expr::Comparison(lhs, rhs) | Expr::Conjunction(lhs, rhs) | Expr::Disjunction(lhs, rhs) => + count(name, lhs).merge(&count(name, rhs)), + + Expr::IfExpression(cond, then_expr, else_expr) => + count(name, cond).merge(&count(name, then_expr)).merge(&count(name, else_expr)), + + Expr::RecordConstructor(args) | Expr::UnionConstructor(args) => + args.iter().map(|arg| count(name, arg)).fold(Appearances::zero(), |acc, x| acc.merge(&x)), + + Expr::Path { head, elems } => { + let head_apps = count(name, head); + let elems_apps = elems.iter().map(|elem| count(name, elem)).fold(Appearances::zero(), |acc, x| acc.merge(&x)); + head_apps.merge(&elems_apps) + }, + + Expr::BorrowExpr(mode, var_name) => + if var_name == name { + match mode { + BorrowMode::ReadBorrow => Appearances::read_once(), + BorrowMode::WriteBorrow => Appearances::write_once(), + } + } else { + Appearances::zero() + }, + + Expr::ArrayIndex(e) => + count(name, e), + } + } + */ + + fn countInStatements(&self, name: &str, statements: &Vec) -> Appearances { + statements.iter().map(|stmt| self.countInStatement(name, stmt)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) + } + + fn countInStatement(&self, name: &str, statement: &Statement) -> Appearances { match statement { - Statement::Expression(expr) => self.check_expr(expr), - statement. - // Add more statement types as needed - }*/ + Statement::Let(binding) => { + // Handle let bindings, possibly involving pattern matching + self.countInExpression(name, &binding.value) + }, + Statement::If(if_stmt) => { + // Process all components of an if expression + let cond_apps = self.countInExpression(name, &if_stmt.value); + //let then_apps = self.countInStatement(name, &if_stmt.contents); + let then_apps = self.countInStatements(name, &if_stmt.contents); + let else_apps; + let else_statements = &if_stmt.r#else; + if let Some(else_statements) = else_statements { + else_apps = self.countInStatements(name, &else_statements); + } else { + else_apps = Appearances::zero(); + } + cond_apps.merge(&then_apps).merge(&else_apps) + }, + Statement::While(while_expr) => { + let cond= &while_expr.value; + let block = &while_expr.contents; + // Handle while loops + self.countInExpression(name, cond).merge(&&self.countInStatements(name, block)) + }, + Statement::For(for_expr) => { + // Handle for loops + //init, cond, post, block + let init = &for_expr.init; + let cond = &for_expr.condition; + let post = &for_expr.post; + let block = &for_expr.contents; + let mut apps = Appearances::zero(); + if let Some(init) = init{ + if let Some(cond) = cond{ + if let Some(post) = post{ + apps = self.countInLetStatement(name, init).merge(&self.countInExpression(name, cond)).merge(&self.countInAssignStatement(name, post)).merge(&self.countInStatements(name, block)) + } + } + } + apps + }, + /* + Statement::Block(statements) => { + // Handle blocks of statements + //statements.iter().map(|stmt| self.count(name, stmt)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) + self.countInStatements(name, statements) + },*/ + _ => Appearances::zero(), + } + } + + fn countInAssignStatement(&self, name: &str, assign_stmt: &AssignStmt) -> Appearances { + match assign_stmt { + AssignStmt { target, derefs, value, span } => { + // Handle assignments + self.countInExpression(name, value) + }, + } + } + + fn countInLetStatement(&self, name: &str, let_stmt: &LetStmt) -> Appearances { + match let_stmt { + LetStmt { is_mutable, target, value, span } => { + // Handle let bindings, possibly involving pattern matching + self.countInExpression(name, value) + }, + } + } + + fn countInExpression(&self, name: &str, expr: &Expression) -> Appearances { + match expr { + Expression::Value(value_expr, _) => { + // Handle value expressions, typically constant or simple values + Appearances::zero() + }, + Expression::FnCall(fn_call_op) => { + // Process function call arguments + fn_call_op.args.iter().map(|arg| self.countInExpression(name, arg)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) + }, + Expression::Match(match_expr) => todo!(), + /* + Expression::Match(match_expr) => { + // Handle match arms + match_expr.variants.iter().map(|(_, expr)| self.count(name, expr)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) + },*/ + Expression::If(if_expr) => { + // Process all components of an if expression + let cond_apps = self.countInExpression(name, &if_expr.value); + let then_apps = self.countInStatements(name, &if_expr.contents); + //let else_apps = if_expr.else.as_ref().map(|e| self.count(name, e)).unwrap_or_default(); + let else_apps = if_expr.r#else.map(|e| self.countInStatement(name, e)).unwrap_or_default(); + cond_apps.merge(&then_apps).merge(&else_apps) + }, + Expression::UnaryOp(_, expr) => { + // Unary operations likely don't change the count but process the inner expression + self.countInExpression(name, expr) + }, + Expression::BinaryOp(left, _, right) => { + // Handle binary operations by processing both sides + self.countInExpression(name, left).merge(&&self.countInExpression(name, right)) + }, + Expression::StructInit(_) => todo!(), + /* + Expression::StructInit(struct_init_expr) => { + // Handle struct initialization + struct_init_expr.fields.iter().map(|(_, expr)| self.count(name, expr)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) + },*/ + Expression::ArrayInit(array_init_expr) => { + // Handle array initializations + array_init_expr.values.iter().map(|expr| self.countInExpression(name, expr)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) + }, + Expression::Deref(expr, _) | Expression::AsRef(expr, _, _) | Expression::Cast(expr, _, _) => { + // Deref, AsRef, and Cast are handled by just checking the inner expression + self.countInExpression(name, expr) + }, + // Add more cases as necessary based on the Expression types you expect + } + } + + + fn check_bindings(&mut self, depth: u32, binding: &LetStmt) -> Result<(), LinearityError> { + // Handle let bindings, possibly involving pattern matching + let LetStmt { is_mutable, target, value, span } = binding; + match target { + LetStmtTarget::Simple { name, r#type } => { + self.check_var_in_expr(depth, name, &VarState::Unconsumed, value) + }, + LetStmtTarget::Destructure(bindings) => { + for binding in bindings { + self.check_bindings(depth, binding)?; + } + Ok(()) + }, + } + } + + fn check_stmts(&mut self, depth: u32, stmts: &Vec) -> Result<(), LinearityError> { + for stmt in stmts { + self.check_stmt(depth, stmt)?; + } Ok(()) } - fn check_expr(&self, expr: &Expression) -> Result<(), LinearityError> { + fn check_stmt(&mut self, depth: u32, stmt: &Statement) -> Result<(), LinearityError> { + match stmt { + /* + Statement::Expression(expr) => { + // Handle expressions (e.g., variable assignments, function calls) + self.check_expr(depth, expr) + },*/ + Statement::Let(binding) => { + // Handle let bindings, possibly involving pattern matching + self.check_bindings(depth, binding) + }, + //Statement::If(cond, then_block, else_block) => { + Statement::If(if_stmt) => { + // Handle conditional statements + self.check_expr(depth, &if_stmt.value)?; + self.check_stmts(depth + 1, &if_stmt.contents)?; + if let Some(else_block) = if_stmt.else_block { + self.check_stmt(depth + 1, else_block)?; + } + Ok(()) + }, + //Statement::While(cond, block) => { + Statement::While(while_stmt) => { + // Handle while loops + self.check_expr(depth, &while_stmt.value)?; + self.check_stmts(depth + 1, &while_stmt.contents) + }, + //Statement::For(init, cond, post, block) => { + Statement::For(for_stmt) => { + // Handle for loops + if let Some(init) = for_stmt.init { + self.check_stmt(depth, init)?; + } + self.check_stmt(depth, &for_stmt.init)?; + self.check_expr(depth, &for_stmt.condition)?; + if let Some(post) = &for_stmt.post { + self.check_stmt(depth, post)?; + } + self.check_stmt(depth + 1, &for_stmt.block) + }, + Statement::Block(statements) => { + // Handle blocks of statements + for statement in statements { + self.check_stmt(depth + 1, statement)?; + } + Ok(()) + }, + _ => Err(LinearityError::UnhandledStatementType { r#type: format!("{:?}", stmt) }), + } + } + + /* + fn check_expr(&self, depth: usize, expr: &Expression) -> Result<(), LinearityError> { // Expression checking logic here Ok(()) - } + }*/ - fn check_var_in_expr(&mut self, _depth: u32, name: &str, expr: &Expr) -> Result<(), LinearityError> { - let apps = count(name, expr); // Assume count function implementation + //fn check_var_in_expr(&mut self, depth: u32, name: &str, ty: &VarState, expr: &Expr) -> Result<(), LinearityError> { + fn check_var_in_expr(&mut self, depth: u32, name: &str, ty: &VarState, expr: &Expression) -> Result<(), LinearityError> { + let apps = self.countInExpression(name, expr); // Assume count function implementation let Appearances { consumed, write, read, path } = apps; let state = self.state_tbl.get_state(name).unwrap_or(&VarState::Unconsumed); // Assume default state @@ -278,11 +485,11 @@ impl LinearityChecker { (* ------------------|-------------------|-----------------|------------------|----------------)*/ //(VarState::Unconsumed, CountResult::Zero, CountResult::Zero, _, _) => Ok(state_tbl.clone()), //(VarState::Unconsumed, CountResult::Zero, CountResult::One, CountResult::Zero, CountResult::Zero) => Ok(state_tbl.clone()), - (VarState::Unconsumed, CountResult::Zero, CountResult::Zero, _, _) => Ok(()), + (VarState::Unconsumed, CountResult::Zero, CountResult::Zero, _, _) => Ok(()), (VarState::Unconsumed, CountResult::Zero, CountResult::One, CountResult::Zero, CountResult::Zero) => Ok(()), - (VarState::Unconsumed, CountResult::Zero, CountResult::MoreThanOne, _, _) => + (VarState::Unconsumed, CountResult::Zero, CountResult::MoreThanOne, _, _) => Err(LinearityError::BorrowedMutMoreThanOnce { variable: name.to_string() }), - (VarState::Unconsumed, CountResult::One, _, _, _) => + (VarState::Unconsumed, CountResult::One, _, _, _) => Err(LinearityError::ConsumedAndUsed { variable: name.to_string() }), (VarState::Unconsumed, CountResult::MoreThanOne, _, _, _) => Err(LinearityError::ConsumedMoreThanOnce { variable: name.to_string() }), @@ -303,7 +510,8 @@ impl LinearityChecker { // Do nothing implementation of linearity check //#[cfg(feature = "linearity")] #[allow(unused_variables)] -pub fn linearity_check_program(program_ir: &ProgramBody, session: &Session) -> Result { +pub fn linearity_check_program(program_ir: &FunctionDef, session: &Session) -> Result { + /* let mut linearity_table = StateTbl::new(); linearity_table.update_state("x", VarState::Unconsumed); linearity_table.update_state("y", VarState::Consumed); @@ -312,6 +520,9 @@ pub fn linearity_check_program(program_ir: &ProgramBody, session: &Session) -> linearity_table.remove_entry("x"); let state = linearity_table.get_state("y"); + */ + let mut checker = LinearityChecker::new(); + checker.linearity_check(program_ir)?; Ok("OK".to_string()) } diff --git a/crates/concrete_check/src/linearity_check/errors.rs b/crates/concrete_check/src/linearity_check/errors.rs index e2de2d7..4d261c1 100644 --- a/crates/concrete_check/src/linearity_check/errors.rs +++ b/crates/concrete_check/src/linearity_check/errors.rs @@ -45,4 +45,8 @@ pub enum LinearityError { variable: String, message: String, }, + #[error("Unhandled statement type {r#type}")] + UnhandledStatementType{ + r#type: String, + }, } From 9c7f450c92db8e78745b33302c254875e79ebb01 Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Wed, 15 May 2024 17:03:59 -0300 Subject: [PATCH 13/31] linearity_check building with ast --- crates/concrete_check/src/linearity_check.rs | 129 ++++++++++++------- crates/concrete_driver/src/lib.rs | 4 +- 2 files changed, 87 insertions(+), 46 deletions(-) diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index d5b6010..2696b37 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -7,12 +7,14 @@ use self::errors::LinearityError; pub mod errors; - +use std::path::PathBuf; //use concrete_ir::{ProgramBody, FnBody}; //use concrete_ast::Program{ ProgramBody, FnBody }; -use concrete_ast::functions::FunctionDef; +use concrete_ast::Program; +use concrete_ast::modules::ModuleDefItem; +//use concrete_ast::functions::FunctionDef; use concrete_ast::expressions::Expression; -use concrete_ast::statements::{Statement, AssignStmt, LetStmt, WhileStmt, ForStmt, LetStmtTarget}; +use concrete_ast::statements::{Statement, AssignStmt, LetStmt, WhileStmt, ForStmt, LetStmtTarget, Binding}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] enum VarState { @@ -186,6 +188,8 @@ impl LinearityChecker { } } + //TODO remove + /* fn linearity_check(&mut self, program: &FunctionDef) -> Result<(), LinearityError> { // Assume Program is a struct that represents the entire program. for statement in &program.body { @@ -193,6 +197,7 @@ impl LinearityChecker { } Ok(()) } + */ /* fn check_function(&mut self, function: &FnBody) -> Result<(), LinearityError> { // Logic to check linearity within a function @@ -267,25 +272,25 @@ impl LinearityChecker { } */ - fn countInStatements(&self, name: &str, statements: &Vec) -> Appearances { - statements.iter().map(|stmt| self.countInStatement(name, stmt)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) + fn count_in_statements(&self, name: &str, statements: &Vec) -> Appearances { + statements.iter().map(|stmt| self.count_in_statement(name, stmt)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) } - fn countInStatement(&self, name: &str, statement: &Statement) -> Appearances { + fn count_in_statement(&self, name: &str, statement: &Statement) -> Appearances { match statement { Statement::Let(binding) => { // Handle let bindings, possibly involving pattern matching - self.countInExpression(name, &binding.value) + self.count_in_expression(name, &binding.value) }, Statement::If(if_stmt) => { // Process all components of an if expression - let cond_apps = self.countInExpression(name, &if_stmt.value); - //let then_apps = self.countInStatement(name, &if_stmt.contents); - let then_apps = self.countInStatements(name, &if_stmt.contents); + let cond_apps = self.count_in_expression(name, &if_stmt.value); + //let then_apps = self.count_in_statement(name, &if_stmt.contents); + let then_apps = self.count_in_statements(name, &if_stmt.contents); let else_apps; let else_statements = &if_stmt.r#else; if let Some(else_statements) = else_statements { - else_apps = self.countInStatements(name, &else_statements); + else_apps = self.count_in_statements(name, &else_statements); } else { else_apps = Appearances::zero(); } @@ -295,7 +300,7 @@ impl LinearityChecker { let cond= &while_expr.value; let block = &while_expr.contents; // Handle while loops - self.countInExpression(name, cond).merge(&&self.countInStatements(name, block)) + self.count_in_expression(name, cond).merge(&&self.count_in_statements(name, block)) }, Statement::For(for_expr) => { // Handle for loops @@ -308,7 +313,7 @@ impl LinearityChecker { if let Some(init) = init{ if let Some(cond) = cond{ if let Some(post) = post{ - apps = self.countInLetStatement(name, init).merge(&self.countInExpression(name, cond)).merge(&self.countInAssignStatement(name, post)).merge(&self.countInStatements(name, block)) + apps = self.count_in_let_statements(name, init).merge(&self.count_in_expression(name, cond)).merge(&self.count_in_assign_statement(name, post)).merge(&self.count_in_statements(name, block)) } } } @@ -318,31 +323,31 @@ impl LinearityChecker { Statement::Block(statements) => { // Handle blocks of statements //statements.iter().map(|stmt| self.count(name, stmt)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) - self.countInStatements(name, statements) + self.count_in_statements(name, statements) },*/ _ => Appearances::zero(), } } - fn countInAssignStatement(&self, name: &str, assign_stmt: &AssignStmt) -> Appearances { + fn count_in_assign_statement(&self, name: &str, assign_stmt: &AssignStmt) -> Appearances { match assign_stmt { AssignStmt { target, derefs, value, span } => { // Handle assignments - self.countInExpression(name, value) + self.count_in_expression(name, value) }, } } - fn countInLetStatement(&self, name: &str, let_stmt: &LetStmt) -> Appearances { + fn count_in_let_statements(&self, name: &str, let_stmt: &LetStmt) -> Appearances { match let_stmt { LetStmt { is_mutable, target, value, span } => { // Handle let bindings, possibly involving pattern matching - self.countInExpression(name, value) + self.count_in_expression(name, value) }, } } - fn countInExpression(&self, name: &str, expr: &Expression) -> Appearances { + fn count_in_expression(&self, name: &str, expr: &Expression) -> Appearances { match expr { Expression::Value(value_expr, _) => { // Handle value expressions, typically constant or simple values @@ -350,7 +355,7 @@ impl LinearityChecker { }, Expression::FnCall(fn_call_op) => { // Process function call arguments - fn_call_op.args.iter().map(|arg| self.countInExpression(name, arg)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) + fn_call_op.args.iter().map(|arg| self.count_in_expression(name, arg)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) }, Expression::Match(match_expr) => todo!(), /* @@ -360,19 +365,23 @@ impl LinearityChecker { },*/ Expression::If(if_expr) => { // Process all components of an if expression - let cond_apps = self.countInExpression(name, &if_expr.value); - let then_apps = self.countInStatements(name, &if_expr.contents); + let cond_apps = self.count_in_expression(name, &if_expr.value); + let then_apps = self.count_in_statements(name, &if_expr.contents); //let else_apps = if_expr.else.as_ref().map(|e| self.count(name, e)).unwrap_or_default(); - let else_apps = if_expr.r#else.map(|e| self.countInStatement(name, e)).unwrap_or_default(); - cond_apps.merge(&then_apps).merge(&else_apps) + cond_apps.merge(&then_apps); + if let Some(else_block) = &if_expr.r#else { + let else_apps = self.count_in_statements(name, else_block); + cond_apps.merge(&then_apps).merge(&else_apps); + } + cond_apps }, Expression::UnaryOp(_, expr) => { // Unary operations likely don't change the count but process the inner expression - self.countInExpression(name, expr) + self.count_in_expression(name, expr) }, Expression::BinaryOp(left, _, right) => { // Handle binary operations by processing both sides - self.countInExpression(name, left).merge(&&self.countInExpression(name, right)) + self.count_in_expression(name, left).merge(&&self.count_in_expression(name, right)) }, Expression::StructInit(_) => todo!(), /* @@ -382,23 +391,23 @@ impl LinearityChecker { },*/ Expression::ArrayInit(array_init_expr) => { // Handle array initializations - array_init_expr.values.iter().map(|expr| self.countInExpression(name, expr)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) + array_init_expr.values.iter().map(|expr| self.count_in_expression(name, expr)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) }, Expression::Deref(expr, _) | Expression::AsRef(expr, _, _) | Expression::Cast(expr, _, _) => { // Deref, AsRef, and Cast are handled by just checking the inner expression - self.countInExpression(name, expr) + self.count_in_expression(name, expr) }, // Add more cases as necessary based on the Expression types you expect } } - fn check_bindings(&mut self, depth: u32, binding: &LetStmt) -> Result<(), LinearityError> { + fn check_stmt_let(&mut self, depth: u32, binding: &LetStmt) -> Result<(), LinearityError> { // Handle let bindings, possibly involving pattern matching let LetStmt { is_mutable, target, value, span } = binding; match target { LetStmtTarget::Simple { name, r#type } => { - self.check_var_in_expr(depth, name, &VarState::Unconsumed, value) + self.check_var_in_expr(depth, &name.name, &VarState::Unconsumed, value) }, LetStmtTarget::Destructure(bindings) => { for binding in bindings { @@ -409,6 +418,11 @@ impl LinearityChecker { } } + fn check_bindings(&mut self, depth: u32, binding: &Binding) -> Result<(), LinearityError> { + // Do something with the bindings + Ok(()) + } + fn check_stmts(&mut self, depth: u32, stmts: &Vec) -> Result<(), LinearityError> { for stmt in stmts { self.check_stmt(depth, stmt)?; @@ -425,15 +439,15 @@ impl LinearityChecker { },*/ Statement::Let(binding) => { // Handle let bindings, possibly involving pattern matching - self.check_bindings(depth, binding) + self.check_stmt_let(depth, binding) }, //Statement::If(cond, then_block, else_block) => { Statement::If(if_stmt) => { // Handle conditional statements self.check_expr(depth, &if_stmt.value)?; self.check_stmts(depth + 1, &if_stmt.contents)?; - if let Some(else_block) = if_stmt.else_block { - self.check_stmt(depth + 1, else_block)?; + if let Some(else_block) = &if_stmt.r#else { + self.check_stmts(depth + 1, &else_block)?; } Ok(()) }, @@ -444,25 +458,28 @@ impl LinearityChecker { self.check_stmts(depth + 1, &while_stmt.contents) }, //Statement::For(init, cond, post, block) => { - Statement::For(for_stmt) => { + Statement::For(for_stmt) => { // Handle for loops - if let Some(init) = for_stmt.init { - self.check_stmt(depth, init)?; + if let Some(init) = &for_stmt.init { + self.check_stmt_let(depth, &init)?; + } + if let Some(condition) = &for_stmt.condition { + self.check_expr(depth, &condition)?; } - self.check_stmt(depth, &for_stmt.init)?; - self.check_expr(depth, &for_stmt.condition)?; if let Some(post) = &for_stmt.post { - self.check_stmt(depth, post)?; + //TODO check assign statement + //self.check_stmt_assign(depth, post)?; } - self.check_stmt(depth + 1, &for_stmt.block) + self.check_stmts(depth + 1, &for_stmt.contents) }, + /* Statement::Block(statements) => { // Handle blocks of statements for statement in statements { self.check_stmt(depth + 1, statement)?; } Ok(()) - }, + },*/ _ => Err(LinearityError::UnhandledStatementType { r#type: format!("{:?}", stmt) }), } } @@ -475,7 +492,7 @@ impl LinearityChecker { //fn check_var_in_expr(&mut self, depth: u32, name: &str, ty: &VarState, expr: &Expr) -> Result<(), LinearityError> { fn check_var_in_expr(&mut self, depth: u32, name: &str, ty: &VarState, expr: &Expression) -> Result<(), LinearityError> { - let apps = self.countInExpression(name, expr); // Assume count function implementation + let apps = self.count_in_expression(name, expr); // Assume count function implementation let Appearances { consumed, write, read, path } = apps; let state = self.state_tbl.get_state(name).unwrap_or(&VarState::Unconsumed); // Assume default state @@ -507,10 +524,10 @@ impl LinearityChecker { -// Do nothing implementation of linearity check //#[cfg(feature = "linearity")] #[allow(unused_variables)] -pub fn linearity_check_program(program_ir: &FunctionDef, session: &Session) -> Result { +//pub fn linearity_check_program(program_ir: &FunctionDef, session: &Session) -> Result { +pub fn linearity_check_program(programs: &Vec<(PathBuf, String, Program)>, session: &Session) -> Result { /* let mut linearity_table = StateTbl::new(); linearity_table.update_state("x", VarState::Unconsumed); @@ -522,7 +539,29 @@ pub fn linearity_check_program(program_ir: &FunctionDef, session: &Session) -> let state = linearity_table.get_state("y"); */ let mut checker = LinearityChecker::new(); - checker.linearity_check(program_ir)?; + for (path, name, program) in programs { + println!("Checking linearity for program: {}", name); + for module in &program.modules { + println!("Checking linearity for module: {}", module.name.name); + for module_content in &module.contents { + match module_content { + ModuleDefItem::Function(function) => { + //checker.check_function(&function)?; + for statement in &function.body { + //checker.check_function(&function)?; + checker.check_stmt(0, &statement)?; + } + //checker.linearity_check(&function)?; + } + _ => + { + println!("Skipping linear check for module content: {:?}", module_content); + () + }, + } + } + } + } Ok("OK".to_string()) } diff --git a/crates/concrete_driver/src/lib.rs b/crates/concrete_driver/src/lib.rs index d185da6..089da5e 100644 --- a/crates/concrete_driver/src/lib.rs +++ b/crates/concrete_driver/src/lib.rs @@ -593,7 +593,9 @@ pub fn compile(args: &CompilerArgs) -> Result { }; #[allow(unused_variables)] - let linearity_result = match concrete_check::linearity_check::linearity_check_program(&program_ir, &session) { + //When tried to use ir representation for linearity check + //let linearity_result = match concrete_check::linearity_check::linearity_check_program(&program_ir, &session) { + let linearity_result = match concrete_check::linearity_check::linearity_check_program(&programs, &session) { Ok(ir) => ir, Err(error) => { println!("TODO error message when linearity fails"); From 0953bee47b30864b48c13723a975e2ae8226f5b5 Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Wed, 15 May 2024 17:58:55 -0300 Subject: [PATCH 14/31] Now it runs the linear checker --- crates/concrete_check/src/linearity_check.rs | 70 ++++---------------- crates/concrete_driver/src/lib.rs | 2 +- 2 files changed, 14 insertions(+), 58 deletions(-) diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index 2696b37..ddf7ae1 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -13,8 +13,9 @@ use std::path::PathBuf; use concrete_ast::Program; use concrete_ast::modules::ModuleDefItem; //use concrete_ast::functions::FunctionDef; -use concrete_ast::expressions::Expression; -use concrete_ast::statements::{Statement, AssignStmt, LetStmt, WhileStmt, ForStmt, LetStmtTarget, Binding}; +use concrete_ast::expressions::{Expression, StructInitField}; +//use concrete_ast::statements::{Statement, AssignStmt, LetStmt, WhileStmt, ForStmt, LetStmtTarget, Binding}; +use concrete_ast::statements::{Statement, AssignStmt, LetStmt, LetStmtTarget, Binding}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] enum VarState { @@ -143,10 +144,12 @@ impl StateTbl { self.vars.insert(var.to_string(), state); } + /* // Remove a variable from the state table fn remove_entry(&mut self, var: &str) { self.vars.remove(var); } + */ // Retrieve a variable's state fn get_state(&self, var: &str) -> Option<&VarState> { @@ -221,56 +224,6 @@ impl LinearityChecker { Ok(()) } - #[allow(dead_code)] - #[allow(unused_variables)] - /* - fn count(name: &str, expr: &Expr) -> Appearances { - match expr { - Expr::NilConstant | Expr::BoolConstant(_) | Expr::IntConstant(_) | Expr::FloatConstant(_) | Expr::StringConstant(_) | Expr::ConstVar | Expr::FunVar | Expr::SizeOf => - Appearances::zero(), - - Expr::ParamVar(var_name) | Expr::LocalVar(var_name) => - if var_name == name { Appearances::consumed_once() } else { Appearances::zero() }, - - Expr::Funcall(func, args) | Expr::MethodCall(func, args) | Expr::FptrCall(func, args) => - args.iter().map(|arg| count(name, arg)).fold(Appearances::zero(), |acc, x| acc.merge(&x)), - - Expr::VarMethodCall(args) | Expr::Embed(args) => - args.iter().map(|arg| count(name, arg)).fold(Appearances::zero(), |acc, x| acc.merge(&x)), - - Expr::Cast(e, _) | Expr::Negation(e) | Expr::Deref(e) => - count(name, e), - - Expr::Comparison(lhs, rhs) | Expr::Conjunction(lhs, rhs) | Expr::Disjunction(lhs, rhs) => - count(name, lhs).merge(&count(name, rhs)), - - Expr::IfExpression(cond, then_expr, else_expr) => - count(name, cond).merge(&count(name, then_expr)).merge(&count(name, else_expr)), - - Expr::RecordConstructor(args) | Expr::UnionConstructor(args) => - args.iter().map(|arg| count(name, arg)).fold(Appearances::zero(), |acc, x| acc.merge(&x)), - - Expr::Path { head, elems } => { - let head_apps = count(name, head); - let elems_apps = elems.iter().map(|elem| count(name, elem)).fold(Appearances::zero(), |acc, x| acc.merge(&x)); - head_apps.merge(&elems_apps) - }, - - Expr::BorrowExpr(mode, var_name) => - if var_name == name { - match mode { - BorrowMode::ReadBorrow => Appearances::read_once(), - BorrowMode::WriteBorrow => Appearances::write_once(), - } - } else { - Appearances::zero() - }, - - Expr::ArrayIndex(e) => - count(name, e), - } - } - */ fn count_in_statements(&self, name: &str, statements: &Vec) -> Appearances { statements.iter().map(|stmt| self.count_in_statement(name, stmt)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) @@ -300,7 +253,7 @@ impl LinearityChecker { let cond= &while_expr.value; let block = &while_expr.contents; // Handle while loops - self.count_in_expression(name, cond).merge(&&self.count_in_statements(name, block)) + self.count_in_expression(name, cond).merge(&self.count_in_statements(name, block)) }, Statement::For(for_expr) => { // Handle for loops @@ -383,12 +336,11 @@ impl LinearityChecker { // Handle binary operations by processing both sides self.count_in_expression(name, left).merge(&&self.count_in_expression(name, right)) }, - Expression::StructInit(_) => todo!(), - /* + //Expression::StructInit(_) => todo!(), Expression::StructInit(struct_init_expr) => { // Handle struct initialization - struct_init_expr.fields.iter().map(|(_, expr)| self.count(name, expr)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) - },*/ + struct_init_expr.fields.iter().map(|(_, expr)| self.count_struct_init(name, expr)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) + }, Expression::ArrayInit(array_init_expr) => { // Handle array initializations array_init_expr.values.iter().map(|expr| self.count_in_expression(name, expr)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) @@ -402,6 +354,10 @@ impl LinearityChecker { } + fn count_struct_init(&self, name: &str, struct_init: &StructInitField) -> Appearances { + self.count_in_expression(name, &struct_init.value) + } + fn check_stmt_let(&mut self, depth: u32, binding: &LetStmt) -> Result<(), LinearityError> { // Handle let bindings, possibly involving pattern matching let LetStmt { is_mutable, target, value, span } = binding; diff --git a/crates/concrete_driver/src/lib.rs b/crates/concrete_driver/src/lib.rs index 089da5e..22117c6 100644 --- a/crates/concrete_driver/src/lib.rs +++ b/crates/concrete_driver/src/lib.rs @@ -598,7 +598,7 @@ pub fn compile(args: &CompilerArgs) -> Result { let linearity_result = match concrete_check::linearity_check::linearity_check_program(&programs, &session) { Ok(ir) => ir, Err(error) => { - println!("TODO error message when linearity fails"); + println!("Linearity check failed: {:#?}", error); std::process::exit(1); } }; From bc9c8da6b283437c58d7d2d63017e346f0cc214c Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Wed, 15 May 2024 18:39:51 -0300 Subject: [PATCH 15/31] Shows the StateTbl for each function call --- crates/concrete_check/src/linearity_check.rs | 30 +++++++++++++++----- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index ddf7ae1..065867a 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -428,15 +428,29 @@ impl LinearityChecker { } self.check_stmts(depth + 1, &for_stmt.contents) }, - /* - Statement::Block(statements) => { - // Handle blocks of statements - for statement in statements { - self.check_stmt(depth + 1, statement)?; + Statement::Assign(assign_stmt) => { + // Handle assignments + let AssignStmt { target, derefs, value, span } = assign_stmt; + self.check_expr(depth, value) + }, + Statement::Return(return_stmt) => { + if let Some(value) = &return_stmt.value { + self.check_expr(depth, value) + } else { + Ok(()) + } + }, + Statement::FnCall(fn_call_op) => { + // Process function call arguments + for arg in &fn_call_op.args { + self.check_expr(depth, arg)?; } Ok(()) - },*/ - _ => Err(LinearityError::UnhandledStatementType { r#type: format!("{:?}", stmt) }), + }, + Statement::Match(_) => { + println!("Skipping linearity check for statement type: \n{:?}", stmt); + todo!() + } } } @@ -494,6 +508,7 @@ pub fn linearity_check_program(programs: &Vec<(PathBuf, String, Program)>, sessi linearity_table.remove_entry("x"); let state = linearity_table.get_state("y"); */ + println!("Starting linearity check"); let mut checker = LinearityChecker::new(); for (path, name, program) in programs { println!("Checking linearity for program: {}", name); @@ -507,6 +522,7 @@ pub fn linearity_check_program(programs: &Vec<(PathBuf, String, Program)>, sessi //checker.check_function(&function)?; checker.check_stmt(0, &statement)?; } + println!("Finished checking linearity for function: {} {:?}", function.decl.name.name, checker.state_tbl); //checker.linearity_check(&function)?; } _ => From 65e768e5c5b1a34792368a94fc02db0ab5702548 Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Fri, 17 May 2024 12:08:16 -0300 Subject: [PATCH 16/31] StateTbl::update_state implemented --- crates/concrete_check/src/linearity_check.rs | 109 ++++++++++++++----- 1 file changed, 84 insertions(+), 25 deletions(-) diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index 065867a..27fe876 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -25,6 +25,13 @@ enum VarState { BorrowedMut, } +#[derive(Debug, Clone, PartialEq)] +pub struct VarInfo { + ty: String, //TODO Define 'Type' as needed + depth: usize, + state: VarState, +} + #[derive(Debug, Clone, PartialEq)] enum CountResult { Zero, @@ -40,6 +47,8 @@ struct Appearances { path: u32, } + +// TODO remove. This are structures translated from Austral #[allow(dead_code)] enum Expr { NilConstant, @@ -128,7 +137,7 @@ impl Appearances { #[derive(Debug, Clone)] struct StateTbl { - vars: HashMap, + vars: HashMap, } impl StateTbl { @@ -140,7 +149,7 @@ impl StateTbl { } // Example of updating the state table - fn update_state(&mut self, var: & str, state: VarState) { + fn update_info(&mut self, var: & str, state: VarInfo) { self.vars.insert(var.to_string(), state); } @@ -151,9 +160,38 @@ impl StateTbl { } */ + fn get_info(&mut self, var: &str) -> Option<&mut VarInfo> { + if !self.vars.contains_key(var){ + self.vars.insert(var.to_string(), VarInfo{ty: "".to_string(), depth: 0, state: VarState::Unconsumed}); + } + self.vars.get_mut(var) + } + // Retrieve a variable's state - fn get_state(&self, var: &str) -> Option<&VarState> { - self.vars.get(var) + fn get_state(&mut self, var: &str) -> Option<&VarState> { + if let Some(info) = self.get_info(var) { + Some(&info.state) + } else { + None + } + } + + // Retrieve a variable's state + fn update_state(&mut self, var: &str, new_state: &VarState){ + let info = self.get_info(var); + if let Some(info) = info { + info.state = new_state.clone(); + } + } + + + fn get_loop_depth(&mut self, name: &str) -> usize { + let state = self.get_info(name); + if let Some(state) = state { + state.depth + } else { + 0 + } } } @@ -169,13 +207,6 @@ impl StateTbl { -#[allow(dead_code)] -#[allow(unused_variables)] -fn consume_once(state_tbl: &mut StateTbl, depth: u32, name: &str) -> Result { - // TODO Implement the logic to consume a variable once, updating the state table and handling depth - state_tbl.update_state(name, VarState::Consumed); - Ok(state_tbl.clone()) -} struct LinearityChecker { @@ -214,12 +245,51 @@ impl LinearityChecker { Ok(()) }*/ + + + fn consume_once(&mut self, depth: usize, name: &String) -> Result<(), LinearityError> { + if depth == self.state_tbl.get_loop_depth(name) { + self.state_tbl.update_state(name, &VarState::Consumed); + /* + let mut state = self.state_tbl.get_state(name); + if let Some(state) = state { + state = &VarState::Consumed; + } + else{ + //self.state_tbl.update_state(name, VarInfo{"".to_string(), depth, VarState::Unconsumed}); + }*/ + + Ok(()) + } + else{ + Err(LinearityError::ConsumedMoreThanOnce { variable: name.to_string()}) + } + } + + /* + fn consume_once(&mut self, name: &str, depth: usize) -> Result<(), LinearityError> { + if let Some(var_state) = self.state_tbl.get_state(name) { + if var_state.depth == depth { + var_state.state = VarState::Consumed; + Ok(()) + } else { + Err(LinearityError::InvalidLoopDepth) + } + } else { + Err(LinearityError::VariableNotFound) + } + } + */ + fn check_expr(&mut self, depth: u32, expr: &Expression) -> Result<(), LinearityError> { // Assuming you have a method to get all variable names and types //let vars = &mut self.state_tbl.vars; let vars = self.state_tbl.vars.clone(); - for (name, ty) in vars.iter() { - self.check_var_in_expr(depth, &name, &ty, expr)?; + for (name, info) in vars.iter() { + //self.check_var_in_expr(depth, &name, &info.ty, expr)?; + self.check_var_in_expr(depth, &name, &info.state, expr)?; + //fn check_var_in_expr(&mut self, depth: u32, name: &str, ty: &VarState, expr: &Expression) -> Result<(), LinearityError> { + } Ok(()) } @@ -238,7 +308,6 @@ impl LinearityChecker { Statement::If(if_stmt) => { // Process all components of an if expression let cond_apps = self.count_in_expression(name, &if_stmt.value); - //let then_apps = self.count_in_statement(name, &if_stmt.contents); let then_apps = self.count_in_statements(name, &if_stmt.contents); let else_apps; let else_statements = &if_stmt.r#else; @@ -461,7 +530,7 @@ impl LinearityChecker { }*/ //fn check_var_in_expr(&mut self, depth: u32, name: &str, ty: &VarState, expr: &Expr) -> Result<(), LinearityError> { - fn check_var_in_expr(&mut self, depth: u32, name: &str, ty: &VarState, expr: &Expression) -> Result<(), LinearityError> { + fn check_var_in_expr(&mut self, depth: u32, name: &str, state: &VarState, expr: &Expression) -> Result<(), LinearityError> { let apps = self.count_in_expression(name, expr); // Assume count function implementation let Appearances { consumed, write, read, path } = apps; @@ -498,16 +567,6 @@ impl LinearityChecker { #[allow(unused_variables)] //pub fn linearity_check_program(program_ir: &FunctionDef, session: &Session) -> Result { pub fn linearity_check_program(programs: &Vec<(PathBuf, String, Program)>, session: &Session) -> Result { - /* - let mut linearity_table = StateTbl::new(); - linearity_table.update_state("x", VarState::Unconsumed); - linearity_table.update_state("y", VarState::Consumed); - linearity_table.update_state("z", VarState::Borrowed); - linearity_table.update_state("w", VarState::BorrowedMut); - - linearity_table.remove_entry("x"); - let state = linearity_table.get_state("y"); - */ println!("Starting linearity check"); let mut checker = LinearityChecker::new(); for (path, name, program) in programs { From b70b7859f79ee357fcc1a0046ccf38648bb92bc6 Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Fri, 17 May 2024 13:29:58 -0300 Subject: [PATCH 17/31] check_var_in_expr: Covered all 13 cases --- crates/concrete_check/src/linearity_check.rs | 159 ++++++++++++------ .../src/linearity_check/errors.rs | 4 + 2 files changed, 113 insertions(+), 50 deletions(-) diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index 27fe876..ee4318b 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -13,7 +13,7 @@ use std::path::PathBuf; use concrete_ast::Program; use concrete_ast::modules::ModuleDefItem; //use concrete_ast::functions::FunctionDef; -use concrete_ast::expressions::{Expression, StructInitField}; +use concrete_ast::expressions::{Expression, StructInitField, PathOp}; //use concrete_ast::statements::{Statement, AssignStmt, LetStmt, WhileStmt, ForStmt, LetStmtTarget, Binding}; use concrete_ast::statements::{Statement, AssignStmt, LetStmt, LetStmtTarget, Binding}; @@ -149,8 +149,8 @@ impl StateTbl { } // Example of updating the state table - fn update_info(&mut self, var: & str, state: VarInfo) { - self.vars.insert(var.to_string(), state); + fn update_info(&mut self, var: & str, info: VarInfo) { + self.vars.insert(var.to_string(), info); } /* @@ -247,9 +247,12 @@ impl LinearityChecker { - fn consume_once(&mut self, depth: usize, name: &String) -> Result<(), LinearityError> { + fn consume_once(&mut self, depth: usize, name: &str) -> Result<(), LinearityError> { + let loop_depth = self.state_tbl.get_loop_depth(name); + println!("Consuming variable: {} depth {} loop_depth {}", name, depth, loop_depth); if depth == self.state_tbl.get_loop_depth(name) { self.state_tbl.update_state(name, &VarState::Consumed); + println!("Consumed variable: {}", name); /* let mut state = self.state_tbl.get_state(name); if let Some(state) = state { @@ -266,30 +269,16 @@ impl LinearityChecker { } } - /* - fn consume_once(&mut self, name: &str, depth: usize) -> Result<(), LinearityError> { - if let Some(var_state) = self.state_tbl.get_state(name) { - if var_state.depth == depth { - var_state.state = VarState::Consumed; - Ok(()) - } else { - Err(LinearityError::InvalidLoopDepth) - } - } else { - Err(LinearityError::VariableNotFound) - } - } - */ - fn check_expr(&mut self, depth: u32, expr: &Expression) -> Result<(), LinearityError> { + + fn check_expr(&mut self, depth: usize, expr: &Expression) -> Result<(), LinearityError> { // Assuming you have a method to get all variable names and types //let vars = &mut self.state_tbl.vars; + //TODO check if we can avoid cloning let vars = self.state_tbl.vars.clone(); for (name, info) in vars.iter() { //self.check_var_in_expr(depth, &name, &info.ty, expr)?; self.check_var_in_expr(depth, &name, &info.state, expr)?; - //fn check_var_in_expr(&mut self, depth: u32, name: &str, ty: &VarState, expr: &Expression) -> Result<(), LinearityError> { - } Ok(()) } @@ -427,7 +416,7 @@ impl LinearityChecker { self.count_in_expression(name, &struct_init.value) } - fn check_stmt_let(&mut self, depth: u32, binding: &LetStmt) -> Result<(), LinearityError> { + fn check_stmt_let(&mut self, depth: usize, binding: &LetStmt) -> Result<(), LinearityError> { // Handle let bindings, possibly involving pattern matching let LetStmt { is_mutable, target, value, span } = binding; match target { @@ -443,19 +432,19 @@ impl LinearityChecker { } } - fn check_bindings(&mut self, depth: u32, binding: &Binding) -> Result<(), LinearityError> { + fn check_bindings(&mut self, depth: usize, binding: &Binding) -> Result<(), LinearityError> { // Do something with the bindings Ok(()) } - fn check_stmts(&mut self, depth: u32, stmts: &Vec) -> Result<(), LinearityError> { + fn check_stmts(&mut self, depth: usize, stmts: &Vec) -> Result<(), LinearityError> { for stmt in stmts { self.check_stmt(depth, stmt)?; } Ok(()) } - fn check_stmt(&mut self, depth: u32, stmt: &Statement) -> Result<(), LinearityError> { + fn check_stmt(&mut self, depth: usize, stmt: &Statement) -> Result<(), LinearityError> { match stmt { /* Statement::Expression(expr) => { @@ -500,6 +489,9 @@ impl LinearityChecker { Statement::Assign(assign_stmt) => { // Handle assignments let AssignStmt { target, derefs, value, span } = assign_stmt; + println!("Checking assignment: {:?}", assign_stmt); + // TODO check target + //self.check_expr(depth, &self.path_op_to_expression(target))?; self.check_expr(depth, value) }, Statement::Return(return_stmt) => { @@ -523,41 +515,107 @@ impl LinearityChecker { } } - /* - fn check_expr(&self, depth: usize, expr: &Expression) -> Result<(), LinearityError> { - // Expression checking logic here - Ok(()) + /* + fn path_op_to_expression(&self, path_op: &PathOp) -> Expression { + // Convert the first identifier part of the path into an Expression + let mut expr = Expression::Variable(path_op.first.clone()); + + // Process additional path segments + for segment in &path_op.extra { + match segment { + PathSegment::Field(field) => { + expr = Expression::Field(Box::new(expr), field.clone()); + }, + PathSegment::Index(index) => { + // Assuming index is an Expression + expr = Expression::Index(Box::new(expr), Box::new(Expression::Variable(index.clone()))); + }, + // Add other cases as necessary + } + } + + expr + } + + fn build_complex_expression_from_path(&self, components: &[PathComponent]) -> Expression { + // Construct a complex expression from path components + // This is just a placeholder, real implementation will depend on your specific case + Expression::Variable(components.iter().map(|c| c.to_string()).collect::>().join(".")) }*/ - //fn check_var_in_expr(&mut self, depth: u32, name: &str, ty: &VarState, expr: &Expr) -> Result<(), LinearityError> { - fn check_var_in_expr(&mut self, depth: u32, name: &str, state: &VarState, expr: &Expression) -> Result<(), LinearityError> { + fn check_var_in_expr(&mut self, depth: usize, name: &str, state: &VarState, expr: &Expression) -> Result<(), LinearityError> { let apps = self.count_in_expression(name, expr); // Assume count function implementation let Appearances { consumed, write, read, path } = apps; let state = self.state_tbl.get_state(name).unwrap_or(&VarState::Unconsumed); // Assume default state - + println!("Checking variable: {} with state: {:?} and appearances: {:?} in expression {:?}", name, state, apps, expr); match (state, Appearances::partition(consumed), Appearances::partition(write), Appearances::partition(read), Appearances::partition(path)) { - /*( State Consumed WBorrow RBorrow Path ) + /*( State Consumed WBorrow RBorrow Path ) (* ------------------|-------------------|-----------------|------------------|----------------)*/ - //(VarState::Unconsumed, CountResult::Zero, CountResult::Zero, _, _) => Ok(state_tbl.clone()), - //(VarState::Unconsumed, CountResult::Zero, CountResult::One, CountResult::Zero, CountResult::Zero) => Ok(state_tbl.clone()), + // Not yet consumed, and at most used through immutable borrows or path reads. (VarState::Unconsumed, CountResult::Zero, CountResult::Zero, _, _) => Ok(()), - (VarState::Unconsumed, CountResult::Zero, CountResult::One, CountResult::Zero, CountResult::Zero) => Ok(()), - (VarState::Unconsumed, CountResult::Zero, CountResult::MoreThanOne, _, _) => - Err(LinearityError::BorrowedMutMoreThanOnce { variable: name.to_string() }), - (VarState::Unconsumed, CountResult::One, _, _, _) => - Err(LinearityError::ConsumedAndUsed { variable: name.to_string() }), - (VarState::Unconsumed, CountResult::MoreThanOne, _, _, _) => - Err(LinearityError::ConsumedMoreThanOnce { variable: name.to_string() }), - (VarState::Borrowed, _, _, _, _) => - Err(LinearityError::ReadBorrowedAndUsed { variable: name.to_string() }), - (VarState::BorrowedMut, _, _, _, _) => - Err(LinearityError::WriteBorrowedAndUsed { variable: name.to_string() }), - (VarState::Consumed, _, _, _, _) => - Err(LinearityError::AlreadyConsumedAndUsed { variable: name.to_string() }), - _ => Err(LinearityError::UnhandledStateOrCount { variable: name.to_string() }), + // Not yet consumed, borrowed mutably once, and nothing else. + (VarState::Unconsumed, CountResult::Zero, CountResult::One, CountResult::Zero, CountResult::Zero) => Ok(()), + // Not yet consumed, borrowed mutably, then either borrowed immutably or accessed through a path. + (VarState::Unconsumed, CountResult::Zero, CountResult::One, _, _) => Err(LinearityError::BorrowedMutUsed { variable: name.to_string() }), + // Not yet consumed, borrowed mutably more than once. + (VarState::Unconsumed, CountResult::Zero, CountResult::MoreThanOne, _, _) => Err(LinearityError::BorrowedMutMoreThanOnce { variable: name.to_string() }), + // Not yet consumed, consumed once, and nothing else. Valid IF the loop depth matches. + (VarState::Unconsumed, CountResult::One, CountResult::Zero, CountResult::Zero, CountResult::Zero) => self.consume_once(depth, name), + // Not yet consumed, consumed once, then either borrowed or accessed through a path. + (VarState::Unconsumed, CountResult::One, _, _, _) => Err(LinearityError::ConsumedAndUsed { variable: name.to_string() }), + // Not yet consumed, consumed more than once. + (VarState::Unconsumed, CountResult::MoreThanOne, _, _, _) => Err(LinearityError::ConsumedMoreThanOnce { variable: name.to_string() }), + // Read borrowed, and at most accessed through a path. + (VarState::Borrowed, CountResult::Zero, CountResult::Zero, CountResult::Zero, _) => Ok(()), + // Read borrowed, and either consumed or borrowed again. + (VarState::Borrowed, _, _, _, _) => Err(LinearityError::ReadBorrowedAndUsed { variable: name.to_string() }), + // Write borrowed, unused. + (VarState::BorrowedMut, CountResult::Zero, CountResult::Zero, CountResult::Zero, CountResult::Zero) => Ok(()), + // Write borrowed, used in some way. + (VarState::BorrowedMut, _, _, _, _) => Err(LinearityError::WriteBorrowedAndUsed { variable: name.to_string() }), + // Already consumed, and unused. + (VarState::Consumed, CountResult::Zero, CountResult::Zero, CountResult::Zero, CountResult::Zero) => Ok(()), + // Already consumed, and used in some way. + (VarState::Consumed, _, _, _, _) => Err(LinearityError::AlreadyConsumedAndUsed { variable: name.to_string() }), } } + /* + fn check_var_in_expr(&mut self, depth: u32, name: &str, state: &VarState, expr: &Expression) -> Result<(), LinearityError> { + let apps = self.count_in_expression(name, expr); // Assume count function implementation + let Appearances { consumed, write, read, path } = apps; + + //let state = self.state_tbl.get_state(name).unwrap_or(&VarState::Unconsumed); // Assume default state + let state = self.state_tbl.get_state(name);// Assume default state + if let Some(state) = state{ + println!("Checking variable: {} with state: {:?} and appearances: {:?}", name, state, apps); + match (state, Appearances::partition(consumed), Appearances::partition(write), Appearances::partition(read), Appearances::partition(path)) { + /*( State Consumed WBorrow RBorrow Path ) + (* ------------------|-------------------|-----------------|------------------|----------------)*/ + //(VarState::Unconsumed, CountResult::Zero, CountResult::Zero, _, _) => Ok(state_tbl.clone()), + //(VarState::Unconsumed, CountResult::Zero, CountResult::One, CountResult::Zero, CountResult::Zero) => Ok(state_tbl.clone()), + (VarState::Unconsumed, CountResult::Zero, CountResult::Zero, _, _) => Ok(()), + (VarState::Unconsumed, CountResult::Zero, CountResult::One, CountResult::Zero, CountResult::Zero) => Ok(()), + (VarState::Unconsumed, CountResult::Zero, CountResult::MoreThanOne, _, _) => + Err(LinearityError::BorrowedMutMoreThanOnce { variable: name.to_string() }), + (VarState::Unconsumed, CountResult::One, _, _, _) => + Err(LinearityError::ConsumedAndUsed { variable: name.to_string() }), + (VarState::Unconsumed, CountResult::MoreThanOne, _, _, _) => + Err(LinearityError::ConsumedMoreThanOnce { variable: name.to_string() }), + (VarState::Borrowed, _, _, _, _) => + Err(LinearityError::ReadBorrowedAndUsed { variable: name.to_string() }), + (VarState::BorrowedMut, _, _, _, _) => + Err(LinearityError::WriteBorrowedAndUsed { variable: name.to_string() }), + (VarState::Consumed, _, _, _, _) => + Err(LinearityError::AlreadyConsumedAndUsed { variable: name.to_string() }), + _ => Err(LinearityError::UnhandledStateOrCount { variable: name.to_string() }), + } + } + else { + Err(LinearityError::VariableNotFound { variable: name.to_string() }) + } + + }*/ } @@ -576,9 +634,10 @@ pub fn linearity_check_program(programs: &Vec<(PathBuf, String, Program)>, sessi for module_content in &module.contents { match module_content { ModuleDefItem::Function(function) => { + //println!("Checking linearity for function: {:?}", function); //checker.check_function(&function)?; for statement in &function.body { - //checker.check_function(&function)?; + //println!("Checking linearity for function body: {:?}", function.body); checker.check_stmt(0, &statement)?; } println!("Finished checking linearity for function: {} {:?}", function.decl.name.name, checker.state_tbl); diff --git a/crates/concrete_check/src/linearity_check/errors.rs b/crates/concrete_check/src/linearity_check/errors.rs index 4d261c1..7b6e5bf 100644 --- a/crates/concrete_check/src/linearity_check/errors.rs +++ b/crates/concrete_check/src/linearity_check/errors.rs @@ -45,6 +45,10 @@ pub enum LinearityError { variable: String, message: String, }, + #[error("Variable {variable} not found")] + VariableNotFound{ + variable: String, + }, #[error("Unhandled statement type {r#type}")] UnhandledStatementType{ r#type: String, From 946ce9d62ecf77f546599e2dbe0fd087032454e0 Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Fri, 17 May 2024 15:53:00 -0300 Subject: [PATCH 18/31] More info with TODO in Match case --- crates/concrete_check/src/linearity_check.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index ee4318b..81c5d8f 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -368,7 +368,7 @@ impl LinearityChecker { // Process function call arguments fn_call_op.args.iter().map(|arg| self.count_in_expression(name, arg)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) }, - Expression::Match(match_expr) => todo!(), + Expression::Match(match_expr) => todo!("do not support match expression"), /* Expression::Match(match_expr) => { // Handle match arms @@ -394,7 +394,6 @@ impl LinearityChecker { // Handle binary operations by processing both sides self.count_in_expression(name, left).merge(&&self.count_in_expression(name, right)) }, - //Expression::StructInit(_) => todo!(), Expression::StructInit(struct_init_expr) => { // Handle struct initialization struct_init_expr.fields.iter().map(|(_, expr)| self.count_struct_init(name, expr)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) From 8256951a38338b6c2add37d8eab4440d9ba60527 Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Fri, 17 May 2024 15:53:17 -0300 Subject: [PATCH 19/31] _Borrowed types because Borrow not covered yet in linearity check --- crates/concrete_check/src/linearity_check.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index 81c5d8f..fda69eb 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -21,8 +21,8 @@ use concrete_ast::statements::{Statement, AssignStmt, LetStmt, LetStmtTarget, Bi enum VarState { Unconsumed, Consumed, - Borrowed, - BorrowedMut, + _Borrowed, + _BorrowedMut, } #[derive(Debug, Clone, PartialEq)] @@ -566,13 +566,13 @@ impl LinearityChecker { // Not yet consumed, consumed more than once. (VarState::Unconsumed, CountResult::MoreThanOne, _, _, _) => Err(LinearityError::ConsumedMoreThanOnce { variable: name.to_string() }), // Read borrowed, and at most accessed through a path. - (VarState::Borrowed, CountResult::Zero, CountResult::Zero, CountResult::Zero, _) => Ok(()), + (VarState::_Borrowed, CountResult::Zero, CountResult::Zero, CountResult::Zero, _) => Ok(()), // Read borrowed, and either consumed or borrowed again. - (VarState::Borrowed, _, _, _, _) => Err(LinearityError::ReadBorrowedAndUsed { variable: name.to_string() }), + (VarState::_Borrowed, _, _, _, _) => Err(LinearityError::ReadBorrowedAndUsed { variable: name.to_string() }), // Write borrowed, unused. - (VarState::BorrowedMut, CountResult::Zero, CountResult::Zero, CountResult::Zero, CountResult::Zero) => Ok(()), + (VarState::_BorrowedMut, CountResult::Zero, CountResult::Zero, CountResult::Zero, CountResult::Zero) => Ok(()), // Write borrowed, used in some way. - (VarState::BorrowedMut, _, _, _, _) => Err(LinearityError::WriteBorrowedAndUsed { variable: name.to_string() }), + (VarState::_BorrowedMut, _, _, _, _) => Err(LinearityError::WriteBorrowedAndUsed { variable: name.to_string() }), // Already consumed, and unused. (VarState::Consumed, CountResult::Zero, CountResult::Zero, CountResult::Zero, CountResult::Zero) => Ok(()), // Already consumed, and used in some way. From 3d6083ed15fe52077461f20f2bdd3ff9a5f7a38f Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Fri, 17 May 2024 17:47:59 -0300 Subject: [PATCH 20/31] state_tbl is initialized with type --- crates/concrete_check/src/linearity_check.rs | 119 +++++++++++++++++-- 1 file changed, 110 insertions(+), 9 deletions(-) diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index fda69eb..3cc8b40 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -11,12 +11,13 @@ use std::path::PathBuf; //use concrete_ir::{ProgramBody, FnBody}; //use concrete_ast::Program{ ProgramBody, FnBody }; use concrete_ast::Program; +//use concrete_ast::modules::{Module, ModuleDefItem}; use concrete_ast::modules::ModuleDefItem; //use concrete_ast::functions::FunctionDef; use concrete_ast::expressions::{Expression, StructInitField, PathOp}; //use concrete_ast::statements::{Statement, AssignStmt, LetStmt, WhileStmt, ForStmt, LetStmtTarget, Binding}; use concrete_ast::statements::{Statement, AssignStmt, LetStmt, LetStmtTarget, Binding}; - +use concrete_ast::types::TypeSpec; #[derive(Debug, Clone, PartialEq, Eq, Hash)] enum VarState { Unconsumed, @@ -120,6 +121,14 @@ impl Appearances { Self::new(0, 0, 0, 1) } + /* TODO implementation of merge without copying + fn merge(&mut self, other: &Appearances) { + self.consumed += other.consumed; + self.write += other.write; + self.read += other.read; + self.path += other.path; + }*/ + fn merge(&self, other: &Appearances) -> Self { Appearances { consumed: self.consumed + other.consumed, @@ -128,6 +137,7 @@ impl Appearances { path: self.path + other.path, } } + fn merge_list(appearances: Vec) -> Self { appearances.into_iter().fold(Self::zero(), |acc, x| acc.merge(&x)) @@ -163,6 +173,7 @@ impl StateTbl { fn get_info(&mut self, var: &str) -> Option<&mut VarInfo> { if !self.vars.contains_key(var){ self.vars.insert(var.to_string(), VarInfo{ty: "".to_string(), depth: 0, state: VarState::Unconsumed}); + println!("Variable {} not found in state table. Inserting with default state", var); } self.vars.get_mut(var) } @@ -330,13 +341,33 @@ impl LinearityChecker { } apps }, - /* + /* Alucination of GPT Statement::Block(statements) => { // Handle blocks of statements //statements.iter().map(|stmt| self.count(name, stmt)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) self.count_in_statements(name, statements) },*/ - _ => Appearances::zero(), + Statement::Assign(assign_stmt) => { + // Handle assignments + self.count_in_assign_statement(name, assign_stmt) + }, + Statement::Return(return_stmt) => { + // Handle return statements + if let Some(value) = &return_stmt.value { + self.count_in_expression(name, value) + } else { + Appearances::zero() + } + }, + Statement::FnCall(fn_call_op) => { + // Process function call arguments + //fn_call_op.target.iter().map(|arg| self.count_in_path_op(name, arg)).fold(Appearances::zero(), |acc, x| acc.merge(&x)); + fn_call_op.args.iter().map(|arg| self.count_in_expression(name, arg)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) + }, + Statement::Match(_) => { + todo!("do not support match statement") + }, + //_ => Appearances::zero(), } } @@ -344,11 +375,24 @@ impl LinearityChecker { match assign_stmt { AssignStmt { target, derefs, value, span } => { // Handle assignments - self.count_in_expression(name, value) + let ret = self.count_in_path_op(name, target); + ret.merge(&self.count_in_expression(name, value)); + ret }, } } + fn count_in_path_op(&self, name: &str, path_op: &PathOp) -> Appearances { + let apps: Appearances; + if name == path_op.first.name{ + apps = Appearances::path_once(); + } + else{ + apps = Appearances::zero(); + } + apps + } + fn count_in_let_statements(&self, name: &str, let_stmt: &LetStmt) -> Appearances { match let_stmt { LetStmt { is_mutable, target, value, span } => { @@ -412,6 +456,7 @@ impl LinearityChecker { fn count_struct_init(&self, name: &str, struct_init: &StructInitField) -> Appearances { + println!("Checking struct init: {:?}", struct_init); self.count_in_expression(name, &struct_init.value) } @@ -420,6 +465,18 @@ impl LinearityChecker { let LetStmt { is_mutable, target, value, span } = binding; match target { LetStmtTarget::Simple { name, r#type } => { + match r#type { + TypeSpec::Simple{ name: variable_type , qualifiers, span} => { + self.state_tbl.update_info(&name.name, VarInfo{ty: variable_type.name.clone(), depth, state: VarState::Unconsumed}); + }, + TypeSpec::Generic { name: variable_type, qualifiers, type_params, span } =>{ + self.state_tbl.update_info(&name.name, VarInfo{ty: variable_type.name.clone(), depth, state: VarState::Unconsumed}); + }, + TypeSpec::Array { of_type, size, qualifiers, span } => { + let array_type = "Array<".to_string() + &of_type.get_name() + ">"; + self.state_tbl.update_info(&name.name, VarInfo{ty: array_type, depth, state: VarState::Unconsumed}); + }, + } self.check_var_in_expr(depth, &name.name, &VarState::Unconsumed, value) }, LetStmtTarget::Destructure(bindings) => { @@ -490,6 +547,7 @@ impl LinearityChecker { let AssignStmt { target, derefs, value, span } = assign_stmt; println!("Checking assignment: {:?}", assign_stmt); // TODO check target + self.check_path_opt(depth, target)?; //self.check_expr(depth, &self.path_op_to_expression(target))?; self.check_expr(depth, value) }, @@ -510,10 +568,16 @@ impl LinearityChecker { Statement::Match(_) => { println!("Skipping linearity check for statement type: \n{:?}", stmt); todo!() - } + } } } + fn check_path_opt(&mut self, depth: usize, path_op: &PathOp) -> Result<(), LinearityError> { + println!("Checking path: {:?}", path_op); + println!("TODO add to: {:?}", path_op); + //path_op.first.name; + Ok(()) + } /* fn path_op_to_expression(&self, path_op: &PathOp) -> Expression { // Convert the first identifier part of the path into an Expression @@ -641,12 +705,49 @@ pub fn linearity_check_program(programs: &Vec<(PathBuf, String, Program)>, sessi } println!("Finished checking linearity for function: {} {:?}", function.decl.name.name, checker.state_tbl); //checker.linearity_check(&function)?; - } - _ => + }, + ModuleDefItem::FunctionDecl(function_decl) => + { + println!("Skipping linearity check for FunctionDecl: {:?}", module_content); + () + }, + ModuleDefItem::Module(module) => { - println!("Skipping linear check for module content: {:?}", module_content); + println!("Skipping linearity check for Module: {:?}", module_content); () - }, + }, + ModuleDefItem::Struct(struc) => + { + //println!("Skipping linearity check for Struct: {:?}", module_content); + //checker. + checker.state_tbl.update_info(&struc.name.name, VarInfo{ty: "Struct".to_string(), depth: 0, state: VarState::Unconsumed}); + () + }, + ModuleDefItem::Enum(_) => + { + println!("Skipping linearity check for Enum: {:?}", module_content); + () + }, + ModuleDefItem::Constant(_) => + { + println!("Skipping linearity check for Constant: {:?}", module_content); + () + }, + ModuleDefItem::Union(_) => + { + println!("Skipping linearity check for Uinon: {:?}", module_content); + () + }, + ModuleDefItem::Type(_) => + { + println!("Skipping linearity check for module content: {:?}", module_content); + () + }, + /*_ => + { + println!("Skipping linearity check for module content: {:?}", module_content); + () + },*/ } } } From 5c97ebed4b2817fa02cac0a747aa91c27ab2c4d0 Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Fri, 17 May 2024 18:04:01 -0300 Subject: [PATCH 21/31] Only makes linearityChecks for variables of type Linear --- crates/concrete_check/src/linearity_check.rs | 87 ++++++++++++-------- 1 file changed, 51 insertions(+), 36 deletions(-) diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index 3cc8b40..7f6c9e7 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -179,7 +179,7 @@ impl StateTbl { } // Retrieve a variable's state - fn get_state(&mut self, var: &str) -> Option<&VarState> { + fn _get_state(&mut self, var: &str) -> Option<&VarState> { if let Some(info) = self.get_info(var) { Some(&info.state) } else { @@ -489,7 +489,8 @@ impl LinearityChecker { } fn check_bindings(&mut self, depth: usize, binding: &Binding) -> Result<(), LinearityError> { - // Do something with the bindings + // TODO Do something with the bindings + println!("TODO implement Checking bindings: {:?}", binding); Ok(()) } @@ -607,40 +608,54 @@ impl LinearityChecker { }*/ fn check_var_in_expr(&mut self, depth: usize, name: &str, state: &VarState, expr: &Expression) -> Result<(), LinearityError> { - let apps = self.count_in_expression(name, expr); // Assume count function implementation - let Appearances { consumed, write, read, path } = apps; - - let state = self.state_tbl.get_state(name).unwrap_or(&VarState::Unconsumed); // Assume default state - println!("Checking variable: {} with state: {:?} and appearances: {:?} in expression {:?}", name, state, apps, expr); - match (state, Appearances::partition(consumed), Appearances::partition(write), Appearances::partition(read), Appearances::partition(path)) { - /*( State Consumed WBorrow RBorrow Path ) - (* ------------------|-------------------|-----------------|------------------|----------------)*/ - // Not yet consumed, and at most used through immutable borrows or path reads. - (VarState::Unconsumed, CountResult::Zero, CountResult::Zero, _, _) => Ok(()), - // Not yet consumed, borrowed mutably once, and nothing else. - (VarState::Unconsumed, CountResult::Zero, CountResult::One, CountResult::Zero, CountResult::Zero) => Ok(()), - // Not yet consumed, borrowed mutably, then either borrowed immutably or accessed through a path. - (VarState::Unconsumed, CountResult::Zero, CountResult::One, _, _) => Err(LinearityError::BorrowedMutUsed { variable: name.to_string() }), - // Not yet consumed, borrowed mutably more than once. - (VarState::Unconsumed, CountResult::Zero, CountResult::MoreThanOne, _, _) => Err(LinearityError::BorrowedMutMoreThanOnce { variable: name.to_string() }), - // Not yet consumed, consumed once, and nothing else. Valid IF the loop depth matches. - (VarState::Unconsumed, CountResult::One, CountResult::Zero, CountResult::Zero, CountResult::Zero) => self.consume_once(depth, name), - // Not yet consumed, consumed once, then either borrowed or accessed through a path. - (VarState::Unconsumed, CountResult::One, _, _, _) => Err(LinearityError::ConsumedAndUsed { variable: name.to_string() }), - // Not yet consumed, consumed more than once. - (VarState::Unconsumed, CountResult::MoreThanOne, _, _, _) => Err(LinearityError::ConsumedMoreThanOnce { variable: name.to_string() }), - // Read borrowed, and at most accessed through a path. - (VarState::_Borrowed, CountResult::Zero, CountResult::Zero, CountResult::Zero, _) => Ok(()), - // Read borrowed, and either consumed or borrowed again. - (VarState::_Borrowed, _, _, _, _) => Err(LinearityError::ReadBorrowedAndUsed { variable: name.to_string() }), - // Write borrowed, unused. - (VarState::_BorrowedMut, CountResult::Zero, CountResult::Zero, CountResult::Zero, CountResult::Zero) => Ok(()), - // Write borrowed, used in some way. - (VarState::_BorrowedMut, _, _, _, _) => Err(LinearityError::WriteBorrowedAndUsed { variable: name.to_string() }), - // Already consumed, and unused. - (VarState::Consumed, CountResult::Zero, CountResult::Zero, CountResult::Zero, CountResult::Zero) => Ok(()), - // Already consumed, and used in some way. - (VarState::Consumed, _, _, _, _) => Err(LinearityError::AlreadyConsumedAndUsed { variable: name.to_string() }), + + let info = self.state_tbl.get_info(name); // Assume default state + if let Some(info) = info{ + //Only checks Linearity for types of name Linear + // TODO improve this approach + if info.ty == "Linear".to_string(){ + let apps = self.count_in_expression(name, expr); // Assume count function implementation + let Appearances { consumed, write, read, path } = apps; + + println!("Checking variable: {} with state: {:?} and appearances: {:?} in expression {:?}", name, state, apps, expr); + match (state, Appearances::partition(consumed), Appearances::partition(write), Appearances::partition(read), Appearances::partition(path)) { + /*( State Consumed WBorrow RBorrow Path ) + (* ------------------|-------------------|-----------------|------------------|----------------)*/ + // Not yet consumed, and at most used through immutable borrows or path reads. + (VarState::Unconsumed, CountResult::Zero, CountResult::Zero, _, _) => Ok(()), + // Not yet consumed, borrowed mutably once, and nothing else. + (VarState::Unconsumed, CountResult::Zero, CountResult::One, CountResult::Zero, CountResult::Zero) => Ok(()), + // Not yet consumed, borrowed mutably, then either borrowed immutably or accessed through a path. + (VarState::Unconsumed, CountResult::Zero, CountResult::One, _, _) => Err(LinearityError::BorrowedMutUsed { variable: name.to_string() }), + // Not yet consumed, borrowed mutably more than once. + (VarState::Unconsumed, CountResult::Zero, CountResult::MoreThanOne, _, _) => Err(LinearityError::BorrowedMutMoreThanOnce { variable: name.to_string() }), + // Not yet consumed, consumed once, and nothing else. Valid IF the loop depth matches. + (VarState::Unconsumed, CountResult::One, CountResult::Zero, CountResult::Zero, CountResult::Zero) => self.consume_once(depth, name), + // Not yet consumed, consumed once, then either borrowed or accessed through a path. + (VarState::Unconsumed, CountResult::One, _, _, _) => Err(LinearityError::ConsumedAndUsed { variable: name.to_string() }), + // Not yet consumed, consumed more than once. + (VarState::Unconsumed, CountResult::MoreThanOne, _, _, _) => Err(LinearityError::ConsumedMoreThanOnce { variable: name.to_string() }), + // Read borrowed, and at most accessed through a path. + (VarState::_Borrowed, CountResult::Zero, CountResult::Zero, CountResult::Zero, _) => Ok(()), + // Read borrowed, and either consumed or borrowed again. + (VarState::_Borrowed, _, _, _, _) => Err(LinearityError::ReadBorrowedAndUsed { variable: name.to_string() }), + // Write borrowed, unused. + (VarState::_BorrowedMut, CountResult::Zero, CountResult::Zero, CountResult::Zero, CountResult::Zero) => Ok(()), + // Write borrowed, used in some way. + (VarState::_BorrowedMut, _, _, _, _) => Err(LinearityError::WriteBorrowedAndUsed { variable: name.to_string() }), + // Already consumed, and unused. + (VarState::Consumed, CountResult::Zero, CountResult::Zero, CountResult::Zero, CountResult::Zero) => Ok(()), + // Already consumed, and used in some way. + (VarState::Consumed, _, _, _, _) => Err(LinearityError::AlreadyConsumedAndUsed { variable: name.to_string() }), + } + } + else{ + //Only checks Linearity for types of name Linear + Ok(()) + } + } + else { + Err(LinearityError::VariableNotFound { variable: name.to_string()}) } } /* From 869c2ae352dd9804b6a1b7e733e223e1b9974d1e Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Fri, 17 May 2024 18:49:48 -0300 Subject: [PATCH 22/31] New ValueVar struct for trivial expressions building. Does not compile --- crates/concrete_ast/src/expressions.rs | 1 + crates/concrete_check/src/linearity_check.rs | 16 ++++++++-------- crates/concrete_ir/src/lowering.rs | 7 +++++-- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/crates/concrete_ast/src/expressions.rs b/crates/concrete_ast/src/expressions.rs index 818c666..3cefc80 100644 --- a/crates/concrete_ast/src/expressions.rs +++ b/crates/concrete_ast/src/expressions.rs @@ -29,6 +29,7 @@ pub enum ValueExpr { ConstFloat(String, Span), ConstStr(String, Span), Path(PathOp), + ValueVar(Ident, Span), } #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index 7f6c9e7..977026d 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -287,9 +287,9 @@ impl LinearityChecker { //let vars = &mut self.state_tbl.vars; //TODO check if we can avoid cloning let vars = self.state_tbl.vars.clone(); - for (name, info) in vars.iter() { + for (name, _info) in vars.iter() { //self.check_var_in_expr(depth, &name, &info.ty, expr)?; - self.check_var_in_expr(depth, &name, &info.state, expr)?; + self.check_var_in_expr(depth, &name, expr)?; } Ok(()) } @@ -477,7 +477,7 @@ impl LinearityChecker { self.state_tbl.update_info(&name.name, VarInfo{ty: array_type, depth, state: VarState::Unconsumed}); }, } - self.check_var_in_expr(depth, &name.name, &VarState::Unconsumed, value) + self.check_var_in_expr(depth, &name.name, value) }, LetStmtTarget::Destructure(bindings) => { for binding in bindings { @@ -547,7 +547,6 @@ impl LinearityChecker { // Handle assignments let AssignStmt { target, derefs, value, span } = assign_stmt; println!("Checking assignment: {:?}", assign_stmt); - // TODO check target self.check_path_opt(depth, target)?; //self.check_expr(depth, &self.path_op_to_expression(target))?; self.check_expr(depth, value) @@ -575,7 +574,8 @@ impl LinearityChecker { fn check_path_opt(&mut self, depth: usize, path_op: &PathOp) -> Result<(), LinearityError> { println!("Checking path: {:?}", path_op); - println!("TODO add to: {:?}", path_op); + let var_expression = Expression::Value{ValueVar{path_op.first.clone(), path_op.span}}; + self.check_var_in_expr(depth, &path_op.first.name, &var_expression ); //path_op.first.name; Ok(()) } @@ -607,16 +607,16 @@ impl LinearityChecker { Expression::Variable(components.iter().map(|c| c.to_string()).collect::>().join(".")) }*/ - fn check_var_in_expr(&mut self, depth: usize, name: &str, state: &VarState, expr: &Expression) -> Result<(), LinearityError> { + fn check_var_in_expr(&mut self, depth: usize, name: &str, expr: &Expression) -> Result<(), LinearityError> { let info = self.state_tbl.get_info(name); // Assume default state if let Some(info) = info{ //Only checks Linearity for types of name Linear // TODO improve this approach if info.ty == "Linear".to_string(){ + let state = &info.state; let apps = self.count_in_expression(name, expr); // Assume count function implementation - let Appearances { consumed, write, read, path } = apps; - + let Appearances { consumed, write, read, path } = apps; println!("Checking variable: {} with state: {:?} and appearances: {:?} in expression {:?}", name, state, apps, expr); match (state, Appearances::partition(consumed), Appearances::partition(write), Appearances::partition(read), Appearances::partition(path)) { /*( State Consumed WBorrow RBorrow Path ) diff --git a/crates/concrete_ir/src/lowering.rs b/crates/concrete_ir/src/lowering.rs index 777c5dd..7ae3a50 100644 --- a/crates/concrete_ir/src/lowering.rs +++ b/crates/concrete_ir/src/lowering.rs @@ -844,7 +844,9 @@ fn find_expression_type(builder: &mut FnBodyBuilder, info: &Expression) -> Optio ValueExpr::Path(path) => { let local = builder.get_local(&path.first.name).unwrap(); // todo handle segments Some(local.ty.clone()) - } + }, + //TODO check this behavior + ValueExpr::ValueVar(_, _span) => None, }, Expression::FnCall(info) => { let fn_id = { @@ -1549,7 +1551,8 @@ fn lower_value_expr( ValueExpr::Path(info) => { let (place, place_ty, _span) = lower_path(builder, info)?; (Rvalue::Use(Operand::Place(place.clone())), place_ty) - } + }, + ValueExpr::ValueVar(_, _) => todo!("ValueVar not yet implemented"), }) } From eeccb0f4788652e37c968559ab9e77a5acae236e Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Mon, 20 May 2024 10:44:35 -0300 Subject: [PATCH 23/31] New ValueVar struct for trivial expressions building --- crates/concrete_check/src/linearity_check.rs | 22 ++++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index 977026d..a82eabc 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -14,10 +14,11 @@ use concrete_ast::Program; //use concrete_ast::modules::{Module, ModuleDefItem}; use concrete_ast::modules::ModuleDefItem; //use concrete_ast::functions::FunctionDef; -use concrete_ast::expressions::{Expression, StructInitField, PathOp}; +use concrete_ast::expressions::{Expression, StructInitField, PathOp, ValueExpr}; //use concrete_ast::statements::{Statement, AssignStmt, LetStmt, WhileStmt, ForStmt, LetStmtTarget, Binding}; use concrete_ast::statements::{Statement, AssignStmt, LetStmt, LetStmtTarget, Binding}; use concrete_ast::types::TypeSpec; +//use concrete_ast::expressions::Value; // Import the missing module #[derive(Debug, Clone, PartialEq, Eq, Hash)] enum VarState { Unconsumed, @@ -170,7 +171,7 @@ impl StateTbl { } */ - fn get_info(&mut self, var: &str) -> Option<&mut VarInfo> { + fn get_info_mut(&mut self, var: &str) -> Option<&mut VarInfo> { if !self.vars.contains_key(var){ self.vars.insert(var.to_string(), VarInfo{ty: "".to_string(), depth: 0, state: VarState::Unconsumed}); println!("Variable {} not found in state table. Inserting with default state", var); @@ -178,6 +179,10 @@ impl StateTbl { self.vars.get_mut(var) } + fn get_info(& self, var: &str) -> Option<& VarInfo> { + self.vars.get(var) + } + // Retrieve a variable's state fn _get_state(&mut self, var: &str) -> Option<&VarState> { if let Some(info) = self.get_info(var) { @@ -189,7 +194,7 @@ impl StateTbl { // Retrieve a variable's state fn update_state(&mut self, var: &str, new_state: &VarState){ - let info = self.get_info(var); + let info = self.get_info_mut(var); if let Some(info) = info { info.state = new_state.clone(); } @@ -571,13 +576,12 @@ impl LinearityChecker { } } } - + fn check_path_opt(&mut self, depth: usize, path_op: &PathOp) -> Result<(), LinearityError> { println!("Checking path: {:?}", path_op); - let var_expression = Expression::Value{ValueVar{path_op.first.clone(), path_op.span}}; - self.check_var_in_expr(depth, &path_op.first.name, &var_expression ); - //path_op.first.name; - Ok(()) + //let var_expression = Value::new(path_op.first.clone(), path_op.span); // Use the imported module + let var_expression = Expression::Value(ValueExpr::ValueVar(path_op.first.clone(), path_op.span), path_op.span); + self.check_var_in_expr(depth, &path_op.first.name, &var_expression) } /* fn path_op_to_expression(&self, path_op: &PathOp) -> Expression { @@ -617,7 +621,7 @@ impl LinearityChecker { let state = &info.state; let apps = self.count_in_expression(name, expr); // Assume count function implementation let Appearances { consumed, write, read, path } = apps; - println!("Checking variable: {} with state: {:?} and appearances: {:?} in expression {:?}", name, state, apps, expr); + //println!("Checking variable: {} with state: {:?} and appearances: {:?} in expression {:?}", name, state, apps, expr); match (state, Appearances::partition(consumed), Appearances::partition(write), Appearances::partition(read), Appearances::partition(path)) { /*( State Consumed WBorrow RBorrow Path ) (* ------------------|-------------------|-----------------|------------------|----------------)*/ From 542cdd0edd65e028a32e83b0299d973f8bd4d936 Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Mon, 20 May 2024 14:38:39 -0300 Subject: [PATCH 24/31] linearExample01.con xy consumed --- crates/concrete_check/src/linearity_check.rs | 23 ++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index a82eabc..7a847d7 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -88,7 +88,6 @@ enum BorrowMode { WriteBorrow, } -#[allow(dead_code)] impl Appearances { fn new(consumed: u32, write: u32, read: u32, path: u32) -> Self { Appearances { consumed, write, read, path } @@ -110,11 +109,12 @@ impl Appearances { Self::new(1, 0, 0, 0) } - fn read_once() -> Self { + // When borrowed implemented + fn _read_once() -> Self { Self::new(0, 0, 1, 0) } - fn write_once() -> Self { + fn _write_once() -> Self { Self::new(0, 1, 0, 0) } @@ -411,7 +411,21 @@ impl LinearityChecker { match expr { Expression::Value(value_expr, _) => { // Handle value expressions, typically constant or simple values - Appearances::zero() + //Appearances::zero() + match value_expr { + ValueExpr::ValueVar(ident, _) => { + if name == ident.name { + Appearances::consumed_once() + } else { + Appearances::zero() + } + }, + ValueExpr::Path(path) => { + //path.first.name == name; + Appearances::zero() + }, + _ => Appearances::zero(), + } }, Expression::FnCall(fn_call_op) => { // Process function call arguments @@ -622,6 +636,7 @@ impl LinearityChecker { let apps = self.count_in_expression(name, expr); // Assume count function implementation let Appearances { consumed, write, read, path } = apps; //println!("Checking variable: {} with state: {:?} and appearances: {:?} in expression {:?}", name, state, apps, expr); + println!("Checking state_tbl variable: {}: {:?} {:?} in expression {:?}", name, info, apps, expr); match (state, Appearances::partition(consumed), Appearances::partition(write), Appearances::partition(read), Appearances::partition(path)) { /*( State Consumed WBorrow RBorrow Path ) (* ------------------|-------------------|-----------------|------------------|----------------)*/ From 036b4eb0902c13af400f655bd001d591309eac67 Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Mon, 20 May 2024 15:41:43 -0300 Subject: [PATCH 25/31] Inclued an option for checking types (linearity by now) --- crates/concrete_check/src/linearity_check.rs | 2 +- crates/concrete_driver/src/lib.rs | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index 7a847d7..71d778f 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -140,7 +140,7 @@ impl Appearances { } - fn merge_list(appearances: Vec) -> Self { + fn _merge_list(appearances: Vec) -> Self { appearances.into_iter().fold(Self::zero(), |acc, x| acc.merge(&x)) } } diff --git a/crates/concrete_driver/src/lib.rs b/crates/concrete_driver/src/lib.rs index 22117c6..7cc8faf 100644 --- a/crates/concrete_driver/src/lib.rs +++ b/crates/concrete_driver/src/lib.rs @@ -100,6 +100,11 @@ pub struct BuildArgs { /// Also output the object file. #[arg(long, default_value_t = false)] object: bool, + + /// This option is for checking the program for linearity. + #[arg(long, default_value_t = false)] + check: bool, + } #[derive(Parser, Debug)] @@ -150,6 +155,11 @@ pub struct CompilerArgs { /// Also output the object file. #[arg(long, default_value_t = false)] object: bool, + + /// This option is for checking the program for linearity. + #[arg(long, default_value_t = false)] + check: bool, + } pub fn main() -> Result<()> { @@ -298,6 +308,7 @@ fn handle_build( asm, object, lib, + check, }: BuildArgs, ) -> Result { match path { @@ -325,6 +336,7 @@ fn handle_build( asm, object, mlir, + check, }; println!( @@ -453,6 +465,7 @@ fn handle_build( asm, object, mlir, + check, }; let object = compile(&compile_args)?; From e6c6568867dc82ffbad9ac33c58b2a9a25328fd6 Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Mon, 20 May 2024 15:46:04 -0300 Subject: [PATCH 26/31] flag for checking types included in code --- crates/concrete_driver/src/lib.rs | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/crates/concrete_driver/src/lib.rs b/crates/concrete_driver/src/lib.rs index 7cc8faf..229338e 100644 --- a/crates/concrete_driver/src/lib.rs +++ b/crates/concrete_driver/src/lib.rs @@ -606,15 +606,16 @@ pub fn compile(args: &CompilerArgs) -> Result { }; #[allow(unused_variables)] - //When tried to use ir representation for linearity check - //let linearity_result = match concrete_check::linearity_check::linearity_check_program(&program_ir, &session) { - let linearity_result = match concrete_check::linearity_check::linearity_check_program(&programs, &session) { - Ok(ir) => ir, - Err(error) => { - println!("Linearity check failed: {:#?}", error); - std::process::exit(1); - } - }; + if args.check { + let linearity_result = match concrete_check::linearity_check::linearity_check_program(&programs, &session) { + Ok(ir) => ir, + Err(error) => { + //TODO improve reporting + println!("Linearity check failed: {:#?}", error); + std::process::exit(1); + } + }; + } if args.ir { From e3f66750b970d039cb0004225ba1665998a35336 Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Mon, 20 May 2024 16:34:09 -0300 Subject: [PATCH 27/31] Lintr corrections --- crates/concrete_check/src/linearity_check.rs | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index 71d778f..24bd049 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -78,7 +78,7 @@ enum Expr { Embed(Vec), Deref(Box), SizeOf, - BorrowExpr(BorrowMode, String), + Borrow(BorrowMode, String), ArrayIndex(Box), } @@ -294,7 +294,7 @@ impl LinearityChecker { let vars = self.state_tbl.vars.clone(); for (name, _info) in vars.iter() { //self.check_var_in_expr(depth, &name, &info.ty, expr)?; - self.check_var_in_expr(depth, &name, expr)?; + self.check_var_in_expr(depth, name, expr)?; } Ok(()) } @@ -551,10 +551,10 @@ impl LinearityChecker { Statement::For(for_stmt) => { // Handle for loops if let Some(init) = &for_stmt.init { - self.check_stmt_let(depth, &init)?; + self.check_stmt_let(depth, init)?; } if let Some(condition) = &for_stmt.condition { - self.check_expr(depth, &condition)?; + self.check_expr(depth, condition)?; } if let Some(post) = &for_stmt.post { //TODO check assign statement @@ -631,7 +631,7 @@ impl LinearityChecker { if let Some(info) = info{ //Only checks Linearity for types of name Linear // TODO improve this approach - if info.ty == "Linear".to_string(){ + if info.ty == *"Linear".to_string(){ let state = &info.state; let apps = self.count_in_expression(name, expr); // Assume count function implementation let Appearances { consumed, write, read, path } = apps; @@ -735,7 +735,7 @@ pub fn linearity_check_program(programs: &Vec<(PathBuf, String, Program)>, sessi //checker.check_function(&function)?; for statement in &function.body { //println!("Checking linearity for function body: {:?}", function.body); - checker.check_stmt(0, &statement)?; + checker.check_stmt(0, statement)?; } println!("Finished checking linearity for function: {} {:?}", function.decl.name.name, checker.state_tbl); //checker.linearity_check(&function)?; @@ -743,39 +743,32 @@ pub fn linearity_check_program(programs: &Vec<(PathBuf, String, Program)>, sessi ModuleDefItem::FunctionDecl(function_decl) => { println!("Skipping linearity check for FunctionDecl: {:?}", module_content); - () }, ModuleDefItem::Module(module) => { println!("Skipping linearity check for Module: {:?}", module_content); - () }, ModuleDefItem::Struct(struc) => { //println!("Skipping linearity check for Struct: {:?}", module_content); //checker. checker.state_tbl.update_info(&struc.name.name, VarInfo{ty: "Struct".to_string(), depth: 0, state: VarState::Unconsumed}); - () }, ModuleDefItem::Enum(_) => { println!("Skipping linearity check for Enum: {:?}", module_content); - () }, ModuleDefItem::Constant(_) => { println!("Skipping linearity check for Constant: {:?}", module_content); - () }, ModuleDefItem::Union(_) => { println!("Skipping linearity check for Uinon: {:?}", module_content); - () }, ModuleDefItem::Type(_) => { println!("Skipping linearity check for module content: {:?}", module_content); - () }, /*_ => { From c8b395b87e426bed28693e6a0f496089ddff562a Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Mon, 20 May 2024 16:51:29 -0300 Subject: [PATCH 28/31] Clippy produces no warnings --- crates/concrete_check/src/linearity_check.rs | 39 ++++++++------------ 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index 24bd049..519016a 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -300,7 +300,7 @@ impl LinearityChecker { } - fn count_in_statements(&self, name: &str, statements: &Vec) -> Appearances { + fn count_in_statements(&self, name: &str, statements: &[Statement]) -> Appearances { statements.iter().map(|stmt| self.count_in_statement(name, stmt)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) } @@ -317,7 +317,7 @@ impl LinearityChecker { let else_apps; let else_statements = &if_stmt.r#else; if let Some(else_statements) = else_statements { - else_apps = self.count_in_statements(name, &else_statements); + else_apps = self.count_in_statements(name, else_statements); } else { else_apps = Appearances::zero(); } @@ -377,34 +377,25 @@ impl LinearityChecker { } fn count_in_assign_statement(&self, name: &str, assign_stmt: &AssignStmt) -> Appearances { - match assign_stmt { - AssignStmt { target, derefs, value, span } => { - // Handle assignments - let ret = self.count_in_path_op(name, target); - ret.merge(&self.count_in_expression(name, value)); - ret - }, - } + let AssignStmt { target, derefs, value, span } = assign_stmt; + // Handle assignments + let ret = self.count_in_path_op(name, target); + ret.merge(&self.count_in_expression(name, value)); + ret } fn count_in_path_op(&self, name: &str, path_op: &PathOp) -> Appearances { - let apps: Appearances; if name == path_op.first.name{ - apps = Appearances::path_once(); + Appearances::path_once() } else{ - apps = Appearances::zero(); - } - apps + Appearances::zero() + } } fn count_in_let_statements(&self, name: &str, let_stmt: &LetStmt) -> Appearances { - match let_stmt { - LetStmt { is_mutable, target, value, span } => { - // Handle let bindings, possibly involving pattern matching - self.count_in_expression(name, value) - }, - } + let LetStmt { is_mutable, target, value, span } = let_stmt; + self.count_in_expression(name, value) } fn count_in_expression(&self, name: &str, expr: &Expression) -> Appearances { @@ -455,11 +446,11 @@ impl LinearityChecker { }, Expression::BinaryOp(left, _, right) => { // Handle binary operations by processing both sides - self.count_in_expression(name, left).merge(&&self.count_in_expression(name, right)) + self.count_in_expression(name, left).merge(&self.count_in_expression(name, right)) }, Expression::StructInit(struct_init_expr) => { // Handle struct initialization - struct_init_expr.fields.iter().map(|(_, expr)| self.count_struct_init(name, expr)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) + struct_init_expr.fields.values().map(|expr| self.count_struct_init(name, expr)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) }, Expression::ArrayInit(array_init_expr) => { // Handle array initializations @@ -537,7 +528,7 @@ impl LinearityChecker { self.check_expr(depth, &if_stmt.value)?; self.check_stmts(depth + 1, &if_stmt.contents)?; if let Some(else_block) = &if_stmt.r#else { - self.check_stmts(depth + 1, &else_block)?; + self.check_stmts(depth + 1, else_block)?; } Ok(()) }, From 5779cd613899914065543835d41a41d42812512e Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Mon, 20 May 2024 16:52:09 -0300 Subject: [PATCH 29/31] Ran Cargo fmt --- crates/concrete_check/src/lib.rs | 1 - crates/concrete_check/src/linearity_check.rs | 544 ++++++++++++------ .../src/linearity_check/errors.rs | 49 +- crates/concrete_driver/src/lib.rs | 20 +- crates/concrete_ir/src/lowering.rs | 4 +- 5 files changed, 377 insertions(+), 241 deletions(-) diff --git a/crates/concrete_check/src/lib.rs b/crates/concrete_check/src/lib.rs index 11062e9..7554c15 100644 --- a/crates/concrete_check/src/lib.rs +++ b/crates/concrete_check/src/lib.rs @@ -6,7 +6,6 @@ use concrete_session::Session; pub mod linearity_check; - /// Creates a report from a lowering error. pub fn lowering_error_to_report( error: LoweringError, diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index 519016a..16b77e1 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -2,11 +2,9 @@ use std::collections::HashMap; use concrete_session::Session; - use self::errors::LinearityError; pub mod errors; - use std::path::PathBuf; //use concrete_ir::{ProgramBody, FnBody}; //use concrete_ast::Program{ ProgramBody, FnBody }; @@ -14,9 +12,9 @@ use concrete_ast::Program; //use concrete_ast::modules::{Module, ModuleDefItem}; use concrete_ast::modules::ModuleDefItem; //use concrete_ast::functions::FunctionDef; -use concrete_ast::expressions::{Expression, StructInitField, PathOp, ValueExpr}; +use concrete_ast::expressions::{Expression, PathOp, StructInitField, ValueExpr}; //use concrete_ast::statements::{Statement, AssignStmt, LetStmt, WhileStmt, ForStmt, LetStmtTarget, Binding}; -use concrete_ast::statements::{Statement, AssignStmt, LetStmt, LetStmtTarget, Binding}; +use concrete_ast::statements::{AssignStmt, Binding, LetStmt, LetStmtTarget, Statement}; use concrete_ast::types::TypeSpec; //use concrete_ast::expressions::Value; // Import the missing module #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -49,7 +47,6 @@ struct Appearances { path: u32, } - // TODO remove. This are structures translated from Austral #[allow(dead_code)] enum Expr { @@ -90,7 +87,12 @@ enum BorrowMode { impl Appearances { fn new(consumed: u32, write: u32, read: u32, path: u32) -> Self { - Appearances { consumed, write, read, path } + Appearances { + consumed, + write, + read, + path, + } } fn partition(count: u32) -> CountResult { @@ -138,14 +140,14 @@ impl Appearances { path: self.path + other.path, } } - fn _merge_list(appearances: Vec) -> Self { - appearances.into_iter().fold(Self::zero(), |acc, x| acc.merge(&x)) + appearances + .into_iter() + .fold(Self::zero(), |acc, x| acc.merge(&x)) } } - #[derive(Debug, Clone)] struct StateTbl { vars: HashMap, @@ -160,11 +162,11 @@ impl StateTbl { } // Example of updating the state table - fn update_info(&mut self, var: & str, info: VarInfo) { + fn update_info(&mut self, var: &str, info: VarInfo) { self.vars.insert(var.to_string(), info); } - /* + /* // Remove a variable from the state table fn remove_entry(&mut self, var: &str) { self.vars.remove(var); @@ -172,15 +174,25 @@ impl StateTbl { */ fn get_info_mut(&mut self, var: &str) -> Option<&mut VarInfo> { - if !self.vars.contains_key(var){ - self.vars.insert(var.to_string(), VarInfo{ty: "".to_string(), depth: 0, state: VarState::Unconsumed}); - println!("Variable {} not found in state table. Inserting with default state", var); + if !self.vars.contains_key(var) { + self.vars.insert( + var.to_string(), + VarInfo { + ty: "".to_string(), + depth: 0, + state: VarState::Unconsumed, + }, + ); + println!( + "Variable {} not found in state table. Inserting with default state", + var + ); } - self.vars.get_mut(var) + self.vars.get_mut(var) } - fn get_info(& self, var: &str) -> Option<& VarInfo> { - self.vars.get(var) + fn get_info(&self, var: &str) -> Option<&VarInfo> { + self.vars.get(var) } // Retrieve a variable's state @@ -193,14 +205,13 @@ impl StateTbl { } // Retrieve a variable's state - fn update_state(&mut self, var: &str, new_state: &VarState){ + fn update_state(&mut self, var: &str, new_state: &VarState) { let info = self.get_info_mut(var); if let Some(info) = info { info.state = new_state.clone(); - } + } } - fn get_loop_depth(&mut self, name: &str) -> usize { let state = self.get_info(name); if let Some(state) = state { @@ -211,20 +222,6 @@ impl StateTbl { } } - - - - - - - - - - - - - - struct LinearityChecker { state_tbl: StateTbl, } @@ -239,7 +236,7 @@ impl LinearityChecker { } //TODO remove - /* + /* fn linearity_check(&mut self, program: &FunctionDef) -> Result<(), LinearityError> { // Assume Program is a struct that represents the entire program. for statement in &program.body { @@ -248,7 +245,7 @@ impl LinearityChecker { Ok(()) } */ - /* + /* fn check_function(&mut self, function: &FnBody) -> Result<(), LinearityError> { // Logic to check linearity within a function // This may involve iterating over statements and expressions, similar to OCaml's recursion. @@ -256,20 +253,21 @@ impl LinearityChecker { for statement in &basic_block.statements { self.check_stmt(0, &statement)?; } - - } + + } Ok(()) }*/ - - fn consume_once(&mut self, depth: usize, name: &str) -> Result<(), LinearityError> { let loop_depth = self.state_tbl.get_loop_depth(name); - println!("Consuming variable: {} depth {} loop_depth {}", name, depth, loop_depth); + println!( + "Consuming variable: {} depth {} loop_depth {}", + name, depth, loop_depth + ); if depth == self.state_tbl.get_loop_depth(name) { self.state_tbl.update_state(name, &VarState::Consumed); println!("Consumed variable: {}", name); - /* + /* let mut state = self.state_tbl.get_state(name); if let Some(state) = state { state = &VarState::Consumed; @@ -277,21 +275,20 @@ impl LinearityChecker { else{ //self.state_tbl.update_state(name, VarInfo{"".to_string(), depth, VarState::Unconsumed}); }*/ - + Ok(()) - } - else{ - Err(LinearityError::ConsumedMoreThanOnce { variable: name.to_string()}) + } else { + Err(LinearityError::ConsumedMoreThanOnce { + variable: name.to_string(), + }) } } - - fn check_expr(&mut self, depth: usize, expr: &Expression) -> Result<(), LinearityError> { // Assuming you have a method to get all variable names and types - //let vars = &mut self.state_tbl.vars; + //let vars = &mut self.state_tbl.vars; //TODO check if we can avoid cloning - let vars = self.state_tbl.vars.clone(); + let vars = self.state_tbl.vars.clone(); for (name, _info) in vars.iter() { //self.check_var_in_expr(depth, &name, &info.ty, expr)?; self.check_var_in_expr(depth, name, expr)?; @@ -299,17 +296,19 @@ impl LinearityChecker { Ok(()) } - fn count_in_statements(&self, name: &str, statements: &[Statement]) -> Appearances { - statements.iter().map(|stmt| self.count_in_statement(name, stmt)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) + statements + .iter() + .map(|stmt| self.count_in_statement(name, stmt)) + .fold(Appearances::zero(), |acc, x| acc.merge(&x)) } - + fn count_in_statement(&self, name: &str, statement: &Statement) -> Appearances { match statement { Statement::Let(binding) => { // Handle let bindings, possibly involving pattern matching self.count_in_expression(name, &binding.value) - }, + } Statement::If(if_stmt) => { // Process all components of an if expression let cond_apps = self.count_in_expression(name, &if_stmt.value); @@ -322,13 +321,14 @@ impl LinearityChecker { else_apps = Appearances::zero(); } cond_apps.merge(&then_apps).merge(&else_apps) - }, + } Statement::While(while_expr) => { - let cond= &while_expr.value; + let cond = &while_expr.value; let block = &while_expr.contents; // Handle while loops - self.count_in_expression(name, cond).merge(&self.count_in_statements(name, block)) - }, + self.count_in_expression(name, cond) + .merge(&self.count_in_statements(name, block)) + } Statement::For(for_expr) => { // Handle for loops //init, cond, post, block @@ -337,15 +337,19 @@ impl LinearityChecker { let post = &for_expr.post; let block = &for_expr.contents; let mut apps = Appearances::zero(); - if let Some(init) = init{ - if let Some(cond) = cond{ - if let Some(post) = post{ - apps = self.count_in_let_statements(name, init).merge(&self.count_in_expression(name, cond)).merge(&self.count_in_assign_statement(name, post)).merge(&self.count_in_statements(name, block)) + if let Some(init) = init { + if let Some(cond) = cond { + if let Some(post) = post { + apps = self + .count_in_let_statements(name, init) + .merge(&self.count_in_expression(name, cond)) + .merge(&self.count_in_assign_statement(name, post)) + .merge(&self.count_in_statements(name, block)) } } } apps - }, + } /* Alucination of GPT Statement::Block(statements) => { // Handle blocks of statements @@ -355,7 +359,7 @@ impl LinearityChecker { Statement::Assign(assign_stmt) => { // Handle assignments self.count_in_assign_statement(name, assign_stmt) - }, + } Statement::Return(return_stmt) => { // Handle return statements if let Some(value) = &return_stmt.value { @@ -363,39 +367,52 @@ impl LinearityChecker { } else { Appearances::zero() } - }, + } Statement::FnCall(fn_call_op) => { // Process function call arguments //fn_call_op.target.iter().map(|arg| self.count_in_path_op(name, arg)).fold(Appearances::zero(), |acc, x| acc.merge(&x)); - fn_call_op.args.iter().map(|arg| self.count_in_expression(name, arg)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) - }, + fn_call_op + .args + .iter() + .map(|arg| self.count_in_expression(name, arg)) + .fold(Appearances::zero(), |acc, x| acc.merge(&x)) + } Statement::Match(_) => { todo!("do not support match statement") - }, + } //_ => Appearances::zero(), - } + } } fn count_in_assign_statement(&self, name: &str, assign_stmt: &AssignStmt) -> Appearances { - let AssignStmt { target, derefs, value, span } = assign_stmt; + let AssignStmt { + target, + derefs, + value, + span, + } = assign_stmt; // Handle assignments let ret = self.count_in_path_op(name, target); - ret.merge(&self.count_in_expression(name, value)); + ret.merge(&self.count_in_expression(name, value)); ret } fn count_in_path_op(&self, name: &str, path_op: &PathOp) -> Appearances { - if name == path_op.first.name{ + if name == path_op.first.name { Appearances::path_once() - } - else{ + } else { Appearances::zero() - } + } } fn count_in_let_statements(&self, name: &str, let_stmt: &LetStmt) -> Appearances { - let LetStmt { is_mutable, target, value, span } = let_stmt; - self.count_in_expression(name, value) + let LetStmt { + is_mutable, + target, + value, + span, + } = let_stmt; + self.count_in_expression(name, value) } fn count_in_expression(&self, name: &str, expr: &Expression) -> Appearances { @@ -410,20 +427,24 @@ impl LinearityChecker { } else { Appearances::zero() } - }, + } ValueExpr::Path(path) => { //path.first.name == name; Appearances::zero() - }, + } _ => Appearances::zero(), } - }, + } Expression::FnCall(fn_call_op) => { // Process function call arguments - fn_call_op.args.iter().map(|arg| self.count_in_expression(name, arg)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) - }, - Expression::Match(match_expr) => todo!("do not support match expression"), - /* + fn_call_op + .args + .iter() + .map(|arg| self.count_in_expression(name, arg)) + .fold(Appearances::zero(), |acc, x| acc.merge(&x)) + } + Expression::Match(match_expr) => todo!("do not support match expression"), + /* Expression::Match(match_expr) => { // Handle match arms match_expr.variants.iter().map(|(_, expr)| self.count(name, expr)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) @@ -439,32 +460,42 @@ impl LinearityChecker { cond_apps.merge(&then_apps).merge(&else_apps); } cond_apps - }, + } Expression::UnaryOp(_, expr) => { // Unary operations likely don't change the count but process the inner expression self.count_in_expression(name, expr) - }, + } Expression::BinaryOp(left, _, right) => { // Handle binary operations by processing both sides - self.count_in_expression(name, left).merge(&self.count_in_expression(name, right)) - }, + self.count_in_expression(name, left) + .merge(&self.count_in_expression(name, right)) + } Expression::StructInit(struct_init_expr) => { // Handle struct initialization - struct_init_expr.fields.values().map(|expr| self.count_struct_init(name, expr)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) - }, + struct_init_expr + .fields + .values() + .map(|expr| self.count_struct_init(name, expr)) + .fold(Appearances::zero(), |acc, x| acc.merge(&x)) + } Expression::ArrayInit(array_init_expr) => { // Handle array initializations - array_init_expr.values.iter().map(|expr| self.count_in_expression(name, expr)).fold(Appearances::zero(), |acc, x| acc.merge(&x)) - }, - Expression::Deref(expr, _) | Expression::AsRef(expr, _, _) | Expression::Cast(expr, _, _) => { + array_init_expr + .values + .iter() + .map(|expr| self.count_in_expression(name, expr)) + .fold(Appearances::zero(), |acc, x| acc.merge(&x)) + } + Expression::Deref(expr, _) + | Expression::AsRef(expr, _, _) + | Expression::Cast(expr, _, _) => { // Deref, AsRef, and Cast are handled by just checking the inner expression self.count_in_expression(name, expr) - }, + } // Add more cases as necessary based on the Expression types you expect } } - fn count_struct_init(&self, name: &str, struct_init: &StructInitField) -> Appearances { println!("Checking struct init: {:?}", struct_init); self.count_in_expression(name, &struct_init.value) @@ -472,34 +503,74 @@ impl LinearityChecker { fn check_stmt_let(&mut self, depth: usize, binding: &LetStmt) -> Result<(), LinearityError> { // Handle let bindings, possibly involving pattern matching - let LetStmt { is_mutable, target, value, span } = binding; + let LetStmt { + is_mutable, + target, + value, + span, + } = binding; match target { LetStmtTarget::Simple { name, r#type } => { match r#type { - TypeSpec::Simple{ name: variable_type , qualifiers, span} => { - self.state_tbl.update_info(&name.name, VarInfo{ty: variable_type.name.clone(), depth, state: VarState::Unconsumed}); - }, - TypeSpec::Generic { name: variable_type, qualifiers, type_params, span } =>{ - self.state_tbl.update_info(&name.name, VarInfo{ty: variable_type.name.clone(), depth, state: VarState::Unconsumed}); - }, - TypeSpec::Array { of_type, size, qualifiers, span } => { + TypeSpec::Simple { + name: variable_type, + qualifiers, + span, + } => { + self.state_tbl.update_info( + &name.name, + VarInfo { + ty: variable_type.name.clone(), + depth, + state: VarState::Unconsumed, + }, + ); + } + TypeSpec::Generic { + name: variable_type, + qualifiers, + type_params, + span, + } => { + self.state_tbl.update_info( + &name.name, + VarInfo { + ty: variable_type.name.clone(), + depth, + state: VarState::Unconsumed, + }, + ); + } + TypeSpec::Array { + of_type, + size, + qualifiers, + span, + } => { let array_type = "Array<".to_string() + &of_type.get_name() + ">"; - self.state_tbl.update_info(&name.name, VarInfo{ty: array_type, depth, state: VarState::Unconsumed}); - }, + self.state_tbl.update_info( + &name.name, + VarInfo { + ty: array_type, + depth, + state: VarState::Unconsumed, + }, + ); + } } self.check_var_in_expr(depth, &name.name, value) - }, + } LetStmtTarget::Destructure(bindings) => { for binding in bindings { self.check_bindings(depth, binding)?; } Ok(()) - }, + } } } fn check_bindings(&mut self, depth: usize, binding: &Binding) -> Result<(), LinearityError> { - // TODO Do something with the bindings + // TODO Do something with the bindings println!("TODO implement Checking bindings: {:?}", binding); Ok(()) } @@ -513,7 +584,7 @@ impl LinearityChecker { fn check_stmt(&mut self, depth: usize, stmt: &Statement) -> Result<(), LinearityError> { match stmt { - /* + /* Statement::Expression(expr) => { // Handle expressions (e.g., variable assignments, function calls) self.check_expr(depth, expr) @@ -521,7 +592,7 @@ impl LinearityChecker { Statement::Let(binding) => { // Handle let bindings, possibly involving pattern matching self.check_stmt_let(depth, binding) - }, + } //Statement::If(cond, then_block, else_block) => { Statement::If(if_stmt) => { // Handle conditional statements @@ -531,20 +602,20 @@ impl LinearityChecker { self.check_stmts(depth + 1, else_block)?; } Ok(()) - }, + } //Statement::While(cond, block) => { - Statement::While(while_stmt) => { + Statement::While(while_stmt) => { // Handle while loops self.check_expr(depth, &while_stmt.value)?; self.check_stmts(depth + 1, &while_stmt.contents) - }, + } //Statement::For(init, cond, post, block) => { Statement::For(for_stmt) => { // Handle for loops - if let Some(init) = &for_stmt.init { + if let Some(init) = &for_stmt.init { self.check_stmt_let(depth, init)?; } - if let Some(condition) = &for_stmt.condition { + if let Some(condition) = &for_stmt.condition { self.check_expr(depth, condition)?; } if let Some(post) = &for_stmt.post { @@ -552,41 +623,49 @@ impl LinearityChecker { //self.check_stmt_assign(depth, post)?; } self.check_stmts(depth + 1, &for_stmt.contents) - }, + } Statement::Assign(assign_stmt) => { // Handle assignments - let AssignStmt { target, derefs, value, span } = assign_stmt; + let AssignStmt { + target, + derefs, + value, + span, + } = assign_stmt; println!("Checking assignment: {:?}", assign_stmt); self.check_path_opt(depth, target)?; //self.check_expr(depth, &self.path_op_to_expression(target))?; self.check_expr(depth, value) - }, + } Statement::Return(return_stmt) => { if let Some(value) = &return_stmt.value { self.check_expr(depth, value) } else { Ok(()) } - }, + } Statement::FnCall(fn_call_op) => { // Process function call arguments for arg in &fn_call_op.args { self.check_expr(depth, arg)?; } Ok(()) - }, + } Statement::Match(_) => { println!("Skipping linearity check for statement type: \n{:?}", stmt); todo!() - } + } } } - + fn check_path_opt(&mut self, depth: usize, path_op: &PathOp) -> Result<(), LinearityError> { println!("Checking path: {:?}", path_op); - //let var_expression = Value::new(path_op.first.clone(), path_op.span); // Use the imported module - let var_expression = Expression::Value(ValueExpr::ValueVar(path_op.first.clone(), path_op.span), path_op.span); - self.check_var_in_expr(depth, &path_op.first.name, &var_expression) + //let var_expression = Value::new(path_op.first.clone(), path_op.span); // Use the imported module + let var_expression = Expression::Value( + ValueExpr::ValueVar(path_op.first.clone(), path_op.span), + path_op.span, + ); + self.check_var_in_expr(depth, &path_op.first.name, &var_expression) } /* fn path_op_to_expression(&self, path_op: &PathOp) -> Expression { @@ -616,63 +695,137 @@ impl LinearityChecker { Expression::Variable(components.iter().map(|c| c.to_string()).collect::>().join(".")) }*/ - fn check_var_in_expr(&mut self, depth: usize, name: &str, expr: &Expression) -> Result<(), LinearityError> { - + fn check_var_in_expr( + &mut self, + depth: usize, + name: &str, + expr: &Expression, + ) -> Result<(), LinearityError> { let info = self.state_tbl.get_info(name); // Assume default state - if let Some(info) = info{ + if let Some(info) = info { //Only checks Linearity for types of name Linear // TODO improve this approach - if info.ty == *"Linear".to_string(){ + if info.ty == *"Linear".to_string() { let state = &info.state; let apps = self.count_in_expression(name, expr); // Assume count function implementation - let Appearances { consumed, write, read, path } = apps; + let Appearances { + consumed, + write, + read, + path, + } = apps; //println!("Checking variable: {} with state: {:?} and appearances: {:?} in expression {:?}", name, state, apps, expr); - println!("Checking state_tbl variable: {}: {:?} {:?} in expression {:?}", name, info, apps, expr); - match (state, Appearances::partition(consumed), Appearances::partition(write), Appearances::partition(read), Appearances::partition(path)) { - /*( State Consumed WBorrow RBorrow Path ) + println!( + "Checking state_tbl variable: {}: {:?} {:?} in expression {:?}", + name, info, apps, expr + ); + match ( + state, + Appearances::partition(consumed), + Appearances::partition(write), + Appearances::partition(read), + Appearances::partition(path), + ) { + /*( State Consumed WBorrow RBorrow Path ) (* ------------------|-------------------|-----------------|------------------|----------------)*/ // Not yet consumed, and at most used through immutable borrows or path reads. - (VarState::Unconsumed, CountResult::Zero, CountResult::Zero, _, _) => Ok(()), + (VarState::Unconsumed, CountResult::Zero, CountResult::Zero, _, _) => Ok(()), // Not yet consumed, borrowed mutably once, and nothing else. - (VarState::Unconsumed, CountResult::Zero, CountResult::One, CountResult::Zero, CountResult::Zero) => Ok(()), + ( + VarState::Unconsumed, + CountResult::Zero, + CountResult::One, + CountResult::Zero, + CountResult::Zero, + ) => Ok(()), // Not yet consumed, borrowed mutably, then either borrowed immutably or accessed through a path. - (VarState::Unconsumed, CountResult::Zero, CountResult::One, _, _) => Err(LinearityError::BorrowedMutUsed { variable: name.to_string() }), + (VarState::Unconsumed, CountResult::Zero, CountResult::One, _, _) => { + Err(LinearityError::BorrowedMutUsed { + variable: name.to_string(), + }) + } // Not yet consumed, borrowed mutably more than once. - (VarState::Unconsumed, CountResult::Zero, CountResult::MoreThanOne, _, _) => Err(LinearityError::BorrowedMutMoreThanOnce { variable: name.to_string() }), + (VarState::Unconsumed, CountResult::Zero, CountResult::MoreThanOne, _, _) => { + Err(LinearityError::BorrowedMutMoreThanOnce { + variable: name.to_string(), + }) + } // Not yet consumed, consumed once, and nothing else. Valid IF the loop depth matches. - (VarState::Unconsumed, CountResult::One, CountResult::Zero, CountResult::Zero, CountResult::Zero) => self.consume_once(depth, name), + ( + VarState::Unconsumed, + CountResult::One, + CountResult::Zero, + CountResult::Zero, + CountResult::Zero, + ) => self.consume_once(depth, name), // Not yet consumed, consumed once, then either borrowed or accessed through a path. - (VarState::Unconsumed, CountResult::One, _, _, _) => Err(LinearityError::ConsumedAndUsed { variable: name.to_string() }), + (VarState::Unconsumed, CountResult::One, _, _, _) => { + Err(LinearityError::ConsumedAndUsed { + variable: name.to_string(), + }) + } // Not yet consumed, consumed more than once. - (VarState::Unconsumed, CountResult::MoreThanOne, _, _, _) => Err(LinearityError::ConsumedMoreThanOnce { variable: name.to_string() }), + (VarState::Unconsumed, CountResult::MoreThanOne, _, _, _) => { + Err(LinearityError::ConsumedMoreThanOnce { + variable: name.to_string(), + }) + } // Read borrowed, and at most accessed through a path. - (VarState::_Borrowed, CountResult::Zero, CountResult::Zero, CountResult::Zero, _) => Ok(()), + ( + VarState::_Borrowed, + CountResult::Zero, + CountResult::Zero, + CountResult::Zero, + _, + ) => Ok(()), // Read borrowed, and either consumed or borrowed again. - (VarState::_Borrowed, _, _, _, _) => Err(LinearityError::ReadBorrowedAndUsed { variable: name.to_string() }), + (VarState::_Borrowed, _, _, _, _) => Err(LinearityError::ReadBorrowedAndUsed { + variable: name.to_string(), + }), // Write borrowed, unused. - (VarState::_BorrowedMut, CountResult::Zero, CountResult::Zero, CountResult::Zero, CountResult::Zero) => Ok(()), + ( + VarState::_BorrowedMut, + CountResult::Zero, + CountResult::Zero, + CountResult::Zero, + CountResult::Zero, + ) => Ok(()), // Write borrowed, used in some way. - (VarState::_BorrowedMut, _, _, _, _) => Err(LinearityError::WriteBorrowedAndUsed { variable: name.to_string() }), + (VarState::_BorrowedMut, _, _, _, _) => { + Err(LinearityError::WriteBorrowedAndUsed { + variable: name.to_string(), + }) + } // Already consumed, and unused. - (VarState::Consumed, CountResult::Zero, CountResult::Zero, CountResult::Zero, CountResult::Zero) => Ok(()), + ( + VarState::Consumed, + CountResult::Zero, + CountResult::Zero, + CountResult::Zero, + CountResult::Zero, + ) => Ok(()), // Already consumed, and used in some way. - (VarState::Consumed, _, _, _, _) => Err(LinearityError::AlreadyConsumedAndUsed { variable: name.to_string() }), + (VarState::Consumed, _, _, _, _) => { + Err(LinearityError::AlreadyConsumedAndUsed { + variable: name.to_string(), + }) + } } - } - else{ + } else { //Only checks Linearity for types of name Linear Ok(()) } - } - else { - Err(LinearityError::VariableNotFound { variable: name.to_string()}) + } else { + Err(LinearityError::VariableNotFound { + variable: name.to_string(), + }) } } /* fn check_var_in_expr(&mut self, depth: u32, name: &str, state: &VarState, expr: &Expression) -> Result<(), LinearityError> { let apps = self.count_in_expression(name, expr); // Assume count function implementation let Appearances { consumed, write, read, path } = apps; - + //let state = self.state_tbl.get_state(name).unwrap_or(&VarState::Unconsumed); // Assume default state let state = self.state_tbl.get_state(name);// Assume default state if let Some(state) = state{ @@ -702,17 +855,17 @@ impl LinearityChecker { else { Err(LinearityError::VariableNotFound { variable: name.to_string() }) } - + }*/ - } - - //#[cfg(feature = "linearity")] #[allow(unused_variables)] //pub fn linearity_check_program(program_ir: &FunctionDef, session: &Session) -> Result { -pub fn linearity_check_program(programs: &Vec<(PathBuf, String, Program)>, session: &Session) -> Result { +pub fn linearity_check_program( + programs: &Vec<(PathBuf, String, Program)>, + session: &Session, +) -> Result { println!("Starting linearity check"); let mut checker = LinearityChecker::new(); for (path, name, program) in programs { @@ -725,51 +878,62 @@ pub fn linearity_check_program(programs: &Vec<(PathBuf, String, Program)>, sessi //println!("Checking linearity for function: {:?}", function); //checker.check_function(&function)?; for statement in &function.body { - //println!("Checking linearity for function body: {:?}", function.body); + //println!("Checking linearity for function body: {:?}", function.body); checker.check_stmt(0, statement)?; } - println!("Finished checking linearity for function: {} {:?}", function.decl.name.name, checker.state_tbl); + println!( + "Finished checking linearity for function: {} {:?}", + function.decl.name.name, checker.state_tbl + ); //checker.linearity_check(&function)?; - }, - ModuleDefItem::FunctionDecl(function_decl) => - { - println!("Skipping linearity check for FunctionDecl: {:?}", module_content); - }, - ModuleDefItem::Module(module) => - { + } + ModuleDefItem::FunctionDecl(function_decl) => { + println!( + "Skipping linearity check for FunctionDecl: {:?}", + module_content + ); + } + ModuleDefItem::Module(module) => { println!("Skipping linearity check for Module: {:?}", module_content); - }, - ModuleDefItem::Struct(struc) => - { + } + ModuleDefItem::Struct(struc) => { //println!("Skipping linearity check for Struct: {:?}", module_content); //checker. - checker.state_tbl.update_info(&struc.name.name, VarInfo{ty: "Struct".to_string(), depth: 0, state: VarState::Unconsumed}); - }, - ModuleDefItem::Enum(_) => - { + checker.state_tbl.update_info( + &struc.name.name, + VarInfo { + ty: "Struct".to_string(), + depth: 0, + state: VarState::Unconsumed, + }, + ); + } + ModuleDefItem::Enum(_) => { println!("Skipping linearity check for Enum: {:?}", module_content); - }, - ModuleDefItem::Constant(_) => - { - println!("Skipping linearity check for Constant: {:?}", module_content); - }, - ModuleDefItem::Union(_) => - { + } + ModuleDefItem::Constant(_) => { + println!( + "Skipping linearity check for Constant: {:?}", + module_content + ); + } + ModuleDefItem::Union(_) => { println!("Skipping linearity check for Uinon: {:?}", module_content); - }, - ModuleDefItem::Type(_) => - { - println!("Skipping linearity check for module content: {:?}", module_content); - }, - /*_ => - { + } + ModuleDefItem::Type(_) => { + println!( + "Skipping linearity check for module content: {:?}", + module_content + ); + } + /*_ => + { println!("Skipping linearity check for module content: {:?}", module_content); () - },*/ - } + },*/ + } } } } Ok("OK".to_string()) } - diff --git a/crates/concrete_check/src/linearity_check/errors.rs b/crates/concrete_check/src/linearity_check/errors.rs index 7b6e5bf..9dce425 100644 --- a/crates/concrete_check/src/linearity_check/errors.rs +++ b/crates/concrete_check/src/linearity_check/errors.rs @@ -5,52 +5,27 @@ use thiserror::Error; #[derive(Debug, Error, Clone)] pub enum LinearityError { #[error("Variable {variable} not consumed")] - NotConsumed { - variable: String, - }, + NotConsumed { variable: String }, #[error("Borrowed mutably and used for Variable {variable}")] - BorrowedMutUsed { - variable: String, - }, + BorrowedMutUsed { variable: String }, #[error("Variable {variable} borrowed mutably more than once")] - BorrowedMutMoreThanOnce { - variable: String, - }, + BorrowedMutMoreThanOnce { variable: String }, #[error("Variable {variable} consumed once and then used again")] - ConsumedAndUsed { - variable: String, - }, + ConsumedAndUsed { variable: String }, #[error("Variable {variable} consumed more than once")] - ConsumedMoreThanOnce { - variable: String, - }, + ConsumedMoreThanOnce { variable: String }, #[error("Variable {variable} read borrowed and used in other ways")] - ReadBorrowedAndUsed { - variable: String, - }, + ReadBorrowedAndUsed { variable: String }, #[error("Variable {variable} write borrowed and used")] - WriteBorrowedAndUsed { - variable: String, - }, + WriteBorrowedAndUsed { variable: String }, #[error("Variable {variable} already consumed and used again")] - AlreadyConsumedAndUsed { - variable: String, - }, + AlreadyConsumedAndUsed { variable: String }, #[error("Unhandled state or appearance count for Variable {variable}")] - UnhandledStateOrCount { - variable: String, - }, + UnhandledStateOrCount { variable: String }, #[error("Linearity error. Variable {variable} generated {message}")] - Unspecified { - variable: String, - message: String, - }, + Unspecified { variable: String, message: String }, #[error("Variable {variable} not found")] - VariableNotFound{ - variable: String, - }, + VariableNotFound { variable: String }, #[error("Unhandled statement type {r#type}")] - UnhandledStatementType{ - r#type: String, - }, + UnhandledStatementType { r#type: String }, } diff --git a/crates/concrete_driver/src/lib.rs b/crates/concrete_driver/src/lib.rs index 229338e..76ac10e 100644 --- a/crates/concrete_driver/src/lib.rs +++ b/crates/concrete_driver/src/lib.rs @@ -104,7 +104,6 @@ pub struct BuildArgs { /// This option is for checking the program for linearity. #[arg(long, default_value_t = false)] check: bool, - } #[derive(Parser, Debug)] @@ -159,7 +158,6 @@ pub struct CompilerArgs { /// This option is for checking the program for linearity. #[arg(long, default_value_t = false)] check: bool, - } pub fn main() -> Result<()> { @@ -607,17 +605,17 @@ pub fn compile(args: &CompilerArgs) -> Result { #[allow(unused_variables)] if args.check { - let linearity_result = match concrete_check::linearity_check::linearity_check_program(&programs, &session) { - Ok(ir) => ir, - Err(error) => { - //TODO improve reporting - println!("Linearity check failed: {:#?}", error); - std::process::exit(1); - } - }; + let linearity_result = + match concrete_check::linearity_check::linearity_check_program(&programs, &session) { + Ok(ir) => ir, + Err(error) => { + //TODO improve reporting + println!("Linearity check failed: {:#?}", error); + std::process::exit(1); + } + }; } - if args.ir { std::fs::write( session.output_file.with_extension("ir"), diff --git a/crates/concrete_ir/src/lowering.rs b/crates/concrete_ir/src/lowering.rs index 7ae3a50..9c8e60d 100644 --- a/crates/concrete_ir/src/lowering.rs +++ b/crates/concrete_ir/src/lowering.rs @@ -844,7 +844,7 @@ fn find_expression_type(builder: &mut FnBodyBuilder, info: &Expression) -> Optio ValueExpr::Path(path) => { let local = builder.get_local(&path.first.name).unwrap(); // todo handle segments Some(local.ty.clone()) - }, + } //TODO check this behavior ValueExpr::ValueVar(_, _span) => None, }, @@ -1551,7 +1551,7 @@ fn lower_value_expr( ValueExpr::Path(info) => { let (place, place_ty, _span) = lower_path(builder, info)?; (Rvalue::Use(Operand::Place(place.clone())), place_ty) - }, + } ValueExpr::ValueVar(_, _) => todo!("ValueVar not yet implemented"), }) } From 18a0bdb59db6ce1f8a0e3339ceddef97714a63d3 Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Mon, 20 May 2024 17:03:38 -0300 Subject: [PATCH 30/31] Removed FIXME from linearExamples --- examples/linearExample01.con | 2 -- examples/linearExample02.con | 2 -- examples/linearExample03if.con | 2 -- 3 files changed, 6 deletions(-) diff --git a/examples/linearExample01.con b/examples/linearExample01.con index 3c226e4..cc26c82 100644 --- a/examples/linearExample01.con +++ b/examples/linearExample01.con @@ -11,8 +11,6 @@ mod LinearExampleStub { x: 1, y: 0, }; - // FIXME Prefered initialization but not yet implemented - // [1, 0]; // linear value is written/consumed xy.x = xy.x + 1; return xy.x; diff --git a/examples/linearExample02.con b/examples/linearExample02.con index 81d6fce..8365611 100644 --- a/examples/linearExample02.con +++ b/examples/linearExample02.con @@ -11,8 +11,6 @@ mod LinearExampleStub { x: 1, y: 0, }; - // FIXME Prefered initialization but not yet implemented - // [1, 0]; // linear value is written/consumed consume_x(&mut xy); return xy.x; diff --git a/examples/linearExample03if.con b/examples/linearExample03if.con index 25ba01f..57566f0 100644 --- a/examples/linearExample03if.con +++ b/examples/linearExample03if.con @@ -11,8 +11,6 @@ mod LinearExampleIfStub { x: 1, y: 0, }; - // FIXME Prefered initialization but not yet implemented - // [1, 0]; if xy.x < xy.y{ consume_x(&mut xy, 1); } From 22016aef6f0dfc85f00f7e2f14c912486dd2fced Mon Sep 17 00:00:00 2001 From: alejandro baranek Date: Mon, 20 May 2024 17:04:19 -0300 Subject: [PATCH 31/31] cargo fmt --- crates/concrete_check/src/linearity_check.rs | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/crates/concrete_check/src/linearity_check.rs b/crates/concrete_check/src/linearity_check.rs index 16b77e1..742ced1 100644 --- a/crates/concrete_check/src/linearity_check.rs +++ b/crates/concrete_check/src/linearity_check.rs @@ -379,8 +379,7 @@ impl LinearityChecker { } Statement::Match(_) => { todo!("do not support match statement") - } - //_ => Appearances::zero(), + } //_ => Appearances::zero(), } } @@ -491,8 +490,7 @@ impl LinearityChecker { | Expression::Cast(expr, _, _) => { // Deref, AsRef, and Cast are handled by just checking the inner expression self.count_in_expression(name, expr) - } - // Add more cases as necessary based on the Expression types you expect + } // Add more cases as necessary based on the Expression types you expect } } @@ -925,12 +923,11 @@ pub fn linearity_check_program( "Skipping linearity check for module content: {:?}", module_content ); - } - /*_ => - { - println!("Skipping linearity check for module content: {:?}", module_content); - () - },*/ + } /*_ => + { + println!("Skipping linearity check for module content: {:?}", module_content); + () + },*/ } } }