diff --git a/crates/samlang-core/src/ast/source.rs b/crates/samlang-core/src/ast/source.rs index 01428baf8..7ea074ecd 100644 --- a/crates/samlang-core/src/ast/source.rs +++ b/crates/samlang-core/src/ast/source.rs @@ -219,7 +219,7 @@ pub(crate) mod pattern { pub(crate) enum DestructuringPattern { Tuple(Location, Vec, T>>), Object(Location, Vec, T>>), - Id(Id), + Id(Id, T), Wildcard(Location), } @@ -228,7 +228,7 @@ pub(crate) mod pattern { match self { Self::Tuple(loc, _) | Self::Object(loc, _) - | Self::Id(Id { loc, .. }) + | Self::Id(Id { loc, .. }, _) | Self::Wildcard(loc) => loc, } } @@ -239,16 +239,16 @@ pub(crate) mod pattern { pub(crate) loc: Location, pub(crate) tag_order: usize, pub(crate) tag: Id, - pub(crate) data_variables: Vec>, + pub(crate) data_variables: Vec<(MatchingPattern, T)>, pub(crate) type_: T, } #[derive(Clone, PartialEq, Eq)] pub(crate) enum MatchingPattern { - Tuple(Location, Vec, T>>), - Object(Location, Vec, T>>), + Tuple(Location, Vec, T>>), + Object(Location, Vec, T>>), Variant(VariantPattern), - Id(Id), + Id(Id, T), Wildcard(Location), } @@ -258,7 +258,7 @@ pub(crate) mod pattern { Self::Tuple(loc, _) | Self::Object(loc, _) | Self::Variant(VariantPattern { loc, .. }) - | Self::Id(Id { loc, .. }) + | Self::Id(Id { loc, .. }, _) | Self::Wildcard(loc) => loc, } } @@ -424,10 +424,16 @@ pub(crate) mod expr { pub(crate) e2: Box>, } + #[derive(Clone, PartialEq, Eq)] + pub(crate) enum IfElseCondition { + Expression(E), + Guard(super::pattern::MatchingPattern, E), + } + #[derive(Clone, PartialEq, Eq)] pub(crate) struct IfElse { pub(crate) common: ExpressionCommon, - pub(crate) condition: Box>, + pub(crate) condition: Box>, pub(crate) e1: Box>, pub(crate) e2: Box>, } diff --git a/crates/samlang-core/src/ast/source_tests.rs b/crates/samlang-core/src/ast/source_tests.rs index 3c3bd07f6..d9ca841c6 100644 --- a/crates/samlang-core/src/ast/source_tests.rs +++ b/crates/samlang-core/src/ast/source_tests.rs @@ -45,7 +45,7 @@ mod tests { assert_eq!(*destructuring_pattern.loc(), Location::dummy()); destructuring_pattern = pattern::DestructuringPattern::Tuple(Location::dummy(), vec![]); assert_eq!(*destructuring_pattern.loc(), Location::dummy()); - destructuring_pattern = pattern::DestructuringPattern::Id(Id::from(PStr::LOWER_A)); + destructuring_pattern = pattern::DestructuringPattern::Id(Id::from(PStr::LOWER_A), ()); assert_eq!(*destructuring_pattern.loc(), Location::dummy()); destructuring_pattern = pattern::DestructuringPattern::Wildcard(Location::dummy()); assert_eq!(*destructuring_pattern.loc(), Location::dummy()); @@ -63,7 +63,7 @@ mod tests { type_: (), }); assert_eq!(*matching_pattern.clone().loc(), Location::dummy()); - matching_pattern = pattern::MatchingPattern::Id(Id::from(PStr::LOWER_A)); + matching_pattern = pattern::MatchingPattern::Id(Id::from(PStr::LOWER_A), ()); assert_eq!(*matching_pattern.loc(), Location::dummy()); matching_pattern = pattern::MatchingPattern::Wildcard(Location::dummy()); assert_eq!(*matching_pattern.loc(), Location::dummy()); @@ -205,7 +205,16 @@ mod tests { })); coverage_hack_for_expr(E::IfElse(IfElse { common: common.clone(), - condition: Box::new(zero_expr.clone()), + condition: Box::new(IfElseCondition::Expression(zero_expr.clone())), + e1: Box::new(zero_expr.clone()), + e2: Box::new(zero_expr.clone()), + })); + coverage_hack_for_expr(E::IfElse(IfElse { + common: common.clone(), + condition: Box::new(IfElseCondition::Guard( + pattern::MatchingPattern::Wildcard(Location::dummy()), + zero_expr.clone(), + )), e1: Box::new(zero_expr.clone()), e2: Box::new(zero_expr.clone()), })); @@ -242,9 +251,10 @@ mod tests { loc: Location::dummy(), field_order: 0, field_name: Id::from(heap.alloc_str_for_test("name")), - pattern: Box::new(pattern::DestructuringPattern::Id(Id::from( - heap.alloc_str_for_test("name"), - ))), + pattern: Box::new(pattern::DestructuringPattern::Id( + Id::from(heap.alloc_str_for_test("name")), + (), + )), shorthand: true, type_: (), }], @@ -262,9 +272,10 @@ mod tests { pattern: pattern::DestructuringPattern::Tuple( Location::dummy(), vec![pattern::TuplePatternElement { - pattern: Box::new(pattern::DestructuringPattern::Id(Id::from( - heap.alloc_str_for_test("name"), - ))), + pattern: Box::new(pattern::DestructuringPattern::Id( + Id::from(heap.alloc_str_for_test("name")), + (), + )), type_: (), }], ), @@ -289,7 +300,7 @@ mod tests { expr::DeclarationStatement { loc: Location::dummy(), associated_comments: NO_COMMENT_REFERENCE, - pattern: pattern::DestructuringPattern::Id(Id::from(heap.alloc_str_for_test("s"))), + pattern: pattern::DestructuringPattern::Id(Id::from(heap.alloc_str_for_test("s")), ()), annotation: Some(annotation::T::Fn(annotation::Function { location: Location::dummy(), associated_comments: NO_COMMENT_REFERENCE, diff --git a/crates/samlang-core/src/checker/checker_tests.rs b/crates/samlang-core/src/checker/checker_tests.rs index 8423536d7..deb8ab57f 100644 --- a/crates/samlang-core/src/checker/checker_tests.rs +++ b/crates/samlang-core/src/checker/checker_tests.rs @@ -2586,11 +2586,21 @@ Found 1 error. "{ let _ = (b: bool, t: int, f: int) -> if b then t else f; }", &builder.unit_type(), ); + assert_checks( + heap, + "{ let _ = (t: Test) -> if let {foo, bar as _} = t then 1 else 2; }", + &builder.unit_type(), + ); assert_checks( heap, "{ let _ = (t: Test2) -> match (t) { Foo(_) -> 1, Bar(s) -> 2 }; }", &builder.unit_type(), ); + assert_checks( + heap, + "{ let _ = (t: Test2) -> if let Foo(_) = t then 1 else 2; }", + &builder.unit_type(), + ); assert_errors_full_customization( heap, "{ let _ = (t: Test2) -> match (t) { Foo(_) -> 1, Bar(s) -> 2 }; }", @@ -2716,6 +2726,79 @@ Error ---------------------------------- DUMMY.sam:3:22-3:23 Found 1 error. "#, ); + assert_errors_full_customization( + heap, + "{ let _ = (t: Test) -> if let [a, b, _] = [1, 2] then 1 else 2; }", + &builder.unit_type(), + r#" +Error ---------------------------------- DUMMY.sam:1:38-1:39 + +Cannot access member of `Pair` at index 2. + + 1| { let _ = (t: Test) -> if let [a, b, _] = [1, 2] then 1 else 2; } + ^ + + +Found 1 error. +"#, + "Test", + true, + ); + assert_errors_full_customization( + heap, + r#"{ let _ = (t: Test) -> if let {bar, boo} = t then 1 else 2; +let _ = (t: Test) -> if let [_, bar] = t then 1 else 2; +let _ = (t: Test2) -> if let Foo(_) = t then 1 else 2; +let _ = (t: Test2) -> if let Foo(_, _) = t then 1 else 2; +let _ = (t: Test2) -> if let Foo111(_) = t then 1 else 2; +}"#, + &builder.unit_type(), + r#" +Error ---------------------------------- DUMMY.sam:1:32-1:35 + +Cannot find member `bar` on `Test`. + + 1| { let _ = (t: Test) -> if let {bar, boo} = t then 1 else 2; + ^^^ + + +Error ---------------------------------- DUMMY.sam:1:37-1:40 + +Cannot find member `boo` on `Test`. + + 1| { let _ = (t: Test) -> if let {bar, boo} = t then 1 else 2; + ^^^ + + +Error ---------------------------------- DUMMY.sam:2:33-2:36 + +Cannot access member of `Test` at index 1. + + 2| let _ = (t: Test) -> if let [_, bar] = t then 1 else 2; + ^^^ + + +Error ---------------------------------- DUMMY.sam:4:37-4:38 + +Cannot access member of `Test2` at index 1. + + 4| let _ = (t: Test2) -> if let Foo(_, _) = t then 1 else 2; + ^ + + +Error ---------------------------------- DUMMY.sam:5:30-5:36 + +Cannot find member `Foo111` on `Test2`. + + 5| let _ = (t: Test2) -> if let Foo111(_) = t then 1 else 2; + ^^^^^^ + + +Found 5 errors. +"#, + "Test2", + false, + ); assert_errors( heap, "match (3) { Foo(_) -> 1, Bar(s) -> 2 }", diff --git a/crates/samlang-core/src/checker/main_checker.rs b/crates/samlang-core/src/checker/main_checker.rs index 493f882e2..2f8eae0b4 100644 --- a/crates/samlang-core/src/checker/main_checker.rs +++ b/crates/samlang-core/src/checker/main_checker.rs @@ -896,7 +896,16 @@ fn check_if_else( expression: &expr::IfElse<()>, hint: type_hint::Hint, ) -> expr::E> { - let condition = Box::new(type_check_expression(cx, &expression.condition, type_hint::MISSING)); + let condition = Box::new(match expression.condition.as_ref() { + expr::IfElseCondition::Expression(e) => { + expr::IfElseCondition::Expression(type_check_expression(cx, e, type_hint::MISSING)) + } + expr::IfElseCondition::Guard(p, e) => { + let e = type_check_expression(cx, e, type_hint::MISSING); + let p = check_matching_pattern(cx, p, e.type_()); + expr::IfElseCondition::Guard(p, e) + } + }); let e1 = Box::new(type_check_expression(cx, &expression.e1, hint)); let e2 = Box::new(type_check_expression(cx, &expression.e2, type_hint::available(e1.type_()))); assignability_check(cx, e2.loc(), e2.type_(), e1.type_()); @@ -1140,14 +1149,153 @@ fn check_destructuring_pattern( } pattern::DestructuringPattern::Object(*pattern_loc, checked_destructured_names) } - pattern::DestructuringPattern::Id(id) => { + pattern::DestructuringPattern::Id(id, ()) => { cx.local_typing_context.write(id.loc, pattern_type.clone()); - pattern::DestructuringPattern::Id(*id) + pattern::DestructuringPattern::Id(*id, pattern_type.clone()) } pattern::DestructuringPattern::Wildcard(loc) => pattern::DestructuringPattern::Wildcard(*loc), } } +fn check_matching_pattern( + cx: &mut TypingContext, + pattern: &pattern::MatchingPattern<()>, + pattern_type: &Rc, +) -> pattern::MatchingPattern> { + match pattern { + pattern::MatchingPattern::Tuple(pattern_loc, destructured_names) => { + let fields = cx.resolve_struct_definitions(pattern_type); + let mut checked_destructured_names = vec![]; + for (index, pattern::TuplePatternElement { pattern, type_: _ }) in + destructured_names.iter().enumerate() + { + let loc = pattern.loc(); + if let Some(field_sig) = fields.get(index) { + if !field_sig.is_public { + cx.error_set.report_element_missing_error(*loc, pattern_type.to_description(), index); + } + let checked = Box::new(check_matching_pattern(cx, pattern, &field_sig.type_)); + checked_destructured_names.push(pattern::TuplePatternElement { + pattern: checked, + type_: Rc::new(field_sig.type_.reposition(*loc)), + }); + continue; + } + cx.error_set.report_element_missing_error(*loc, pattern_type.to_description(), index); + let type_ = Rc::new(Type::Any(Reason::new(*loc, Some(*loc)), false)); + let checked = Box::new(check_matching_pattern(cx, pattern, &type_)); + checked_destructured_names.push(pattern::TuplePatternElement { pattern: checked, type_ }); + } + pattern::MatchingPattern::Tuple(*pattern_loc, checked_destructured_names) + } + pattern::MatchingPattern::Object(pattern_loc, destructed_names) => { + let fields = cx.resolve_struct_definitions(pattern_type); + let mut field_order_mapping = HashMap::new(); + let mut field_mappings = HashMap::new(); + for (i, field) in fields.into_iter().enumerate() { + field_order_mapping.insert(field.name, i); + field_mappings.insert(field.name, (field.type_, field.is_public)); + } + let mut checked_destructured_names = vec![]; + for pattern::ObjectPatternElement { + loc, + field_order, + field_name, + pattern, + shorthand, + type_: _, + } in destructed_names + { + if let Some((field_type, is_public)) = field_mappings.get(&field_name.name) { + if !is_public { + cx.error_set.report_member_missing_error( + field_name.loc, + pattern_type.to_description(), + field_name.name, + ); + } + let checked = Box::new(check_matching_pattern(cx, pattern, field_type)); + let field_order = field_order_mapping.get(&field_name.name).unwrap(); + checked_destructured_names.push(pattern::ObjectPatternElement { + loc: *loc, + field_order: *field_order, + field_name: *field_name, + pattern: checked, + shorthand: *shorthand, + type_: Rc::new(field_type.reposition(*loc)), + }); + continue; + } + cx.error_set.report_member_missing_error( + field_name.loc, + pattern_type.to_description(), + field_name.name, + ); + let type_ = Rc::new(Type::Any(Reason::new(*loc, Some(*loc)), false)); + let checked = Box::new(check_matching_pattern(cx, pattern, &type_)); + checked_destructured_names.push(pattern::ObjectPatternElement { + loc: *loc, + field_order: *field_order, + field_name: *field_name, + pattern: checked, + shorthand: *shorthand, + type_: Rc::new(Type::Any(Reason::new(*loc, Some(*loc)), false)), + }); + } + pattern::MatchingPattern::Object(*pattern_loc, checked_destructured_names) + } + pattern::MatchingPattern::Variant(pattern::VariantPattern { + loc, + tag_order, + tag, + data_variables, + type_: _, + }) => { + let Some((tag_order, resolved_enum)) = + cx.resolve_enum_definitions(pattern_type).into_iter().find_position(|e| e.name == tag.name) + else { + cx.error_set.report_member_missing_error(tag.loc, pattern_type.to_description(), tag.name); + let type_ = Rc::new(Type::Any(Reason::new(*loc, Some(*loc)), false)); + return pattern::MatchingPattern::Variant(pattern::VariantPattern { + loc: *loc, + tag_order: *tag_order, + tag: *tag, + data_variables: data_variables + .iter() + .map(|(p, ())| (check_matching_pattern(cx, p, &type_), type_.clone())) + .collect(), + type_, + }); + }; + let mut checked_data_variables = Vec::with_capacity(data_variables.len()); + for (index, (p, ())) in data_variables.iter().enumerate() { + if let Some(resolved_pattern_type) = resolved_enum.types.get(index) { + checked_data_variables.push(( + check_matching_pattern(cx, p, resolved_pattern_type), + resolved_pattern_type.clone(), + )); + } else { + cx.error_set.report_element_missing_error(*p.loc(), pattern_type.to_description(), index); + let type_ = Rc::new(Type::Any(Reason::new(*p.loc(), Some(*p.loc())), false)); + checked_data_variables.push((check_matching_pattern(cx, p, &type_), type_)); + } + } + pattern::MatchingPattern::Variant(pattern::VariantPattern { + loc: *loc, + tag_order, + tag: *tag, + data_variables: checked_data_variables, + type_: pattern_type.clone(), + }) + } + pattern::MatchingPattern::Id(id, ()) => { + cx.local_typing_context.write(id.loc, pattern_type.clone()); + pattern::MatchingPattern::Id(*id, pattern_type.clone()) + } + pattern::MatchingPattern::Wildcard(loc) => pattern::MatchingPattern::Wildcard(*loc), + } +} + fn check_statement( cx: &mut TypingContext, statement: &expr::DeclarationStatement<()>, diff --git a/crates/samlang-core/src/checker/ssa_analysis.rs b/crates/samlang-core/src/checker/ssa_analysis.rs index 15b35888e..bdd687828 100644 --- a/crates/samlang-core/src/checker/ssa_analysis.rs +++ b/crates/samlang-core/src/checker/ssa_analysis.rs @@ -271,11 +271,21 @@ impl<'a> SsaAnalysisState<'a> { self.visit_expression(&e.e1); self.visit_expression(&e.e2); } - expr::E::IfElse(e) => { - self.visit_expression(&e.condition); - self.visit_expression(&e.e1); - self.visit_expression(&e.e2); - } + expr::E::IfElse(e) => match e.condition.as_ref() { + expr::IfElseCondition::Expression(guard) => { + self.visit_expression(guard); + self.visit_expression(&e.e1); + self.visit_expression(&e.e2); + } + expr::IfElseCondition::Guard(p, guard) => { + self.visit_expression(guard); + self.context.push_scope(); + self.visit_matching_pattern(p); + self.visit_expression(&e.e1); + self.context.pop_scope(); + self.visit_expression(&e.e2); + } + }, expr::E::Match(e) => { self.visit_expression(&e.matched); for case in &e.cases { @@ -315,7 +325,7 @@ impl<'a> SsaAnalysisState<'a> { if let Some(annot) = annotation { self.visit_annot(annot); } - self.visit_pattern(pattern); + self.visit_destructuring_pattern(pattern); } if let Some(final_expr) = &e.expression { self.visit_expression(final_expr); @@ -326,23 +336,51 @@ impl<'a> SsaAnalysisState<'a> { } } - fn visit_pattern(&mut self, pattern: &pattern::DestructuringPattern<()>) { + fn visit_destructuring_pattern(&mut self, pattern: &pattern::DestructuringPattern<()>) { match pattern { pattern::DestructuringPattern::Tuple(_, names) => { for pattern::TuplePatternElement { pattern, type_: _ } in names { - self.visit_pattern(pattern); + self.visit_destructuring_pattern(pattern); } } pattern::DestructuringPattern::Object(_, names) => { for name in names { - self.visit_pattern(&name.pattern); + self.visit_destructuring_pattern(&name.pattern); } } - pattern::DestructuringPattern::Id(id) => self.define_id(id.name, id.loc), + pattern::DestructuringPattern::Id(id, ()) => self.define_id(id.name, id.loc), pattern::DestructuringPattern::Wildcard(_) => {} } } + fn visit_matching_pattern(&mut self, pattern: &pattern::MatchingPattern<()>) { + match pattern { + pattern::MatchingPattern::Tuple(_, names) => { + for pattern::TuplePatternElement { pattern, type_: _ } in names { + self.visit_matching_pattern(pattern); + } + } + pattern::MatchingPattern::Object(_, names) => { + for name in names { + self.visit_matching_pattern(&name.pattern); + } + } + pattern::MatchingPattern::Variant(pattern::VariantPattern { + loc: _, + tag_order: _, + tag: _, + data_variables, + type_: _, + }) => { + for (p, _) in data_variables { + self.visit_matching_pattern(p); + } + } + pattern::MatchingPattern::Id(id, ()) => self.define_id(id.name, id.loc), + pattern::MatchingPattern::Wildcard(_) => {} + } + } + fn visit_id_annot( &mut self, annotation::Id { location, module_reference, id, type_arguments }: &annotation::Id, diff --git a/crates/samlang-core/src/checker/ssa_analysis_tests.rs b/crates/samlang-core/src/checker/ssa_analysis_tests.rs index c01785732..fb24b3f70 100644 --- a/crates/samlang-core/src/checker/ssa_analysis_tests.rs +++ b/crates/samlang-core/src/checker/ssa_analysis_tests.rs @@ -51,6 +51,7 @@ mod tests { } else { (p1: Foo, p2) -> Baz.ouch(p2).ahha(p1) + a }; + let _ = if let {pat1 as {pat2 as [Fizz(pat3), Buzz, _], pat4}} = true then 1 else 2; let a = 3; let {o1, o2 as o3} = {}; o1 + o3 @@ -64,12 +65,12 @@ mod tests { assert!(!error_set.has_errors()); let expected = r#" Unbound names: [Foo] -Invalid defines: [13:7-13:8] +Invalid defines: [14:7-14:8] Locally Scoped Defs: 10:10-12:4: [] 11:5-11:57: [p1, p2] -14:24-14:26: [] -1:1-16:2: [a, b, c, o1, o3] +15:24-15:26: [] +1:1-17:2: [a, b, c, o1, o3] 5:26-10:4: [] 7:7-7:23: [d] 8:7-8:18: [] @@ -77,9 +78,11 @@ Lambda Capture Locs: [11:5-11:57] def_to_use_map: 11:15-11:17 -> [11:15-11:17, 11:36-11:38] 11:6-11:8 -> [11:50-11:52, 11:6-11:8] -13:7-13:8 -> [13:7-13:8] -14:18-14:20 -> [14:18-14:20, 15:8-15:10] -14:8-14:10 -> [14:8-14:10, 15:3-15:5] +13:42-13:46 -> [13:42-13:46] +13:59-13:63 -> [13:59-13:63] +14:7-14:8 -> [14:7-14:8] +15:18-15:20 -> [15:18-15:20, 16:8-16:10] +15:8-15:10 -> [15:8-15:10, 16:3-16:5] 2:7-2:8 -> [11:56-11:57, 2:7-2:8, 4:11-4:12, 5:14-5:15, 7:21-7:22] 3:7-3:8 -> [3:7-3:8, 5:19-5:20, 8:17-8:18] 4:7-4:8 -> [4:7-4:8, 6:13-6:14] diff --git a/crates/samlang-core/src/compiler/hir_lowering.rs b/crates/samlang-core/src/compiler/hir_lowering.rs index 4eaff5d85..04fb15b25 100644 --- a/crates/samlang-core/src/compiler/hir_lowering.rs +++ b/crates/samlang-core/src/compiler/hir_lowering.rs @@ -590,7 +590,13 @@ impl<'a> ExpressionLoweringManager<'a> { expression: &source::expr::IfElse>, ) -> LoweringResult { let mut lowered_stmts = vec![]; - let condition = self.lowered_and_add_statements(&expression.condition, &mut lowered_stmts); + let condition = self.lowered_and_add_statements( + match expression.condition.as_ref() { + source::expr::IfElseCondition::Expression(e) => e, + source::expr::IfElseCondition::Guard(_, _) => panic!("TODO IF_LET"), + }, + &mut lowered_stmts, + ); let final_var_name = self.allocate_temp_variable(); let LoweringResult { statements: s1, expression: e1 } = self.lower(&expression.e1); let LoweringResult { statements: s2, expression: e2 } = self.lower(&expression.e2); @@ -859,13 +865,82 @@ impl<'a> ExpressionLoweringManager<'a> { ); } } - source::pattern::DestructuringPattern::Id(id) => { + source::pattern::DestructuringPattern::Id(id, _) => { bind_value(&mut self.variable_cx, id.name, assigned_expr); } source::pattern::DestructuringPattern::Wildcard(_) => {} } } + /* + fn lower_matching_pattern( + &mut self, + pattern: &source::pattern::MatchingPattern>, + lowered_stmts: &mut Vec, + assigned_expr: hir::Expression, + e1: hir::Expression, + e2: hir::Expression, + ) { + match pattern { + source::pattern::MatchingPattern::Tuple(_, destructured_names) => { + let id_type = assigned_expr.type_().as_id().unwrap(); + let resolved_struct_mappings = self.resolve_struct_mapping_of_id_type(id_type); + for (index, nested) in destructured_names.iter().enumerate() { + let field_type = &resolved_struct_mappings[index]; + let name = self.allocate_temp_variable(); + lowered_stmts.push(hir::Statement::IndexedAccess { + name, + type_: field_type.clone(), + pointer_expression: assigned_expr.clone(), + index, + }); + self.lower_matching_pattern( + &nested.pattern, + lowered_stmts, + hir::Expression::var_name(name, field_type.clone()), + e1, + e2, + ); + } + } + source::pattern::MatchingPattern::Object(_, destructured_names) => { + let id_type = assigned_expr.type_().as_id().unwrap(); + let resolved_struct_mappings = self.resolve_struct_mapping_of_id_type(id_type); + for destructured_name in destructured_names { + let field_type = &resolved_struct_mappings[destructured_name.field_order]; + let name = self.allocate_temp_variable(); + lowered_stmts.push(hir::Statement::IndexedAccess { + name, + type_: field_type.clone(), + pointer_expression: assigned_expr.clone(), + index: destructured_name.field_order, + }); + self.lower_matching_pattern( + &destructured_name.pattern, + lowered_stmts, + hir::Expression::var_name(name, field_type.clone()), + e1, + e2, + ); + } + } + source::pattern::MatchingPattern::Variant(source::pattern::VariantPattern { + loc: _, + tag_order, + tag, + data_variables, + type_, + }) => { + let id_type = assigned_expr.type_().as_id().unwrap(); + } + source::pattern::MatchingPattern::Id(id, _) => { + bind_value(&mut self.variable_cx, id.name, assigned_expr); + } + source::pattern::MatchingPattern::Wildcard(_) => {} + } + } + */ + fn lower_block(&mut self, expression: &source::expr::Block>) -> LoweringResult { let mut lowered_stmts = vec![]; self.variable_cx.push_scope(); @@ -1968,6 +2043,25 @@ return (_t3: _$SyntheticIDType0);"#, ); } + #[should_panic] + #[test] + fn if_let_unsupported_test() { + let heap = &mut Heap::new(); + assert_expr_correctly_lowered( + &source::expr::E::IfElse(source::expr::IfElse { + common: source::expr::ExpressionCommon::dummy(Rc::new(dummy_source_id_type(heap))), + condition: Box::new(source::expr::IfElseCondition::Guard( + source::pattern::MatchingPattern::Wildcard(Location::dummy()), + dummy_source_this(heap), + )), + e1: Box::new(dummy_source_this(heap)), + e2: Box::new(dummy_source_this(heap)), + }), + heap, + "", + ); + } + #[test] fn control_flow_lowering_tests() { let builder = test_type_builder::create(); @@ -1976,7 +2070,7 @@ return (_t3: _$SyntheticIDType0);"#, assert_expr_correctly_lowered( &source::expr::E::IfElse(source::expr::IfElse { common: source::expr::ExpressionCommon::dummy(Rc::new(dummy_source_id_type(heap))), - condition: Box::new(dummy_source_this(heap)), + condition: Box::new(source::expr::IfElseCondition::Expression(dummy_source_this(heap))), e1: Box::new(dummy_source_this(heap)), e2: Box::new(dummy_source_this(heap)), }), @@ -2102,7 +2196,10 @@ return (_t7: DUMMY_Dummy);"#, statements: vec![source::expr::DeclarationStatement { loc: Location::dummy(), associated_comments: NO_COMMENT_REFERENCE, - pattern: source::pattern::DestructuringPattern::Id(source::Id::from(PStr::LOWER_A)), + pattern: source::pattern::DestructuringPattern::Id( + source::Id::from(PStr::LOWER_A), + builder.unit_type(), + ), annotation: Some(annot_builder.unit_annot()), assigned_expression: Box::new(source::expr::E::Block(source::expr::Block { common: source::expr::ExpressionCommon::dummy(builder.unit_type()), @@ -2119,6 +2216,7 @@ return (_t7: DUMMY_Dummy);"#, field_name: source::Id::from(PStr::LOWER_A), pattern: Box::new(source::pattern::DestructuringPattern::Id( source::Id::from(PStr::LOWER_A), + builder.int_type(), )), shorthand: true, type_: builder.int_type(), @@ -2129,6 +2227,7 @@ return (_t7: DUMMY_Dummy);"#, field_name: source::Id::from(PStr::LOWER_B), pattern: Box::new(source::pattern::DestructuringPattern::Id( source::Id::from(PStr::LOWER_C), + builder.int_type(), )), shorthand: false, type_: builder.int_type(), @@ -2172,9 +2271,10 @@ return 0;"#, loc: Location::dummy(), field_order: 0, field_name: source::Id::from(PStr::LOWER_A), - pattern: Box::new(source::pattern::DestructuringPattern::Id(source::Id::from( - PStr::LOWER_A, - ))), + pattern: Box::new(source::pattern::DestructuringPattern::Id( + source::Id::from(PStr::LOWER_A), + builder.int_type(), + )), shorthand: true, type_: builder.int_type(), }, @@ -2182,9 +2282,10 @@ return 0;"#, loc: Location::dummy(), field_order: 1, field_name: source::Id::from(PStr::LOWER_B), - pattern: Box::new(source::pattern::DestructuringPattern::Id(source::Id::from( - PStr::LOWER_C, - ))), + pattern: Box::new(source::pattern::DestructuringPattern::Id( + source::Id::from(PStr::LOWER_C), + builder.int_type(), + )), shorthand: false, type_: builder.int_type(), }, @@ -2200,9 +2301,10 @@ return 0;"#, Location::dummy(), vec![ source::pattern::TuplePatternElement { - pattern: Box::new(source::pattern::DestructuringPattern::Id(source::Id::from( - PStr::LOWER_D, - ))), + pattern: Box::new(source::pattern::DestructuringPattern::Id( + source::Id::from(PStr::LOWER_D), + builder.int_type(), + )), type_: builder.int_type(), }, source::pattern::TuplePatternElement { @@ -2241,7 +2343,10 @@ return 0;"#, statements: vec![source::expr::DeclarationStatement { loc: Location::dummy(), associated_comments: NO_COMMENT_REFERENCE, - pattern: source::pattern::DestructuringPattern::Id(source::Id::from( PStr::LOWER_A)), + pattern: source::pattern::DestructuringPattern::Id( + source::Id::from(PStr::LOWER_A), + builder.int_type(), + ), annotation: Some(annot_builder.int_annot()), assigned_expression: Box::new(source::expr::E::Call(source::expr::Call { common: source::expr::ExpressionCommon::dummy(builder.int_type()), @@ -2281,7 +2386,10 @@ return 0;"#, source::expr::DeclarationStatement { loc: Location::dummy(), associated_comments: NO_COMMENT_REFERENCE, - pattern: source::pattern::DestructuringPattern::Id(source::Id::from(PStr::LOWER_A)), + pattern: source::pattern::DestructuringPattern::Id( + source::Id::from(PStr::LOWER_A), + builder.unit_type(), + ), annotation: Some(annot_builder.unit_annot()), assigned_expression: Box::new(source::expr::E::Literal( source::expr::ExpressionCommon::dummy(builder.string_type()), @@ -2291,7 +2399,10 @@ return 0;"#, source::expr::DeclarationStatement { loc: Location::dummy(), associated_comments: NO_COMMENT_REFERENCE, - pattern: source::pattern::DestructuringPattern::Id(source::Id::from(PStr::LOWER_B)), + pattern: source::pattern::DestructuringPattern::Id( + source::Id::from(PStr::LOWER_B), + builder.unit_type(), + ), annotation: Some(annot_builder.unit_annot()), assigned_expression: Box::new(id_expr(PStr::LOWER_A, builder.string_type())), }, @@ -2309,14 +2420,20 @@ return 0;"#, statements: vec![source::expr::DeclarationStatement { loc: Location::dummy(), associated_comments: NO_COMMENT_REFERENCE, - pattern: source::pattern::DestructuringPattern::Id(source::Id::from(PStr::LOWER_A)), + pattern: source::pattern::DestructuringPattern::Id( + source::Id::from(PStr::LOWER_A), + builder.unit_type(), + ), annotation: Some(annot_builder.unit_annot()), assigned_expression: Box::new(source::expr::E::Block(source::expr::Block { common: source::expr::ExpressionCommon::dummy(builder.unit_type()), statements: vec![source::expr::DeclarationStatement { loc: Location::dummy(), associated_comments: NO_COMMENT_REFERENCE, - pattern: source::pattern::DestructuringPattern::Id(source::Id::from(PStr::LOWER_A)), + pattern: source::pattern::DestructuringPattern::Id( + source::Id::from(PStr::LOWER_A), + builder.unit_type(), + ), annotation: Some(annot_builder.int_annot()), assigned_expression: Box::new(dummy_source_this(heap)), }], @@ -2544,16 +2661,18 @@ return 0;"#, }, body: source::expr::E::IfElse(source::expr::IfElse { common: source::expr::ExpressionCommon::dummy(builder.int_type()), - condition: Box::new(source::expr::E::Binary(source::expr::Binary { - common: source::expr::ExpressionCommon::dummy(builder.int_type()), - operator_preceding_comments: NO_COMMENT_REFERENCE, - operator: source::expr::BinaryOperator::EQ, - e1: Box::new(id_expr(heap.alloc_str_for_test("n"), builder.int_type())), - e2: Box::new(source::expr::E::Literal( - source::expr::ExpressionCommon::dummy(builder.int_type()), - source::Literal::Int(0), - )), - })), + condition: Box::new(source::expr::IfElseCondition::Expression( + source::expr::E::Binary(source::expr::Binary { + common: source::expr::ExpressionCommon::dummy(builder.int_type()), + operator_preceding_comments: NO_COMMENT_REFERENCE, + operator: source::expr::BinaryOperator::EQ, + e1: Box::new(id_expr(heap.alloc_str_for_test("n"), builder.int_type())), + e2: Box::new(source::expr::E::Literal( + source::expr::ExpressionCommon::dummy(builder.int_type()), + source::Literal::Int(0), + )), + }), + )), e1: Box::new(source::expr::E::Literal( source::expr::ExpressionCommon::dummy(builder.int_type()), source::Literal::Int(1), diff --git a/crates/samlang-core/src/parser.rs b/crates/samlang-core/src/parser.rs index 7e9b5e3a4..334ff1f49 100644 --- a/crates/samlang-core/src/parser.rs +++ b/crates/samlang-core/src/parser.rs @@ -101,6 +101,7 @@ mod tests { expect_good_expr("false || true"); expect_good_expr("\"hello\"::\"world\""); expect_good_expr("if (true) then 3 else bar"); + expect_good_expr("if let {foo as {bar as [Fizz(baz), Buzz, _], boo}} = true then 3 else bar"); expect_good_expr("match (this) { None(_) -> 0, Some(d) -> d }"); expect_good_expr("match (this) { None(_) -> match this { None(_) -> 1 } Some(d) -> d }"); expect_good_expr("match (this) { None(_) -> {}, Some(d) -> d }"); diff --git a/crates/samlang-core/src/parser/source_parser.rs b/crates/samlang-core/src/parser/source_parser.rs index fa3be8ed3..4cc62cb0b 100644 --- a/crates/samlang-core/src/parser/source_parser.rs +++ b/crates/samlang-core/src/parser/source_parser.rs @@ -1,6 +1,9 @@ use super::lexer::{Keyword, Token, TokenContent, TokenOp}; use crate::{ - ast::{source::*, Location, Position}, + ast::{ + source::{expr::IfElseCondition, *}, + Location, Position, + }, errors::ErrorSet, }; use itertools::Itertools; @@ -724,7 +727,16 @@ impl<'a> SourceParser<'a> { let associated_comments = self.collect_preceding_comments(); if let Token(peeked_loc, TokenContent::Keyword(Keyword::IF)) = self.peek() { self.consume(); - let condition = self.parse_expression(); + let condition = + if let Token(_peeked_let_loc, TokenContent::Keyword(Keyword::LET)) = self.peek() { + self.consume(); + let pattern = self.parse_matching_pattern(); + self.assert_and_consume_operator(TokenOp::ASSIGN); + let expr = self.parse_expression(); + IfElseCondition::Guard(pattern, expr) + } else { + IfElseCondition::Expression(self.parse_expression()) + }; self.assert_and_consume_keyword(Keyword::THEN); let e1 = self.parse_expression(); self.assert_and_consume_keyword(Keyword::ELSE); @@ -1354,7 +1366,7 @@ impl<'a> SourceParser<'a> { let loc = field_name.loc.union(nested.loc()); (nested, loc, false) } else { - (Box::new(pattern::DestructuringPattern::Id(field_name)), field_name.loc, true) + (Box::new(pattern::DestructuringPattern::Id(field_name, ())), field_name.loc, true) }; pattern::ObjectPatternElement { loc, @@ -1375,11 +1387,88 @@ impl<'a> SourceParser<'a> { self.consume(); return pattern::DestructuringPattern::Wildcard(peeked_loc); } - pattern::DestructuringPattern::Id(Id { - loc: peeked.0, - associated_comments: NO_COMMENT_REFERENCE, - name: self.assert_and_peek_lower_id().1, - }) + pattern::DestructuringPattern::Id( + Id { + loc: peeked.0, + associated_comments: NO_COMMENT_REFERENCE, + name: self.assert_and_peek_lower_id().1, + }, + (), + ) + } + + pub(super) fn parse_matching_pattern(&mut self) -> pattern::MatchingPattern<()> { + let peeked = self.peek(); + if let Token(peeked_loc, TokenContent::Operator(TokenOp::LBRACKET)) = peeked { + self.consume(); + let destructured_names = + self.parse_comma_separated_list(Some(TokenOp::RBRACKET), &mut |s: &mut Self| { + pattern::TuplePatternElement { pattern: Box::new(s.parse_matching_pattern()), type_: () } + }); + let end_location = self.assert_and_consume_operator(TokenOp::RBRACKET); + return pattern::MatchingPattern::Tuple(peeked_loc.union(&end_location), destructured_names); + } + if let Token(peeked_loc, TokenContent::Operator(TokenOp::LBRACE)) = peeked { + self.consume(); + let destructured_names = + self.parse_comma_separated_list(Some(TokenOp::RBRACE), &mut |s: &mut Self| { + let field_name = s.parse_lower_id(); + let (pattern, loc, shorthand) = + if let Token(_, TokenContent::Keyword(Keyword::AS)) = s.peek() { + s.consume(); + let nested = Box::new(s.parse_matching_pattern()); + let loc = field_name.loc.union(nested.loc()); + (nested, loc, false) + } else { + (Box::new(pattern::MatchingPattern::Id(field_name, ())), field_name.loc, true) + }; + pattern::ObjectPatternElement { + loc, + field_name, + field_order: 0, + pattern, + shorthand, + type_: (), + } + }); + let end_location = self.assert_and_consume_operator(TokenOp::RBRACE); + return pattern::MatchingPattern::Object(peeked_loc.union(&end_location), destructured_names); + } + if let Token(peeked_loc, TokenContent::UpperId(id)) = peeked { + self.consume(); + let tag = Id { loc: peeked_loc, associated_comments: NO_COMMENT_REFERENCE, name: id }; + let (data_variables, loc) = + if let Token(_, TokenContent::Operator(TokenOp::LPAREN)) = self.peek() { + self.assert_and_consume_operator(TokenOp::LPAREN); + let data_variables = self + .parse_comma_separated_list(Some(TokenOp::RPAREN), &mut |s: &mut Self| { + (s.parse_matching_pattern(), ()) + }); + let end_loc = self.assert_and_consume_operator(TokenOp::RPAREN); + (data_variables, peeked_loc.union(&end_loc)) + } else { + (Vec::with_capacity(0), peeked_loc) + }; + return pattern::MatchingPattern::Variant(pattern::VariantPattern { + loc, + tag_order: 0, + tag, + data_variables, + type_: (), + }); + } + if let Token(peeked_loc, TokenContent::Operator(TokenOp::UNDERSCORE)) = peeked { + self.consume(); + return pattern::MatchingPattern::Wildcard(peeked_loc); + } + pattern::MatchingPattern::Id( + Id { + loc: peeked.0, + associated_comments: NO_COMMENT_REFERENCE, + name: self.assert_and_peek_lower_id().1, + }, + (), + ) } fn parse_upper_id(&mut self) -> Id { diff --git a/crates/samlang-core/src/printer/source_printer.rs b/crates/samlang-core/src/printer/source_printer.rs index b5fcc295a..5f30ecefa 100644 --- a/crates/samlang-core/src/printer/source_printer.rs +++ b/crates/samlang-core/src/printer/source_printer.rs @@ -274,7 +274,15 @@ impl expr::E<()> { documents: &mut Vec, ) { documents.push(Document::Text(rcs("if "))); - documents.push(if_else.condition.create_doc(heap, comment_store)); + match if_else.condition.as_ref() { + expr::IfElseCondition::Expression(e) => documents.push(e.create_doc(heap, comment_store)), + expr::IfElseCondition::Guard(p, e) => { + documents.push(Document::Text(rcs("let "))); + documents.push(matching_pattern_to_document(heap, p)); + documents.push(Document::Text(rcs(" = "))); + documents.push(e.create_doc(heap, comment_store)); + } + }; documents.push(Document::Text(rcs(" then "))); documents.push(self.expr_wrapped_with_braces_expanded_in_if_else( heap, @@ -652,11 +660,14 @@ impl expr::E<()> { } } -fn pattern_to_document(heap: &Heap, pattern: &pattern::DestructuringPattern<()>) -> Document { +fn destructuring_pattern_to_document( + heap: &Heap, + pattern: &pattern::DestructuringPattern<()>, +) -> Document { match pattern { pattern::DestructuringPattern::Tuple(_, names) => { square_brackets_surrounded_doc(comma_sep_list(names, |it| { - pattern_to_document(heap, &it.pattern) + destructuring_pattern_to_document(heap, &it.pattern) })) } pattern::DestructuringPattern::Object(_, names) => { @@ -667,23 +678,60 @@ fn pattern_to_document(heap: &Heap, pattern: &pattern::DestructuringPattern<()>) Document::concat(vec![ Document::Text(rc_pstr(heap, it.field_name.name)), Document::Text(rcs(" as ")), - pattern_to_document(heap, &it.pattern), + destructuring_pattern_to_document(heap, &it.pattern), ]) } })) } - pattern::DestructuringPattern::Id(id) => Document::Text(rc_pstr(heap, id.name)), + pattern::DestructuringPattern::Id(id, _) => Document::Text(rc_pstr(heap, id.name)), pattern::DestructuringPattern::Wildcard(_) => Document::Text(rcs("_")), } } +fn matching_pattern_to_document(heap: &Heap, pattern: &pattern::MatchingPattern<()>) -> Document { + match pattern { + pattern::MatchingPattern::Tuple(_, names) => { + square_brackets_surrounded_doc(comma_sep_list(names, |it| { + matching_pattern_to_document(heap, &it.pattern) + })) + } + pattern::MatchingPattern::Object(_, names) => { + braces_surrounded_doc(comma_sep_list(names, |it| { + if it.shorthand { + Document::Text(rc_pstr(heap, it.field_name.name)) + } else { + Document::concat(vec![ + Document::Text(rc_pstr(heap, it.field_name.name)), + Document::Text(rcs(" as ")), + matching_pattern_to_document(heap, &it.pattern), + ]) + } + })) + } + pattern::MatchingPattern::Variant(pattern::VariantPattern { + loc: _, + tag_order: _, + tag, + data_variables, + type_: (), + }) => Document::concat(vec![ + Document::Text(rc_pstr(heap, tag.name)), + parenthesis_surrounded_doc(comma_sep_list(data_variables, |(p, _)| { + matching_pattern_to_document(heap, p) + })), + ]), + pattern::MatchingPattern::Id(id, _) => Document::Text(rc_pstr(heap, id.name)), + pattern::MatchingPattern::Wildcard(_) => Document::Text(rcs("_")), + } +} + pub(super) fn statement_to_document( heap: &Heap, comment_store: &CommentStore, stmt: &expr::DeclarationStatement<()>, ) -> Document { let mut segments = vec![]; - let pattern_doc = pattern_to_document(heap, &stmt.pattern); + let pattern_doc = destructuring_pattern_to_document(heap, &stmt.pattern); segments.push( associated_comments_doc( heap, @@ -1305,6 +1353,21 @@ Test /* b */ /* c */.VariantName(42)"#, ]: int = 3; }"#, ); + assert_reprint_expr( + "{let_=if let {foo as {bar as [Fizz(baz), Buzz, _], boo}} = true then 3 else bar;}", + r#"{ + let _ = if let { + foo as { + bar as [Fizz(baz), Buzz(), _], + boo + } + } = true then { + 3 + } else { + bar + }; +}"#, + ); assert_reprint_expr( "{ let a: unit = { let b: unit = { let c: unit = { let d: unit = aVariableNameThatIsVeryVeryVeryVeryVeryLong; }; }; }; }", diff --git a/crates/samlang-core/src/services/ast_differ.rs b/crates/samlang-core/src/services/ast_differ.rs index 9e15925a0..afc3808b6 100644 --- a/crates/samlang-core/src/services/ast_differ.rs +++ b/crates/samlang-core/src/services/ast_differ.rs @@ -494,9 +494,10 @@ mod tests { let stmt = expr::DeclarationStatement { loc: Location::dummy(), associated_comments: NO_COMMENT_REFERENCE, - pattern: crate::ast::source::pattern::DestructuringPattern::Id(Id::from( - heap.alloc_str_for_test("v"), - )), + pattern: crate::ast::source::pattern::DestructuringPattern::Id( + Id::from(heap.alloc_str_for_test("v")), + (), + ), annotation: Some(builder.bool_annot()), assigned_expression: Box::new(expr::E::LocalId( expr::ExpressionCommon::dummy(()), diff --git a/crates/samlang-core/src/services/gc.rs b/crates/samlang-core/src/services/gc.rs index 828b48a65..c87f79f93 100644 --- a/crates/samlang-core/src/services/gc.rs +++ b/crates/samlang-core/src/services/gc.rs @@ -67,11 +67,11 @@ fn mark_id(heap: &mut Heap, id: &Id) { heap.mark(id.name); } -fn mark_pattern(heap: &mut Heap, pattern: &pattern::DestructuringPattern>) { +fn mark_destructuring_pattern(heap: &mut Heap, pattern: &pattern::DestructuringPattern>) { match pattern { pattern::DestructuringPattern::Tuple(_, names) => { for n in names { - mark_pattern(heap, &n.pattern); + mark_destructuring_pattern(heap, &n.pattern); mark_type(heap, &n.type_); } } @@ -79,14 +79,54 @@ fn mark_pattern(heap: &mut Heap, pattern: &pattern::DestructuringPattern mark_id(heap, id), + pattern::DestructuringPattern::Id(id, type_) => { + mark_id(heap, id); + mark_type(heap, type_); + } pattern::DestructuringPattern::Wildcard(_) => {} } } +fn mark_matching_pattern(heap: &mut Heap, pattern: &pattern::MatchingPattern>) { + match pattern { + pattern::MatchingPattern::Tuple(_, names) => { + for n in names { + mark_matching_pattern(heap, &n.pattern); + mark_type(heap, &n.type_); + } + } + pattern::MatchingPattern::Object(_, names) => { + for n in names { + mark_type(heap, &n.type_); + mark_id(heap, &n.field_name); + mark_matching_pattern(heap, &n.pattern); + } + } + pattern::MatchingPattern::Variant(pattern::VariantPattern { + loc: _, + tag_order: _, + tag, + data_variables, + type_, + }) => { + mark_type(heap, type_); + mark_id(heap, tag); + for (p, type_) in data_variables { + mark_matching_pattern(heap, p); + mark_type(heap, type_); + } + } + pattern::MatchingPattern::Id(id, type_) => { + mark_id(heap, id); + mark_type(heap, type_); + } + pattern::MatchingPattern::Wildcard(_) => {} + } +} + fn mark_expression(heap: &mut Heap, expr: &expr::E>) { mark_type(heap, &expr.common().type_); match expr { @@ -122,7 +162,13 @@ fn mark_expression(heap: &mut Heap, expr: &expr::E>) { mark_expression(heap, &e.e2); } expr::E::IfElse(e) => { - mark_expression(heap, &e.condition); + match e.condition.as_ref() { + expr::IfElseCondition::Expression(e) => mark_expression(heap, e), + expr::IfElseCondition::Guard(p, e) => { + mark_matching_pattern(heap, p); + mark_expression(heap, e); + } + } mark_expression(heap, &e.e1); mark_expression(heap, &e.e2); } @@ -148,7 +194,7 @@ fn mark_expression(heap: &mut Heap, expr: &expr::E>) { for stmt in &e.statements { mark_expression(heap, &stmt.assigned_expression); mark_annot_opt(heap, &stmt.annotation); - mark_pattern(heap, &stmt.pattern); + mark_destructuring_pattern(heap, &stmt.pattern); } if let Some(e) = &e.expression { mark_expression(heap, e); @@ -292,6 +338,9 @@ mod tests { let { d } = Obj.init(5, 4); let [_, d1] = [1, 2]; let { e as d2 } = Obj.init(5, 4); // d = 4 + let _ = if let { e as d3 } = Obj.init(5, 4) then {} else {}; + let _ = if let Some(_) = Option.Some(1) then {} else {}; + let _ = if let [_, _] = [1,2] then {} else {}; let f = Obj.init(5, 4); // d = 4 let g = Obj.init(d, 4); // d = 4 let _ = f.d; diff --git a/crates/samlang-core/src/services/global_searcher.rs b/crates/samlang-core/src/services/global_searcher.rs index 1075cddb6..cdf848c7e 100644 --- a/crates/samlang-core/src/services/global_searcher.rs +++ b/crates/samlang-core/src/services/global_searcher.rs @@ -49,6 +49,76 @@ fn search_id_annot( } } +fn search_destructuring_pattern( + pattern: &pattern::DestructuringPattern>, + pattern_type: &Rc, + request: &GlobalNameSearchRequest, + collector: &mut Vec, +) { + match pattern { + pattern::DestructuringPattern::Tuple(_, patterns) => { + for p in patterns { + search_destructuring_pattern(&p.pattern, &p.type_, request, collector) + } + } + pattern::DestructuringPattern::Object(_, patterns) => { + match (pattern_type.as_ref(), request) { + ( + Type::Nominal(nominal_type), + GlobalNameSearchRequest::Property(mod_ref, toplevel_name, field_name), + ) if mod_ref.eq(&nominal_type.module_reference) && toplevel_name.eq(&nominal_type.id) => { + for n in patterns { + if field_name.eq(&n.field_name.name) { + collector.push(n.field_name.loc); + } + } + } + _ => {} + } + for p in patterns { + search_destructuring_pattern(&p.pattern, &p.type_, request, collector) + } + } + pattern::DestructuringPattern::Id(_, _) | pattern::DestructuringPattern::Wildcard(_) => {} + } +} + +fn search_matching_pattern( + pattern: &pattern::MatchingPattern>, + pattern_type: &Rc, + request: &GlobalNameSearchRequest, + collector: &mut Vec, +) { + match pattern { + pattern::MatchingPattern::Tuple(_, patterns) => { + for p in patterns { + search_matching_pattern(&p.pattern, &p.type_, request, collector) + } + } + pattern::MatchingPattern::Object(_, patterns) => { + match (pattern_type.as_ref(), request) { + ( + Type::Nominal(nominal_type), + GlobalNameSearchRequest::Property(mod_ref, toplevel_name, field_name), + ) if mod_ref.eq(&nominal_type.module_reference) && toplevel_name.eq(&nominal_type.id) => { + for n in patterns { + if field_name.eq(&n.field_name.name) { + collector.push(n.field_name.loc); + } + } + } + _ => {} + } + for p in patterns { + search_matching_pattern(&p.pattern, &p.type_, request, collector) + } + } + pattern::MatchingPattern::Variant(_) + | pattern::MatchingPattern::Id(_, _) + | pattern::MatchingPattern::Wildcard(_) => {} + } +} + fn search_expression( expr: &expr::E>, request: &GlobalNameSearchRequest, @@ -112,7 +182,13 @@ fn search_expression( search_expression(&e.e2, request, collector); } expr::E::IfElse(e) => { - search_expression(&e.condition, request, collector); + match e.condition.as_ref() { + expr::IfElseCondition::Expression(e) => search_expression(e, request, collector), + expr::IfElseCondition::Guard(p, e) => { + search_matching_pattern(p, e.type_(), request, collector); + search_expression(e, request, collector); + } + } search_expression(&e.e1, request, collector); search_expression(&e.e2, request, collector); } @@ -147,20 +223,12 @@ fn search_expression( if let Some(annot) = &stmt.annotation { search_annot(annot, request, collector); } - match (&stmt.pattern, request, stmt.assigned_expression.type_().as_nominal()) { - ( - pattern::DestructuringPattern::Object(_, destructured_names), - GlobalNameSearchRequest::Property(mod_ref, toplevel_name, field_name), - Some(nominal_type), - ) if mod_ref.eq(&nominal_type.module_reference) && toplevel_name.eq(&nominal_type.id) => { - for n in destructured_names { - if field_name.eq(&n.field_name.name) { - collector.push(n.field_name.loc); - } - } - } - _ => {} - } + search_destructuring_pattern( + &stmt.pattern, + stmt.assigned_expression.type_(), + request, + collector, + ); search_expression(&stmt.assigned_expression, request, collector); } if let Some(e) = &e.expression { @@ -276,6 +344,9 @@ mod tests { let { e as d } = Obj.init(5, 4); // d = 4 let f = Obj.init(5, 4); // d = 4 let g = Obj.init(d, 4); // d = 4 + let _ = if let { a as d3, b as d4 } = Foo.init(5, false) then {} else {}; + let _ = if let Some(_) = Option.Some(1) then {} else {}; + let _ = if let [_, _] = [1,2] then {} else {}; let _ = f.d; // 1 + 2 * 3 / 4 = 1 + 6/4 = 1 + 1 = 2 a + b * c / d diff --git a/crates/samlang-core/src/services/location_cover.rs b/crates/samlang-core/src/services/location_cover.rs index 052cf84c7..945fb9fcc 100644 --- a/crates/samlang-core/src/services/location_cover.rs +++ b/crates/samlang-core/src/services/location_cover.rs @@ -7,7 +7,7 @@ use crate::{ ModuleReference, }; use samlang_heap::PStr; -use std::{ops::Deref, rc::Rc}; +use std::rc::Rc; pub(super) enum LocationCoverSearchResult<'a> { Expression(&'a expr::E>), @@ -24,6 +24,45 @@ pub(super) enum LocationCoverSearchResult<'a> { TypedName(Location, PStr, Type), } +fn search_destructuring_pattern( + pattern: &pattern::DestructuringPattern>, + position: Position, +) -> Option { + match pattern { + pattern::DestructuringPattern::Tuple(_, patterns) => { + patterns.iter().find_map(|p| search_destructuring_pattern(&p.pattern, position)) + } + pattern::DestructuringPattern::Object(_, patterns) => { + patterns.iter().find_map(|p| search_destructuring_pattern(&p.pattern, position)) + } + pattern::DestructuringPattern::Id(id, type_) if id.loc.contains_position(position) => { + Some(LocationCoverSearchResult::TypedName(id.loc, id.name, type_.as_ref().clone())) + } + pattern::DestructuringPattern::Id(_, _) | pattern::DestructuringPattern::Wildcard(_) => None, + } +} + +fn search_matching_pattern( + pattern: &pattern::MatchingPattern>, + position: Position, +) -> Option { + match pattern { + pattern::MatchingPattern::Tuple(_, patterns) => { + patterns.iter().find_map(|p| search_matching_pattern(&p.pattern, position)) + } + pattern::MatchingPattern::Object(_, patterns) => { + patterns.iter().find_map(|p| search_matching_pattern(&p.pattern, position)) + } + pattern::MatchingPattern::Variant(variant_pattern) => { + variant_pattern.data_variables.iter().find_map(|(p, _)| search_matching_pattern(p, position)) + } + pattern::MatchingPattern::Id(id, type_) if id.loc.contains_position(position) => { + Some(LocationCoverSearchResult::TypedName(id.loc, id.name, type_.as_ref().clone())) + } + pattern::MatchingPattern::Id(_, _) | pattern::MatchingPattern::Wildcard(_) => None, + } +} + fn search_expression( expr: &expr::E>, position: Position, @@ -109,9 +148,13 @@ fn search_expression( } expr::E::Binary(e) => search_expression(&e.e1, position, stop_at_call) .or_else(|| search_expression(&e.e2, position, stop_at_call)), - expr::E::IfElse(e) => search_expression(&e.condition, position, stop_at_call) - .or_else(|| search_expression(&e.e1, position, stop_at_call)) - .or_else(|| search_expression(&e.e2, position, stop_at_call)), + expr::E::IfElse(e) => (match e.condition.as_ref() { + expr::IfElseCondition::Expression(e) => search_expression(e, position, stop_at_call), + expr::IfElseCondition::Guard(p, e) => search_matching_pattern(p, position) + .or_else(|| search_expression(e, position, stop_at_call)), + }) + .or_else(|| search_expression(&e.e1, position, stop_at_call)) + .or_else(|| search_expression(&e.e2, position, stop_at_call)), expr::E::Match(e) => { let mut found = search_expression(&e.matched, position, stop_at_call); for case in &e.cases { @@ -141,16 +184,9 @@ fn search_expression( if let Some(found) = search_expression(&stmt.assigned_expression, position, stop_at_call) { return Some(found); } - match &stmt.pattern { - pattern::DestructuringPattern::Id(id) if id.loc.contains_position(position) => { - return Some(LocationCoverSearchResult::TypedName( - id.loc, - id.name, - stmt.assigned_expression.common().type_.deref().clone(), - )) - } - _ => {} - }; + if let Some(found) = search_destructuring_pattern(&stmt.pattern, position) { + return Some(found); + } } if let Some(e) = &e.expression { return search_expression(e, position, stop_at_call); @@ -300,6 +336,9 @@ mod tests { let { e as d } = Obj.init(5, 4); // d = 4 let f = Obj.init(5, 4); // d = 4 let g = Obj.init(d, 4); // d = 4 + let _ = if let { a as d3 } = Foo.init(5) then {} else {}; + let _ = if let Some(_) = Option.Some(1) then {} else {}; + let _ = if let [_, _] = [1,2] then {} else {}; let _ = f.d; let [h, i] = [111, 122]; // 1 + 2 * 3 / 4 = 1 + 6/4 = 1 + 1 = 2 diff --git a/crates/samlang-core/src/services/variable_definition.rs b/crates/samlang-core/src/services/variable_definition.rs index da6bf4107..87bcd1e3a 100644 --- a/crates/samlang-core/src/services/variable_definition.rs +++ b/crates/samlang-core/src/services/variable_definition.rs @@ -119,7 +119,7 @@ fn apply_destructuring_pattern_renaming( )); let shorthand = matches!( pattern.as_ref(), - pattern::DestructuringPattern::Id(id) if id.name.eq(&field_name.name), + pattern::DestructuringPattern::Id(id, _) if id.name.eq(&field_name.name), ); pattern::ObjectPatternElement { loc: *loc, @@ -133,19 +133,97 @@ fn apply_destructuring_pattern_renaming( ) .collect(), ), - pattern::DestructuringPattern::Id(id) => { + pattern::DestructuringPattern::Id(id, ()) => { let name = if id.loc.eq(&definition_and_uses.definition_location) { new_name } else { id.name }; - pattern::DestructuringPattern::Id(Id { - loc: id.loc, - associated_comments: id.associated_comments, - name, - }) + pattern::DestructuringPattern::Id( + Id { loc: id.loc, associated_comments: id.associated_comments, name }, + (), + ) } pattern::DestructuringPattern::Wildcard(_) => pattern.clone(), } } +fn apply_matching_pattern_renaming( + pattern: &pattern::MatchingPattern<()>, + definition_and_uses: &DefinitionAndUses, + new_name: PStr, +) -> pattern::MatchingPattern<()> { + match pattern { + pattern::MatchingPattern::Tuple(l, names) => pattern::MatchingPattern::Tuple( + *l, + names + .iter() + .map(|pattern::TuplePatternElement { pattern, type_ }| pattern::TuplePatternElement { + pattern: Box::new(apply_matching_pattern_renaming( + pattern, + definition_and_uses, + new_name, + )), + type_: *type_, + }) + .collect(), + ), + pattern::MatchingPattern::Object(l, names) => pattern::MatchingPattern::Object( + *l, + names + .iter() + .map( + |pattern::ObjectPatternElement { + loc, + field_order, + field_name, + pattern, + shorthand: _, + type_, + }| { + let pattern = + Box::new(apply_matching_pattern_renaming(pattern, definition_and_uses, new_name)); + let shorthand = matches!( + pattern.as_ref(), + pattern::MatchingPattern::Id(id, _) if id.name.eq(&field_name.name), + ); + pattern::ObjectPatternElement { + loc: *loc, + field_order: *field_order, + field_name: *field_name, + pattern, + shorthand, + type_: *type_, + } + }, + ) + .collect(), + ), + pattern::MatchingPattern::Variant(pattern::VariantPattern { + loc, + tag_order, + tag, + data_variables, + type_: (), + }) => pattern::MatchingPattern::Variant(pattern::VariantPattern { + loc: *loc, + tag_order: *tag_order, + tag: *tag, + data_variables: data_variables + .iter() + .map(|(p, ())| (apply_matching_pattern_renaming(p, definition_and_uses, new_name), ())) + .collect(), + type_: (), + }), + pattern::MatchingPattern::Id(id, ()) => { + let name = + if id.loc.eq(&definition_and_uses.definition_location) { new_name } else { id.name }; + pattern::MatchingPattern::Id( + Id { loc: id.loc, associated_comments: id.associated_comments, name }, + (), + ) + } + pattern::MatchingPattern::Wildcard(_) => pattern.clone(), + } +} + fn apply_expr_renaming( expr: &expr::E<()>, definition_and_uses: &DefinitionAndUses, @@ -200,7 +278,15 @@ fn apply_expr_renaming( }), expr::E::IfElse(e) => expr::E::IfElse(expr::IfElse { common: e.common.clone(), - condition: Box::new(apply_expr_renaming(&e.condition, definition_and_uses, new_name)), + condition: Box::new(match e.condition.as_ref() { + expr::IfElseCondition::Expression(e) => { + expr::IfElseCondition::Expression(apply_expr_renaming(e, definition_and_uses, new_name)) + } + expr::IfElseCondition::Guard(p, e) => expr::IfElseCondition::Guard( + apply_matching_pattern_renaming(p, definition_and_uses, new_name), + apply_expr_renaming(e, definition_and_uses, new_name), + ), + }), e1: Box::new(apply_expr_renaming(&e.e1, definition_and_uses, new_name)), e2: Box::new(apply_expr_renaming(&e.e2, definition_and_uses, new_name)), }), @@ -542,6 +628,100 @@ interface Foo #[test] fn rename_test_2() { let source = r#" +class Main { + function test(a: int, b: bool): unit = { + let _ = if let { e as d3, f } = Obj.init(5, 4) then 1 else 2; + } +}"#; + let (_, lookup) = prepare_lookup(source); + assert_correctly_rewritten( + source, + &lookup, + Location::from_pos(3, 26, 3, 28), + r#"class Main { + function test(a: int, b: bool): unit = { + let _ = if let { e as renAmeD, f } = Obj.init( + 5, + 4 + ) then { + 1 + } else { + 2 + }; + } +} +"#, + ); + + let source = r#" +class Main { + function test(a: int, b: bool): unit = { + let _ = if let { renAmeD as d3 } = Obj.init(5, 4) then 1 else 2; + } +}"#; + let (_, lookup) = prepare_lookup(source); + assert_correctly_rewritten( + source, + &lookup, + Location::from_pos(3, 32, 3, 34), + r#"class Main { + function test(a: int, b: bool): unit = { + let _ = if let { renAmeD } = Obj.init(5, 4) then { + 1 + } else { + 2 + }; + } +} +"#, + ); + + let source = r#" +class Main { + function test(a: int, b: bool): unit = { + let _ = if let Some(v) = Option.Some(1) then 1 else 2; + } +}"#; + let (_, lookup) = prepare_lookup(source); + assert_correctly_rewritten( + source, + &lookup, + Location::from_pos(3, 24, 3, 25), + r#"class Main { + function test(a: int, b: bool): unit = { + let _ = if let Some(renAmeD) = Option.Some(1) then { + 1 + } else { + 2 + }; + } +} +"#, + ); + + let source = r#" +class Main { + function test(a: int, b: bool): unit = { + let _ = if let [v, _] = [1, 2] then 1 else 2; + } +}"#; + let (_, lookup) = prepare_lookup(source); + assert_correctly_rewritten( + source, + &lookup, + Location::from_pos(3, 20, 3, 21), + r#"class Main { + function test(a: int, b: bool): unit = { + let _ = if let [renAmeD, _] = [1, 2] then 1 else 2; + } +} +"#, + ); + } + + #[test] + fn rename_test_3() { + let source = r#" class Main { function test(a: int, b: bool): unit = { let c = a;