diff --git a/jac/jaclang/lib.py b/jac/jaclang/lib.py new file mode 100644 index 0000000000..f6fb82b3ee --- /dev/null +++ b/jac/jaclang/lib.py @@ -0,0 +1,123 @@ +"""Jac python interface.""" + +import inspect +import typing + +from .plugin.feature import JacFeature as _Jac +from dataclasses import dataclass as __jac_dataclass__ + + +# ---------------------------------------------------------------------------- +# Meta classes. +# ---------------------------------------------------------------------------- + +class JacMetaCommon(type): + """Common metaclass for Jac types.""" + + def __new__(cls, name, bases, dct, make_func): + + on_entry, on_exit = [], [] + for name, value in dct.items(): + if hasattr(value, "__jac_entry"): + entry_node = getattr(value, "__jac_entry") + on_entry.append(_Jac.DSFunc(value.__name__, entry_node)) + if hasattr(value, "__jac_exit"): + exit_node = getattr(value, "__jac_exit") + on_exit.append(_Jac.DSFunc(value.__name__, exit_node)) + + cls = super().__new__(cls, name, bases, dct) + cls = __jac_dataclass__(eq=False)(cls) + cls = make_func(on_entry=on_entry, on_exit=on_exit)(cls) + return cls + +class JacMetaObj(JacMetaCommon): + def __new__(cls, name, bases, dct): + return super().__new__(cls, name, bases, dct, _Jac.make_obj) + +class JacMetaWalker(JacMetaCommon): + def __new__(cls, name, bases, dct): + return super().__new__(cls, name, bases, dct, _Jac.make_walker) + +class JacMetaNode(JacMetaCommon): + def __new__(cls, name, bases, dct): + return super().__new__(cls, name, bases, dct, _Jac.make_node) + +class JacMetaEdge(JacMetaCommon): + def __new__(cls, name, bases, dct): + return super().__new__(cls, name, bases, dct, _Jac.make_edge) + +# ---------------------------------------------------------------------------- +# Base classes. +# ---------------------------------------------------------------------------- + +# https://stackoverflow.com/a/9639512/10846399 +class _JacArchiTypeBase: + pass + +class JacObj(_JacArchiTypeBase, metaclass=JacMetaObj): + pass + +class JacWalker(_JacArchiTypeBase, metaclass=JacMetaWalker): + def spawn(self, node): + _Jac.spawn_call(self, node) + +class JacNode(_JacArchiTypeBase, metaclass=JacMetaNode): + def spawn(self, walker): + _Jac.spawn_call(self, walker) + +class JacEdge(_JacArchiTypeBase, metaclass=JacMetaEdge): + pass + +# ---------------------------------------------------------------------------- +# Decorators. +# ---------------------------------------------------------------------------- + +def with_entry(func): + """Decorator to mark a method as jac entry.""" + + # Ensure the functioin has 2 parameters (self, here). + sig = inspect.signature(func) + param_count = len(sig.parameters) + if param_count != 2: + raise ValueError("Jac entry function must have exactly 2 parameters.") + + # Get the entry node from the type hints. + second_param_name = list(sig.parameters.keys())[1] + entry_node = typing.get_type_hints(func).get(second_param_name, None) + + # Mark the function as jac entry. + setattr(func, "__jac_entry", entry_node) + return func + + +def with_exit(func): + """Decorator to mark a method as jac exit.""" + + # Ensure the functioin has 2 parameters (self, here). + sig = inspect.signature(func) + param_count = len(sig.parameters) + if param_count != 2: + raise ValueError("Jac exit function must have exactly 2 parameters.") + + # Get the entry node from the type hints. + second_param_name = list(sig.parameters.keys())[1] + exit_node = typing.get_type_hints(func).get(second_param_name, None) + + # Mark the function as jac entry. + setattr(func, "__jac_exit", exit_node) + return func + + +# ---------------------------------------------------------------------------- +# +# ---------------------------------------------------------------------------- + +spawn = _Jac.spawn_call + +def connect(node_from, node_to, edge=None, unidir=False): + return _Jac.connect( + left=node_from, + right=node_to, + edge_spec=_Jac.build_edge(is_undirected=unidir, + conn_type=edge, + conn_assign=None)) diff --git a/jac/jaclang/plugin/default.py b/jac/jaclang/plugin/default.py index 4a65eff210..3b1ca05873 100644 --- a/jac/jaclang/plugin/default.py +++ b/jac/jaclang/plugin/default.py @@ -591,11 +591,12 @@ def make_architype( if not hasattr(cls, "_jac_entry_funcs_") or not hasattr( cls, "_jac_exit_funcs_" ): - # Saving the module path and reassign it after creating cls - # So the jac modules are part of the correct module - cur_module = cls.__module__ - cls = type(cls.__name__, (cls, arch_base), {}) - cls.__module__ = cur_module + bases = ( + (cls.__bases__ + (arch_base,)) + if arch_base not in cls.__bases__ + else cls.__bases__ + ) + cls.__bases__ = bases cls._jac_entry_funcs_ = on_entry # type: ignore cls._jac_exit_funcs_ = on_exit # type: ignore else: diff --git a/jac/main.py b/jac/main.py new file mode 100644 index 0000000000..99ca37e763 --- /dev/null +++ b/jac/main.py @@ -0,0 +1,118 @@ + +from __future__ import annotations +from jaclang.lib import * + + + +from jaclang.plugin.feature import JacFeature as _Jac +from jaclang.plugin.builtin import * +from dataclasses import dataclass as __jac_dataclass__ + +""" + +class Foo(JacObj): + a1: int = 1 + a2: int = 2 + + def do_something(self): + print("Doing something") + +class MyNode(JacNode): + pass + +class MyWalker(JacWalker): + + @with_entry + def do_entry(self, here): + print("here = ", here) + +a_walker = MyWalker() +a_node = MyNode() + + +a_walker.spawn(a_node) +spawn(a_walker, a_node) + +# _Jac.spawn_call(w, n) + +# @_Jac.make_walker(on_entry=[_Jac.DSFunc('do_entry', MyNode)], on_exit=[]) +# @__jac_dataclass__(eq=False) +# class MyWalker(_Jac.Walker): + +# def do_entry(self, _jac_here_: MyNode) -> None: +# print('here =', _jac_here_) + +# print(Foo()) +""" + + + +class MyNode(JacNode): + val: int = 0 + +class a(JacEdge): + pass + +class b(JacEdge): + pass + +class c(JacEdge): + pass + +Start = MyNode(5) + +root = _Jac.get_root() + +connect(root, Start, a) # _Jac.connect(left=_Jac.get_root(), right=Start, edge_spec=_Jac.build_edge(is_undirected=False, conn_type=a, conn_assign=None)) + +i1 = MyNode(10) +connect(Start, i1, b) +connect(i1, MyNode(15), c) + +i2 = MyNode(20) +connect(Start, i2, b) +connect(i2, MyNode(25), a) + +print(_Jac.edge_ref(_Jac.get_root(), target_obj=None, dir=_Jac.EdgeDir.OUT, filter_func=None, edges_only=False)) +print(_Jac.edge_ref(_Jac.get_root(), target_obj=None, dir=_Jac.EdgeDir.IN, filter_func=None, edges_only=False)) +print(_Jac.edge_ref(_Jac.get_root(), target_obj=None, dir=_Jac.EdgeDir.OUT, filter_func=lambda x: [i for i in x if isinstance(i, a)], edges_only=False)) +print(_Jac.edge_ref(_Jac.edge_ref(_Jac.get_root(), target_obj=None, dir=_Jac.EdgeDir.OUT, filter_func=lambda x: [i for i in x if isinstance(i, a)], edges_only=False), target_obj=None, dir=_Jac.EdgeDir.OUT, filter_func=lambda x: [i for i in x if isinstance(i, b)], edges_only=False)) +print(_Jac.edge_ref(_Jac.edge_ref(_Jac.edge_ref(_Jac.get_root(), target_obj=None, dir=_Jac.EdgeDir.OUT, filter_func=lambda x: [i for i in x if isinstance(i, a)], edges_only=False), target_obj=None, dir=_Jac.EdgeDir.OUT, filter_func=lambda x: [i for i in x if isinstance(i, b)], edges_only=False), target_obj=None, dir=_Jac.EdgeDir.OUT, filter_func=lambda x: [i for i in x if isinstance(i, c)], edges_only=False)) + + +# from __future__ import annotations +# from jaclang.plugin.feature import JacFeature as _Jac +# from jaclang.plugin.builtin import * +# from dataclasses import dataclass as __jac_dataclass__ +# +# +# class Meta(type): +# +# def __new__(cls, name, bases, dct): +# cls = super().__new__(cls, name, bases, dct) +# cls = __jac_dataclass__(eq=False)(cls) +# cls = _Jac.make_node(on_entry=[], on_exit=[])(cls) +# return cls +# +# +# class Base(metaclass=Meta): +# pass +# +# +# class Foo(Base): +# pass +# +# print(Foo()) + +# @_Jac.make_node(on_entry=[], on_exit=[]) +# @__jac_dataclass__(eq=False) +# class MyNode(_Jac.Node): +# pass +# +# # Print all the parent class names of MyNode +# def print_parents(node): +# print(node.__class__.__name__) +# if hasattr(node, 'parent'): +# print_parents(node.parent) +# +# print_parents(MyNode())