From 699ba81531eaa7ddb062f0c43fc248643e479b9a Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Thu, 26 Dec 2024 11:29:07 -0800 Subject: [PATCH] fixed basic issues in async run_graph --- backend/onyx/agent_search/run_graph.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index 20373baa81f..eed5401c51a 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -3,12 +3,13 @@ from collections.abc import Iterable from langchain_core.runnables.schema import StreamEvent -from langgraph.graph import StateGraph +from langgraph.graph.state import CompiledStateGraph from onyx.agent_search.main.graph_builder import main_graph_builder from onyx.agent_search.main.states import MainInput from onyx.chat.answer import AnswerStream from onyx.chat.models import AnswerQuestionPossibleReturn +from onyx.chat.models import OnyxAnswerPiece from onyx.context.search.models import SearchRequest from onyx.db.engine import get_session_context_manager from onyx.llm.interfaces import LLM @@ -28,17 +29,16 @@ def _parse_agent_event( return ToolCallKickoff(**event["data"]) elif event_type == "tool_response": return ToolResponse(**event["data"]) - elif event_type == "answer_question_possible": - return AnswerQuestionPossibleReturn(**event["data"]) + elif event_type == "on_chat_model_stream": + return OnyxAnswerPiece(answer_piece=event["data"]["chunk"].content) return None def _manage_async_event_streaming( - graph: StateGraph, + compiled_graph: CompiledStateGraph, graph_input: MainInput, ) -> Iterable[StreamEvent]: async def _run_async_event_stream() -> AsyncIterable[StreamEvent]: - compiled_graph = graph.compile() async for event in compiled_graph.astream_events( input=graph_input, # indicating v2 here deserves further scrutiny @@ -70,12 +70,11 @@ def _yield_async_to_sync() -> Iterable[StreamEvent]: def run_graph( + compiled_graph: CompiledStateGraph, search_request: SearchRequest, primary_llm: LLM, fast_llm: LLM, ) -> AnswerStream: - graph = main_graph_builder() - with get_session_context_manager() as db_session: input = MainInput( search_request=search_request, @@ -83,9 +82,8 @@ def run_graph( fast_llm=fast_llm, db_session=db_session, ) - compiled_graph = graph.compile() for event in _manage_async_event_streaming( - graph=compiled_graph, graph_input=input + compiled_graph=compiled_graph, graph_input=input ): if parsed_object := _parse_agent_event(event): yield parsed_object @@ -101,5 +99,5 @@ def run_graph( search_request = SearchRequest( query="what can you do with onyx or danswer?", ) - for output in run_graph(search_request, primary_llm, fast_llm): + for output in run_graph(compiled_graph, search_request, primary_llm, fast_llm): print(output)