Skip to content

Commit

Permalink
fixed basic issues in async run_graph
Browse files Browse the repository at this point in the history
  • Loading branch information
evan-danswer committed Dec 26, 2024
1 parent 20a1e14 commit 699ba81
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions backend/onyx/agent_search/run_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from collections.abc import Iterable

from langchain_core.runnables.schema import StreamEvent
from langgraph.graph import StateGraph
from langgraph.graph.state import CompiledStateGraph

from onyx.agent_search.main.graph_builder import main_graph_builder
from onyx.agent_search.main.states import MainInput
from onyx.chat.answer import AnswerStream
from onyx.chat.models import AnswerQuestionPossibleReturn
from onyx.chat.models import OnyxAnswerPiece
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager
from onyx.llm.interfaces import LLM
Expand All @@ -28,17 +29,16 @@ def _parse_agent_event(
return ToolCallKickoff(**event["data"])
elif event_type == "tool_response":
return ToolResponse(**event["data"])
elif event_type == "answer_question_possible":
return AnswerQuestionPossibleReturn(**event["data"])
elif event_type == "on_chat_model_stream":
return OnyxAnswerPiece(answer_piece=event["data"]["chunk"].content)
return None


def _manage_async_event_streaming(
graph: StateGraph,
compiled_graph: CompiledStateGraph,
graph_input: MainInput,
) -> Iterable[StreamEvent]:
async def _run_async_event_stream() -> AsyncIterable[StreamEvent]:
compiled_graph = graph.compile()
async for event in compiled_graph.astream_events(
input=graph_input,
# indicating v2 here deserves further scrutiny
Expand Down Expand Up @@ -70,22 +70,20 @@ def _yield_async_to_sync() -> Iterable[StreamEvent]:


def run_graph(
compiled_graph: CompiledStateGraph,
search_request: SearchRequest,
primary_llm: LLM,
fast_llm: LLM,
) -> AnswerStream:
graph = main_graph_builder()

with get_session_context_manager() as db_session:
input = MainInput(
search_request=search_request,
primary_llm=primary_llm,
fast_llm=fast_llm,
db_session=db_session,
)
compiled_graph = graph.compile()
for event in _manage_async_event_streaming(
graph=compiled_graph, graph_input=input
compiled_graph=compiled_graph, graph_input=input
):
if parsed_object := _parse_agent_event(event):
yield parsed_object
Expand All @@ -101,5 +99,5 @@ def run_graph(
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
for output in run_graph(search_request, primary_llm, fast_llm):
for output in run_graph(compiled_graph, search_request, primary_llm, fast_llm):
print(output)

0 comments on commit 699ba81

Please sign in to comment.