Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE]: JID #1473

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
307 changes: 137 additions & 170 deletions jac-cloud/jac_cloud/core/architype.py

Large diffs are not rendered by default.

26 changes: 16 additions & 10 deletions jac-cloud/jac_cloud/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
Anchor,
AnchorState,
BaseArchitype,
JID,
JacCloudJID,
NodeAnchor,
Permission,
Root,
Expand All @@ -29,8 +31,8 @@

SUPER_ROOT_ID = ObjectId("000000000000000000000000")
PUBLIC_ROOT_ID = ObjectId("000000000000000000000001")
SUPER_ROOT = NodeAnchor.ref(f"n::{SUPER_ROOT_ID}")
PUBLIC_ROOT = NodeAnchor.ref(f"n::{PUBLIC_ROOT_ID}")
SUPER_ROOT_JID = JacCloudJID[NodeAnchor](f"n::{SUPER_ROOT_ID}")
PUBLIC_ROOT_JID = JacCloudJID[NodeAnchor](f"n::{PUBLIC_ROOT_ID}")

RT = TypeVar("RT")

Expand Down Expand Up @@ -69,7 +71,7 @@ def close(self) -> None:
self.mem.close()

@staticmethod
def create(request: Request, entry: NodeAnchor | None = None) -> "JaseciContext": # type: ignore[override]
def create(request: Request, entry: str | None = None) -> "JaseciContext": # type: ignore[override]
"""Create JacContext."""
ctx = JaseciContext()
ctx.base = ExecutionContext.get()
Expand All @@ -78,7 +80,9 @@ def create(request: Request, entry: NodeAnchor | None = None) -> "JaseciContext"
ctx.reports = []
ctx.status = 200

if not isinstance(system_root := ctx.mem.find_by_id(SUPER_ROOT), NodeAnchor):
if not isinstance(
system_root := ctx.mem.find_by_id(SUPER_ROOT_JID), NodeAnchor
):
system_root = NodeAnchor(
architype=object.__new__(Root),
id=SUPER_ROOT_ID,
Expand All @@ -90,16 +94,16 @@ def create(request: Request, entry: NodeAnchor | None = None) -> "JaseciContext"
system_root.architype.__jac__ = system_root
NodeAnchor.Collection.insert_one(system_root.serialize())
system_root.sync_hash()
ctx.mem.set(system_root.id, system_root)
ctx.mem.set(system_root.jid, system_root)

ctx.system_root = system_root

if _root := getattr(request, "_root", None):
ctx.root = _root
ctx.mem.set(_root.id, _root)
ctx.mem.set(_root.jid, _root)
else:
if not isinstance(
public_root := ctx.mem.find_by_id(PUBLIC_ROOT), NodeAnchor
public_root := ctx.mem.find_by_id(PUBLIC_ROOT_JID), NodeAnchor
):
public_root = NodeAnchor(
architype=object.__new__(Root),
Expand All @@ -110,13 +114,13 @@ def create(request: Request, entry: NodeAnchor | None = None) -> "JaseciContext"
edges=[],
)
public_root.architype.__jac__ = public_root
ctx.mem.set(public_root.id, public_root)
ctx.mem.set(public_root.jid, public_root)

ctx.root = public_root

if entry:
if not isinstance(entry_node := ctx.mem.find_by_id(entry), NodeAnchor):
raise ValueError(f"Invalid anchor id {entry.ref_id} !")
if not (entry_node := ctx.mem.find_by_id(JacCloudJID[NodeAnchor](entry))):
raise ValueError(f"Invalid anchor id {entry} !")
ctx.entry_node = entry_node
else:
ctx.entry_node = ctx.root
Expand Down Expand Up @@ -167,6 +171,8 @@ def clean_response(
case dict():
for key, dval in val.items():
self.clean_response(key, dval, val)
case JID():
cast(dict, obj)[key] = str(val)
case Anchor():
cast(dict, obj)[key] = val.report()
case BaseArchitype():
Expand Down
123 changes: 68 additions & 55 deletions jac-cloud/jac_cloud/core/memory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Memory abstraction for jaseci plugin."""

from dataclasses import dataclass
from dataclasses import dataclass, field
from os import getenv
from typing import Callable, Generator, Iterable, TypeVar, cast
from typing import Callable, Generator, Iterable, TypeVar

from bson import ObjectId

Expand All @@ -14,10 +14,9 @@
from pymongo.client_session import ClientSession

from .architype import (
Anchor,
BaseAnchor,
BulkWrite,
EdgeAnchor,
JacCloudJID,
NodeAnchor,
ObjectAnchor,
Root,
Expand All @@ -27,78 +26,91 @@

DISABLE_AUTO_CLEANUP = getenv("DISABLE_AUTO_CLEANUP") == "true"
SINGLE_QUERY = getenv("SINGLE_QUERY") == "true"
IDS = ObjectId | Iterable[ObjectId]
BA = TypeVar("BA", bound="BaseAnchor")

_ANCHOR = TypeVar("_ANCHOR", NodeAnchor, EdgeAnchor, WalkerAnchor, ObjectAnchor)


@dataclass
class MongoDB(Memory[ObjectId, BaseAnchor | Anchor]):
class MongoDB(Memory):
"""Shelf Handler."""

__mem__: dict[
JacCloudJID, NodeAnchor | EdgeAnchor | WalkerAnchor | ObjectAnchor
] = field(
default_factory=dict
) # type: ignore[assignment]
__gc__: set[JacCloudJID] = field(default_factory=set) # type: ignore[assignment]
__session__: ClientSession | None = None

def populate_data(self, edges: Iterable[EdgeAnchor]) -> None:
def populate_data(self, edges: Iterable[JacCloudJID[EdgeAnchor]]) -> None:
"""Populate data to avoid multiple query."""
if not SINGLE_QUERY:
nodes: set[NodeAnchor] = set()
nodes: set[JacCloudJID] = set()
for edge in self.find(edges):
if edge.source:
nodes.add(edge.source)
if edge.target:
nodes.add(edge.target)
nodes.add(edge.source)
nodes.add(edge.target)
self.find(nodes)

def find( # type: ignore[override]
self,
anchors: BA | Iterable[BA],
filter: Callable[[Anchor], Anchor] | None = None,
ids: JacCloudJID[_ANCHOR] | Iterable[JacCloudJID[_ANCHOR]],
filter: Callable[[_ANCHOR], _ANCHOR] | None = None,
session: ClientSession | None = None,
) -> Generator[BA, None, None]:
) -> Generator[_ANCHOR, None, None]:
"""Find anchors from datasource by ids with filter."""
if not isinstance(anchors, Iterable):
anchors = [anchors]

collections: dict[type[Collection[BaseAnchor]], list[ObjectId]] = {}
for anchor in anchors:
if anchor.id not in self.__mem__ and anchor not in self.__gc__:
coll = collections.get(anchor.Collection)
if not isinstance(ids, Iterable):
ids = [ids]

collections: dict[
type[
Collection[NodeAnchor]
| Collection[EdgeAnchor]
| Collection[WalkerAnchor]
| Collection[ObjectAnchor]
],
list[ObjectId],
] = {}
for jid in ids:
if jid not in self.__mem__ and jid not in self.__gc__:
coll = collections.get(jid.type.Collection)
if coll is None:
coll = collections[anchor.Collection] = []
coll = collections[jid.type.Collection] = []

coll.append(anchor.id)
coll.append(jid.id)

for cl, ids in collections.items():
for cl, oids in collections.items():
for anch_db in cl.find(
{
"_id": {"$in": ids},
"_id": {"$in": oids},
},
session=session or self.__session__,
):
self.__mem__[anch_db.id] = anch_db
self.__mem__[anch_db.jid] = anch_db

for anchor in anchors:
for jid in ids:
if (
anchor not in self.__gc__
and (anch_mem := self.__mem__.get(anchor.id))
and (not filter or filter(anch_mem)) # type: ignore[arg-type]
jid not in self.__gc__
and (anch_mem := self.__mem__.get(jid))
and isinstance(anch_mem, jid.type)
and (not filter or filter(anch_mem))
):
yield cast(BA, anch_mem)
yield anch_mem

def find_one( # type: ignore[override]
self,
anchors: BA | Iterable[BA],
filter: Callable[[Anchor], Anchor] | None = None,
ids: JacCloudJID[_ANCHOR] | Iterable[JacCloudJID[_ANCHOR]],
filter: Callable[[_ANCHOR], _ANCHOR] | None = None,
session: ClientSession | None = None,
) -> BA | None:
) -> _ANCHOR | None:
"""Find one anchor from memory by ids with filter."""
return next(self.find(anchors, filter, session), None)
return next(self.find(ids, filter, session), None)

def find_by_id(self, anchor: BA) -> BA | None:
def find_by_id(self, id: JacCloudJID[_ANCHOR]) -> _ANCHOR | None: # type: ignore[override]
"""Find one by id."""
data = super().find_by_id(anchor.id)
data = super().find_by_id(id)

if not data and (data := anchor.Collection.find_by_id(anchor.id)):
self.__mem__[data.id] = data
if not data and (data := id.type.Collection.find_by_id(id.id)):
self.__mem__[data.jid] = data

return data

Expand All @@ -115,7 +127,9 @@ def close(self) -> None:

super().close()

def sync_mem_to_db(self, bulk_write: BulkWrite, keys: Iterable[ObjectId]) -> None:
def sync_mem_to_db(
self, bulk_write: BulkWrite, keys: Iterable[JacCloudJID]
) -> None:
"""Manually sync memory to db."""
for key in keys:
if (
Expand Down Expand Up @@ -146,19 +160,18 @@ def sync_mem_to_db(self, bulk_write: BulkWrite, keys: Iterable[ObjectId]) -> Non
def get_bulk_write(self) -> BulkWrite:
"""Sync memory to database."""
bulk_write = BulkWrite()

for anchor in self.__gc__:
match anchor:
case NodeAnchor():
bulk_write.del_node(anchor.id)
case EdgeAnchor():
bulk_write.del_edge(anchor.id)
case WalkerAnchor():
bulk_write.del_walker(anchor.id)
case ObjectAnchor():
bulk_write.del_object(anchor.id)
case _:
pass
for jid in self.__gc__:
self.__mem__.pop(jid, None)
# match case doesn't work yet with
# type checking for type (not instance)
if jid.type is NodeAnchor:
bulk_write.del_node(jid.id)
elif jid.type is EdgeAnchor:
bulk_write.del_edge(jid.id)
elif jid.type is WalkerAnchor:
bulk_write.del_walker(jid.id)
elif jid.type is ObjectAnchor:
bulk_write.del_object(jid.id)

keys = set(self.__mem__.keys())

Expand Down
Loading
Loading