diff --git a/jac/jaclang/compiler/passes/main/import_pass.py b/jac/jaclang/compiler/passes/main/import_pass.py index 6471440da..618f987a2 100644 --- a/jac/jaclang/compiler/passes/main/import_pass.py +++ b/jac/jaclang/compiler/passes/main/import_pass.py @@ -16,6 +16,7 @@ from jaclang.compiler.passes import Pass from jaclang.compiler.passes.main import DefUsePass, SubNodeTabPass, SymTabBuildPass from jaclang.compiler.passes.main.sym_tab_build_pass import PyInspectSymTabBuildPass +from jaclang.compiler.symtable import Symbol, SymbolTable from jaclang.settings import settings from jaclang.utils.log import logging @@ -295,7 +296,6 @@ def __import_from_symbol_table_build(self) -> None: is_symbol_tabled_refreshed: list[str] = [] self.import_from_build_list.reverse() for imp_node, imported_mod in self.import_from_build_list: - # Need to build the symbol tables again to make sure that the # complete symbol table is built. # @@ -319,35 +319,47 @@ def __import_from_symbol_table_build(self) -> None: sym_tab = imported_mod.sym_tab parent_sym_tab = imp_node.parent_of_type(ast.Module).sym_tab - for i in imp_node.items.items: - assert isinstance(i, ast.ModuleItem) - needed_sym = sym_tab.lookup(i.name.sym_name) - - if needed_sym and needed_sym.defn[0].parent: - self.__debug_print( - f"\tAdding {needed_sym.sym_type}:{needed_sym.sym_name} into {parent_sym_tab.name}" - ) - assert isinstance(needed_sym.defn[0], ast.AstSymbolNode) - parent_sym_tab.def_insert( - node=needed_sym.defn[0], - access_spec=needed_sym.access, - force_overwrite=True, - ) + if imp_node.is_absorb: + for symbol in sym_tab.tab.values(): + if symbol.sym_type == SymbolType.MODULE: + continue + self.__import_from_sym_table_add_symbols(symbol, parent_sym_tab) + else: + for i in imp_node.items.items: + assert isinstance(i, ast.ModuleItem) + needed_sym = sym_tab.lookup(i.name.sym_name) - if needed_sym.fetch_sym_tab: - msg = f"\tAdding SymbolTable:{needed_sym.fetch_sym_tab.name} into " - msg += f"SymbolTable:{parent_sym_tab.name} kids" - self.__debug_print(msg) - parent_sym_tab.kid.append(needed_sym.fetch_sym_tab) - elif needed_sym.sym_type != SymbolType.VAR: - raise AssertionError( - "Unexpected symbol type that doesn't have a symbl table" + if needed_sym and needed_sym.defn[0].parent: + self.__import_from_sym_table_add_symbols( + needed_sym, parent_sym_tab + ) + else: + self.__debug_print( + f"Can't find a symbol matching {i.name.sym_name} in {sym_tab.name}" ) - else: - self.__debug_print( - f"Can't find a symbol matching {i.name.sym_name} in {sym_tab.name}" - ) + def __import_from_sym_table_add_symbols( + self, sym: Symbol, sym_table: SymbolTable + ) -> None: + self.__debug_print( + f"\tAdding {sym.sym_type}:{sym.sym_name} into {sym_table.name}" + ) + assert isinstance(sym.defn[0], ast.AstSymbolNode) + sym_table.def_insert( + node=sym.defn[0], + access_spec=sym.access, + force_overwrite=True, + ) + + if sym.fetch_sym_tab: + msg = f"\tAdding SymbolTable:{sym.fetch_sym_tab.name} into " + msg += f"SymbolTable:{sym_table.name} kids" + self.__debug_print(msg) + sym_table.kid.append(sym.fetch_sym_tab) + elif sym.sym_type not in (SymbolType.VAR, SymbolType.MOD_VAR): + raise AssertionError( + f"Unexpected symbol type '{sym.sym_type}' that doesn't have a symbl table" + ) def __process_import(self, imp_node: ast.Import) -> None: """Process the imports in form of `import X`.""" @@ -379,14 +391,34 @@ def __process_import(self, imp_node: ast.Import) -> None: f"\tIgnoring attaching builtins {imp_node.loc.mod_path} {imp_node.loc}" ) return + self.__debug_print( f"\tAttaching {imported_mod.name} into {ast.Module.get_href_path(imp_node)}" ) self.attach_mod_to_node(imported_item, imported_mod) - self.__debug_print( - f"\tBuilding symbol table for module:{ast.Module.get_href_path(imported_mod)}" - ) - SymTabBuildPass(input_ir=imported_mod, prior=self) + + if imp_node.is_absorb: + msg = f"\tRegistering module:{imported_mod.name} to " + msg += f"import_from (import all) handling with {imp_node.loc.mod_path}:{imp_node.loc}" + self.__debug_print(msg) + + self.import_from_build_list.append((imp_node, imported_mod)) + if imported_mod._sym_tab is None: + self.__debug_print( + f"\tBuilding symbol table for module:{ast.Module.get_href_path(imported_mod)}" + ) + else: + self.__debug_print( + f"\tRefreshing symbol table for module:{ast.Module.get_href_path(imported_mod)}" + ) + PyInspectSymTabBuildPass(input_ir=imported_mod, prior=self) + DefUsePass(input_ir=imported_mod, prior=self) + + else: + self.__debug_print( + f"\tBuilding symbol table for module:{ast.Module.get_href_path(imported_mod)}" + ) + SymTabBuildPass(input_ir=imported_mod, prior=self) def __import_py_module( self, diff --git a/jac/jaclang/tests/fixtures/import_all.jac b/jac/jaclang/tests/fixtures/import_all.jac new file mode 100644 index 000000000..e1b6706e1 --- /dev/null +++ b/jac/jaclang/tests/fixtures/import_all.jac @@ -0,0 +1,7 @@ +import:py import_all_py; + +with entry { + print(import_all_py.custom_func(11)); + print(import_all_py.pi); + print(import_all_py.floor(0.5)); +} diff --git a/jac/jaclang/tests/fixtures/import_all_py.py b/jac/jaclang/tests/fixtures/import_all_py.py new file mode 100644 index 000000000..cd2fd4460 --- /dev/null +++ b/jac/jaclang/tests/fixtures/import_all_py.py @@ -0,0 +1,8 @@ +"""Module used for testing from x import * in jac.""" + +from math import * # noqa + + +def custom_func(x: int) -> str: + """Dummy custom function for testing purposes.""" # noqa + return str(x) diff --git a/jac/jaclang/tests/test_cli.py b/jac/jaclang/tests/test_cli.py index 0a9739111..f75a5989a 100644 --- a/jac/jaclang/tests/test_cli.py +++ b/jac/jaclang/tests/test_cli.py @@ -233,6 +233,29 @@ def test_builtins_loading(self) -> None: r"13\:12 \- 13\:18.*Name - append - .*SymbolPath: builtins_test.builtins.list.append", ) + def test_import_all(self) -> None: + """Testing for print AstTool.""" + from jaclang.settings import settings + + settings.ast_symbol_info_detailed = True + captured_output = io.StringIO() + sys.stdout = captured_output + + cli.tool("ir", ["ast", f"{self.fixture_abs_path('import_all.jac')}"]) + + sys.stdout = sys.__stdout__ + stdout_value = captured_output.getvalue() + settings.ast_symbol_info_detailed = False + + self.assertRegex( + stdout_value, + r"6\:25 - 6\:30.*Name - floor -.*SymbolPath: import_all.import_all_py.floor", + ) + self.assertRegex( + stdout_value, + r"5\:25 - 5\:27.*Name - pi -.*SymbolPath: import_all.import_all_py.pi", + ) + def test_sub_class_symbol_table_fix_1(self) -> None: """Testing for print AstTool.""" from jaclang.settings import settings