diff --git a/jac/jaclang/compiler/passes/main/import_pass.py b/jac/jaclang/compiler/passes/main/import_pass.py index 9cff91ba45..ebbfbb1fce 100644 --- a/jac/jaclang/compiler/passes/main/import_pass.py +++ b/jac/jaclang/compiler/passes/main/import_pass.py @@ -35,13 +35,21 @@ def enter_module(self, node: ast.Module) -> None: self.annex_impl(node) self.terminate() # Turns off auto traversal for deliberate traversal self.run_again = True + all_mods = [node] + new_mods = all_mods while self.run_again: self.run_again = False - all_imports = self.get_all_sub_nodes(node, ast.ModulePath) + all_imports = [] + for i in new_mods: + all_imports.extend(i.get_all_sub_nodes(ast.ModulePath)) + new_mods = [] for i in all_imports: - self.process_import(i) + mods = self.process_import(i) + if mods: + new_mods.extend(mods) + self.enter_module_path(i) - SubNodeTabPass(prior=self, input_ir=node) + SubNodeTabPass(prior=self, input_ir=node) node.mod_deps.update(self.import_table) @@ -49,11 +57,13 @@ def process_import(self, i: ast.ModulePath) -> None: """Process an import.""" imp_node = i.parent_of_type(ast.Import) if imp_node.is_jac and not i.sub_module: - self.import_jac_module(node=i) + mod = self.import_jac_module(node=i) + if mod: + return mod def attach_mod_to_node( self, node: ast.ModulePath | ast.ModuleItem, mod: ast.Module | None - ) -> None: + ) -> ast.Module | None: """Attach a module to a node.""" if mod: self.run_again = True @@ -61,6 +71,7 @@ def attach_mod_to_node( self.annex_impl(mod) node.add_kids_right([mod], pos_update=False) mod.parent = node + return mod def annex_impl(self, node: ast.Module) -> None: """Annex impl and test modules.""" @@ -126,10 +137,13 @@ def enter_module_path(self, node: ast.ModulePath) -> None: def import_jac_module(self, node: ast.ModulePath) -> None: """Import a module.""" self.cur_node = node # impacts error reporting + new_mods = [] target = node.resolve_relative_path() # If the module is a package (dir) if os.path.isdir(target): - self.attach_mod_to_node(node, self.import_jac_mod_from_dir(target)) + mod = self.attach_mod_to_node(node, self.import_jac_mod_from_dir(target)) + if mod: + new_mods.append(mod) import_node = node.parent_of_type(ast.Import) # And the import is a from import and I am the from module if node == import_node.from_loc: @@ -139,16 +153,23 @@ def import_jac_module(self, node: ast.ModulePath) -> None: from_mod_target = node.resolve_relative_path(i.name.value) # If package if os.path.isdir(from_mod_target): - self.attach_mod_to_node( + mod = self.attach_mod_to_node( i, self.import_jac_mod_from_dir(from_mod_target) ) + if mod: + new_mods.append(mod) # Else module else: - self.attach_mod_to_node( + mod = self.attach_mod_to_node( i, self.import_jac_mod_from_file(from_mod_target) ) + if mod: + new_mods.append(mod) else: - self.attach_mod_to_node(node, self.import_jac_mod_from_file(target)) + mod = self.attach_mod_to_node(node, self.import_jac_mod_from_file(target)) + if mod: + new_mods.append(mod) + return new_mods def import_jac_mod_from_dir(self, target: str) -> ast.Module | None: """Import a module from a directory.""" @@ -220,23 +241,27 @@ def process_import(self, i: ast.ModulePath) -> None: # from a import (c, d, e) # Solution to that is to get the import node and check the from loc then # handle it based on if there a from loc or not + new_mod_list = [] imp_node = i.parent_of_type(ast.Import) if imp_node.is_py and not i.sub_module: if imp_node.from_loc: for j in imp_node.items.items: assert isinstance(j, ast.ModuleItem) mod_path = f"{imp_node.from_loc.dot_path_str}.{j.name.sym_name}" - self.import_py_module( + mod = self.import_py_module( parent_node=j, mod_path=mod_path, imported_mod_name=( j.name.sym_name if not j.alias else j.alias.sym_name ), ) + if mod: + new_mod_list.append(mod) + return new_mod_list else: for j in imp_node.items.items: assert isinstance(j, ast.ModulePath) - self.import_py_module( + mod = self.import_py_module( parent_node=j, mod_path=j.dot_path_str, imported_mod_name=( @@ -245,6 +270,9 @@ def process_import(self, i: ast.ModulePath) -> None: else j.alias.sym_name ), ) + if mod: + new_mod_list.append(mod) + return new_mod_list def import_py_module( self,