Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add node write cache support and implement for Postgres checkpointer #2786

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import threading
from contextlib import contextmanager
from typing import Any, Iterator, Optional, Sequence, Union
from typing import Any, Iterator, Optional, Sequence, Union, List, Dict

from langchain_core.runnables import RunnableConfig
from psycopg import Capabilities, Connection, Cursor, Pipeline
Expand Down Expand Up @@ -180,6 +180,38 @@ def list(
self._load_writes(value["pending_writes"]),
)

def get_writes(self, task_ids: List[str]) -> Dict[str, List[Any]]:
"""Retrieve pending writes given a list of task_id. Used for retrieving pending writes for cached nodes.

Args:
task_ids: a list of a task identifiers, usually generated from a cache key for a task deriving from a cached node

Returns:
a dictionary mapping each task_id to a list of corresponding decoded pending writes
"""

query = """
SELECT
task_id,
array_agg(array[task_id::text::bytea, channel::bytea, type::bytea, blob] ORDER BY task_id, idx) AS pending_writes
FROM checkpoint_writes
WHERE task_id = ANY(%s)
GROUP BY 1
"""

args = (task_ids,)

results = {}
with self._cursor() as cur:
cur.execute(query, args, binary=True)

for row in cur:
results[row["task_id"]] = self._load_writes(row["pending_writes"])
return results




def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Get a checkpoint tuple from the database.

Expand Down
15 changes: 15 additions & 0 deletions libs/checkpoint/langgraph/checkpoint/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def get(self, config: RunnableConfig) -> Optional[Checkpoint]:
"""
if value := self.get_tuple(config):
return value.checkpoint


def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Fetch a checkpoint tuple using the given configuration.
Expand All @@ -252,6 +253,20 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
NotImplementedError: Implement this method in your custom checkpoint saver.
"""
raise NotImplementedError

def get_writes(self, task_id: str, ttl: int = None) -> Optional[List[Any]]:
""""Retrieve pending writes given a task_id and optionally a ttl (in seconds). Used for retrieving pending writes for cached nodes.
If ttl is specified, only retrieves pending writes such that the timestamp of the corresponding checkpoint is within ttl seconds
of the query.

Args:
task_id: a task identifier, usually generated from a cache key for a task deriving from a cached node
ttl: the time to live of the cached writes in seconds

Returns:
a list of decoded pending writes for the specified task id, optionally recorded within ttl seconds of this query being made"""

raise NotImplementedError

def list(
self,
Expand Down
5 changes: 4 additions & 1 deletion libs/checkpoint/langgraph/checkpoint/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from contextlib import AbstractAsyncContextManager, AbstractContextManager, ExitStack
from functools import partial
from types import TracebackType
from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple, Type
from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple, Type, List

from langchain_core.runnables import RunnableConfig

Expand Down Expand Up @@ -105,6 +105,9 @@ async def __aexit__(
__traceback: Optional[TracebackType],
) -> Optional[bool]:
return self.stack.__exit__(__exc_type, __exc_value, __traceback)

def get_writes(self, task_id: str) -> Optional[List[Any]]:
pass

def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Get a checkpoint tuple from the in-memory storage.
Expand Down
8 changes: 7 additions & 1 deletion libs/langgraph/langgraph/graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from langgraph.pregel.read import ChannelRead, PregelNode
from langgraph.pregel.write import SKIP_WRITE, ChannelWrite, ChannelWriteEntry
from langgraph.store.base import BaseStore
from langgraph.types import _DC_KWARGS, All, Checkpointer, Command, N, RetryPolicy
from langgraph.types import _DC_KWARGS, All, Checkpointer, Command, N, RetryPolicy, CachePolicy
from langgraph.utils.fields import get_field_default
from langgraph.utils.pydantic import create_model
from langgraph.utils.runnable import RunnableCallable, coerce_to_runnable
Expand Down Expand Up @@ -100,6 +100,7 @@ class StateNodeSpec(NamedTuple):
metadata: Optional[dict[str, Any]]
input: Type[Any]
retry_policy: Optional[RetryPolicy]
cache_policy: Optional[CachePolicy]
ends: Optional[tuple[str, ...]] = EMPTY_SEQ


Expand Down Expand Up @@ -241,6 +242,7 @@ def add_node(
metadata: Optional[dict[str, Any]] = None,
input: Optional[Type[Any]] = None,
retry: Optional[RetryPolicy] = None,
cache: Optional[CachePolicy] = None,
) -> Self:
"""Adds a new node to the state graph.
Will take the name of the function/runnable as the node name.
Expand All @@ -265,6 +267,7 @@ def add_node(
metadata: Optional[dict[str, Any]] = None,
input: Optional[Type[Any]] = None,
retry: Optional[RetryPolicy] = None,
cache: Optional[CachePolicy] = None,
) -> Self:
"""Adds a new node to the state graph.

Expand All @@ -288,6 +291,7 @@ def add_node(
metadata: Optional[dict[str, Any]] = None,
input: Optional[Type[Any]] = None,
retry: Optional[RetryPolicy] = None,
cache: Optional[CachePolicy] = None,
) -> Self:
"""Adds a new node to the state graph.

Expand Down Expand Up @@ -402,6 +406,7 @@ def add_node(
metadata,
input=input or self.schema,
retry_policy=retry,
cache_policy=cache,
ends=ends,
)
return self
Expand Down Expand Up @@ -700,6 +705,7 @@ def _get_state_key(input: Union[None, dict, Any], *, key: str) -> Any:
],
metadata=node.metadata,
retry_policy=node.retry_policy,
cache_policy=node.cache_policy,
bound=node.runnable,
)
else:
Expand Down
134 changes: 107 additions & 27 deletions libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
overload,
)
from uuid import UUID
import time
import inspect

from langchain_core.callbacks.manager import AsyncParentRunManager, ParentRunManager
from langchain_core.runnables.config import RunnableConfig
Expand Down Expand Up @@ -475,19 +477,46 @@ def prepare_single_task(
f"Ignoring unknown node name {packet.node} in pending sends"
)
return

# create task id
# -- if task is enabled for caching, create task_id based on node identifier and cache key
triggers = [PUSH]
checkpoint_ns = (
f"{parent_ns}{NS_SEP}{packet.node}" if parent_ns else packet.node
)
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
packet.node,
PUSH,
str(idx),
)
proc = processes[packet.node]

if proc.cache_policy:
# cache policy fn can take state and config or just state
cache_key_fn = proc.cache_policy.cache_key
if len(inspect.signature(cache_key_fn).parameters) == 2:
cache_key = proc.cache_policy.cache_key(packet.arg, config)
else:
cache_key = proc.cache_policy.cache_key(packet.arg)

if ttl := proc.cache_policy.ttl:
ttl_str = str(time.time() // ttl)
task_id = _uuid5_str(
b"",
packet.node,
cache_key,
ttl_str
)
else:
task_id = _uuid5_str(
b"",
packet.node,
cache_key,
)
else:
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
packet.node,
PUSH,
str(idx),
)
elif len(task_path) == 4:
# new PUSH tasks, executed in superstep n
# (PUSH, parent task path, idx of PUSH write, id of parent task)
Expand All @@ -509,20 +538,46 @@ def prepare_single_task(
f"Ignoring unknown node name {packet.node} in pending writes"
)
return

# create task id
# -- if task is enabled for caching, create task_id based on node identifier and cache key
triggers = [PUSH]
checkpoint_ns = (
f"{parent_ns}{NS_SEP}{packet.node}" if parent_ns else packet.node
)
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
packet.node,
PUSH,
_tuple_str(task_path[1]),
str(task_path[2]),
)
proc = processes[packet.node]
if proc.cache_policy:
# cache policy fn can take state and config or just state
cache_key_fn = proc.cache_policy.cache_key
if len(inspect.signature(cache_key_fn).parameters) == 2:
cache_key = proc.cache_policy.cache_key(packet.arg, config)
else:
cache_key = proc.cache_policy.cache_key(packet.arg)

if ttl := proc.cache_policy.ttl:
ttl_str = str(time.time() // ttl)
task_id = _uuid5_str(
b"",
packet.node,
cache_key,
ttl_str
)
else:
task_id = _uuid5_str(
b"",
packet.node,
cache_key,
)
else:
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
packet.node,
PUSH,
_tuple_str(task_path[1]),
str(task_path[2]),
)
else:
logger.warning(f"Ignoring invalid PUSH task path {task_path}")
return
Expand Down Expand Up @@ -599,7 +654,7 @@ def prepare_single_task(
),
triggers,
proc.retry_policy,
None,
proc.cache_policy,
task_id,
task_path,
)
Expand Down Expand Up @@ -635,15 +690,40 @@ def prepare_single_task(
return

# create task id
# -- if task is enabled for caching, create task_id based on node identifier and cache key
checkpoint_ns = f"{parent_ns}{NS_SEP}{name}" if parent_ns else name
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
name,
PULL,
*triggers,
)

if proc.cache_policy:
# cache policy fn can take state and config or just state
cache_key_fn = proc.cache_policy.cache_key
if len(inspect.signature(cache_key_fn).parameters) == 2:
cache_key = proc.cache_policy.cache_key(val, config)
else:
cache_key = proc.cache_policy.cache_key(val)

if ttl := proc.cache_policy.ttl:
ttl_str = str(time.time() // ttl)
task_id = _uuid5_str(
b"",
name,
cache_key,
ttl_str
)
else:
task_id = _uuid5_str(
b"",
name,
cache_key,
)
else:
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
name,
PULL,
*triggers,
)
task_checkpoint_ns = f"{checkpoint_ns}{NS_END}{task_id}"
metadata = {
"langgraph_step": step,
Expand Down Expand Up @@ -717,7 +797,7 @@ def prepare_single_task(
),
triggers,
proc.retry_policy,
None,
proc.cache_policy,
task_id,
task_path,
)
Expand Down
12 changes: 12 additions & 0 deletions libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ def __init__(
)
self.prev_checkpoint_config = None

self.exists_cached_node = any([node.cache_policy for node in self.nodes.values()])

def put_writes(self, task_id: str, writes: Sequence[tuple[str, Any]]) -> None:
"""Put writes for a task, to be read by the next tick."""
if not writes:
Expand Down Expand Up @@ -466,6 +468,16 @@ def tick(
"input": None,
}
)

if self.exists_cached_node and self.checkpointer:
cached_tids = [tid for tid, task in self.tasks.items() if task.cache_policy]
cached_writes_map = self.checkpointer.get_writes(cached_tids)
for tid, cached_task_writes in cached_writes_map.items():
task = self.tasks[tid]
task.writes.extend(
[(channel, value) for _, channel, value in cached_task_writes]
)
self._output_writes(tid, task.writes, cached=True)

# if there are pending writes from a previous loop, apply them
if self.skip_done_tasks and self.checkpoint_pending_writes:
Expand Down
Loading