Skip to content

Commit

Permalink
async streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
hagen-danswer committed Dec 23, 2024
1 parent ce16a88 commit 20a1e14
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 126 deletions.
74 changes: 61 additions & 13 deletions backend/onyx/agent_search/run_graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from typing import Any
import asyncio
from collections.abc import AsyncIterable
from collections.abc import Iterable

from langchain_core.runnables.schema import StreamEvent
from langgraph.graph import StateGraph

from onyx.agent_search.main.graph_builder import main_graph_builder
from onyx.agent_search.main.states import MainInput
Expand All @@ -11,12 +16,57 @@
from onyx.tools.tool_runner import ToolCallKickoff


def _parse_agent_output(
output: dict[str, Any] | Any
) -> AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse:
if isinstance(output, dict):
return output
return output.model_dump()
def _parse_agent_event(
event: StreamEvent,
) -> AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse | None:
"""
Parse the event into a typed object.
Return None if we are not interested in the event.
"""
event_type = event["event"]
if event_type == "tool_call_kickoff":
return ToolCallKickoff(**event["data"])
elif event_type == "tool_response":
return ToolResponse(**event["data"])
elif event_type == "answer_question_possible":
return AnswerQuestionPossibleReturn(**event["data"])
return None


def _manage_async_event_streaming(
graph: StateGraph,
graph_input: MainInput,
) -> Iterable[StreamEvent]:
async def _run_async_event_stream() -> AsyncIterable[StreamEvent]:
compiled_graph = graph.compile()
async for event in compiled_graph.astream_events(
input=graph_input,
# indicating v2 here deserves further scrutiny
version="v2",
):
yield event

# This might be able to be simplified
def _yield_async_to_sync() -> Iterable[StreamEvent]:
loop = asyncio.new_event_loop()
try:
# Get the async generator
async_gen = _run_async_event_stream()
# Convert to AsyncIterator
async_iter = async_gen.__aiter__()
while True:
try:
# Create a coroutine by calling anext with the async iterator
next_coro = anext(async_iter)
# Run the coroutine to get the next event
event = loop.run_until_complete(next_coro)
yield event
except StopAsyncIteration:
break
finally:
loop.close()

return _yield_async_to_sync()


def run_graph(
Expand All @@ -34,13 +84,11 @@ def run_graph(
db_session=db_session,
)
compiled_graph = graph.compile()
for output in compiled_graph.stream(
input=input,
stream_mode="values",
subgraphs=True,
for event in _manage_async_event_streaming(
graph=compiled_graph, graph_input=input
):
parsed_object = _parse_agent_output(output)
yield parsed_object
if parsed_object := _parse_agent_event(event):
yield parsed_object


if __name__ == "__main__":
Expand Down
113 changes: 0 additions & 113 deletions main.py

This file was deleted.

0 comments on commit 20a1e14

Please sign in to comment.