Skip to content
This repository has been archived by the owner on Sep 12, 2024. It is now read-only.

Commit

Permalink
[MINOR]: Address PR Comments
Browse files Browse the repository at this point in the history
  • Loading branch information
amadolid committed Jul 18, 2024
1 parent bf54383 commit 410bebb
Show file tree
Hide file tree
Showing 16 changed files with 52 additions and 1,620 deletions.
81 changes: 38 additions & 43 deletions jaclang/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import ast as ast3
import importlib
import inspect
import marshal
import os
import pickle
Expand All @@ -18,7 +17,7 @@
from jaclang.compiler.passes.main.pyast_load_pass import PyastBuildPass
from jaclang.compiler.passes.main.schedules import py_code_gen_typed
from jaclang.compiler.passes.tool.schedules import format_pass
from jaclang.core.constructs import Anchor, NodeAnchor
from jaclang.core.constructs import Anchor, NodeAnchor, WalkerArchitype
from jaclang.plugin.builtin import dotgen
from jaclang.plugin.feature import JacCmd as Cmd
from jaclang.plugin.feature import JacFeature as Jac
Expand Down Expand Up @@ -66,13 +65,7 @@ def format_file(filename: str) -> None:

@cmd_registry.register
def run(
filename: str,
session: str = "",
main: bool = True,
cache: bool = True,
walker: str = "",
node: str = "",
root: str = "",
filename: str, session: str = "", main: bool = True, cache: bool = True
) -> None:
"""Run the specified .jac file."""
# if no session specified, check if it was defined when starting the command shell
Expand All @@ -86,58 +79,32 @@ def run(
else ""
)

jctx = Jac.context(
{
"session": session,
"root": NodeAnchor.ref(root),
"entry": NodeAnchor.ref(node),
}
)
jctx = Jac.context({"session": session})

base, mod = os.path.split(filename)
base = base if base else "./"
mod = mod[:-4]
if filename.endswith(".jac"):
ret_module = jac_import(
jac_import(
target=mod,
base_path=base,
cachable=cache,
override_name="__main__" if main else None,
)
if ret_module is None:
loaded_mod = None
else:
(loaded_mod,) = ret_module
elif filename.endswith(".jir"):
with open(filename, "rb") as f:
ir = pickle.load(f)
ret_module = jac_import(
jac_import(
target=mod,
base_path=base,
cachable=cache,
override_name="__main__" if main else None,
mod_bundle=ir,
)
if ret_module is None:
loaded_mod = None
else:
(loaded_mod,) = ret_module
else:
print("Not a .jac file.")
return

# TODO: handle no override name
if walker:
walker_module = dict(inspect.getmembers(loaded_mod)).get(walker)
if (
walker_module
and jctx.validate_access()
and (architype := jctx.entry.architype)
):
Jac.spawn_call(architype, walker_module())
else:
print(f"Walker {walker} not found.")

jctx.close()


Expand Down Expand Up @@ -207,13 +174,29 @@ def lsp() -> None:


@cmd_registry.register
def enter(filename: str, entrypoint: str, args: list) -> None:
"""Run the specified entrypoint function in the given .jac file.
def enter(
filename: str,
session: str = "",
walker: str = "",
node: str = "",
root: str = "",
args: Optional[list] = None,
) -> None:
"""
Run the specified entrypoint function in the given .jac file.
:param filename: The path to the .jac file.
:param entrypoint: The name of the entrypoint function.
:param args: Arguments to pass to the entrypoint function.
"""
jctx = Jac.context(
{
"session": session,
"root": NodeAnchor.ref(root),
"entry": NodeAnchor.ref(node),
}
)

if filename.endswith(".jac"):
base, mod_name = os.path.split(filename)
base = base if base else "./"
Expand All @@ -223,10 +206,24 @@ def enter(filename: str, entrypoint: str, args: list) -> None:
print("Errors occurred while importing the module.")
return
else:
getattr(mod, entrypoint)(*args)
walker_architype: WalkerArchitype | None = getattr(mod[0], walker)(
*args or []
)
if (
walker_architype
and jctx.validate_access()
and (architype := jctx.entry.architype)
):
Jac.spawn_call(architype, walker_architype)
else:
print(f"Invalid Walker {walker} execution.")

getattr(mod[0], walker)(*args or [])
else:
print("Not a .jac file.")

jctx.close()


@cmd_registry.register
def test(
Expand Down Expand Up @@ -444,8 +441,6 @@ def start_cli() -> None:
args_dict = vars(args)
args_dict.pop("command")
args_dict.pop("version", None)
if command.func.__name__ != "run":
args_dict.pop("session")
ret = command.call(**args_dict)
if ret:
print(ret)
Expand Down
21 changes: 11 additions & 10 deletions jaclang/core/architype.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

from jaclang.compiler.constant import EdgeDir
from jaclang.core.utils import collect_node_connections
from jaclang.vendor.orjson import dumps

from orjson import dumps

GENERIC_ID_REGEX = compile(
r"^(g|n|e|w):([^:]*):([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})$",
Expand Down Expand Up @@ -196,7 +197,7 @@ def sync(self, node: Optional["NodeAnchor"] = None) -> Optional[Architype]:

from .context import ExecutionContext

jsrc = ExecutionContext.get().datasource
jsrc = ExecutionContext.get_or_create().datasource
anchor = jsrc.find_one(self.id)

if anchor and (node or self).has_read_access(anchor):
Expand All @@ -208,7 +209,7 @@ def allocate(self) -> None:
"""Allocate hashes and memory."""
from .context import ExecutionContext

jctx = ExecutionContext.get()
jctx = ExecutionContext.get_or_create()
self.root = jctx.root.id
jctx.datasource.set(self, True)

Expand All @@ -228,7 +229,7 @@ def access_level(self, to: Anchor) -> int:
"""Access validation."""
from .context import ExecutionContext

jctx = ExecutionContext.get()
jctx = ExecutionContext.get_or_create()
jroot = jctx.root
to.current_access_level = -1

Expand Down Expand Up @@ -347,7 +348,7 @@ def ref(cls, ref_id: str) -> Optional[NodeAnchor]:
def _save(self) -> None:
from .context import ExecutionContext

jsrc = ExecutionContext.get().datasource
jsrc = ExecutionContext.get_or_create().datasource

for edge in self.edges:
edge.save()
Expand All @@ -359,7 +360,7 @@ def destroy(self) -> None:
if self.architype and self.current_access_level > 1:
from .context import ExecutionContext

jsrc = ExecutionContext.get().datasource
jsrc = ExecutionContext.get_or_create().datasource
for edge in self.edges:
edge.destroy()

Expand Down Expand Up @@ -507,7 +508,7 @@ def ref(cls, ref_id: str) -> Optional[EdgeAnchor]:
def _save(self) -> None:
from .context import ExecutionContext

jsrc = ExecutionContext.get().datasource
jsrc = ExecutionContext.get_or_create().datasource

if source := self.source:
source.save()
Expand All @@ -522,7 +523,7 @@ def destroy(self) -> None:
if self.architype and self.current_access_level == 1:
from .context import ExecutionContext

jsrc = ExecutionContext.get().datasource
jsrc = ExecutionContext.get_or_create().datasource

source = self.source
target = self.target
Expand Down Expand Up @@ -601,14 +602,14 @@ def ref(cls, ref_id: str) -> Optional[WalkerAnchor]:
def _save(self) -> None:
from .context import ExecutionContext

ExecutionContext.get().datasource.set(self)
ExecutionContext.get_or_create().datasource.set(self)

def destroy(self) -> None:
"""Delete Anchor."""
if self.architype and self.current_access_level > 1:
from .context import ExecutionContext

ExecutionContext.get().datasource.remove(self)
ExecutionContext.get_or_create().datasource.remove(self)

def sync(self, node: Optional["NodeAnchor"] = None) -> Optional[WalkerArchitype]:
"""Retrieve the Architype from db and return."""
Expand Down
2 changes: 1 addition & 1 deletion jaclang/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def validate_access(self) -> bool:
return self.root.has_read_access(self.entry)

@staticmethod
def get(options: Optional[dict[str, Any]] = None) -> ExecutionContext:
def get_or_create(options: Optional[dict[str, Any]] = None) -> ExecutionContext:
"""Get or create execution context."""
if not isinstance(ctx := EXECUTION_CONTEXT.get(None), ExecutionContext):
EXECUTION_CONTEXT.set(ctx := ExecutionContext(**options or {}))
Expand Down
2 changes: 1 addition & 1 deletion jaclang/plugin/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class JacFeatureDefaults:
@hookimpl
def context(options: Optional[dict[str, Any]]) -> ExecutionContext:
"""Get the execution context."""
return ExecutionContext.get(options)
return ExecutionContext.get_or_create(options)

@staticmethod
@hookimpl
Expand Down
1 change: 0 additions & 1 deletion jaclang/vendor/orjson-3.10.6.dist-info/INSTALLER

This file was deleted.

Loading

0 comments on commit 410bebb

Please sign in to comment.