Skip to content

Commit

Permalink
[Issue #1452 fix] Support from x import *
Browse files Browse the repository at this point in the history
  • Loading branch information
mgtm98 committed Dec 10, 2024
1 parent a9e2d9e commit 9d8536a
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 31 deletions.
2 changes: 2 additions & 0 deletions jac/jaclang/compiler/absyntree.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,12 +886,14 @@ def __init__(
is_absorb: bool, # For includes
kid: Sequence[AstNode],
doc: Optional[String] = None,
import_all: bool = False,
) -> None:
"""Initialize import node."""
self.hint = hint
self.from_loc = from_loc
self.items = items
self.is_absorb = is_absorb
self.import_all = import_all
AstNode.__init__(self, kid=kid)
AstDocNode.__init__(self, doc=doc)

Expand Down
94 changes: 63 additions & 31 deletions jac/jaclang/compiler/passes/main/import_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from jaclang.compiler.passes import Pass
from jaclang.compiler.passes.main import DefUsePass, SubNodeTabPass, SymTabBuildPass
from jaclang.compiler.passes.main.sym_tab_build_pass import PyInspectSymTabBuildPass
from jaclang.compiler.symtable import Symbol, SymbolTable
from jaclang.settings import settings
from jaclang.utils.log import logging

Expand Down Expand Up @@ -295,7 +296,6 @@ def __import_from_symbol_table_build(self) -> None:
is_symbol_tabled_refreshed: list[str] = []
self.import_from_build_list.reverse()
for imp_node, imported_mod in self.import_from_build_list:

# Need to build the symbol tables again to make sure that the
# complete symbol table is built.
#
Expand All @@ -319,35 +319,47 @@ def __import_from_symbol_table_build(self) -> None:
sym_tab = imported_mod.sym_tab
parent_sym_tab = imp_node.parent_of_type(ast.Module).sym_tab

for i in imp_node.items.items:
assert isinstance(i, ast.ModuleItem)
needed_sym = sym_tab.lookup(i.name.sym_name)

if needed_sym and needed_sym.defn[0].parent:
self.__debug_print(
f"\tAdding {needed_sym.sym_type}:{needed_sym.sym_name} into {parent_sym_tab.name}"
)
assert isinstance(needed_sym.defn[0], ast.AstSymbolNode)
parent_sym_tab.def_insert(
node=needed_sym.defn[0],
access_spec=needed_sym.access,
force_overwrite=True,
)
if imp_node.import_all:
for symbol in sym_tab.tab.values():
if symbol.sym_type == SymbolType.MODULE:
continue
self.__import_from_sym_table_add_symbols(symbol, parent_sym_tab)
else:
for i in imp_node.items.items:
assert isinstance(i, ast.ModuleItem)
needed_sym = sym_tab.lookup(i.name.sym_name)

if needed_sym.fetch_sym_tab:
msg = f"\tAdding SymbolTable:{needed_sym.fetch_sym_tab.name} into "
msg += f"SymbolTable:{parent_sym_tab.name} kids"
self.__debug_print(msg)
parent_sym_tab.kid.append(needed_sym.fetch_sym_tab)
elif needed_sym.sym_type != SymbolType.VAR:
raise AssertionError(
"Unexpected symbol type that doesn't have a symbl table"
if needed_sym and needed_sym.defn[0].parent:
self.__import_from_sym_table_add_symbols(
needed_sym, parent_sym_tab
)
else:
self.__debug_print(
f"Can't find a symbol matching {i.name.sym_name} in {sym_tab.name}"
)

else:
self.__debug_print(
f"Can't find a symbol matching {i.name.sym_name} in {sym_tab.name}"
)
def __import_from_sym_table_add_symbols(
self, sym: Symbol, sym_table: SymbolTable
) -> None:
self.__debug_print(
f"\tAdding {sym.sym_type}:{sym.sym_name} into {sym_table.name}"
)
assert isinstance(sym.defn[0], ast.AstSymbolNode)
sym_table.def_insert(
node=sym.defn[0],
access_spec=sym.access,
force_overwrite=True,
)

if sym.fetch_sym_tab:
msg = f"\tAdding SymbolTable:{sym.fetch_sym_tab.name} into "
msg += f"SymbolTable:{sym_table.name} kids"
self.__debug_print(msg)
sym_table.kid.append(sym.fetch_sym_tab)
elif sym.sym_type not in (SymbolType.VAR, SymbolType.MOD_VAR):
raise AssertionError(
f"Unexpected symbol type '{sym.sym_type}' that doesn't have a symbl table"
)

def __process_import(self, imp_node: ast.Import) -> None:
"""Process the imports in form of `import X`."""
Expand Down Expand Up @@ -379,14 +391,34 @@ def __process_import(self, imp_node: ast.Import) -> None:
f"\tIgnoring attaching builtins {imp_node.loc.mod_path} {imp_node.loc}"
)
return

self.__debug_print(
f"\tAttaching {imported_mod.name} into {ast.Module.get_href_path(imp_node)}"
)
self.attach_mod_to_node(imported_item, imported_mod)
self.__debug_print(
f"\tBuilding symbol table for module:{ast.Module.get_href_path(imported_mod)}"
)
SymTabBuildPass(input_ir=imported_mod, prior=self)

if imp_node.import_all:
msg = f"\tRegistering module:{imported_mod.name} to "
msg += f"import_from (import all) handling with {imp_node.loc.mod_path}:{imp_node.loc}"
self.__debug_print(msg)

self.import_from_build_list.append((imp_node, imported_mod))
if imported_mod._sym_tab is None:
self.__debug_print(
f"\tBuilding symbol table for module:{ast.Module.get_href_path(imported_mod)}"
)
else:
self.__debug_print(
f"\tRefreshing symbol table for module:{ast.Module.get_href_path(imported_mod)}"
)
PyInspectSymTabBuildPass(input_ir=imported_mod, prior=self)
DefUsePass(input_ir=imported_mod, prior=self)

else:
self.__debug_print(
f"\tBuilding symbol table for module:{ast.Module.get_href_path(imported_mod)}"
)
SymTabBuildPass(input_ir=imported_mod, prior=self)

def __import_py_module(
self,
Expand Down
1 change: 1 addition & 0 deletions jac/jaclang/compiler/passes/main/pyast_load_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1560,6 +1560,7 @@ class ImportFrom(stmt):
items=path_in,
is_absorb=True,
kid=[pytag, path_in],
import_all=True,
)
return ret
ret = ast.Import(
Expand Down
7 changes: 7 additions & 0 deletions jac/jaclang/tests/fixtures/import_all.jac
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import:py import_all_py;

with entry {
print(import_all_py.custom_func(11));
print(import_all_py.pi);
print(import_all_py.floor(0.5));
}
8 changes: 8 additions & 0 deletions jac/jaclang/tests/fixtures/import_all_py.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Module used for testing from x import * in jac."""

from math import * # noqa


def custom_func(x: int) -> str:
"""Dummy custom function for testing purposes.""" # noqa
return str(x)
23 changes: 23 additions & 0 deletions jac/jaclang/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,29 @@ def test_builtins_loading(self) -> None:
r"13\:12 \- 13\:18.*Name - append - .*SymbolPath: builtins_test.builtins.list.append",
)

def test_import_all(self) -> None:
"""Testing for print AstTool."""
from jaclang.settings import settings

settings.ast_symbol_info_detailed = True
captured_output = io.StringIO()
sys.stdout = captured_output

cli.tool("ir", ["ast", f"{self.fixture_abs_path('import_all.jac')}"])

sys.stdout = sys.__stdout__
stdout_value = captured_output.getvalue()
settings.ast_symbol_info_detailed = False

self.assertRegex(
stdout_value,
r"6\:25 - 6\:30.*Name - floor -.*SymbolPath: import_all.import_all_py.floor",
)
self.assertRegex(
stdout_value,
r"5\:25 - 5\:27.*Name - pi -.*SymbolPath: import_all.import_all_py.pi",
)

def test_expr_types(self) -> None:
"""Testing for print AstTool."""
captured_output = io.StringIO()
Expand Down

0 comments on commit 9d8536a

Please sign in to comment.