Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Refactor Import Pass to Optimize get_all_subnodes Implementation #1455

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 39 additions & 11 deletions jac/jaclang/compiler/passes/main/import_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,32 +35,43 @@ def enter_module(self, node: ast.Module) -> None:
self.annex_impl(node)
self.terminate() # Turns off auto traversal for deliberate traversal
self.run_again = True
all_mods = [node]
new_mods = all_mods
while self.run_again:
self.run_again = False
all_imports = self.get_all_sub_nodes(node, ast.ModulePath)
all_imports = []
for i in new_mods:
all_imports.extend(i.get_all_sub_nodes(ast.ModulePath))
new_mods = []
for i in all_imports:
self.process_import(i)
mods = self.process_import(i)
if mods:
new_mods.extend(mods)

self.enter_module_path(i)
SubNodeTabPass(prior=self, input_ir=node)
SubNodeTabPass(prior=self, input_ir=node)

node.mod_deps.update(self.import_table)

def process_import(self, i: ast.ModulePath) -> None:
"""Process an import."""
imp_node = i.parent_of_type(ast.Import)
if imp_node.is_jac and not i.sub_module:
self.import_jac_module(node=i)
mod = self.import_jac_module(node=i)
if mod:
return mod

def attach_mod_to_node(
self, node: ast.ModulePath | ast.ModuleItem, mod: ast.Module | None
) -> None:
) -> ast.Module | None:
"""Attach a module to a node."""
if mod:
self.run_again = True
node.sub_module = mod
self.annex_impl(mod)
node.add_kids_right([mod], pos_update=False)
mod.parent = node
return mod

def annex_impl(self, node: ast.Module) -> None:
"""Annex impl and test modules."""
Expand Down Expand Up @@ -126,10 +137,13 @@ def enter_module_path(self, node: ast.ModulePath) -> None:
def import_jac_module(self, node: ast.ModulePath) -> None:
"""Import a module."""
self.cur_node = node # impacts error reporting
new_mods = []
target = node.resolve_relative_path()
# If the module is a package (dir)
if os.path.isdir(target):
self.attach_mod_to_node(node, self.import_jac_mod_from_dir(target))
mod = self.attach_mod_to_node(node, self.import_jac_mod_from_dir(target))
if mod:
new_mods.append(mod)
import_node = node.parent_of_type(ast.Import)
# And the import is a from import and I am the from module
if node == import_node.from_loc:
Expand All @@ -139,16 +153,23 @@ def import_jac_module(self, node: ast.ModulePath) -> None:
from_mod_target = node.resolve_relative_path(i.name.value)
# If package
if os.path.isdir(from_mod_target):
self.attach_mod_to_node(
mod = self.attach_mod_to_node(
i, self.import_jac_mod_from_dir(from_mod_target)
)
if mod:
new_mods.append(mod)
# Else module
else:
self.attach_mod_to_node(
mod = self.attach_mod_to_node(
i, self.import_jac_mod_from_file(from_mod_target)
)
if mod:
new_mods.append(mod)
else:
self.attach_mod_to_node(node, self.import_jac_mod_from_file(target))
mod = self.attach_mod_to_node(node, self.import_jac_mod_from_file(target))
if mod:
new_mods.append(mod)
return new_mods

def import_jac_mod_from_dir(self, target: str) -> ast.Module | None:
"""Import a module from a directory."""
Expand Down Expand Up @@ -220,23 +241,27 @@ def process_import(self, i: ast.ModulePath) -> None:
# from a import (c, d, e)
# Solution to that is to get the import node and check the from loc then
# handle it based on if there a from loc or not
new_mod_list = []
imp_node = i.parent_of_type(ast.Import)
if imp_node.is_py and not i.sub_module:
if imp_node.from_loc:
for j in imp_node.items.items:
assert isinstance(j, ast.ModuleItem)
mod_path = f"{imp_node.from_loc.dot_path_str}.{j.name.sym_name}"
self.import_py_module(
mod = self.import_py_module(
parent_node=j,
mod_path=mod_path,
imported_mod_name=(
j.name.sym_name if not j.alias else j.alias.sym_name
),
)
if mod:
new_mod_list.append(mod)
return new_mod_list
else:
for j in imp_node.items.items:
assert isinstance(j, ast.ModulePath)
self.import_py_module(
mod = self.import_py_module(
parent_node=j,
mod_path=j.dot_path_str,
imported_mod_name=(
Expand All @@ -245,6 +270,9 @@ def process_import(self, i: ast.ModulePath) -> None:
else j.alias.sym_name
),
)
if mod:
new_mod_list.append(mod)
return new_mod_list

def import_py_module(
self,
Expand Down
Loading