diff --git a/jac/examples/reference/architypes.py b/jac/examples/reference/architypes.py index b41e3ee0b1..3350a4c4de 100644 --- a/jac/examples/reference/architypes.py +++ b/jac/examples/reference/architypes.py @@ -6,8 +6,17 @@ def print_base_classes(cls: type) -> type: return cls +# Since the Animal class cannot be inherit from object, (cause the base class will be changed at run time) +# we need a base class. +# +# reference: https://stackoverflow.com/a/9639512/10846399 +# +class Base: + pass + + @jac.make_obj(on_entry=[], on_exit=[]) -class Animal: +class Animal(Base): pass diff --git a/jac/examples/reference/data_spatial_walker_statements.py b/jac/examples/reference/data_spatial_walker_statements.py index 44591a51bb..49ff89abb0 100644 --- a/jac/examples/reference/data_spatial_walker_statements.py +++ b/jac/examples/reference/data_spatial_walker_statements.py @@ -2,8 +2,17 @@ from jaclang.plugin.feature import JacFeature as _Jac +# Since the Jac class cannot be inherit from object, (cause the base class will be changed at run time) +# we need a base class. +# +# reference: https://stackoverflow.com/a/9639512/10846399 +# +class Base: + pass + + @_Jac.make_walker(on_entry=[_Jac.DSFunc("self_destruct", None)], on_exit=[]) -class Visitor: +class Visitor(Base): def self_destruct(self, _jac_here_) -> None: print("get's here") _Jac.disengage(self) diff --git a/jac/examples/reference/disengage_statements.py b/jac/examples/reference/disengage_statements.py index a8b2266d93..a458567979 100644 --- a/jac/examples/reference/disengage_statements.py +++ b/jac/examples/reference/disengage_statements.py @@ -2,8 +2,17 @@ from jaclang.plugin.feature import JacFeature as _Jac +# Since the Jac class cannot be inherit from object, (cause the base class will be changed at run time) +# we need a base class. +# +# reference: https://stackoverflow.com/a/9639512/10846399 +# +class Base: + pass + + @_Jac.make_walker(on_entry=[_Jac.DSFunc("travel", _Jac.get_root_type())], on_exit=[]) -class Visitor: +class Visitor(Base): def travel(self, _jac_here_: _Jac.get_root_type()) -> None: if _Jac.visit_node( self, _Jac.edge_ref(_jac_here_, None, _Jac.EdgeDir.OUT, None, None) @@ -14,7 +23,7 @@ def travel(self, _jac_here_: _Jac.get_root_type()) -> None: @_Jac.make_node(on_entry=[_Jac.DSFunc("speak", Visitor)], on_exit=[]) -class item: +class item(Base): def speak(self, _jac_here_: Visitor) -> None: print("Hey There!!!") _Jac.disengage(_jac_here_) diff --git a/jac/examples/reference/match_class_patterns.py b/jac/examples/reference/match_class_patterns.py index e1128d7719..6646e28fb0 100644 --- a/jac/examples/reference/match_class_patterns.py +++ b/jac/examples/reference/match_class_patterns.py @@ -3,9 +3,18 @@ from dataclasses import dataclass as dataclass +# Since the Jac class cannot be inherit from object, (cause the base class will be changed at run time) +# we need a base class. +# +# reference: https://stackoverflow.com/a/9639512/10846399 +# +class Base: + pass + + @Jac.make_obj(on_entry=[], on_exit=[]) @dataclass(eq=False) -class Point: +class Point(Base): x: float y: float diff --git a/jac/examples/reference/special_comprehensions.py b/jac/examples/reference/special_comprehensions.py index 10472bca99..7b5e731673 100644 --- a/jac/examples/reference/special_comprehensions.py +++ b/jac/examples/reference/special_comprehensions.py @@ -4,9 +4,18 @@ import random +# Since the Jac class cannot be inherit from object, (cause the base class will be changed at run time) +# we need a base class. +# +# reference: https://stackoverflow.com/a/9639512/10846399 +# +class Base: + pass + + @Jac.make_obj(on_entry=[], on_exit=[]) @dataclass(eq=False) -class TestObj: +class TestObj(Base): x: int = Jac.has_instance_default(gen_func=lambda: random.randint(0, 15)) y: int = Jac.has_instance_default(gen_func=lambda: random.randint(0, 15)) z: int = Jac.has_instance_default(gen_func=lambda: random.randint(0, 15)) @@ -23,7 +32,7 @@ class TestObj: @Jac.make_obj(on_entry=[], on_exit=[]) @dataclass(eq=False) -class MyObj: +class MyObj(Base): apple: int = Jac.has_instance_default(gen_func=lambda: 0) banana: int = Jac.has_instance_default(gen_func=lambda: 0) diff --git a/jac/examples/reference/visit_statements.py b/jac/examples/reference/visit_statements.py index 7a7a53363f..b48d6f9dc6 100644 --- a/jac/examples/reference/visit_statements.py +++ b/jac/examples/reference/visit_statements.py @@ -2,8 +2,17 @@ from jaclang.plugin.feature import JacFeature as _Jac +# Since the Jac class cannot be inherit from object, (cause the base class will be changed at run time) +# we need a base class. +# +# reference: https://stackoverflow.com/a/9639512/10846399 +# +class Base: + pass + + @_Jac.make_walker(on_entry=[_Jac.DSFunc("travel", _Jac.get_root_type())], on_exit=[]) -class Visitor: +class Visitor(Base): def travel(self, _jac_here_: _Jac.get_root_type()) -> None: if _Jac.visit_node( self, _Jac.edge_ref(_jac_here_, None, _Jac.EdgeDir.OUT, None, None) @@ -14,7 +23,7 @@ def travel(self, _jac_here_: _Jac.get_root_type()) -> None: @_Jac.make_node(on_entry=[_Jac.DSFunc("speak", Visitor)], on_exit=[]) -class item: +class item(Base): def speak(self, _jac_here_: Visitor) -> None: print("Hey There!!!") diff --git a/jac/jaclang/__init__.py b/jac/jaclang/__init__.py index 157fe1ab21..d02801a955 100644 --- a/jac/jaclang/__init__.py +++ b/jac/jaclang/__init__.py @@ -1,12 +1,239 @@ """The Jac Programming Language.""" -from jaclang.plugin.default import JacFeatureImpl -from jaclang.plugin.feature import JacFeature, plugin_manager +__all__ = [ + "JacObj", + "JacWalker", + "JacNode", + "JacEdge", + "with_entry", + "with_exit", + "abstractmethod", + "jac_import", + "root", +] + +import inspect +import types +import typing +from abc import ABC, ABCMeta, abstractmethod +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, Tuple, Type -jac_import = JacFeature.jac_import +from jaclang.plugin.default import JacFeatureImpl +from jaclang.plugin.feature import JacFeature as _Jac, plugin_manager +# ---------------------------------------------------------------------------- +# Plugin Initialization. +# ---------------------------------------------------------------------------- plugin_manager.register(JacFeatureImpl) plugin_manager.load_setuptools_entrypoints("jac") -__all__ = ["jac_import"] + +# ---------------------------------------------------------------------------- +# Meta classes. +# ---------------------------------------------------------------------------- + + +# https://stackoverflow.com/a/9639512/10846399 +class _JacArchiTypeBase: + pass + + +class JacMetaCommon(ABCMeta): + """Common metaclass for Jac types.""" + + def __new__( + cls, + name: str, + bases: Tuple[Type, ...], + dct: Dict[str, Any], + make_func: Callable[[list, list], Callable[[type], type]], + ) -> "JacMetaCommon": + + # We have added this "__init__" to the jac base class just to make the type checkers happy. + # Actually the dataclass decorator will create an __init__ function and assign it here bellow. + if bases == (_JacArchiTypeBase,) and "__init__" in dct: + del dct["__init__"] + + on_entry, on_exit = [], [] + for value in dct.values(): + if hasattr(value, "__jac_entry"): + entry_node = getattr(value, "__jac_entry") # noqa: B009 + on_entry.append(_Jac.DSFunc(value.__name__, entry_node)) + if hasattr(value, "__jac_exit"): + exit_node = getattr(value, "__jac_exit") # noqa: B009 + on_exit.append(_Jac.DSFunc(value.__name__, exit_node)) + + inst = super().__new__(cls, name, bases, dct) + inst = dataclass(eq=False)(inst) # type: ignore [arg-type, assignment] + inst = make_func(on_entry, on_exit)(inst) # type: ignore [assignment] + return inst + + +class JacMetaObj(JacMetaCommon, ABC): + def __new__( + cls, name: str, bases: Tuple[Type, ...], dct: Dict[str, Any] + ) -> "JacMetaCommon": + return super().__new__(cls, name, bases, dct, _Jac.make_obj) + + +class JacMetaWalker(JacMetaCommon, ABC): + def __new__( + cls, name: str, bases: Tuple[Type, ...], dct: Dict[str, Any] + ) -> "JacMetaCommon": + return super().__new__(cls, name, bases, dct, _Jac.make_walker) + + +class JacMetaNode(JacMetaCommon, ABC): + def __new__( + cls, name: str, bases: Tuple[Type, ...], dct: Dict[str, Any] + ) -> "JacMetaCommon": + return super().__new__(cls, name, bases, dct, _Jac.make_node) + + +class JacMetaEdge(JacMetaCommon, ABC): + def __new__( + cls, name: str, bases: Tuple[Type, ...], dct: Dict[str, Any] + ) -> "JacMetaCommon": + return super().__new__(cls, name, bases, dct, _Jac.make_edge) + + +# ---------------------------------------------------------------------------- +# Base classes. +# ---------------------------------------------------------------------------- + + +class JacObj(_JacArchiTypeBase, metaclass=JacMetaObj): + """Base class for all the jac object types.""" + + def __init__(self, *args, **kwargs) -> None: # noqa: ANN002, ANN003 + """Initialize Jac architype base.""" + + +class JacWalker(_JacArchiTypeBase, metaclass=JacMetaWalker): + """Base class for all the jac walker types.""" + + def __init__(self, *args, **kwargs) -> None: # noqa: ANN002, ANN003 + """Initialize Jac architype base.""" + + def spawn(self, node: "JacNode") -> None: # REVIEW: What should it return? + """Spawn a new node from the walker.""" + _Jac.spawn_call(self, node) # type: ignore [arg-type] + + def visit( + self, + expr: ( + list["JacNode" | "JacEdge"] + | list["JacNode"] + | list["JacEdge"] + | "JacNode" + | "JacEdge" + ), + ) -> None: # noqa: ANN401 + """Visit nodes.""" + _Jac.visit_node(self, expr) # type: ignore [arg-type] + + +class JacNode(_JacArchiTypeBase, metaclass=JacMetaNode): + """Base class for all the jac node types.""" + + def __init__(self, *args, **kwargs) -> None: # noqa: ANN002, ANN003 + """Initialize Jac architype base.""" + + def spawn(self, walker: JacWalker) -> None: # REVIEW: What should it return? + """Spawn a new node from the walker.""" + _Jac.spawn_call(self, walker) # type: ignore [arg-type] + + def connect( + self, + node: "JacNode", + edge: "type[JacEdge] | JacEdge | None" = None, + unidir: bool = False, + ) -> "JacNode": + """Connect the current node to another node.""" + # TODO: The above edge type should be reviewed, as the bellow can also take None, Edge, type[Edge]. + _Jac.connect( + left=self, # type: ignore [arg-type] + right=node, # type: ignore [arg-type] + edge_spec=_Jac.build_edge( + is_undirected=unidir, conn_type=edge, conn_assign=None # type: ignore [arg-type] + ), + ) + return node + + def connected( + self, + dir: _Jac.EdgeDir, + filter_edge: Optional["JacEdge"], + target: Optional["JacNode"], + edges_only: bool = False, + ) -> list["JacNode"] | list["JacEdge"]: + """Return connected nodes or edges.""" + filter_func = lambda x: ( + [i for i in x if isinstance(i, filter_edge)] if filter_edge else None + ) + return _Jac.edge_ref( + node_obj=self, + target_obj=target, + dir=dir, + filter_func=filter_func, + edges_only=edges_only, + ) + + +class JacEdge(_JacArchiTypeBase, metaclass=JacMetaEdge): + """Base class for all the jac edge types.""" + + def __init__(self, *args, **kwargs) -> None: # noqa: ANN002, ANN003 + """Initialize Jac architype base.""" + + +# ---------------------------------------------------------------------------- +# Decorators. +# ---------------------------------------------------------------------------- + + +def with_entry(func: Callable) -> Callable: + """Mark a method as jac entry with this decorator.""" + # Ensure the functioin has 2 parameters (self, here). + sig = inspect.signature(func, eval_str=True) + param_count = len(sig.parameters) + if param_count != 2: + raise ValueError("Jac entry function must have exactly 2 parameters.") + + # Get the entry node from the type hints. + second_param_name = list(sig.parameters.keys())[1] + entry_node = typing.get_type_hints(func).get(second_param_name) + + # Mark the function as jac entry. + setattr(func, "__jac_entry", entry_node) # noqa: B010 + return func + + +def with_exit(func: Callable) -> Callable: + """Mark a method as jac exit with this decorator.""" + # Ensure the functioin has 2 parameters (self, here). + sig = inspect.signature(func, eval_str=True) + param_count = len(sig.parameters) + if param_count != 2: + raise ValueError("Jac exit function must have exactly 2 parameters.") + + # Get the entry node from the type hints. + second_param_name = list(sig.parameters.keys())[1] + exit_node = typing.get_type_hints(func).get(second_param_name) + + # Mark the function as jac entry. + setattr(func, "__jac_exit", exit_node) # noqa: B010 + return func + + +# ---------------------------------------------------------------------------- +# Functions. +# ---------------------------------------------------------------------------- + +jac_import = _Jac.jac_import +root = _Jac.get_root() + +root.spawn = types.MethodType(JacNode.spawn, root) +root.connect = types.MethodType(JacNode.connect, root) diff --git a/jac/jaclang/plugin/default.py b/jac/jaclang/plugin/default.py index eb58e94a01..46d93fb3b2 100644 --- a/jac/jaclang/plugin/default.py +++ b/jac/jaclang/plugin/default.py @@ -597,11 +597,15 @@ def make_architype( if not hasattr(cls, "_jac_entry_funcs_") or not hasattr( cls, "_jac_exit_funcs_" ): - # Saving the module path and reassign it after creating cls - # So the jac modules are part of the correct module - cur_module = cls.__module__ - cls = type(cls.__name__, (cls, arch_base), {}) - cls.__module__ = cur_module + # If a class only inherit from object (ie. Doesn't inherit from a class), we cannot modify + # the __bases__ property of it, so it's necessary to make sure the class is not a direct child of object. + assert cls.__bases__ != (object,) + bases = ( + (cls.__bases__ + (arch_base,)) + if arch_base not in cls.__bases__ + else cls.__bases__ + ) + cls.__bases__ = bases cls._jac_entry_funcs_ = on_entry # type: ignore cls._jac_exit_funcs_ = on_exit # type: ignore else: diff --git a/jac/jaclang/runtimelib/architype.py b/jac/jaclang/runtimelib/architype.py index 082065fc9f..d01db2cf00 100644 --- a/jac/jaclang/runtimelib/architype.py +++ b/jac/jaclang/runtimelib/architype.py @@ -7,7 +7,7 @@ from enum import IntEnum from logging import getLogger from pickle import dumps -from types import UnionType +from types import MethodType, UnionType from typing import Any, Callable, ClassVar, Optional, TypeVar from uuid import UUID, uuid4 @@ -279,12 +279,19 @@ class GenericEdge(EdgeArchitype): class Root(NodeArchitype): """Generic Root Node.""" + # We define the 'spawn' and 'connect' here which will be added to the root instance + # as method bound. This slots definition here will allow the type checker to + # assign dynamic attributes. + __slots__ = ("__jac__", "spawn", "connect") + _jac_entry_funcs_: ClassVar[list[DSFunc]] = [] _jac_exit_funcs_: ClassVar[list[DSFunc]] = [] def __init__(self) -> None: """Create root node.""" self.__jac__ = NodeAnchor(architype=self, persistent=True, edges=[]) + self.spawn: MethodType | None = None + self.connect: MethodType | None = None @dataclass(eq=False) @@ -304,7 +311,17 @@ def get_funcparam_annotations( """Get function parameter annotations.""" if not func: return None + + sig = inspect.signature(func, eval_str=True) + param_count = len(sig.parameters) + + if param_count < 2: + return None + + second_param_name = list(sig.parameters.keys())[1] # "_jac_here_" annotation = ( - inspect.signature(func, eval_str=True).parameters["_jac_here_"].annotation + inspect.signature(func, eval_str=True) + .parameters[second_param_name] + .annotation ) return annotation if annotation != inspect._empty else None