-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add node write cache support and implement for Postgres checkpointer #2786
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks really good!
@@ -180,6 +180,69 @@ def list( | |||
self._load_writes(value["pending_writes"]), | |||
) | |||
|
|||
def get_writes(self, task_id: str, ttl: int = None) -> List[Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should accept a list of task ids ideally, as when we do these queries we might have multiple tasks to get cached writes for (and will be more efficient to run a single query for all)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
proc = processes[packet.node] | ||
|
||
if proc.cache_policy: | ||
cache_key = proc.cache_policy.cache_key(packet.arg, config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
small nit: might be nice to let users also write cache key functions that accept only the first arg (in which case here we'd conditionally call with 1 or 2 args depending on signature
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
if self.exists_cached_node and self.checkpointer: | ||
for tid, task in self.tasks.items(): | ||
if task.cache_policy: | ||
cached_writes = self.checkpointer.get_writes(tid, task.cache_policy.ttl) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will actually need to be slightly more complicated (possibly out of scope for this pr) in reality as when using AsyncPregelLoop this call cannot block
libs/langgraph/tests/test_cache.py
Outdated
"stop_condition": input["stop_condition"] + 1, | ||
"dependent_field_1": (input["dependent_field_1"] + 1) % 2, | ||
"dependent_field_2": (input["dependent_field_2"] + 1) % 4, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit formatting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
… used as filter in postgres query. formatting. also get_writes now finds writes for list of task_ids.
Hey @mhk197 would this new feature mean that I can decide which intermidiete super-steps would be persisted and which would be cached? |
Summary:
Added caching support for individual nodes, and implemented in Postgres. Added tests for caching.
Details:
cache
parameter toadd_node
method inlanggraph/graph/state.py
. Takes aCachePolicy
object.CachePolicy
type inlanggraph/types.py
with two parameters: requiredcache_key
and optionalttl
(in seconds).cache_key
is a function that generates a hash based on (1) the current state of the graph during runtime and (2) the config.task_id
based on thecache_key
function for nodes that have a definedCachePolicy
, inlanggraph/pregel/algo.py
.cache_key
will generate a hash based on the current input and config.CachePolicy
inlanggraph/pregel/loop.py
.get_writes(task_id, ttl)
method to base checkpointer incheckpoint/langgraph/checkpoint/base/__init__.py
.get_writes()
method for Postgres checkpointer incheckpoint-postgres/langgraph/checkpoint/postgres/__init__.py
.langgraph/tests/test_cache.py
.