Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parse streaming objects from the agentic search #3521

Open
wants to merge 3 commits into
base: search_2_0_initial_test
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions backend/onyx/agent_search/main/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,35 @@ def main_graph_builder() -> StateGraph:
return graph


# from langgraph_sdk import get_client

# client = get_client(url="http://localhost:8000")
# # Using the graph deployed with the name "agent"
# assistant_id = "agent"
# # create thread
# thread = await client.threads.create()
# print(thread)
# # create input
# input = {
# "messages": [
# {
# "role": "user",
# "content": "What's the weather in SF?",
# }
# ]
# }

# # stream events
# async for chunk in client.runs.stream(
# thread_id=thread["thread_id"],
# assistant_id=assistant_id,
# input=input,
# stream_mode="events",
# ):
# print(f"Receiving new event of type: {chunk.event}...")
# print(chunk.data)
# print("\n\n")

if __name__ == "__main__":
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
Expand All @@ -113,9 +142,9 @@ def main_graph_builder() -> StateGraph:
)
for thing in compiled_graph.stream(
input=inputs,
# stream_mode="debug",
# debug=True,
subgraphs=True,
stream_mode="messages",
# subgraphs=True,
):
# print(thing)
print()
print(thing)
# print()
print("done")
110 changes: 93 additions & 17 deletions backend/onyx/agent_search/run_graph.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,103 @@
import asyncio
from collections.abc import AsyncIterable
from collections.abc import Iterable

from langchain_core.runnables.schema import StreamEvent
from langgraph.graph.state import CompiledStateGraph

from onyx.agent_search.main.graph_builder import main_graph_builder
from onyx.agent_search.main.states import MainInput
from onyx.chat.answer import AnswerStream
from onyx.chat.models import AnswerQuestionPossibleReturn
from onyx.chat.models import OnyxAnswerPiece
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager
from onyx.llm.interfaces import LLM
from onyx.tools.tool import Tool
from onyx.tools.models import ToolResponse
from onyx.tools.tool_runner import ToolCallKickoff


def _parse_agent_event(
event: StreamEvent,
) -> AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse | None:
"""
Parse the event into a typed object.
Return None if we are not interested in the event.
"""
event_type = event["event"]
if event_type == "tool_call_kickoff":
return ToolCallKickoff(**event["data"])
elif event_type == "tool_response":
return ToolResponse(**event["data"])
elif event_type == "on_chat_model_stream":
return OnyxAnswerPiece(answer_piece=event["data"]["chunk"].content)
return None


def _manage_async_event_streaming(
compiled_graph: CompiledStateGraph,
graph_input: MainInput,
) -> Iterable[StreamEvent]:
async def _run_async_event_stream() -> AsyncIterable[StreamEvent]:
async for event in compiled_graph.astream_events(
input=graph_input,
# indicating v2 here deserves further scrutiny
version="v2",
):
yield event

# This might be able to be simplified
def _yield_async_to_sync() -> Iterable[StreamEvent]:
loop = asyncio.new_event_loop()
try:
# Get the async generator
async_gen = _run_async_event_stream()
# Convert to AsyncIterator
async_iter = async_gen.__aiter__()
while True:
try:
# Create a coroutine by calling anext with the async iterator
next_coro = anext(async_iter)
# Run the coroutine to get the next event
event = loop.run_until_complete(next_coro)
yield event
except StopAsyncIteration:
break
finally:
loop.close()

return _yield_async_to_sync()


def run_graph(
query: str,
llm: LLM,
tools: list[Tool],
compiled_graph: CompiledStateGraph,
search_request: SearchRequest,
primary_llm: LLM,
fast_llm: LLM,
) -> AnswerStream:
graph = main_graph_builder()

inputs = {
"original_query": query,
"messages": [],
"tools": tools,
"llm": llm,
}
compiled_graph = graph.compile()
output = compiled_graph.invoke(input=inputs)
yield from output
with get_session_context_manager() as db_session:
input = MainInput(
search_request=search_request,
primary_llm=primary_llm,
fast_llm=fast_llm,
db_session=db_session,
)
for event in _manage_async_event_streaming(
compiled_graph=compiled_graph, graph_input=input
):
if parsed_object := _parse_agent_event(event):
yield parsed_object


if __name__ == "__main__":
pass
# run_graph("What is the capital of France?", llm, [])
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest

graph = main_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
for output in run_graph(compiled_graph, search_request, primary_llm, fast_llm):
print(output)
8 changes: 0 additions & 8 deletions backend/onyx/chat/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from enum import Enum
from typing import Any
Expand Down Expand Up @@ -215,17 +213,11 @@ class PersonaOverrideConfig(BaseModel):
)


AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn]


class LLMMetricsContainer(BaseModel):
prompt_tokens: int
response_tokens: int


StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]


class DocumentPruningConfig(BaseModel):
max_chunks: int | None = None
max_window_percentage: float | None = None
Expand Down