Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Issue #1452 fix] Support from x import * #1484

Merged
merged 3 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions jac/jaclang/compiler/absyntree.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,12 +886,14 @@ def __init__(
is_absorb: bool, # For includes
kid: Sequence[AstNode],
doc: Optional[String] = None,
import_all: bool = False,
mgtm98 marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""Initialize import node."""
self.hint = hint
self.from_loc = from_loc
self.items = items
self.is_absorb = is_absorb
self.import_all = import_all
AstNode.__init__(self, kid=kid)
AstDocNode.__init__(self, doc=doc)

Expand Down
94 changes: 63 additions & 31 deletions jac/jaclang/compiler/passes/main/import_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from jaclang.compiler.passes import Pass
from jaclang.compiler.passes.main import DefUsePass, SubNodeTabPass, SymTabBuildPass
from jaclang.compiler.passes.main.sym_tab_build_pass import PyInspectSymTabBuildPass
from jaclang.compiler.symtable import Symbol, SymbolTable
from jaclang.settings import settings
from jaclang.utils.log import logging

Expand Down Expand Up @@ -295,7 +296,6 @@ def __import_from_symbol_table_build(self) -> None:
is_symbol_tabled_refreshed: list[str] = []
self.import_from_build_list.reverse()
for imp_node, imported_mod in self.import_from_build_list:

# Need to build the symbol tables again to make sure that the
# complete symbol table is built.
#
Expand All @@ -319,35 +319,47 @@ def __import_from_symbol_table_build(self) -> None:
sym_tab = imported_mod.sym_tab
parent_sym_tab = imp_node.parent_of_type(ast.Module).sym_tab

for i in imp_node.items.items:
assert isinstance(i, ast.ModuleItem)
needed_sym = sym_tab.lookup(i.name.sym_name)

if needed_sym and needed_sym.defn[0].parent:
self.__debug_print(
f"\tAdding {needed_sym.sym_type}:{needed_sym.sym_name} into {parent_sym_tab.name}"
)
assert isinstance(needed_sym.defn[0], ast.AstSymbolNode)
parent_sym_tab.def_insert(
node=needed_sym.defn[0],
access_spec=needed_sym.access,
force_overwrite=True,
)
if imp_node.import_all:
for symbol in sym_tab.tab.values():
if symbol.sym_type == SymbolType.MODULE:
continue
self.__import_from_sym_table_add_symbols(symbol, parent_sym_tab)
else:
for i in imp_node.items.items:
assert isinstance(i, ast.ModuleItem)
needed_sym = sym_tab.lookup(i.name.sym_name)

if needed_sym.fetch_sym_tab:
msg = f"\tAdding SymbolTable:{needed_sym.fetch_sym_tab.name} into "
msg += f"SymbolTable:{parent_sym_tab.name} kids"
self.__debug_print(msg)
parent_sym_tab.kid.append(needed_sym.fetch_sym_tab)
elif needed_sym.sym_type != SymbolType.VAR:
raise AssertionError(
"Unexpected symbol type that doesn't have a symbl table"
if needed_sym and needed_sym.defn[0].parent:
self.__import_from_sym_table_add_symbols(
needed_sym, parent_sym_tab
)
else:
self.__debug_print(
f"Can't find a symbol matching {i.name.sym_name} in {sym_tab.name}"
)

else:
self.__debug_print(
f"Can't find a symbol matching {i.name.sym_name} in {sym_tab.name}"
)
def __import_from_sym_table_add_symbols(
self, sym: Symbol, sym_table: SymbolTable
) -> None:
self.__debug_print(
f"\tAdding {sym.sym_type}:{sym.sym_name} into {sym_table.name}"
)
assert isinstance(sym.defn[0], ast.AstSymbolNode)
sym_table.def_insert(
node=sym.defn[0],
access_spec=sym.access,
force_overwrite=True,
)

if sym.fetch_sym_tab:
msg = f"\tAdding SymbolTable:{sym.fetch_sym_tab.name} into "
msg += f"SymbolTable:{sym_table.name} kids"
self.__debug_print(msg)
sym_table.kid.append(sym.fetch_sym_tab)
elif sym.sym_type not in (SymbolType.VAR, SymbolType.MOD_VAR):
raise AssertionError(
f"Unexpected symbol type '{sym.sym_type}' that doesn't have a symbl table"
)

def __process_import(self, imp_node: ast.Import) -> None:
"""Process the imports in form of `import X`."""
Expand Down Expand Up @@ -379,14 +391,34 @@ def __process_import(self, imp_node: ast.Import) -> None:
f"\tIgnoring attaching builtins {imp_node.loc.mod_path} {imp_node.loc}"
)
return

self.__debug_print(
f"\tAttaching {imported_mod.name} into {ast.Module.get_href_path(imp_node)}"
)
self.attach_mod_to_node(imported_item, imported_mod)
self.__debug_print(
f"\tBuilding symbol table for module:{ast.Module.get_href_path(imported_mod)}"
)
SymTabBuildPass(input_ir=imported_mod, prior=self)

if imp_node.import_all:
msg = f"\tRegistering module:{imported_mod.name} to "
msg += f"import_from (import all) handling with {imp_node.loc.mod_path}:{imp_node.loc}"
self.__debug_print(msg)

self.import_from_build_list.append((imp_node, imported_mod))
if imported_mod._sym_tab is None:
self.__debug_print(
f"\tBuilding symbol table for module:{ast.Module.get_href_path(imported_mod)}"
)
else:
self.__debug_print(
f"\tRefreshing symbol table for module:{ast.Module.get_href_path(imported_mod)}"
)
PyInspectSymTabBuildPass(input_ir=imported_mod, prior=self)
DefUsePass(input_ir=imported_mod, prior=self)

else:
self.__debug_print(
f"\tBuilding symbol table for module:{ast.Module.get_href_path(imported_mod)}"
)
SymTabBuildPass(input_ir=imported_mod, prior=self)

def __import_py_module(
self,
Expand Down
1 change: 1 addition & 0 deletions jac/jaclang/compiler/passes/main/pyast_load_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1560,6 +1560,7 @@ class ImportFrom(stmt):
items=path_in,
is_absorb=True,
kid=[pytag, path_in],
import_all=True,
)
return ret
ret = ast.Import(
Expand Down
7 changes: 7 additions & 0 deletions jac/jaclang/tests/fixtures/import_all.jac
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import:py import_all_py;

with entry {
print(import_all_py.custom_func(11));
print(import_all_py.pi);
print(import_all_py.floor(0.5));
}
8 changes: 8 additions & 0 deletions jac/jaclang/tests/fixtures/import_all_py.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Module used for testing from x import * in jac."""

from math import * # noqa


def custom_func(x: int) -> str:
"""Dummy custom function for testing purposes.""" # noqa
return str(x)
23 changes: 23 additions & 0 deletions jac/jaclang/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,29 @@ def test_builtins_loading(self) -> None:
r"13\:12 \- 13\:18.*Name - append - .*SymbolPath: builtins_test.builtins.list.append",
)

def test_import_all(self) -> None:
"""Testing for print AstTool."""
from jaclang.settings import settings

settings.ast_symbol_info_detailed = True
captured_output = io.StringIO()
sys.stdout = captured_output

cli.tool("ir", ["ast", f"{self.fixture_abs_path('import_all.jac')}"])

sys.stdout = sys.__stdout__
stdout_value = captured_output.getvalue()
settings.ast_symbol_info_detailed = False

self.assertRegex(
stdout_value,
r"6\:25 - 6\:30.*Name - floor -.*SymbolPath: import_all.import_all_py.floor",
)
self.assertRegex(
stdout_value,
r"5\:25 - 5\:27.*Name - pi -.*SymbolPath: import_all.import_all_py.pi",
)

def test_expr_types(self) -> None:
"""Testing for print AstTool."""
captured_output = io.StringIO()
Expand Down
Loading