Skip to content
This repository has been archived by the owner on Sep 12, 2024. It is now read-only.

Commit

Permalink
enter_expr refactored
Browse files Browse the repository at this point in the history
  • Loading branch information
ThakeeNathees committed Aug 29, 2024
1 parent a9c6a2b commit d42cc95
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 113 deletions.
217 changes: 105 additions & 112 deletions jaclang/compiler/passes/main/fuse_typeinfo_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@

from __future__ import annotations

from types import MethodType
from typing import Callable, Optional, TypeVar

import jaclang.compiler.absyntree as ast
from jaclang.compiler.passes import Pass
from jaclang.compiler.passes.transform import Transform
from jaclang.settings import settings
from jaclang.utils.helpers import pascal_to_snake
from jaclang.vendor.mypy.nodes import Node as VNode # bit of a hack
Expand All @@ -24,97 +22,32 @@
T = TypeVar("T", bound=ast.AstSymbolNode)


# List of expression nodes which we'll be extracting the type info from.
JAC_EXPR_NODES = (
ast.AwaitExpr,
ast.BinaryExpr,
ast.CompareExpr,
ast.BoolExpr,
ast.LambdaExpr,
ast.UnaryExpr,
ast.IfElseExpr,
ast.AtomTrailer,
ast.AtomUnit,
ast.YieldExpr,
ast.YieldExpr,
ast.FuncCall,
ast.EdgeRefTrailer,
ast.ListVal,
ast.SetVal,
ast.TupleVal,
ast.DictVal,
ast.ListCompr,
ast.DictCompr,
)


class FuseTypeInfoPass(Pass):
"""Python and bytecode file self.__debug_printing pass."""

node_type_hash: dict[MypyNodes.Node | VNode, MyType] = {}

@staticmethod
def enter_expr(self: FuseTypeInfoPass, node: ast.Expr) -> None:
"""
Enter an expression node.
This function is dynamically bound as a method on insntace of this class, since the
group of functions to handle expressions has a the exact same logic.
"""
if len(node.gen.mypy_ast) == 0:
return

# If the corrosponding mypy ast node type has stored here, get the values.
mypy_node = node.gen.mypy_ast[0]
if mypy_node in self.node_type_hash:
mytype: MyType = self.node_type_hash[mypy_node]
node.expr_type = str(mytype)

# TODO: Maybe move this out of the function otherwise it'll construct this dict every time it entered an
# expression. Time and memory wasted here.
collection_types_map = {
ast.ListVal: "builtins.list",
ast.SetVal: "builtins.set",
ast.TupleVal: "builtins.tuple",
ast.DictVal: "builtins.dict",
ast.ListCompr: None,
ast.DictCompr: None,
}

# Set they symbol type for collection expression.
if type(node) in tuple(collection_types_map.keys()):
assert isinstance(node, ast.AtomExpr) # To make mypy happy.
if mypy_node in self.node_type_hash:
node.name_spec.sym_type = str(mytype)
collection_type = collection_types_map[type(node)]
if collection_type is not None:
node.name_spec.sym_type = collection_type

def __init__(self, input_ir: T, prior: Optional[Transform]) -> None:
"""Initialize the FuseTpeInfoPass instance."""
for expr_node in JAC_EXPR_NODES:
method_name = "enter_" + pascal_to_snake(expr_node.__name__)
method = MethodType(
FuseTypeInfoPass.__handle_node(FuseTypeInfoPass.enter_expr), self
)
setattr(self, method_name, method)
super().__init__(input_ir, prior)
# Override this to support enter expression.
def enter_node(self, node: ast.AstNode) -> None:
"""Run on entering node."""
if hasattr(self, f"enter_{pascal_to_snake(type(node).__name__)}"):
getattr(self, f"enter_{pascal_to_snake(type(node).__name__)}")(node)
elif isinstance(node, ast.Expr):
self.enter_expr(node)

def __debug_print(self, *argv: object) -> None:
if settings.fuse_type_info_debug:
self.log_info("FuseTypeInfo::", *argv)

def __call_type_handler(
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.ProperType
) -> None:
def __call_type_handler(self, mypy_type: MypyTypes.Type) -> Optional[str]:
mypy_type_name = pascal_to_snake(mypy_type.__class__.__name__)
type_handler_name = f"get_type_from_{mypy_type_name}"
if hasattr(self, type_handler_name):
getattr(self, type_handler_name)(node, mypy_type)
else:
self.__debug_print(
f'{node.loc}"MypyTypes::{mypy_type.__class__.__name__}" isn\'t supported yet'
)
return getattr(self, type_handler_name)(mypy_type)
self.__debug_print(
f'"MypyTypes::{mypy_type.__class__.__name__}" isn\'t supported yet'
)
return None

def __set_sym_table_link(self, node: ast.AstSymbolNode) -> None:
typ = node.sym_type.split(".")
Expand Down Expand Up @@ -244,7 +177,9 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None:
mypy_node = mypy_node.node

if isinstance(mypy_node, (MypyNodes.Var, MypyNodes.FuncDef)):
self.__call_type_handler(node, mypy_node.type)
node.name_spec.sym_type = (
self.__call_type_handler(mypy_node.type) or node.name_spec.sym_type
)

elif isinstance(mypy_node, MypyNodes.MypyFile):
node.name_spec.sym_type = "types.ModuleType"
Expand All @@ -253,7 +188,10 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None:
node.name_spec.sym_type = mypy_node.fullname

elif isinstance(mypy_node, MypyNodes.OverloadedFuncDef):
self.__call_type_handler(node, mypy_node.items[0].func.type)
node.name_spec.sym_type = (
self.__call_type_handler(mypy_node.items[0].func.type)
or node.name_spec.sym_type
)

elif mypy_node is None:
node.name_spec.sym_type = "None"
Expand All @@ -269,17 +207,67 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None:
node.name_spec.sym_type = mypy_node.fullname
self.__set_sym_table_link(node)
elif isinstance(mypy_node, MypyNodes.FuncDef):
self.__call_type_handler(node, mypy_node.type)
node.name_spec.sym_type = (
self.__call_type_handler(mypy_node.type) or node.name_spec.sym_type
)
elif isinstance(mypy_node, MypyNodes.Argument):
self.__call_type_handler(node, mypy_node.variable.type)
node.name_spec.sym_type = (
self.__call_type_handler(mypy_node.variable.type)
or node.name_spec.sym_type
)
elif isinstance(mypy_node, MypyNodes.Decorator):
self.__call_type_handler(node, mypy_node.func.type.ret_type)
node.name_spec.sym_type = (
self.__call_type_handler(mypy_node.func.type.ret_type)
or node.name_spec.sym_type
)
else:
self.__debug_print(
f'"{node.loc}::{node.__class__.__name__}" mypy node isn\'t supported',
type(mypy_node),
)

# NOTE: Since expression nodes are not AstSymbolNodes, I'm not decorating this with __handle_node
# and IMO instead of checking if it's a symbol node or an expression, we somehow mark expressions as
# valid nodes that can have symbols. At this point I'm leaving this like this and lemme know
# otherwise.
def enter_expr(self: FuseTypeInfoPass, node: ast.Expr) -> None:
"""
Enter an expression node.
This function is dynamically bound as a method on insntace of this class, since the
group of functions to handle expressions has a the exact same logic.
"""
if len(node.gen.mypy_ast) == 0:
return

# If the corrosponding mypy ast node type has stored here, get the values.
mypy_node = node.gen.mypy_ast[0]
if mypy_node in self.node_type_hash:
mytype: MyType = self.node_type_hash[mypy_node]
node.expr_type = self.__call_type_handler(mytype) or ""

# TODO: Maybe move this out of the function otherwise it'll construct this dict every time it entered an
# expression. Time and memory wasted here.
collection_types_map = {
ast.ListVal: "builtins.list",
ast.SetVal: "builtins.set",
ast.TupleVal: "builtins.tuple",
ast.DictVal: "builtins.dict",
ast.ListCompr: None,
ast.DictCompr: None,
}

# Set they symbol type for collection expression.
if type(node) in tuple(collection_types_map.keys()):
assert isinstance(node, ast.AtomExpr) # To make mypy happy.
collection_type = collection_types_map[type(node)]
if collection_type is not None:
node.name_spec.sym_type = collection_type
if mypy_node in self.node_type_hash:
node.name_spec.sym_type = (
self.__call_type_handler(mytype) or node.name_spec.sym_type
)

@__handle_node
def enter_name(self, node: ast.NameAtom) -> None:
"""Pass handler for name nodes."""
Expand Down Expand Up @@ -319,7 +307,10 @@ def enter_enum_def(self, node: ast.EnumDef) -> None:
def enter_ability(self, node: ast.Ability) -> None:
"""Pass handler for Ability nodes."""
if isinstance(node.gen.mypy_ast[0], MypyNodes.FuncDef):
self.__call_type_handler(node, node.gen.mypy_ast[0].type.ret_type)
node.name_spec.sym_type = (
self.__call_type_handler(node.gen.mypy_ast[0].type.ret_type)
or node.name_spec.sym_type
)
else:
self.__debug_print(
f"{node.loc}: Can't get type of an ability from mypy node other than Ability.",
Expand All @@ -330,7 +321,10 @@ def enter_ability(self, node: ast.Ability) -> None:
def enter_ability_def(self, node: ast.AbilityDef) -> None:
"""Pass handler for AbilityDef nodes."""
if isinstance(node.gen.mypy_ast[0], MypyNodes.FuncDef):
self.__call_type_handler(node, node.gen.mypy_ast[0].type.ret_type)
node.name_spec.sym_type = (
self.__call_type_handler(node.gen.mypy_ast[0].type.ret_type)
or node.name_spec.sym_type
)
else:
self.__debug_print(
f"{node.loc}: Can't get type of an AbilityDef from mypy node other than FuncDef.",
Expand All @@ -343,7 +337,10 @@ def enter_param_var(self, node: ast.ParamVar) -> None:
if isinstance(node.gen.mypy_ast[0], MypyNodes.Argument):
mypy_node: MypyNodes.Argument = node.gen.mypy_ast[0]
if mypy_node.variable.type:
self.__call_type_handler(node, mypy_node.variable.type)
node.name_spec.sym_type = (
self.__call_type_handler(mypy_node.variable.type)
or node.name_spec.sym_type
)
else:
self.__debug_print(
f"{node.loc}: Can't get parameter value from mypyNode other than Argument"
Expand All @@ -357,7 +354,9 @@ def enter_has_var(self, node: ast.HasVar) -> None:
if isinstance(mypy_node, MypyNodes.AssignmentStmt):
n = mypy_node.lvalues[0].node
if isinstance(n, (MypyNodes.Var, MypyNodes.FuncDef)):
self.__call_type_handler(node, n.type)
node.name_spec.sym_type = (
self.__call_type_handler(n.type) or node.name_spec.sym_type
)
else:
self.__debug_print(
"Getting type of 'AssignmentStmt' is only supported with Var and FuncDef"
Expand Down Expand Up @@ -396,7 +395,9 @@ def enter_arch_ref(self, node: ast.ArchRef) -> None:
self.__set_sym_table_link(node)
elif isinstance(node.gen.mypy_ast[0], MypyNodes.FuncDef):
mypy_node2: MypyNodes.FuncDef = node.gen.mypy_ast[0]
self.__call_type_handler(node, mypy_node2.type)
node.name_spec.sym_type = (
self.__call_type_handler(mypy_node2.type) or node.name_spec.sym_type
)
else:
self.__debug_print(
f"{node.loc}: Can't get ArchRef value from mypyNode other than ClassDef",
Expand Down Expand Up @@ -448,42 +449,34 @@ def enter_builtin_type(self, node: ast.BuiltinType) -> None:
"""Pass handler for BuiltinType nodes."""
self.__collect_type_from_symbol(node)

def get_type_from_instance(
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.Instance
) -> None:
def get_type_from_instance(self, mypy_type: MypyTypes.Instance) -> Optional[str]:
"""Get type info from mypy type Instance."""
node.name_spec.sym_type = str(mypy_type)
return str(mypy_type)

def get_type_from_callable_type(
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.CallableType
) -> None:
self, mypy_type: MypyTypes.CallableType
) -> Optional[str]:
"""Get type info from mypy type CallableType."""
node.name_spec.sym_type = str(mypy_type.ret_type)
return str(mypy_type.ret_type)

# TODO: Which overloaded function to get the return value from?
def get_type_from_overloaded(
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.Overloaded
) -> None:
self, mypy_type: MypyTypes.Overloaded
) -> Optional[str]:
"""Get type info from mypy type Overloaded."""
self.__call_type_handler(node, mypy_type.items[0])
return self.__call_type_handler(mypy_type.items[0])

def get_type_from_none_type(
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.NoneType
) -> None:
def get_type_from_none_type(self, mypy_type: MypyTypes.NoneType) -> Optional[str]:
"""Get type info from mypy type NoneType."""
node.name_spec.sym_type = "None"
return "None"

def get_type_from_any_type(
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.AnyType
) -> None:
def get_type_from_any_type(self, mypy_type: MypyTypes.AnyType) -> Optional[str]:
"""Get type info from mypy type NoneType."""
node.name_spec.sym_type = "Any"
return "Any"

def get_type_from_tuple_type(
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.TupleType
) -> None:
def get_type_from_tuple_type(self, mypy_type: MypyTypes.TupleType) -> Optional[str]:
"""Get type info from mypy type TupleType."""
node.name_spec.sym_type = "builtins.tuple"
return "builtins.tuple"

def exit_assignment(self, node: ast.Assignment) -> None:
"""Add new symbols in the symbol table in case of self."""
Expand Down
2 changes: 1 addition & 1 deletion jaclang/compiler/passes/main/tests/test_type_check_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ def test_type_coverage(self) -> None:
self.assertIn("HasVar - species - Type: builtins.str", out)
self.assertIn("myDog - Type: type_info.Dog", out)
self.assertIn("Body - Type: type_info.Dog.Body", out)
self.assertEqual(out.count("Type: builtins.str"), 28)
self.assertEqual(out.count("Type: builtins.str"), 29)
for i in lis:
self.assertNotIn(i, out)
2 changes: 2 additions & 0 deletions jaclang/utils/treeprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ def __node_repr_in_tree(node: AstNode) -> str:
return out
elif isinstance(node, Token):
return f"{node.__class__.__name__} - {node.value}, {access}"
elif isinstance(node, ast.Expr):
return f"{node.__class__.__name__} - Type: {node.expr_type}"
elif (
isinstance(node, ast.Module)
and node.is_raised_from_py
Expand Down

0 comments on commit d42cc95

Please sign in to comment.