diff --git a/jac/jaclang/compiler/absyntree.py b/jac/jaclang/compiler/absyntree.py index 6260d7836d..f5a85d890f 100644 --- a/jac/jaclang/compiler/absyntree.py +++ b/jac/jaclang/compiler/absyntree.py @@ -346,6 +346,10 @@ def __init__(self) -> None: class Expr(AstNode): """Expr node type for Jac Ast.""" + def __init__(self) -> None: + """Initialize expression node.""" + self.expr_type = "" + class AtomExpr(Expr, AstSymbolStubNode): """AtomExpr node type for Jac Ast.""" @@ -437,6 +441,7 @@ def __init__(self) -> None: self._py_ctx_func: Type[ast3.AST] = ast3.Load self._sym_type: str = "NoType" self._type_sym_tab: Optional[SymbolTable] = None + AtomExpr.__init__(self) @property def sym(self) -> Optional[Symbol]: @@ -2527,6 +2532,7 @@ def __init__( """Initialize sync statement node.""" self.target = target AstNode.__init__(self, kid=kid) + Expr.__init__(self) def normalize(self, deep: bool = False) -> bool: """Normalize sync statement node.""" @@ -2655,6 +2661,7 @@ def __init__( self.right = right self.op = op AstNode.__init__(self, kid=kid) + Expr.__init__(self) def normalize(self, deep: bool = False) -> bool: """Normalize ast node.""" @@ -2689,6 +2696,7 @@ def __init__( self.rights = rights self.ops = ops AstNode.__init__(self, kid=kid) + Expr.__init__(self) def normalize(self, deep: bool = False) -> bool: """Normalize ast node.""" @@ -2720,6 +2728,7 @@ def __init__( self.values = values self.op = op AstNode.__init__(self, kid=kid) + Expr.__init__(self) def normalize(self, deep: bool = False) -> bool: """Normalize ast node.""" @@ -2750,6 +2759,7 @@ def __init__( self.signature = signature self.body = body AstNode.__init__(self, kid=kid) + Expr.__init__(self) def normalize(self, deep: bool = False) -> bool: """Normalize ast node.""" @@ -2782,6 +2792,7 @@ def __init__( self.operand = operand self.op = op AstNode.__init__(self, kid=kid) + Expr.__init__(self) def normalize(self, deep: bool = False) -> bool: """Normalize ast node.""" @@ -2809,6 +2820,7 @@ def __init__( self.value = value self.else_value = else_value AstNode.__init__(self, kid=kid) + Expr.__init__(self) def normalize(self, deep: bool = False) -> bool: """Normalize ast node.""" @@ -2839,6 +2851,7 @@ def __init__( """Initialize multi string expression node.""" self.strings = strings AstNode.__init__(self, kid=kid) + Expr.__init__(self) AstSymbolStubNode.__init__(self, sym_type=SymbolType.STRING) def normalize(self, deep: bool = False) -> bool: @@ -2865,6 +2878,7 @@ def __init__( """Initialize fstring expression node.""" self.parts = parts AstNode.__init__(self, kid=kid) + Expr.__init__(self) AstSymbolStubNode.__init__(self, sym_type=SymbolType.STRING) def normalize(self, deep: bool = False) -> bool: @@ -2895,6 +2909,7 @@ def __init__( """Initialize value node.""" self.values = values AstNode.__init__(self, kid=kid) + Expr.__init__(self) AstSymbolStubNode.__init__(self, sym_type=SymbolType.SEQUENCE) def normalize(self, deep: bool = False) -> bool: @@ -2923,6 +2938,7 @@ def __init__( """Initialize value node.""" self.values = values AstNode.__init__(self, kid=kid) + Expr.__init__(self) AstSymbolStubNode.__init__(self, sym_type=SymbolType.SEQUENCE) def normalize(self, deep: bool = False) -> bool: @@ -2951,6 +2967,7 @@ def __init__( """Initialize tuple value node.""" self.values = values AstNode.__init__(self, kid=kid) + Expr.__init__(self) AstSymbolStubNode.__init__(self, sym_type=SymbolType.SEQUENCE) def normalize(self, deep: bool = False) -> bool: @@ -2994,6 +3011,7 @@ def __init__( """Initialize dict expression node.""" self.kv_pairs = kv_pairs AstNode.__init__(self, kid=kid) + Expr.__init__(self) AstSymbolStubNode.__init__(self, sym_type=SymbolType.SEQUENCE) def normalize(self, deep: bool = False) -> bool: @@ -3127,6 +3145,7 @@ def __init__( self.out_expr = out_expr self.compr = compr AstNode.__init__(self, kid=kid) + Expr.__init__(self) AstSymbolStubNode.__init__(self, sym_type=SymbolType.SEQUENCE) def normalize(self, deep: bool = False) -> bool: @@ -3202,6 +3221,7 @@ def __init__( self.kv_pair = kv_pair self.compr = compr AstNode.__init__(self, kid=kid) + Expr.__init__(self) AstSymbolStubNode.__init__(self, sym_type=SymbolType.SEQUENCE) def normalize(self, deep: bool = False) -> bool: @@ -3240,6 +3260,7 @@ def __init__( self.is_null_ok = is_null_ok self.is_genai = is_genai AstNode.__init__(self, kid=kid) + Expr.__init__(self) def normalize(self, deep: bool = True) -> bool: """Normalize ast node.""" @@ -3285,6 +3306,7 @@ def __init__( """Initialize atom unit expression node.""" self.value = value AstNode.__init__(self, kid=kid) + Expr.__init__(self) def normalize(self, deep: bool = True) -> bool: """Normalize ast node.""" @@ -3312,6 +3334,7 @@ def __init__( self.expr = expr self.with_from = with_from AstNode.__init__(self, kid=kid) + Expr.__init__(self) def normalize(self, deep: bool = False) -> bool: """Normalize yield statement node.""" @@ -3343,6 +3366,7 @@ def __init__( self.params = params self.genai_call = genai_call AstNode.__init__(self, kid=kid) + Expr.__init__(self) def normalize(self, deep: bool = True) -> bool: """Normalize ast node.""" @@ -3377,6 +3401,7 @@ def __init__( self.step = step self.is_range = is_range AstNode.__init__(self, kid=kid) + Expr.__init__(self) AstSymbolStubNode.__init__(self, sym_type=SymbolType.SEQUENCE) def normalize(self, deep: bool = True) -> bool: @@ -3419,6 +3444,7 @@ def __init__( self.arch_name = arch_name self.arch_type = arch_type AstNode.__init__(self, kid=kid) + Expr.__init__(self) AstSymbolNode.__init__( self, sym_name=arch_name.sym_name, @@ -3449,6 +3475,7 @@ def __init__( self.chain = chain self.edges_only = edges_only AstNode.__init__(self, kid=kid) + Expr.__init__(self) def normalize(self, deep: bool = True) -> bool: """Normalize ast node.""" @@ -3478,6 +3505,7 @@ def __init__( self.filter_cond = filter_cond self.edge_dir = edge_dir AstNode.__init__(self, kid=kid) + Expr.__init__(self) WalkerStmtOnlyNode.__init__(self) AstSymbolStubNode.__init__(self, sym_type=SymbolType.SEQUENCE) @@ -3608,6 +3636,7 @@ def __init__( self.f_type = f_type self.compares = compares AstNode.__init__(self, kid=kid) + Expr.__init__(self) AstSymbolStubNode.__init__(self, sym_type=SymbolType.SEQUENCE) def normalize(self, deep: bool = False) -> bool: @@ -3645,6 +3674,7 @@ def __init__( """Initialize assign compr expression node.""" self.assigns = assigns AstNode.__init__(self, kid=kid) + Expr.__init__(self) AstSymbolStubNode.__init__(self, sym_type=SymbolType.SEQUENCE) def normalize(self, deep: bool = False) -> bool: @@ -4203,6 +4233,7 @@ def __init__( pos_end=pos_end, ) AstSymbolStubNode.__init__(self, sym_type=self.SYMBOL_TYPE) + Expr.__init__(self) @property def lit_value( diff --git a/jac/jaclang/compiler/passes/main/fuse_typeinfo_pass.py b/jac/jaclang/compiler/passes/main/fuse_typeinfo_pass.py index efe0a85251..5b4e2e2e0c 100644 --- a/jac/jaclang/compiler/passes/main/fuse_typeinfo_pass.py +++ b/jac/jaclang/compiler/passes/main/fuse_typeinfo_pass.py @@ -29,23 +29,32 @@ class FuseTypeInfoPass(Pass): node_type_hash: dict[MypyNodes.Node | VNode, MyType] = {} + # Override this to support enter expression. + def enter_node(self, node: ast.AstNode) -> None: + """Run on entering node.""" + if hasattr(self, f"enter_{pascal_to_snake(type(node).__name__)}"): + getattr(self, f"enter_{pascal_to_snake(type(node).__name__)}")(node) + + # TODO: Make (AstSymbolNode::name_spec.sym_typ and Expr::expr_type) the same + # TODO: Introduce AstTypedNode to be a common parent for Expr and AstSymbolNode + if isinstance(node, ast.Expr): + self.enter_expr(node) + def __debug_print(self, msg: str) -> None: if settings.fuse_type_info_debug: self.log_info("FuseTypeInfo::" + msg) - def __call_type_handler( - self, node: ast.AstSymbolNode, mypy_type: MypyTypes.Type - ) -> None: + def __call_type_handler(self, mypy_type: MypyTypes.Type) -> Optional[str]: mypy_type_name = pascal_to_snake(mypy_type.__class__.__name__) type_handler_name = f"get_type_from_{mypy_type_name}" if hasattr(self, type_handler_name): - getattr(self, type_handler_name)(node, mypy_type) - else: - self.__debug_print( - f'{node.loc}"MypyTypes::{mypy_type.__class__.__name__}" isn\'t supported yet' - ) + return getattr(self, type_handler_name)(mypy_type) + self.__debug_print( + f'"MypyTypes::{mypy_type.__class__.__name__}" isn\'t supported yet' + ) + return None - def __set_sym_table_link(self, node: ast.AstSymbolNode) -> None: + def __set_type_sym_table_link(self, node: ast.AstSymbolNode) -> None: typ = node.sym_type.split(".") typ_sym_table = self.ir.sym_tab @@ -117,7 +126,7 @@ def node_handler(self: FuseTypeInfoPass, node: T) -> None: # Jac node has only one mypy node linked to it if len(node.gen.mypy_ast) == 1: func(self, node) - self.__set_sym_table_link(node) + self.__set_type_sym_table_link(node) self.__collect_python_dependencies(node) # Jac node has multiple mypy nodes linked to it @@ -141,7 +150,7 @@ def node_handler(self: FuseTypeInfoPass, node: T) -> None: f"{jac_node_str} has duplicate mypy nodes associated to it" ) func(self, node) - self.__set_sym_table_link(node) + self.__set_type_sym_table_link(node) self.__collect_python_dependencies(node) # Jac node doesn't have mypy nodes linked to it @@ -164,7 +173,10 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None: if isinstance(mypy_node, MypyNodes.MemberExpr): if mypy_node in self.node_type_hash: - self.__call_type_handler(node, self.node_type_hash[mypy_node]) + node.name_spec.sym_type = ( + self.__call_type_handler(self.node_type_hash[mypy_node]) + or node.name_spec.sym_type + ) else: self.__debug_print(f"{node.loc} MemberExpr type is not found") @@ -173,7 +185,9 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None: mypy_node = mypy_node.node if isinstance(mypy_node, (MypyNodes.Var, MypyNodes.FuncDef)): - self.__call_type_handler(node, mypy_node.type) + node.name_spec.sym_type = ( + self.__call_type_handler(mypy_node.type) or node.name_spec.sym_type + ) elif isinstance(mypy_node, MypyNodes.MypyFile): node.name_spec.sym_type = "types.ModuleType" @@ -182,7 +196,10 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None: node.name_spec.sym_type = mypy_node.fullname elif isinstance(mypy_node, MypyNodes.OverloadedFuncDef): - self.__call_type_handler(node, mypy_node.items[0].func.type) + node.name_spec.sym_type = ( + self.__call_type_handler(mypy_node.items[0].func.type) + or node.name_spec.sym_type + ) elif mypy_node is None: node.name_spec.sym_type = "None" @@ -196,19 +213,66 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None: else: if isinstance(mypy_node, MypyNodes.ClassDef): node.name_spec.sym_type = mypy_node.fullname - self.__set_sym_table_link(node) + self.__set_type_sym_table_link(node) elif isinstance(mypy_node, MypyNodes.FuncDef): - self.__call_type_handler(node, mypy_node.type) + node.name_spec.sym_type = ( + self.__call_type_handler(mypy_node.type) or node.name_spec.sym_type + ) elif isinstance(mypy_node, MypyNodes.Argument): - self.__call_type_handler(node, mypy_node.variable.type) + node.name_spec.sym_type = ( + self.__call_type_handler(mypy_node.variable.type) + or node.name_spec.sym_type + ) elif isinstance(mypy_node, MypyNodes.Decorator): - self.__call_type_handler(node, mypy_node.func.type.ret_type) + node.name_spec.sym_type = ( + self.__call_type_handler(mypy_node.func.type.ret_type) + or node.name_spec.sym_type + ) else: self.__debug_print( f'"{node.loc}::{node.__class__.__name__}" mypy node isn\'t supported' f"{type(mypy_node)}" ) + collection_types_map = { + ast.ListVal: "builtins.list", + ast.SetVal: "builtins.set", + ast.TupleVal: "builtins.tuple", + ast.DictVal: "builtins.dict", + ast.ListCompr: None, + ast.DictCompr: None, + } + + # NOTE (Thakee): Since expression nodes are not AstSymbolNodes, I'm not decorating this with __handle_node + # and IMO instead of checking if it's a symbol node or an expression, we somehow mark expressions as + # valid nodes that can have symbols. At this point I'm leaving this like this and lemme know + # otherwise. + # NOTE (GAMAL): This will be fixed through the AstTypedNode + def enter_expr(self: FuseTypeInfoPass, node: ast.Expr) -> None: + """Enter an expression node.""" + if len(node.gen.mypy_ast) == 0: + return + + # If the corrosponding mypy ast node type has stored here, get the values. + mypy_node = node.gen.mypy_ast[0] + if mypy_node in self.node_type_hash: + mytype: MyType = self.node_type_hash[mypy_node] + node.expr_type = self.__call_type_handler(mytype) or "" + + # Set they symbol type for collection expression. + # + # GenCompr is an instance of ListCompr but we don't handle it here. + # so the isinstace (node, ) doesn't work, I'm going with type(...) == ... + if type(node) in self.collection_types_map: + assert isinstance(node, ast.AtomExpr) # To make mypy happy. + collection_type = self.collection_types_map[type(node)] + if collection_type is not None: + node.name_spec.sym_type = collection_type + if mypy_node in self.node_type_hash: + node.name_spec.sym_type = ( + self.__call_type_handler(mytype) or node.name_spec.sym_type + ) + @__handle_node def enter_name(self, node: ast.NameAtom) -> None: """Pass handler for name nodes.""" @@ -248,7 +312,10 @@ def enter_enum_def(self, node: ast.EnumDef) -> None: def enter_ability(self, node: ast.Ability) -> None: """Pass handler for Ability nodes.""" if isinstance(node.gen.mypy_ast[0], MypyNodes.FuncDef): - self.__call_type_handler(node, node.gen.mypy_ast[0].type.ret_type) + node.name_spec.sym_type = ( + self.__call_type_handler(node.gen.mypy_ast[0].type.ret_type) + or node.name_spec.sym_type + ) else: self.__debug_print( f"{node.loc}: Can't get type of an ability from mypy node other than Ability. " @@ -259,7 +326,10 @@ def enter_ability(self, node: ast.Ability) -> None: def enter_ability_def(self, node: ast.AbilityDef) -> None: """Pass handler for AbilityDef nodes.""" if isinstance(node.gen.mypy_ast[0], MypyNodes.FuncDef): - self.__call_type_handler(node, node.gen.mypy_ast[0].type.ret_type) + node.name_spec.sym_type = ( + self.__call_type_handler(node.gen.mypy_ast[0].type.ret_type) + or node.name_spec.sym_type + ) else: self.__debug_print( f"{node.loc}: Can't get type of an AbilityDef from mypy node other than FuncDef. " @@ -272,7 +342,10 @@ def enter_param_var(self, node: ast.ParamVar) -> None: if isinstance(node.gen.mypy_ast[0], MypyNodes.Argument): mypy_node: MypyNodes.Argument = node.gen.mypy_ast[0] if mypy_node.variable.type: - self.__call_type_handler(node, mypy_node.variable.type) + node.name_spec.sym_type = ( + self.__call_type_handler(mypy_node.variable.type) + or node.name_spec.sym_type + ) else: self.__debug_print( f"{node.loc}: Can't get parameter value from mypyNode other than Argument" @@ -286,7 +359,9 @@ def enter_has_var(self, node: ast.HasVar) -> None: if isinstance(mypy_node, MypyNodes.AssignmentStmt): n = mypy_node.lvalues[0].node if isinstance(n, (MypyNodes.Var, MypyNodes.FuncDef)): - self.__call_type_handler(node, n.type) + node.name_spec.sym_type = ( + self.__call_type_handler(n.type) or node.name_spec.sym_type + ) else: self.__debug_print( "Getting type of 'AssignmentStmt' is only supported with Var and FuncDef" @@ -311,54 +386,6 @@ def enter_f_string(self, node: ast.FString) -> None: """Pass handler for FString nodes.""" self.__debug_print(f"Getting type not supported in {type(node)}") - @__handle_node - def enter_list_val(self, node: ast.ListVal) -> None: - """Pass handler for ListVal nodes.""" - mypy_node = node.gen.mypy_ast[0] - if mypy_node in self.node_type_hash: - node.name_spec.sym_type = str(self.node_type_hash[mypy_node]) - else: - node.name_spec.sym_type = "builtins.list" - - @__handle_node - def enter_set_val(self, node: ast.SetVal) -> None: - """Pass handler for SetVal nodes.""" - mypy_node = node.gen.mypy_ast[0] - if mypy_node in self.node_type_hash: - node.name_spec.sym_type = str(self.node_type_hash[mypy_node]) - else: - node.name_spec.sym_type = "builtins.set" - - @__handle_node - def enter_tuple_val(self, node: ast.TupleVal) -> None: - """Pass handler for TupleVal nodes.""" - mypy_node = node.gen.mypy_ast[0] - if mypy_node in self.node_type_hash: - node.name_spec.sym_type = str(self.node_type_hash[mypy_node]) - else: - node.name_spec.sym_type = "builtins.tuple" - - @__handle_node - def enter_dict_val(self, node: ast.DictVal) -> None: - """Pass handler for DictVal nodes.""" - mypy_node = node.gen.mypy_ast[0] - if mypy_node in self.node_type_hash: - node.name_spec.sym_type = str(self.node_type_hash[mypy_node]) - else: - node.name_spec.sym_type = "builtins.dict" - - @__handle_node - def enter_list_compr(self, node: ast.ListCompr) -> None: - """Pass handler for ListCompr nodes.""" - mypy_node = node.gen.mypy_ast[0] - node.name_spec.sym_type = str(self.node_type_hash[mypy_node]) - - @__handle_node - def enter_dict_compr(self, node: ast.DictCompr) -> None: - """Pass handler for DictCompr nodes.""" - mypy_node = node.gen.mypy_ast[0] - node.name_spec.sym_type = str(self.node_type_hash[mypy_node]) - @__handle_node def enter_index_slice(self, node: ast.IndexSlice) -> None: """Pass handler for IndexSlice nodes.""" @@ -370,10 +397,12 @@ def enter_arch_ref(self, node: ast.ArchRef) -> None: if isinstance(node.gen.mypy_ast[0], MypyNodes.ClassDef): mypy_node: MypyNodes.ClassDef = node.gen.mypy_ast[0] node.name_spec.sym_type = mypy_node.fullname - self.__set_sym_table_link(node) + self.__set_type_sym_table_link(node) elif isinstance(node.gen.mypy_ast[0], MypyNodes.FuncDef): mypy_node2: MypyNodes.FuncDef = node.gen.mypy_ast[0] - self.__call_type_handler(node, mypy_node2.type) + node.name_spec.sym_type = ( + self.__call_type_handler(mypy_node2.type) or node.name_spec.sym_type + ) else: self.__debug_print( f"{node.loc}: Can't get ArchRef value from mypyNode other than ClassDef " @@ -425,48 +454,38 @@ def enter_builtin_type(self, node: ast.BuiltinType) -> None: """Pass handler for BuiltinType nodes.""" self.__collect_type_from_symbol(node) - def get_type_from_instance( - self, node: ast.AstSymbolNode, mypy_type: MypyTypes.Instance - ) -> None: + def get_type_from_instance(self, mypy_type: MypyTypes.Instance) -> Optional[str]: """Get type info from mypy type Instance.""" - node.name_spec.sym_type = str(mypy_type) + return str(mypy_type) def get_type_from_callable_type( - self, node: ast.AstSymbolNode, mypy_type: MypyTypes.CallableType - ) -> None: + self, mypy_type: MypyTypes.CallableType + ) -> Optional[str]: """Get type info from mypy type CallableType.""" - self.__call_type_handler(node, mypy_type.ret_type) + return self.__call_type_handler(mypy_type.ret_type) # TODO: Which overloaded function to get the return value from? def get_type_from_overloaded( - self, node: ast.AstSymbolNode, mypy_type: MypyTypes.Overloaded - ) -> None: + self, mypy_type: MypyTypes.Overloaded + ) -> Optional[str]: """Get type info from mypy type Overloaded.""" - self.__call_type_handler(node, mypy_type.items[0]) + return self.__call_type_handler(mypy_type.items[0]) - def get_type_from_none_type( - self, node: ast.AstSymbolNode, mypy_type: MypyTypes.NoneType - ) -> None: + def get_type_from_none_type(self, mypy_type: MypyTypes.NoneType) -> Optional[str]: """Get type info from mypy type NoneType.""" - node.name_spec.sym_type = "None" + return "None" - def get_type_from_any_type( - self, node: ast.AstSymbolNode, mypy_type: MypyTypes.AnyType - ) -> None: + def get_type_from_any_type(self, mypy_type: MypyTypes.AnyType) -> Optional[str]: """Get type info from mypy type NoneType.""" - node.name_spec.sym_type = "Any" + return "Any" - def get_type_from_tuple_type( - self, node: ast.AstSymbolNode, mypy_type: MypyTypes.TupleType - ) -> None: + def get_type_from_tuple_type(self, mypy_type: MypyTypes.TupleType) -> Optional[str]: """Get type info from mypy type TupleType.""" - node.name_spec.sym_type = "builtins.tuple" + return "builtins.tuple" - def get_type_from_type_type( - self, node: ast.AstSymbolNode, mypy_type: MypyTypes.TypeType - ) -> None: + def get_type_from_type_type(self, mypy_type: MypyTypes.TypeType) -> Optional[str]: """Get type info from mypy type TypeType.""" - node.name_spec.sym_type = str(mypy_type.item) + return str(mypy_type.item) def exit_assignment(self, node: ast.Assignment) -> None: """Add new symbols in the symbol table in case of self.""" diff --git a/jac/jaclang/compiler/passes/main/tests/test_type_check_pass.py b/jac/jaclang/compiler/passes/main/tests/test_type_check_pass.py index 6f0c863b2d..80d4b13e78 100644 --- a/jac/jaclang/compiler/passes/main/tests/test_type_check_pass.py +++ b/jac/jaclang/compiler/passes/main/tests/test_type_check_pass.py @@ -59,6 +59,6 @@ def test_type_coverage(self) -> None: self.assertIn("HasVar - species - Type: builtins.str", out) self.assertIn("myDog - Type: type_info.Dog", out) self.assertIn("Body - Type: type_info.Dog.Body", out) - self.assertEqual(out.count("Type: builtins.str"), 28) + self.assertEqual(out.count("Type: builtins.str"), 39) for i in lis: self.assertNotIn(i, out) diff --git a/jac/jaclang/compiler/passes/utils/mypy_ast_build.py b/jac/jaclang/compiler/passes/utils/mypy_ast_build.py index 536bdbeee4..fee57f155d 100644 --- a/jac/jaclang/compiler/passes/utils/mypy_ast_build.py +++ b/jac/jaclang/compiler/passes/utils/mypy_ast_build.py @@ -4,6 +4,7 @@ import ast import os +from types import MethodType from typing import Callable, TYPE_CHECKING, TextIO from jaclang.compiler.absyntree import AstNode @@ -11,11 +12,13 @@ from jaclang.compiler.passes.main.fuse_typeinfo_pass import ( FuseTypeInfoPass, ) +from jaclang.utils.helpers import pascal_to_snake import mypy.build as myb import mypy.checkexpr as mycke import mypy.errors as mye import mypy.fastparse as myfp +import mypy.nodes as mypy_nodes from mypy.build import BuildSource from mypy.build import BuildSourceSet from mypy.build import FileSystemCache @@ -36,6 +39,55 @@ from mypy.report import Reports # Avoid unconditional slow import +# All the expression nodes of mypy. +EXPRESSION_NODES = ( + mypy_nodes.AssertTypeExpr, + mypy_nodes.AssignmentExpr, + mypy_nodes.AwaitExpr, + mypy_nodes.BytesExpr, + mypy_nodes.CallExpr, + mypy_nodes.CastExpr, + mypy_nodes.ComparisonExpr, + mypy_nodes.ComplexExpr, + mypy_nodes.ConditionalExpr, + mypy_nodes.DictionaryComprehension, + mypy_nodes.DictExpr, + mypy_nodes.EllipsisExpr, + mypy_nodes.EnumCallExpr, + mypy_nodes.Expression, + mypy_nodes.FloatExpr, + mypy_nodes.GeneratorExpr, + mypy_nodes.IndexExpr, + mypy_nodes.IntExpr, + mypy_nodes.LambdaExpr, + mypy_nodes.ListComprehension, + mypy_nodes.ListExpr, + mypy_nodes.MemberExpr, + mypy_nodes.NamedTupleExpr, + mypy_nodes.NameExpr, + mypy_nodes.NewTypeExpr, + mypy_nodes.OpExpr, + mypy_nodes.ParamSpecExpr, + mypy_nodes.PromoteExpr, + mypy_nodes.RefExpr, + mypy_nodes.RevealExpr, + mypy_nodes.SetComprehension, + mypy_nodes.SetExpr, + mypy_nodes.SliceExpr, + mypy_nodes.StarExpr, + mypy_nodes.StrExpr, + mypy_nodes.SuperExpr, + mypy_nodes.TupleExpr, + mypy_nodes.TypeAliasExpr, + mypy_nodes.TypedDictExpr, + mypy_nodes.TypeVarExpr, + mypy_nodes.TypeVarTupleExpr, + mypy_nodes.UnaryExpr, + mypy_nodes.YieldExpr, + mypy_nodes.YieldFromExpr, +) + + mypy_to_jac_node_map: dict[ tuple[int, int | None, int | None, int | None], list[AstNode] ] = {} @@ -131,63 +183,45 @@ def __init__( """Override to mypy expression checker for direct AST pass through.""" super().__init__(tc, msg, plugin, per_line_checking_time_ns) - def visit_list_expr(self, e: mycke.ListExpr) -> mycke.Type: - """Type check a list expression [...].""" - out = super().visit_list_expr(e) - FuseTypeInfoPass.node_type_hash[e] = out - return out - - def visit_set_expr(self, e: mycke.SetExpr) -> mycke.Type: - """Type check a set expression {...}.""" - out = super().visit_set_expr(e) - FuseTypeInfoPass.node_type_hash[e] = out - return out - - def visit_tuple_expr(self, e: myfp.TupleExpr) -> myb.Type: - """Type check a tuple expression (...).""" - out = super().visit_tuple_expr(e) - FuseTypeInfoPass.node_type_hash[e] = out - return out - - def visit_dict_expr(self, e: myfp.DictExpr) -> myb.Type: - """Type check a dictionary expression {...}.""" - out = super().visit_dict_expr(e) - FuseTypeInfoPass.node_type_hash[e] = out - return out - - def visit_list_comprehension(self, e: myfp.ListComprehension) -> myb.Type: - """Type check a list comprehension.""" - out = super().visit_list_comprehension(e) - FuseTypeInfoPass.node_type_hash[e] = out - return out - - def visit_set_comprehension(self, e: myfp.SetComprehension) -> myb.Type: - """Type check a set comprehension.""" - out = super().visit_set_comprehension(e) - FuseTypeInfoPass.node_type_hash[e] = out - return out - - def visit_generator_expr(self, e: myfp.GeneratorExpr) -> myb.Type: - """Type check a generator expression.""" - out = super().visit_generator_expr(e) - FuseTypeInfoPass.node_type_hash[e] = out - return out - - def visit_dictionary_comprehension( - self, e: myfp.DictionaryComprehension - ) -> myb.Type: - """Type check a dict comprehension.""" - out = super().visit_dictionary_comprehension(e) - FuseTypeInfoPass.node_type_hash[e] = out - return out - - def visit_member_expr( - self, e: myfp.MemberExpr, is_lvalue: bool = False - ) -> myb.Type: - """Type check a member expr.""" - out = super().visit_member_expr(e, is_lvalue) - FuseTypeInfoPass.node_type_hash[e] = out - return out + # For every expression there, create attach a method on this instance (self) named "enter_expr()" + for expr_node in EXPRESSION_NODES: + method_name = "visit_" + pascal_to_snake(expr_node.__name__) + + # We call the super() version of the method so ensure the parent class has the method or else continue. + if not hasattr(mycke.ExpressionChecker, method_name): + continue + + # If the method already overriden then don't override it again here. Continue. Note that the method exists + # on the parent class and if it's also exists on this class and it's a different object that means it's + # overrident method. + if getattr(mycke.ExpressionChecker, method_name) != getattr( + ExpressionChecker, method_name + ): + continue + + # Since the "closure" function bellow captures the method name inside it, we cannot use it directly as the + # "method_name" variable is used inside a loop and by the time the closure close the "method_name" value, + # it'll be changed by the loop, so we need another method ("make_closure") to persist the value. + def make_closure(method_name: str): # noqa: ANN201 + def closure( + self: ExpressionChecker, + e: mycke.Expression, + *args, # noqa: ANN002 + **kwargs, # noqa: ANN003 + ) -> mycke.Type: + # Ignore B023 here since we bind loop variable properly but flake8 raise a false alarm + # (in some version of it), a bug in flake8 (https://github.com/PyCQA/flake8-bugbear/issues/269). + out = getattr(mycke.ExpressionChecker, method_name)( # noqa: B023 + self, e, *args, **kwargs + ) + FuseTypeInfoPass.node_type_hash[e] = out + return out + + return closure + + # Attach the new "visit_expr()" method to this instance. + method = make_closure(method_name) + setattr(self, method_name, MethodType(method, self)) class State(myb.State): diff --git a/jac/jaclang/utils/treeprinter.py b/jac/jaclang/utils/treeprinter.py index d4265e39d3..a896e78d81 100644 --- a/jac/jaclang/utils/treeprinter.py +++ b/jac/jaclang/utils/treeprinter.py @@ -143,6 +143,8 @@ def __node_repr_in_tree(node: AstNode) -> str: ) out += f" SymbolPath: {symbol}" return out + elif isinstance(node, ast.Expr): + return f"{node.__class__.__name__} - Type: {node.expr_type}" else: return f"{node.__class__.__name__}, {access}"