diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index dc09000435c..536a79e1e8c 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -93,6 +93,35 @@ def main_graph_builder() -> StateGraph: return graph +# from langgraph_sdk import get_client + +# client = get_client(url="http://localhost:8000") +# # Using the graph deployed with the name "agent" +# assistant_id = "agent" +# # create thread +# thread = await client.threads.create() +# print(thread) +# # create input +# input = { +# "messages": [ +# { +# "role": "user", +# "content": "What's the weather in SF?", +# } +# ] +# } + +# # stream events +# async for chunk in client.runs.stream( +# thread_id=thread["thread_id"], +# assistant_id=assistant_id, +# input=input, +# stream_mode="events", +# ): +# print(f"Receiving new event of type: {chunk.event}...") +# print(chunk.data) +# print("\n\n") + if __name__ == "__main__": from onyx.db.engine import get_session_context_manager from onyx.llm.factory import get_default_llms @@ -113,9 +142,9 @@ def main_graph_builder() -> StateGraph: ) for thing in compiled_graph.stream( input=inputs, - # stream_mode="debug", - # debug=True, - subgraphs=True, + stream_mode="messages", + # subgraphs=True, ): - # print(thing) - print() + print(thing) + # print() + print("done") diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index 9a93dbba646..eed5401c51a 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -1,27 +1,103 @@ +import asyncio +from collections.abc import AsyncIterable +from collections.abc import Iterable + +from langchain_core.runnables.schema import StreamEvent +from langgraph.graph.state import CompiledStateGraph + from onyx.agent_search.main.graph_builder import main_graph_builder +from onyx.agent_search.main.states import MainInput from onyx.chat.answer import AnswerStream +from onyx.chat.models import AnswerQuestionPossibleReturn +from onyx.chat.models import OnyxAnswerPiece +from onyx.context.search.models import SearchRequest +from onyx.db.engine import get_session_context_manager from onyx.llm.interfaces import LLM -from onyx.tools.tool import Tool +from onyx.tools.models import ToolResponse +from onyx.tools.tool_runner import ToolCallKickoff + + +def _parse_agent_event( + event: StreamEvent, +) -> AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse | None: + """ + Parse the event into a typed object. + Return None if we are not interested in the event. + """ + event_type = event["event"] + if event_type == "tool_call_kickoff": + return ToolCallKickoff(**event["data"]) + elif event_type == "tool_response": + return ToolResponse(**event["data"]) + elif event_type == "on_chat_model_stream": + return OnyxAnswerPiece(answer_piece=event["data"]["chunk"].content) + return None + + +def _manage_async_event_streaming( + compiled_graph: CompiledStateGraph, + graph_input: MainInput, +) -> Iterable[StreamEvent]: + async def _run_async_event_stream() -> AsyncIterable[StreamEvent]: + async for event in compiled_graph.astream_events( + input=graph_input, + # indicating v2 here deserves further scrutiny + version="v2", + ): + yield event + + # This might be able to be simplified + def _yield_async_to_sync() -> Iterable[StreamEvent]: + loop = asyncio.new_event_loop() + try: + # Get the async generator + async_gen = _run_async_event_stream() + # Convert to AsyncIterator + async_iter = async_gen.__aiter__() + while True: + try: + # Create a coroutine by calling anext with the async iterator + next_coro = anext(async_iter) + # Run the coroutine to get the next event + event = loop.run_until_complete(next_coro) + yield event + except StopAsyncIteration: + break + finally: + loop.close() + + return _yield_async_to_sync() def run_graph( - query: str, - llm: LLM, - tools: list[Tool], + compiled_graph: CompiledStateGraph, + search_request: SearchRequest, + primary_llm: LLM, + fast_llm: LLM, ) -> AnswerStream: - graph = main_graph_builder() - - inputs = { - "original_query": query, - "messages": [], - "tools": tools, - "llm": llm, - } - compiled_graph = graph.compile() - output = compiled_graph.invoke(input=inputs) - yield from output + with get_session_context_manager() as db_session: + input = MainInput( + search_request=search_request, + primary_llm=primary_llm, + fast_llm=fast_llm, + db_session=db_session, + ) + for event in _manage_async_event_streaming( + compiled_graph=compiled_graph, graph_input=input + ): + if parsed_object := _parse_agent_event(event): + yield parsed_object if __name__ == "__main__": - pass - # run_graph("What is the capital of France?", llm, []) + from onyx.llm.factory import get_default_llms + from onyx.context.search.models import SearchRequest + + graph = main_graph_builder() + compiled_graph = graph.compile() + primary_llm, fast_llm = get_default_llms() + search_request = SearchRequest( + query="what can you do with onyx or danswer?", + ) + for output in run_graph(compiled_graph, search_request, primary_llm, fast_llm): + print(output) diff --git a/backend/onyx/chat/models.py b/backend/onyx/chat/models.py index 44973446f5b..84633c3f49d 100644 --- a/backend/onyx/chat/models.py +++ b/backend/onyx/chat/models.py @@ -1,5 +1,3 @@ -from collections.abc import Callable -from collections.abc import Iterator from datetime import datetime from enum import Enum from typing import Any @@ -215,17 +213,11 @@ class PersonaOverrideConfig(BaseModel): ) -AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn] - - class LLMMetricsContainer(BaseModel): prompt_tokens: int response_tokens: int -StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn] - - class DocumentPruningConfig(BaseModel): max_chunks: int | None = None max_window_percentage: float | None = None