From 20a1e144f6a444a06478a27358f3145466a7be6c Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Mon, 23 Dec 2024 10:40:24 -0800 Subject: [PATCH] async streaming --- backend/onyx/agent_search/run_graph.py | 74 +++++++++++++--- main.py | 113 ------------------------- 2 files changed, 61 insertions(+), 126 deletions(-) delete mode 100644 main.py diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index 8bf9a5d8a90..20373baa81f 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -1,4 +1,9 @@ -from typing import Any +import asyncio +from collections.abc import AsyncIterable +from collections.abc import Iterable + +from langchain_core.runnables.schema import StreamEvent +from langgraph.graph import StateGraph from onyx.agent_search.main.graph_builder import main_graph_builder from onyx.agent_search.main.states import MainInput @@ -11,12 +16,57 @@ from onyx.tools.tool_runner import ToolCallKickoff -def _parse_agent_output( - output: dict[str, Any] | Any -) -> AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse: - if isinstance(output, dict): - return output - return output.model_dump() +def _parse_agent_event( + event: StreamEvent, +) -> AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse | None: + """ + Parse the event into a typed object. + Return None if we are not interested in the event. + """ + event_type = event["event"] + if event_type == "tool_call_kickoff": + return ToolCallKickoff(**event["data"]) + elif event_type == "tool_response": + return ToolResponse(**event["data"]) + elif event_type == "answer_question_possible": + return AnswerQuestionPossibleReturn(**event["data"]) + return None + + +def _manage_async_event_streaming( + graph: StateGraph, + graph_input: MainInput, +) -> Iterable[StreamEvent]: + async def _run_async_event_stream() -> AsyncIterable[StreamEvent]: + compiled_graph = graph.compile() + async for event in compiled_graph.astream_events( + input=graph_input, + # indicating v2 here deserves further scrutiny + version="v2", + ): + yield event + + # This might be able to be simplified + def _yield_async_to_sync() -> Iterable[StreamEvent]: + loop = asyncio.new_event_loop() + try: + # Get the async generator + async_gen = _run_async_event_stream() + # Convert to AsyncIterator + async_iter = async_gen.__aiter__() + while True: + try: + # Create a coroutine by calling anext with the async iterator + next_coro = anext(async_iter) + # Run the coroutine to get the next event + event = loop.run_until_complete(next_coro) + yield event + except StopAsyncIteration: + break + finally: + loop.close() + + return _yield_async_to_sync() def run_graph( @@ -34,13 +84,11 @@ def run_graph( db_session=db_session, ) compiled_graph = graph.compile() - for output in compiled_graph.stream( - input=input, - stream_mode="values", - subgraphs=True, + for event in _manage_async_event_streaming( + graph=compiled_graph, graph_input=input ): - parsed_object = _parse_agent_output(output) - yield parsed_object + if parsed_object := _parse_agent_event(event): + yield parsed_object if __name__ == "__main__": diff --git a/main.py b/main.py deleted file mode 100644 index fc7c0388b74..00000000000 --- a/main.py +++ /dev/null @@ -1,113 +0,0 @@ -from operator import add -from typing import Annotated -from typing import TypedDict - -from langgraph.graph import END -from langgraph.graph import START -from langgraph.graph import StateGraph -from langgraph.types import Send - - -class MainState(TypedDict): - foo: Annotated[str, add] - bar: str - - -class SubState(TypedDict): - foo: str - bar: str - - -class SubStateInput(TypedDict): - foo: str - num: int - - -class SubStateOutput(TypedDict): - foo: str - - -def node_1(state: MainState): - print(f"node_1: {state}") - return { - "foo": " name", - "bar": "bar", - } - - -def node_2(state: SubStateInput): - print(f"node_2: {state}") - return SubState( - foo=" more foo" + str(state["num"]), - bar="barty hard" + str(state["num"]), - ) - - -def node_3(state: SubState): - print(f"node_3: {state}") - return SubStateOutput( - foo=state["foo"] + " more foo", - ) - - -def node_4(state: SubStateOutput): - print(f"node_4: {state}") - return MainState( - foo="", - ) - return MainState( - foo=state["foo"], - ) - - -def send_to_sub_graph(state: MainState): - return [ - Send( - "sub_graph", - SubStateInput( - foo=state["foo"], - num=num, - ), - ) - for num in range(3) - ] - - -def build_sub_graph(): - sub_graph = StateGraph( - state_schema=SubState, - input=SubStateInput, - output=SubStateOutput, - ) - sub_graph.add_node(node="node_2", action=node_2) - sub_graph.add_node(node="node_3", action=node_3) - sub_graph.add_edge(start_key=START, end_key="node_2") - sub_graph.add_edge(start_key="node_2", end_key="node_3") - sub_graph.add_edge(start_key="node_3", end_key=END) - return sub_graph - - -def build_main_graph(): - graph = StateGraph( - state_schema=MainState, - ) - graph.add_node(node="node_1", action=node_1) - - sub_graph = build_sub_graph().compile() - graph.add_node(node="sub_graph", action=sub_graph) - graph.add_node(node="node_4", action=node_4) - graph.add_edge(start_key=START, end_key="node_1") - # graph.add_edge(start_key="node_1", end_key="sub_graph") - graph.add_conditional_edges(source="node_1", path=send_to_sub_graph) - graph.add_edge(start_key="sub_graph", end_key="node_4") - graph.add_edge(start_key="node_4", end_key=END) - return graph - - -graph = build_main_graph().compile() -output = graph.invoke( - { - "foo": "", - }, -) -print(output)