Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial agent search implementation #3486

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions backend/onyx/agent_search/answer_question/edges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from collections.abc import Hashable

from langgraph.types import Send

from onyx.agent_search.answer_question.states import AnswerQuestionInput
from onyx.agent_search.core_state import extract_core_fields
from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput


def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable:
return Send(
"decomped_expanded_retrieval",
ExpandedRetrievalInput(
**extract_core_fields(state),
question=state["question"],
),
)
106 changes: 106 additions & 0 deletions backend/onyx/agent_search/answer_question/graph_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph

from onyx.agent_search.answer_question.edges import send_to_expanded_retrieval
from onyx.agent_search.answer_question.nodes.answer_check import answer_check
from onyx.agent_search.answer_question.nodes.answer_generation import answer_generation
from onyx.agent_search.answer_question.nodes.format_answer import format_answer
from onyx.agent_search.answer_question.nodes.ingest_retrieval import ingest_retrieval
from onyx.agent_search.answer_question.states import AnswerQuestionInput
from onyx.agent_search.answer_question.states import AnswerQuestionOutput
from onyx.agent_search.answer_question.states import AnswerQuestionState
from onyx.agent_search.expanded_retrieval.graph_builder import (
expanded_retrieval_graph_builder,
)


def answer_query_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=AnswerQuestionState,
input=AnswerQuestionInput,
output=AnswerQuestionOutput,
)

### Add nodes ###

expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="decomped_expanded_retrieval",
action=expanded_retrieval,
)
graph.add_node(
node="answer_check",
action=answer_check,
)
graph.add_node(
node="answer_generation",
action=answer_generation,
)
graph.add_node(
node="format_answer",
action=format_answer,
)
graph.add_node(
node="ingest_retrieval",
action=ingest_retrieval,
)

### Add edges ###

graph.add_conditional_edges(
source=START,
path=send_to_expanded_retrieval,
path_map=["decomped_expanded_retrieval"],
)
graph.add_edge(
start_key="decomped_expanded_retrieval",
end_key="ingest_retrieval",
)
graph.add_edge(
start_key="ingest_retrieval",
end_key="answer_generation",
)
graph.add_edge(
start_key="answer_generation",
end_key="answer_check",
)
graph.add_edge(
start_key="answer_check",
end_key="format_answer",
)
graph.add_edge(
start_key="format_answer",
end_key=END,
)

return graph


if __name__ == "__main__":
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest

graph = answer_query_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_context_manager() as db_session:
inputs = AnswerQuestionInput(
search_request=search_request,
primary_llm=primary_llm,
fast_llm=fast_llm,
db_session=db_session,
question="what can you do with onyx?",
)
for thing in compiled_graph.stream(
input=inputs,
# debug=True,
# subgraphs=True,
):
print(thing)
# output = compiled_graph.invoke(inputs)
# print(output)
30 changes: 30 additions & 0 deletions backend/onyx/agent_search/answer_question/nodes/answer_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs

from onyx.agent_search.answer_question.states import AnswerQuestionState
from onyx.agent_search.answer_question.states import QACheckUpdate
from onyx.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT


def answer_check(state: AnswerQuestionState) -> QACheckUpdate:
msg = [
HumanMessage(
content=SUB_CHECK_PROMPT.format(
question=state["question"],
base_answer=state["answer"],
)
)
]

fast_llm = state["fast_llm"]
response = list(
fast_llm.stream(
prompt=msg,
)
)

quality_str = merge_message_runs(response, chunk_separator="")[0].content

return QACheckUpdate(
answer_quality=quality_str,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs

from onyx.agent_search.answer_question.states import AnswerQuestionState
from onyx.agent_search.answer_question.states import QAGenerationUpdate
from onyx.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT
from onyx.agent_search.shared_graph_utils.utils import format_docs


def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate:
question = state["question"]
docs = state["documents"]

print(f"Number of verified retrieval docs: {len(docs)}")

msg = [
HumanMessage(
content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs))
)
]

fast_llm = state["fast_llm"]
response = list(
fast_llm.stream(
prompt=msg,
)
)

answer_str = merge_message_runs(response, chunk_separator="")[0].content
return QAGenerationUpdate(
answer=answer_str,
)
17 changes: 17 additions & 0 deletions backend/onyx/agent_search/answer_question/nodes/format_answer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from onyx.agent_search.answer_question.states import AnswerQuestionOutput
from onyx.agent_search.answer_question.states import AnswerQuestionState
from onyx.agent_search.answer_question.states import QuestionAnswerResults


def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput:
return AnswerQuestionOutput(
answer_results=[
QuestionAnswerResults(
question=state["question"],
quality=state["answer_quality"],
answer=state["answer"],
expanded_retrieval_results=state["expanded_retrieval_results"],
documents=state["documents"],
)
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from onyx.agent_search.answer_question.states import RetrievalIngestionUpdate
from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput


def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate:
return RetrievalIngestionUpdate(
expanded_retrieval_results=state[
"expanded_retrieval_result"
].expanded_queries_results,
documents=state["expanded_retrieval_result"].all_documents,
)
71 changes: 71 additions & 0 deletions backend/onyx/agent_search/answer_question/states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from operator import add
from typing import Annotated
from typing import TypedDict

from pydantic import BaseModel

from onyx.agent_search.core_state import CoreState
from onyx.agent_search.expanded_retrieval.states import QueryResult
from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections
from onyx.context.search.models import InferenceSection


### Models ###


class QuestionAnswerResults(BaseModel):
question: str
answer: str
quality: str
expanded_retrieval_results: list[QueryResult]
documents: list[InferenceSection]


### States ###

## Update States


class QACheckUpdate(TypedDict):
answer_quality: str


class QAGenerationUpdate(TypedDict):
answer: str


class RetrievalIngestionUpdate(TypedDict):
expanded_retrieval_results: list[QueryResult]
documents: Annotated[list[InferenceSection], dedup_inference_sections]


## Graph Input State


class AnswerQuestionInput(CoreState):
question: str


## Graph State


class AnswerQuestionState(
AnswerQuestionInput,
QAGenerationUpdate,
QACheckUpdate,
RetrievalIngestionUpdate,
):
pass


## Graph Output State


class AnswerQuestionOutput(TypedDict):
"""
This is a list of results even though each call of this subgraph only returns one result.
This is because if we parallelize the answer query subgraph, there will be multiple
results in a list so the add operator is used to add them together.
"""

answer_results: Annotated[list[QuestionAnswerResults], add]
29 changes: 29 additions & 0 deletions backend/onyx/agent_search/core_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import TypedDict
from typing import TypeVar

from sqlalchemy.orm import Session

from onyx.context.search.models import SearchRequest
from onyx.llm.interfaces import LLM


class CoreState(TypedDict, total=False):
"""
This is the core state that is shared across all subgraphs.
"""

search_request: SearchRequest
primary_llm: LLM
fast_llm: LLM
# a single session for the entire agent search
# is fine if we are only reading
db_session: Session


# This ensures that the state passed in extends the CoreState
T = TypeVar("T", bound=CoreState)


def extract_core_fields(state: T) -> CoreState:
filtered_dict = {k: v for k, v in state.items() if k in CoreState.__annotations__}
return CoreState(**dict(filtered_dict)) # type: ignore
Empty file.
Empty file.
Loading