Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Issue #1345 fix] Sub class symtab fix #1488

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions jac/jaclang/compiler/absyntree.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,14 +697,16 @@ def unparse(self) -> str:
def get_href_path(node: AstNode) -> str:
"""Return the full path of the module that contains this node."""
parent = node.find_parent_of_type(Module)
mod_list = []
if isinstance(node, Module):
mod_list: list[Module | Architype] = []
if isinstance(node, (Module, Architype)):
mod_list.append(node)
while parent is not None:
mod_list.append(parent)
parent = parent.find_parent_of_type(Module)
mod_list.reverse()
return ".".join(p.name for p in mod_list)
return ".".join(
p.name if isinstance(p, Module) else p.name.sym_name for p in mod_list
)


class GlobalVars(ElementStmt, AstAccessNode):
Expand Down
103 changes: 103 additions & 0 deletions jac/jaclang/compiler/passes/main/inheritance_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Pass used to add the inherited symbols for architypes."""

from __future__ import annotations

from typing import Optional

import jaclang.compiler.absyntree as ast
from jaclang.compiler.passes import Pass
from jaclang.compiler.symtable import Symbol, SymbolTable
from jaclang.settings import settings


class InheritancePass(Pass):
"""Add inherited abilities in the target symbol tables."""

def __debug_print(self, msg: str) -> None:
if settings.inherit_pass_debug:
self.log_info("[PyImportPass] " + msg)

def __lookup(self, name: str, sym_table: SymbolTable) -> Optional[Symbol]:
symbol = sym_table.lookup(name)
if symbol is None:
# Check if the needed symbol in builtins
builtins_symtable = self.ir.sym_tab.find_scope("builtins")
assert builtins_symtable is not None
symbol = builtins_symtable.lookup(name)
return symbol

def enter_architype(self, node: ast.Architype) -> None:
"""Fill architype symbol tables with abilities from parent architypes."""
if node.base_classes is None:
return

for item in node.base_classes.items:
# The assumption is that the base class can only be a name node
# or an atom trailer only.
assert isinstance(item, (ast.Name, ast.AtomTrailer))

# In case of name node, then get the symbol table that contains
# the current class and lookup for that name after that use the
# symbol to get the symbol table of the base class
if isinstance(item, ast.Name):
assert node.sym_tab.parent is not None
base_class_symbol = self.__lookup(item.sym_name, node.sym_tab.parent)
if base_class_symbol is None:
msg = "Missing symbol for base class "
msg += f"{ast.Module.get_href_path(item)}.{item.sym_name}"
msg += f" needed for {ast.Module.get_href_path(node)}"
self.__debug_print(msg)
continue
base_class_symbol_table = base_class_symbol.fetch_sym_tab
if (
base_class_symbol_table is None
and base_class_symbol.defn[0]
.parent_of_type(ast.Module)
.py_info.is_raised_from_py
):
msg = "Missing symbol table for python base class "
msg += f"{ast.Module.get_href_path(item)}.{item.sym_name}"
msg += f" needed for {ast.Module.get_href_path(node)}"
self.__debug_print(msg)
continue
assert base_class_symbol_table is not None
node.sym_tab.inherit_sym_tab(base_class_symbol_table)

# In case of atom trailer, unwind it and use each name node to
# as the code above to lookup for the base class
elif isinstance(item, ast.AtomTrailer):
current_sym_table = node.sym_tab.parent
not_found: bool = False
assert current_sym_table is not None
for name in item.as_attr_list:
sym = self.__lookup(name.sym_name, current_sym_table)
if sym is None:
msg = "Missing symbol for base class "
msg += f"{ast.Module.get_href_path(name)}.{name.sym_name}"
msg += f" needed for {ast.Module.get_href_path(node)}"
self.__debug_print(msg)
not_found = True
break
current_sym_table = sym.fetch_sym_tab

# In case of python nodes, the base class may not be
# raised so ignore these classes for now
# TODO Do we need to import these classes?
if (
sym.defn[0].parent_of_type(ast.Module).py_info.is_raised_from_py
and current_sym_table is None
):
msg = "Missing symbol table for python base class "
msg += f"{ast.Module.get_href_path(name)}.{name.sym_name}"
msg += f" needed for {ast.Module.get_href_path(node)}"
self.__debug_print(msg)
not_found = True
break

assert current_sym_table is not None

if not_found:
continue

assert current_sym_table is not None
node.sym_tab.inherit_sym_tab(current_sym_table)
2 changes: 2 additions & 0 deletions jac/jaclang/compiler/passes/main/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .registry_pass import RegistryPass # noqa: I100
from .access_modifier_pass import AccessCheckPass # noqa: I100
from .py_collect_dep_pass import PyCollectDepsPass # noqa: I100
from .inheritance_pass import InheritancePass # noqa: I100

py_code_gen = [
SubNodeTabPass,
Expand All @@ -38,6 +39,7 @@
PyCollectDepsPass,
PyImportPass,
DefUsePass,
InheritancePass,
FuseTypeInfoPass,
AccessCheckPass,
]
Expand Down
4 changes: 2 additions & 2 deletions jac/jaclang/langserve/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def test_completion(self) -> None:
"doubleinner",
"apply_red",
],
8,
11,
),
(
lspt.Position(65, 23),
Expand Down Expand Up @@ -359,7 +359,7 @@ def test_completion(self) -> None:
"doubleinner",
"apply_red",
],
8,
11,
),
(
lspt.Position(73, 22),
Expand Down
1 change: 1 addition & 0 deletions jac/jaclang/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class Settings:
collect_py_dep_debug: bool = False
print_py_raised_ast: bool = False
py_import_pass_debug: bool = False
inherit_pass_debug: bool = False

# Compiler configuration
disable_mtllm: bool = False
Expand Down
11 changes: 11 additions & 0 deletions jac/jaclang/tests/fixtures/base_class1.jac
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import:py test_py;

class B :test_py.A: {}

with entry {
a = test_py.A();
b = B();

a.start();
b.start();
}
11 changes: 11 additions & 0 deletions jac/jaclang/tests/fixtures/base_class2.jac
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import:py from test_py { A }

class B :A: {}

with entry {
a = A();
b = B();

a.start();
b.start();
}
12 changes: 12 additions & 0 deletions jac/jaclang/tests/fixtures/test_py.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Test file for subclass issue."""

p = 5
g = 6


class A:
"""Dummy class to test the base class issue."""

def start(self) -> int:
"""Return 0."""
return 0
38 changes: 38 additions & 0 deletions jac/jaclang/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,44 @@ def test_builtins_loading(self) -> None:
r"13\:12 \- 13\:18.*Name - append - .*SymbolPath: builtins_test.builtins.list.append",
)

def test_sub_class_symbol_table_fix_1(self) -> None:
"""Testing for print AstTool."""
from jaclang.settings import settings

settings.ast_symbol_info_detailed = True
captured_output = io.StringIO()
sys.stdout = captured_output

cli.tool("ir", ["ast", f"{self.fixture_abs_path('base_class1.jac')}"])

sys.stdout = sys.__stdout__
stdout_value = captured_output.getvalue()
settings.ast_symbol_info_detailed = False

self.assertRegex(
stdout_value,
r"10:7 - 10:12.*Name - start - Type.*SymbolPath: base_class1.B.start",
)

def test_sub_class_symbol_table_fix_2(self) -> None:
"""Testing for print AstTool."""
from jaclang.settings import settings

settings.ast_symbol_info_detailed = True
captured_output = io.StringIO()
sys.stdout = captured_output

cli.tool("ir", ["ast", f"{self.fixture_abs_path('base_class2.jac')}"])

sys.stdout = sys.__stdout__
stdout_value = captured_output.getvalue()
settings.ast_symbol_info_detailed = False

self.assertRegex(
stdout_value,
r"10:7 - 10:12.*Name - start - Type.*SymbolPath: base_class2.B.start",
)

def test_expr_types(self) -> None:
"""Testing for print AstTool."""
captured_output = io.StringIO()
Expand Down
Loading