From 6311b70cc6750dd43ea7aacc5bc676de385f5aab Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Mon, 16 Dec 2024 11:23:01 -0800 Subject: [PATCH 01/19] initial onyx changes --- .../answer_query/graph_builder.py | 100 ++++ .../answer_query/nodes/answer_check.py | 30 ++ .../answer_query/nodes/answer_generation.py | 32 ++ .../answer_query/nodes/format_answer.py | 16 + .../onyx/agent_search/answer_query/states.py | 45 ++ backend/onyx/agent_search/core_state.py | 15 + .../onyx/agent_search/deep_answer/edges.py | 0 .../agent_search/deep_answer/graph_builder.py | 0 .../deep_answer/nodes/answer_generation.py | 114 +++++ .../deep_answer/nodes/deep_decomp.py | 78 ++++ .../nodes/entity_term_extraction.py | 40 ++ .../nodes/sub_qa_level_aggregator.py | 30 ++ .../deep_answer/nodes/sub_qa_manager.py | 19 + .../onyx/agent_search/deep_answer/states.py | 0 .../agent_search/expanded_retrieval/edges.py | 44 ++ .../expanded_retrieval/graph_builder.py | 88 ++++ .../expanded_retrieval/nodes/doc_reranking.py | 11 + .../expanded_retrieval/nodes/doc_retrieval.py | 47 ++ .../nodes/doc_verification.py | 60 +++ .../nodes/verification_kickoff.py | 27 ++ .../expanded_retrieval/prompts.py | 0 .../agent_search/expanded_retrieval/states.py | 36 ++ backend/onyx/agent_search/main/edges.py | 61 +++ .../onyx/agent_search/main/graph_builder.py | 98 ++++ .../agent_search/main/nodes/base_decomp.py | 31 ++ .../main/nodes/generate_initial_answer.py | 53 +++ backend/onyx/agent_search/main/states.py | 37 ++ backend/onyx/agent_search/run_graph.py | 27 ++ .../agent_search/shared_graph_utils/models.py | 12 + .../shared_graph_utils/operators.py | 9 + .../shared_graph_utils/prompts.py | 427 ++++++++++++++++++ .../agent_search/shared_graph_utils/utils.py | 101 +++++ backend/requirements/default.txt | 13 +- 33 files changed, 1697 insertions(+), 4 deletions(-) create mode 100644 backend/onyx/agent_search/answer_query/graph_builder.py create mode 100644 backend/onyx/agent_search/answer_query/nodes/answer_check.py create mode 100644 backend/onyx/agent_search/answer_query/nodes/answer_generation.py create mode 100644 backend/onyx/agent_search/answer_query/nodes/format_answer.py create mode 100644 backend/onyx/agent_search/answer_query/states.py create mode 100644 backend/onyx/agent_search/core_state.py create mode 100644 backend/onyx/agent_search/deep_answer/edges.py create mode 100644 backend/onyx/agent_search/deep_answer/graph_builder.py create mode 100644 backend/onyx/agent_search/deep_answer/nodes/answer_generation.py create mode 100644 backend/onyx/agent_search/deep_answer/nodes/deep_decomp.py create mode 100644 backend/onyx/agent_search/deep_answer/nodes/entity_term_extraction.py create mode 100644 backend/onyx/agent_search/deep_answer/nodes/sub_qa_level_aggregator.py create mode 100644 backend/onyx/agent_search/deep_answer/nodes/sub_qa_manager.py create mode 100644 backend/onyx/agent_search/deep_answer/states.py create mode 100644 backend/onyx/agent_search/expanded_retrieval/edges.py create mode 100644 backend/onyx/agent_search/expanded_retrieval/graph_builder.py create mode 100644 backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py create mode 100644 backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py create mode 100644 backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py create mode 100644 backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py create mode 100644 backend/onyx/agent_search/expanded_retrieval/prompts.py create mode 100644 backend/onyx/agent_search/expanded_retrieval/states.py create mode 100644 backend/onyx/agent_search/main/edges.py create mode 100644 backend/onyx/agent_search/main/graph_builder.py create mode 100644 backend/onyx/agent_search/main/nodes/base_decomp.py create mode 100644 backend/onyx/agent_search/main/nodes/generate_initial_answer.py create mode 100644 backend/onyx/agent_search/main/states.py create mode 100644 backend/onyx/agent_search/run_graph.py create mode 100644 backend/onyx/agent_search/shared_graph_utils/models.py create mode 100644 backend/onyx/agent_search/shared_graph_utils/operators.py create mode 100644 backend/onyx/agent_search/shared_graph_utils/prompts.py create mode 100644 backend/onyx/agent_search/shared_graph_utils/utils.py diff --git a/backend/onyx/agent_search/answer_query/graph_builder.py b/backend/onyx/agent_search/answer_query/graph_builder.py new file mode 100644 index 00000000000..e52bfe28d69 --- /dev/null +++ b/backend/onyx/agent_search/answer_query/graph_builder.py @@ -0,0 +1,100 @@ +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph + +from onyx.agent_search.answer_query.nodes.answer_check import answer_check +from onyx.agent_search.answer_query.nodes.answer_generation import answer_generation +from onyx.agent_search.answer_query.nodes.format_answer import format_answer +from onyx.agent_search.answer_query.states import AnswerQueryInput +from onyx.agent_search.answer_query.states import AnswerQueryOutput +from onyx.agent_search.answer_query.states import AnswerQueryState +from onyx.agent_search.expanded_retrieval.graph_builder import ( + expanded_retrieval_graph_builder, +) + + +def answer_query_graph_builder() -> StateGraph: + graph = StateGraph( + state_schema=AnswerQueryState, + input=AnswerQueryInput, + output=AnswerQueryOutput, + ) + + ### Add nodes ### + + expanded_retrieval = expanded_retrieval_graph_builder().compile() + graph.add_node( + node="expanded_retrieval_for_initial_decomp", + action=expanded_retrieval, + ) + graph.add_node( + node="answer_check", + action=answer_check, + ) + graph.add_node( + node="answer_generation", + action=answer_generation, + ) + graph.add_node( + node="format_answer", + action=format_answer, + ) + + ### Add edges ### + + graph.add_edge( + start_key=START, + end_key="expanded_retrieval_for_initial_decomp", + ) + graph.add_edge( + start_key="expanded_retrieval_for_initial_decomp", + end_key="answer_generation", + ) + graph.add_edge( + start_key="answer_generation", + end_key="answer_check", + ) + graph.add_edge( + start_key="answer_check", + end_key="format_answer", + ) + graph.add_edge( + start_key="format_answer", + end_key=END, + ) + + return graph + + +if __name__ == "__main__": + from onyx.db.engine import get_session_context_manager + from onyx.llm.factory import get_default_llms + from onyx.context.search.models import SearchRequest + + graph = answer_query_graph_builder() + compiled_graph = graph.compile() + primary_llm, fast_llm = get_default_llms() + search_request = SearchRequest( + query="Who made Excel and what other products did they make?", + ) + with get_session_context_manager() as db_session: + inputs = AnswerQueryInput( + search_request=search_request, + primary_llm=primary_llm, + fast_llm=fast_llm, + db_session=db_session, + query_to_answer="Who made Excel?", + ) + output = compiled_graph.invoke( + input=inputs, + # debug=True, + # subgraphs=True, + ) + print(output) + # for namespace, chunk in compiled_graph.stream( + # input=inputs, + # # debug=True, + # subgraphs=True, + # ): + # print(namespace) + # print(chunk) diff --git a/backend/onyx/agent_search/answer_query/nodes/answer_check.py b/backend/onyx/agent_search/answer_query/nodes/answer_check.py new file mode 100644 index 00000000000..8b58129c47b --- /dev/null +++ b/backend/onyx/agent_search/answer_query/nodes/answer_check.py @@ -0,0 +1,30 @@ +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs + +from onyx.agent_search.answer_query.states import AnswerQueryState +from onyx.agent_search.answer_query.states import QACheckOutput +from onyx.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT + + +def answer_check(state: AnswerQueryState) -> QACheckOutput: + msg = [ + HumanMessage( + content=BASE_CHECK_PROMPT.format( + question=state["search_request"].query, + base_answer=state["answer"], + ) + ) + ] + + fast_llm = state["fast_llm"] + response = list( + fast_llm.stream( + prompt=msg, + ) + ) + + response_str = merge_message_runs(response, chunk_separator="")[0].content + + return QACheckOutput( + answer_quality=response_str, + ) diff --git a/backend/onyx/agent_search/answer_query/nodes/answer_generation.py b/backend/onyx/agent_search/answer_query/nodes/answer_generation.py new file mode 100644 index 00000000000..c23f77ee706 --- /dev/null +++ b/backend/onyx/agent_search/answer_query/nodes/answer_generation.py @@ -0,0 +1,32 @@ +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs + +from onyx.agent_search.answer_query.states import AnswerQueryState +from onyx.agent_search.answer_query.states import QAGenerationOutput +from onyx.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT +from onyx.agent_search.shared_graph_utils.utils import format_docs + + +def answer_generation(state: AnswerQueryState) -> QAGenerationOutput: + query = state["query_to_answer"] + docs = state["reordered_documents"] + + print(f"Number of verified retrieval docs: {len(docs)}") + + msg = [ + HumanMessage( + content=BASE_RAG_PROMPT.format(question=query, context=format_docs(docs)) + ) + ] + + fast_llm = state["fast_llm"] + response = list( + fast_llm.stream( + prompt=msg, + ) + ) + + answer_str = merge_message_runs(response, chunk_separator="")[0].content + return QAGenerationOutput( + answer=answer_str, + ) diff --git a/backend/onyx/agent_search/answer_query/nodes/format_answer.py b/backend/onyx/agent_search/answer_query/nodes/format_answer.py new file mode 100644 index 00000000000..8359baec9b4 --- /dev/null +++ b/backend/onyx/agent_search/answer_query/nodes/format_answer.py @@ -0,0 +1,16 @@ +from onyx.agent_search.answer_query.states import AnswerQueryOutput +from onyx.agent_search.answer_query.states import AnswerQueryState +from onyx.agent_search.answer_query.states import SearchAnswerResults + + +def format_answer(state: AnswerQueryState) -> AnswerQueryOutput: + return AnswerQueryOutput( + decomp_answer_results=[ + SearchAnswerResults( + query=state["query_to_answer"], + quality=state["answer_quality"], + answer=state["answer"], + documents=state["reordered_documents"], + ) + ], + ) diff --git a/backend/onyx/agent_search/answer_query/states.py b/backend/onyx/agent_search/answer_query/states.py new file mode 100644 index 00000000000..9f8fe12ab61 --- /dev/null +++ b/backend/onyx/agent_search/answer_query/states.py @@ -0,0 +1,45 @@ +from typing import Annotated +from typing import TypedDict + +from pydantic import BaseModel + +from onyx.agent_search.core_state import PrimaryState +from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections +from onyx.context.search.models import InferenceSection + + +class SearchAnswerResults(BaseModel): + query: str + answer: str + quality: str + documents: Annotated[list[InferenceSection], dedup_inference_sections] + + +class QACheckOutput(TypedDict, total=False): + answer_quality: str + + +class QAGenerationOutput(TypedDict, total=False): + answer: str + + +class ExpandedRetrievalOutput(TypedDict): + reordered_documents: Annotated[list[InferenceSection], dedup_inference_sections] + + +class AnswerQueryState( + PrimaryState, + QACheckOutput, + QAGenerationOutput, + ExpandedRetrievalOutput, + total=True, +): + query_to_answer: str + + +class AnswerQueryInput(PrimaryState, total=True): + query_to_answer: str + + +class AnswerQueryOutput(TypedDict): + decomp_answer_results: list[SearchAnswerResults] diff --git a/backend/onyx/agent_search/core_state.py b/backend/onyx/agent_search/core_state.py new file mode 100644 index 00000000000..fcd8bddf3ec --- /dev/null +++ b/backend/onyx/agent_search/core_state.py @@ -0,0 +1,15 @@ +from typing import TypedDict + +from sqlalchemy.orm import Session + +from onyx.context.search.models import SearchRequest +from onyx.llm.interfaces import LLM + + +class PrimaryState(TypedDict, total=False): + search_request: SearchRequest + primary_llm: LLM + fast_llm: LLM + # a single session for the entire agent search + # is fine if we are only reading + db_session: Session diff --git a/backend/onyx/agent_search/deep_answer/edges.py b/backend/onyx/agent_search/deep_answer/edges.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/onyx/agent_search/deep_answer/graph_builder.py b/backend/onyx/agent_search/deep_answer/graph_builder.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/onyx/agent_search/deep_answer/nodes/answer_generation.py b/backend/onyx/agent_search/deep_answer/nodes/answer_generation.py new file mode 100644 index 00000000000..f0a94b398ad --- /dev/null +++ b/backend/onyx/agent_search/deep_answer/nodes/answer_generation.py @@ -0,0 +1,114 @@ +from typing import Any + +from langchain_core.messages import HumanMessage + +from onyx.agent_search.main.states import MainState +from onyx.agent_search.shared_graph_utils.prompts import COMBINED_CONTEXT +from onyx.agent_search.shared_graph_utils.prompts import MODIFIED_RAG_PROMPT +from onyx.agent_search.shared_graph_utils.utils import format_docs +from onyx.agent_search.shared_graph_utils.utils import normalize_whitespace + + +# aggregate sub questions and answers +def deep_answer_generation(state: MainState) -> dict[str, Any]: + """ + Generate answer + + Args: + state (messages): The current state + + Returns: + dict: The updated state with re-phrased question + """ + print("---DEEP GENERATE---") + + question = state["original_question"] + docs = state["deduped_retrieval_docs"] + + deep_answer_context = state["core_answer_dynamic_context"] + + print(f"Number of verified retrieval docs - deep: {len(docs)}") + + combined_context = normalize_whitespace( + COMBINED_CONTEXT.format( + deep_answer_context=deep_answer_context, formated_docs=format_docs(docs) + ) + ) + + msg = [ + HumanMessage( + content=MODIFIED_RAG_PROMPT.format( + question=question, combined_context=combined_context + ) + ) + ] + + # Grader + model = state["fast_llm"] + response = model.invoke(msg) + + return { + "deep_answer": response.content, + } + + +def final_stuff(state: MainState) -> dict[str, Any]: + """ + Invokes the agent model to generate a response based on the current state. Given + the question, it will decide to retrieve using the retriever tool, or simply end. + + Args: + state (messages): The current state + + Returns: + dict: The updated state with the agent response appended to messages + """ + print("---FINAL---") + + messages = state["log_messages"] + time_ordered_messages = [x.pretty_repr() for x in messages] + time_ordered_messages.sort() + + print("Message Log:") + print("\n".join(time_ordered_messages)) + + initial_sub_qas = state["initial_sub_qas"] + initial_sub_qa_list = [] + for initial_sub_qa in initial_sub_qas: + if initial_sub_qa["sub_answer_check"] == "yes": + initial_sub_qa_list.append( + f' Question:\n {initial_sub_qa["sub_question"]}\n --\n Answer:\n {initial_sub_qa["sub_answer"]}\n -----' + ) + + initial_sub_qa_context = "\n".join(initial_sub_qa_list) + + base_answer = state["base_answer"] + + print(f"Final Base Answer:\n{base_answer}") + print("--------------------------------") + print(f"Initial Answered Sub Questions:\n{initial_sub_qa_context}") + print("--------------------------------") + + if not state.get("deep_answer"): + print("No Deep Answer was required") + return {} + + deep_answer = state["deep_answer"] + sub_qas = state["sub_qas"] + sub_qa_list = [] + for sub_qa in sub_qas: + if sub_qa["sub_answer_check"] == "yes": + sub_qa_list.append( + f' Question:\n {sub_qa["sub_question"]}\n --\n Answer:\n {sub_qa["sub_answer"]}\n -----' + ) + + sub_qa_context = "\n".join(sub_qa_list) + + print(f"Final Base Answer:\n{base_answer}") + print("--------------------------------") + print(f"Final Deep Answer:\n{deep_answer}") + print("--------------------------------") + print("Sub Questions and Answers:") + print(sub_qa_context) + + return {} diff --git a/backend/onyx/agent_search/deep_answer/nodes/deep_decomp.py b/backend/onyx/agent_search/deep_answer/nodes/deep_decomp.py new file mode 100644 index 00000000000..786b2774fc6 --- /dev/null +++ b/backend/onyx/agent_search/deep_answer/nodes/deep_decomp.py @@ -0,0 +1,78 @@ +import json +import re +from datetime import datetime +from typing import Any + +from langchain_core.messages import HumanMessage + +from onyx.agent_search.main.states import MainState +from onyx.agent_search.shared_graph_utils.prompts import DEEP_DECOMPOSE_PROMPT +from onyx.agent_search.shared_graph_utils.utils import format_entity_term_extraction +from onyx.agent_search.shared_graph_utils.utils import generate_log_message + + +def decompose(state: MainState) -> dict[str, Any]: + """ """ + + node_start_time = datetime.now() + + question = state["original_question"] + base_answer = state["base_answer"] + + # get the entity term extraction dict and properly format it + entity_term_extraction_dict = state["retrieved_entities_relationships"][ + "retrieved_entities_relationships" + ] + + entity_term_extraction_str = format_entity_term_extraction( + entity_term_extraction_dict + ) + + initial_question_answers = state["initial_sub_qas"] + + addressed_question_list = [ + x["sub_question"] + for x in initial_question_answers + if x["sub_answer_check"] == "yes" + ] + failed_question_list = [ + x["sub_question"] + for x in initial_question_answers + if x["sub_answer_check"] == "no" + ] + + msg = [ + HumanMessage( + content=DEEP_DECOMPOSE_PROMPT.format( + question=question, + entity_term_extraction_str=entity_term_extraction_str, + base_answer=base_answer, + answered_sub_questions="\n - ".join(addressed_question_list), + failed_sub_questions="\n - ".join(failed_question_list), + ), + ) + ] + + # Grader + model = state["fast_llm"] + response = model.invoke(msg) + + cleaned_response = re.sub(r"```json\n|\n```", "", response.pretty_repr()) + parsed_response = json.loads(cleaned_response) + + sub_questions_dict = {} + for sub_question_nr, sub_question_dict in enumerate( + parsed_response["sub_questions"] + ): + sub_question_dict["answered"] = False + sub_question_dict["verified"] = False + sub_questions_dict[sub_question_nr] = sub_question_dict + + return { + "decomposed_sub_questions_dict": sub_questions_dict, + "log_messages": generate_log_message( + message="deep - decompose", + node_start_time=node_start_time, + graph_start_time=state["graph_start_time"], + ), + } diff --git a/backend/onyx/agent_search/deep_answer/nodes/entity_term_extraction.py b/backend/onyx/agent_search/deep_answer/nodes/entity_term_extraction.py new file mode 100644 index 00000000000..865a78f0a75 --- /dev/null +++ b/backend/onyx/agent_search/deep_answer/nodes/entity_term_extraction.py @@ -0,0 +1,40 @@ +import json +import re +from typing import Any + +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs + +from onyx.agent_search.main.states import MainState +from onyx.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT +from onyx.agent_search.shared_graph_utils.utils import format_docs + + +def entity_term_extraction(state: MainState) -> dict[str, Any]: + """Extract entities and terms from the question and context""" + + question = state["original_question"] + docs = state["deduped_retrieval_docs"] + + doc_context = format_docs(docs) + + msg = [ + HumanMessage( + content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context), + ) + ] + fast_llm = state["fast_llm"] + # Grader + llm_response_list = list( + fast_llm.stream( + prompt=msg, + ) + ) + llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content + + cleaned_response = re.sub(r"```json\n|\n```", "", llm_response) + parsed_response = json.loads(cleaned_response) + + return { + "retrieved_entities_relationships": parsed_response, + } diff --git a/backend/onyx/agent_search/deep_answer/nodes/sub_qa_level_aggregator.py b/backend/onyx/agent_search/deep_answer/nodes/sub_qa_level_aggregator.py new file mode 100644 index 00000000000..5805b3c6324 --- /dev/null +++ b/backend/onyx/agent_search/deep_answer/nodes/sub_qa_level_aggregator.py @@ -0,0 +1,30 @@ +from typing import Any + +from onyx.agent_search.main.states import MainState + + +# aggregate sub questions and answers +def sub_qa_level_aggregator(state: MainState) -> dict[str, Any]: + sub_qas = state["sub_qas"] + + dynamic_context_list = [ + "Below you will find useful information to answer the original question:" + ] + checked_sub_qas = [] + + for core_answer_sub_qa in sub_qas: + question = core_answer_sub_qa["sub_question"] + answer = core_answer_sub_qa["sub_answer"] + verified = core_answer_sub_qa["sub_answer_check"] + + if verified == "yes": + dynamic_context_list.append( + f"Question:\n{question}\n\nAnswer:\n{answer}\n\n---\n\n" + ) + checked_sub_qas.append({"sub_question": question, "sub_answer": answer}) + dynamic_context = "\n".join(dynamic_context_list) + + return { + "core_answer_dynamic_context": dynamic_context, + "checked_sub_qas": checked_sub_qas, + } diff --git a/backend/onyx/agent_search/deep_answer/nodes/sub_qa_manager.py b/backend/onyx/agent_search/deep_answer/nodes/sub_qa_manager.py new file mode 100644 index 00000000000..58b4262cdc5 --- /dev/null +++ b/backend/onyx/agent_search/deep_answer/nodes/sub_qa_manager.py @@ -0,0 +1,19 @@ +from typing import Any + +from onyx.agent_search.main.states import MainState + + +def sub_qa_manager(state: MainState) -> dict[str, Any]: + """ """ + + sub_questions_dict = state["decomposed_sub_questions_dict"] + + sub_questions = {} + + for sub_question_nr, sub_question_dict in sub_questions_dict.items(): + sub_questions[sub_question_nr] = sub_question_dict["sub_question"] + + return { + "sub_questions": sub_questions, + "num_new_question_iterations": 0, + } diff --git a/backend/onyx/agent_search/deep_answer/states.py b/backend/onyx/agent_search/deep_answer/states.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/onyx/agent_search/expanded_retrieval/edges.py b/backend/onyx/agent_search/expanded_retrieval/edges.py new file mode 100644 index 00000000000..2c63125bb9c --- /dev/null +++ b/backend/onyx/agent_search/expanded_retrieval/edges.py @@ -0,0 +1,44 @@ +from collections.abc import Hashable + +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs +from langgraph.types import Send + +from onyx.agent_search.expanded_retrieval.nodes.doc_retrieval import RetrieveInput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput +from onyx.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI +from onyx.llm.interfaces import LLM + + +def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashable]: + print(f"parallel_retrieval_edge state: {state.keys()}") + + # This should be better... + question = state.get("query_to_answer") or state["search_request"].query + llm: LLM = state["fast_llm"] + + msg = [ + HumanMessage( + content=REWRITE_PROMPT_MULTI.format(question=question), + ) + ] + llm_response_list = list( + llm.stream( + prompt=msg, + ) + ) + llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content + + print(f"llm_response: {llm_response}") + + rewritten_queries = llm_response.split("\n") + + print(f"rewritten_queries: {rewritten_queries}") + + return [ + Send( + "doc_retrieval", + RetrieveInput(query_to_retrieve=query, **state), + ) + for query in rewritten_queries + ] diff --git a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py new file mode 100644 index 00000000000..1928e93450c --- /dev/null +++ b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py @@ -0,0 +1,88 @@ +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph + +from onyx.agent_search.expanded_retrieval.edges import parallel_retrieval_edge +from onyx.agent_search.expanded_retrieval.nodes.doc_reranking import doc_reranking +from onyx.agent_search.expanded_retrieval.nodes.doc_retrieval import doc_retrieval +from onyx.agent_search.expanded_retrieval.nodes.doc_verification import ( + doc_verification, +) +from onyx.agent_search.expanded_retrieval.nodes.verification_kickoff import ( + verification_kickoff, +) +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState + + +def expanded_retrieval_graph_builder() -> StateGraph: + graph = StateGraph( + state_schema=ExpandedRetrievalState, + input=ExpandedRetrievalInput, + output=ExpandedRetrievalOutput, + ) + + ### Add nodes ### + + graph.add_node( + node="doc_retrieval", + action=doc_retrieval, + ) + graph.add_node( + node="verification_kickoff", + action=verification_kickoff, + ) + graph.add_node( + node="doc_verification", + action=doc_verification, + ) + graph.add_node( + node="doc_reranking", + action=doc_reranking, + ) + + ### Add edges ### + + graph.add_conditional_edges( + source=START, + path=parallel_retrieval_edge, + path_map=["doc_retrieval"], + ) + graph.add_edge( + start_key="doc_retrieval", + end_key="verification_kickoff", + ) + graph.add_edge( + start_key="doc_verification", + end_key="doc_reranking", + ) + graph.add_edge( + start_key="doc_reranking", + end_key=END, + ) + + return graph + + +if __name__ == "__main__": + from onyx.db.engine import get_session_context_manager + from onyx.llm.factory import get_default_llms + from onyx.context.search.models import SearchRequest + + graph = expanded_retrieval_graph_builder() + compiled_graph = graph.compile() + primary_llm, fast_llm = get_default_llms() + search_request = SearchRequest( + query="Who made Excel and what other products did they make?", + ) + with get_session_context_manager() as db_session: + inputs = ExpandedRetrievalInput( + search_request=search_request, + primary_llm=primary_llm, + fast_llm=fast_llm, + db_session=db_session, + query_to_answer="Who made Excel?", + ) + for thing in compiled_graph.stream(inputs, debug=True): + print(thing) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py new file mode 100644 index 00000000000..1ac36203518 --- /dev/null +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py @@ -0,0 +1,11 @@ +from onyx.agent_search.expanded_retrieval.states import DocRerankingOutput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState + + +def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingOutput: + print(f"doc_reranking state: {state.keys()}") + + verified_documents = state["verified_documents"] + reranked_documents = verified_documents + + return DocRerankingOutput(reranked_documents=reranked_documents) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py new file mode 100644 index 00000000000..8d612499483 --- /dev/null +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py @@ -0,0 +1,47 @@ +from onyx.agent_search.expanded_retrieval.states import DocRetrievalOutput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState +from onyx.context.search.models import InferenceSection +from onyx.context.search.models import SearchRequest +from onyx.context.search.pipeline import SearchPipeline +from onyx.db.engine import get_session_context_manager + + +class RetrieveInput(ExpandedRetrievalState): + query_to_retrieve: str + + +def doc_retrieval(state: RetrieveInput) -> DocRetrievalOutput: + # def doc_retrieval(state: RetrieveInput) -> Command[Literal["doc_verification"]]: + """ + Retrieve documents + + Args: + state (dict): The current graph state + + Returns: + state (dict): New key added to state, documents, that contains retrieved documents + """ + print(f"doc_retrieval state: {state.keys()}") + + state["query_to_retrieve"] + + documents: list[InferenceSection] = [] + llm = state["primary_llm"] + fast_llm = state["fast_llm"] + # db_session = state["db_session"] + query_to_retrieve = state["search_request"].query + with get_session_context_manager() as db_session1: + documents = SearchPipeline( + search_request=SearchRequest( + query=query_to_retrieve, + ), + user=None, + llm=llm, + fast_llm=fast_llm, + db_session=db_session1, + ).reranked_sections + + print(f"retrieved documents: {len(documents)}") + return DocRetrievalOutput( + retrieved_documents=documents, + ) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py new file mode 100644 index 00000000000..f3f993e87b7 --- /dev/null +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py @@ -0,0 +1,60 @@ +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs + +from onyx.agent_search.expanded_retrieval.states import DocVerificationOutput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState +from onyx.agent_search.shared_graph_utils.models import BinaryDecision +from onyx.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT +from onyx.context.search.models import InferenceSection + + +class DocVerificationInput(ExpandedRetrievalState, total=True): + doc_to_verify: InferenceSection + + +def doc_verification(state: DocVerificationInput) -> DocVerificationOutput: + """ + Check whether the document is relevant for the original user question + + Args: + state (VerifierState): The current state + + Returns: + dict: ict: The updated state with the final decision + """ + + print(f"doc_verification state: {state.keys()}") + + original_query = state["search_request"].query + doc_to_verify = state["doc_to_verify"] + document_content = doc_to_verify.combined_content + + msg = [ + HumanMessage( + content=VERIFIER_PROMPT.format( + question=original_query, document_content=document_content + ) + ) + ] + + fast_llm = state["fast_llm"] + response = list( + fast_llm.stream( + prompt=msg, + ) + ) + + response_string = merge_message_runs(response, chunk_separator="")[0].content + # Convert string response to proper dictionary format + decision_dict = {"decision": response_string.lower()} + formatted_response = BinaryDecision.model_validate(decision_dict) + + print(f"Verdict: {formatted_response.decision}") + + verified_documents = [] + if formatted_response.decision == "yes": + verified_documents.append(doc_to_verify) + + return DocVerificationOutput( + verified_documents=verified_documents, + ) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py b/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py new file mode 100644 index 00000000000..d40bf6f0dae --- /dev/null +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py @@ -0,0 +1,27 @@ +from typing import Literal + +from langgraph.types import Command +from langgraph.types import Send + +from onyx.agent_search.expanded_retrieval.nodes.doc_verification import ( + DocVerificationInput, +) +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState + + +def verification_kickoff( + state: ExpandedRetrievalState, +) -> Command[Literal["doc_verification"]]: + print(f"verification_kickoff state: {state.keys()}") + + documents = state["retrieved_documents"] + return Command( + update={}, + goto=[ + Send( + node="doc_verification", + arg=DocVerificationInput(doc_to_verify=doc, **state), + ) + for doc in documents + ], + ) diff --git a/backend/onyx/agent_search/expanded_retrieval/prompts.py b/backend/onyx/agent_search/expanded_retrieval/prompts.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py new file mode 100644 index 00000000000..a0f726b7f8b --- /dev/null +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -0,0 +1,36 @@ +from typing import Annotated +from typing import TypedDict + +from onyx.agent_search.core_state import PrimaryState +from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections +from onyx.context.search.models import InferenceSection + + +class DocRetrievalOutput(TypedDict, total=False): + retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections] + + +class DocVerificationOutput(TypedDict, total=False): + verified_documents: Annotated[list[InferenceSection], dedup_inference_sections] + + +class DocRerankingOutput(TypedDict, total=False): + reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] + + +class ExpandedRetrievalState( + PrimaryState, + DocRetrievalOutput, + DocVerificationOutput, + DocRerankingOutput, + total=True, +): + query_to_answer: str + + +class ExpandedRetrievalInput(PrimaryState, total=True): + query_to_answer: str + + +class ExpandedRetrievalOutput(TypedDict): + reordered_documents: Annotated[list[InferenceSection], dedup_inference_sections] diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py new file mode 100644 index 00000000000..953b0a96275 --- /dev/null +++ b/backend/onyx/agent_search/main/edges.py @@ -0,0 +1,61 @@ +from collections.abc import Hashable + +from langgraph.types import Send + +from onyx.agent_search.answer_query.states import AnswerQueryInput +from onyx.agent_search.main.states import MainState + + +def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hashable]: + return [ + Send( + "answer_query", + AnswerQueryInput( + **state, + query_to_answer=query, + ), + ) + for query in state["initial_decomp_queries"] + ] + + +# def continue_to_answer_sub_questions(state: QAState) -> Union[Hashable, list[Hashable]]: +# # Routes re-written queries to the (parallel) retrieval steps +# # Notice the 'Send()' API that takes care of the parallelization +# return [ +# Send( +# "sub_answers_graph", +# ResearchQAState( +# sub_question=sub_question["sub_question_str"], +# sub_question_nr=sub_question["sub_question_nr"], +# graph_start_time=state["graph_start_time"], +# primary_llm=state["primary_llm"], +# fast_llm=state["fast_llm"], +# ), +# ) +# for sub_question in state["sub_questions"] +# ] + + +# def continue_to_deep_answer(state: QAState) -> Union[Hashable, list[Hashable]]: +# print("---GO TO DEEP ANSWER OR END---") + +# base_answer = state["base_answer"] + +# question = state["original_question"] + +# BASE_CHECK_MESSAGE = [ +# HumanMessage( +# content=BASE_CHECK_PROMPT.format(question=question, base_answer=base_answer) +# ) +# ] + +# model = state["fast_llm"] +# response = model.invoke(BASE_CHECK_MESSAGE) + +# print(f"CAN WE CONTINUE W/O GENERATING A DEEP ANSWER? - {response.pretty_repr()}") + +# if response.pretty_repr() == "no": +# return "decompose" +# else: +# return "end" diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py new file mode 100644 index 00000000000..449ffb89dff --- /dev/null +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -0,0 +1,98 @@ +from langgraph.graph import END +from langgraph.graph import START +from langgraph.graph import StateGraph + +from onyx.agent_search.answer_query.graph_builder import answer_query_graph_builder +from onyx.agent_search.expanded_retrieval.graph_builder import ( + expanded_retrieval_graph_builder, +) +from onyx.agent_search.main.edges import parallelize_decompozed_answer_queries +from onyx.agent_search.main.nodes.base_decomp import main_decomp_base +from onyx.agent_search.main.nodes.generate_initial_answer import ( + generate_initial_answer, +) +from onyx.agent_search.main.states import MainInput +from onyx.agent_search.main.states import MainState + + +def main_graph_builder() -> StateGraph: + graph = StateGraph( + state_schema=MainState, + input=MainInput, + ) + + ### Add nodes ### + + graph.add_node( + node="base_decomp", + action=main_decomp_base, + ) + answer_query_subgraph = answer_query_graph_builder().compile() + graph.add_node( + node="answer_query", + action=answer_query_subgraph, + ) + expanded_retrieval_subgraph = expanded_retrieval_graph_builder().compile() + graph.add_node( + node="expanded_retrieval", + action=expanded_retrieval_subgraph, + ) + graph.add_node( + node="generate_initial_answer", + action=generate_initial_answer, + ) + + ### Add edges ### + graph.add_edge( + start_key=START, + end_key="expanded_retrieval", + ) + + graph.add_edge( + start_key=START, + end_key="base_decomp", + ) + graph.add_conditional_edges( + source="base_decomp", + path=parallelize_decompozed_answer_queries, + path_map=["answer_query"], + ) + graph.add_edge( + start_key=["answer_query", "expanded_retrieval"], + end_key="generate_initial_answer", + ) + graph.add_edge( + start_key="generate_initial_answer", + end_key=END, + ) + + return graph + + +if __name__ == "__main__": + from onyx.db.engine import get_session_context_manager + from onyx.llm.factory import get_default_llms + from onyx.context.search.models import SearchRequest + + graph = main_graph_builder() + compiled_graph = graph.compile() + primary_llm, fast_llm = get_default_llms() + search_request = SearchRequest( + query="If i am familiar with the function that I need, how can I type it into a cell?", + ) + with get_session_context_manager() as db_session: + inputs = MainInput( + search_request=search_request, + primary_llm=primary_llm, + fast_llm=fast_llm, + db_session=db_session, + ) + for thing in compiled_graph.stream( + input=inputs, + # stream_mode="debug", + # debug=True, + subgraphs=True, + ): + # print(thing) + print() + print() diff --git a/backend/onyx/agent_search/main/nodes/base_decomp.py b/backend/onyx/agent_search/main/nodes/base_decomp.py new file mode 100644 index 00000000000..28e93c6cbcc --- /dev/null +++ b/backend/onyx/agent_search/main/nodes/base_decomp.py @@ -0,0 +1,31 @@ +from langchain_core.messages import HumanMessage + +from onyx.agent_search.main.states import BaseDecompOutput +from onyx.agent_search.main.states import MainState +from onyx.agent_search.shared_graph_utils.prompts import INITIAL_DECOMPOSITION_PROMPT +from onyx.agent_search.shared_graph_utils.utils import clean_and_parse_list_string + + +def main_decomp_base(state: MainState) -> BaseDecompOutput: + question = state["search_request"].query + + msg = [ + HumanMessage( + content=INITIAL_DECOMPOSITION_PROMPT.format(question=question), + ) + ] + + # Get the rewritten queries in a defined format + model = state["fast_llm"] + response = model.invoke(msg) + + content = response.pretty_repr() + list_of_subquestions = clean_and_parse_list_string(content) + + decomp_list: list[str] = [ + sub_question["sub_question"].strip() for sub_question in list_of_subquestions + ] + + return BaseDecompOutput( + initial_decomp_queries=decomp_list, + ) diff --git a/backend/onyx/agent_search/main/nodes/generate_initial_answer.py b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py new file mode 100644 index 00000000000..5671b2352fa --- /dev/null +++ b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py @@ -0,0 +1,53 @@ +from langchain_core.messages import HumanMessage + +from onyx.agent_search.main.states import InitialAnswerOutput +from onyx.agent_search.main.states import MainState +from onyx.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT +from onyx.agent_search.shared_graph_utils.utils import format_docs + + +def generate_initial_answer(state: MainState) -> InitialAnswerOutput: + print("---GENERATE INITIAL---") + + question = state["search_request"].query + docs = state["documents"] + + decomp_answer_results = state["decomp_answer_results"] + + good_qa_list: list[str] = [] + + _SUB_QUESTION_ANSWER_TEMPLATE = """ + Sub-Question:\n - {sub_question}\n --\nAnswer:\n - {sub_answer}\n\n + """ + for decomp_answer_result in decomp_answer_results: + if ( + decomp_answer_result.quality.lower() == "yes" + and len(decomp_answer_result.answer) > 0 + and decomp_answer_result.answer != "I don't know" + ): + good_qa_list.append( + _SUB_QUESTION_ANSWER_TEMPLATE.format( + sub_question=decomp_answer_result.query, + sub_answer=decomp_answer_result.answer, + ) + ) + + sub_question_answer_str = "\n\n------\n\n".join(good_qa_list) + + msg = [ + HumanMessage( + content=INITIAL_RAG_PROMPT.format( + question=question, + context=format_docs(docs), + answered_sub_questions=sub_question_answer_str, + ) + ) + ] + + # Grader + model = state["fast_llm"] + response = model.invoke(msg) + answer = response.pretty_repr() + + print(answer) + return InitialAnswerOutput(initial_answer=answer) diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py new file mode 100644 index 00000000000..3b753ff8476 --- /dev/null +++ b/backend/onyx/agent_search/main/states.py @@ -0,0 +1,37 @@ +from operator import add +from typing import Annotated +from typing import TypedDict + +from onyx.agent_search.answer_query.states import SearchAnswerResults +from onyx.agent_search.core_state import PrimaryState +from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections +from onyx.context.search.models import InferenceSection + + +class BaseDecompOutput(TypedDict, total=False): + initial_decomp_queries: list[str] + + +class InitialAnswerOutput(TypedDict, total=False): + initial_answer: str + + +class MainState( + PrimaryState, + BaseDecompOutput, + InitialAnswerOutput, + total=True, +): + documents: Annotated[list[InferenceSection], dedup_inference_sections] + decomp_answer_results: Annotated[list[SearchAnswerResults], add] + + +class MainInput(PrimaryState, total=True): + pass + + +class MainOutput(TypedDict): + """ + This is not used because defining the output only matters for filtering the output of + a .invoke() call but we are streaming so we just yield the entire state. + """ diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py new file mode 100644 index 00000000000..98ed0ff8e62 --- /dev/null +++ b/backend/onyx/agent_search/run_graph.py @@ -0,0 +1,27 @@ +from onyx.agent_search.primary_graph.graph_builder import build_core_graph +from onyx.llm.answering.answer import AnswerStream +from onyx.llm.interfaces import LLM +from onyx.tools.tool import Tool + + +def run_graph( + query: str, + llm: LLM, + tools: list[Tool], +) -> AnswerStream: + graph = build_core_graph() + + inputs = { + "original_query": query, + "messages": [], + "tools": tools, + "llm": llm, + } + compiled_graph = graph.compile() + output = compiled_graph.invoke(input=inputs) + yield from output + + +if __name__ == "__main__": + pass + # run_graph("What is the capital of France?", llm, []) diff --git a/backend/onyx/agent_search/shared_graph_utils/models.py b/backend/onyx/agent_search/shared_graph_utils/models.py new file mode 100644 index 00000000000..162d651fe51 --- /dev/null +++ b/backend/onyx/agent_search/shared_graph_utils/models.py @@ -0,0 +1,12 @@ +from typing import Literal + +from pydantic import BaseModel + + +# Pydantic models for structured outputs +class RewrittenQueries(BaseModel): + rewritten_queries: list[str] + + +class BinaryDecision(BaseModel): + decision: Literal["yes", "no"] diff --git a/backend/onyx/agent_search/shared_graph_utils/operators.py b/backend/onyx/agent_search/shared_graph_utils/operators.py new file mode 100644 index 00000000000..d75eb54cd55 --- /dev/null +++ b/backend/onyx/agent_search/shared_graph_utils/operators.py @@ -0,0 +1,9 @@ +from onyx.chat.prune_and_merge import _merge_sections +from onyx.context.search.models import InferenceSection + + +def dedup_inference_sections( + list1: list[InferenceSection], list2: list[InferenceSection] +) -> list[InferenceSection]: + deduped = _merge_sections(list1 + list2) + return deduped diff --git a/backend/onyx/agent_search/shared_graph_utils/prompts.py b/backend/onyx/agent_search/shared_graph_utils/prompts.py new file mode 100644 index 00000000000..a3eeba29fb9 --- /dev/null +++ b/backend/onyx/agent_search/shared_graph_utils/prompts.py @@ -0,0 +1,427 @@ +REWRITE_PROMPT_MULTI_ORIGINAL = """ \n + Please convert an initial user question into a 2-3 more appropriate short and pointed search queries for retrievel from a + document store. Particularly, try to think about resolving ambiguities and make the search queries more specific, + enabling the system to search more broadly. + Also, try to make the search queries not redundant, i.e. not too similar! \n\n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + Formulate the queries separated by '--' (Do not say 'Query 1: ...', just write the querytext): """ + +REWRITE_PROMPT_MULTI = """ \n + Please create a list of 2-3 sample documents that could answer an original question. Each document + should be about as long as the original question. \n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + Formulate the sample documents separated by '--' (Do not say 'Document 1: ...', just write the text): """ + +BASE_RAG_PROMPT = """ \n + You are an assistant for question-answering tasks. Use the context provided below - and only the + provided context - to answer the question. If you don't know the answer or if the provided context is + empty, just say "I don't know". Do not use your internal knowledge! + + Again, only use the provided context and do not use your internal knowledge! If you cannot answer the + question based on the context, say "I don't know". It is a matter of life and death that you do NOT + use your internal knowledge, just the provided information! + + Use three sentences maximum and keep the answer concise. + answer concise.\nQuestion:\n {question} \nContext:\n {context} \n\n + \n\n + Answer:""" + +BASE_CHECK_PROMPT = """ \n + Please check whether 1) the suggested answer seems to fully address the original question AND 2)the + original question requests a simple, factual answer, and there are no ambiguities, judgements, + aggregations, or any other complications that may require extra context. (I.e., if the question is + somewhat addressed, but the answer would benefit from more context, then answer with 'no'.) + + Please only answer with 'yes' or 'no' \n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + Here is the proposed answer: + \n ------- \n + {base_answer} + \n ------- \n + Please answer with yes or no:""" + +VERIFIER_PROMPT = """ \n + Please check whether the document seems to be relevant for the answer of the question. Please + only answer with 'yes' or 'no' \n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + Here is the document text: + \n ------- \n + {document_content} + \n ------- \n + Please answer with yes or no:""" + +INITIAL_DECOMPOSITION_PROMPT_BASIC = """ \n + Please decompose an initial user question into not more than 4 appropriate sub-questions that help to + answer the original question. The purpose for this decomposition is to isolate individulal entities + (i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales + for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our + sales with company A' + 'what is our market share with company A' + 'is company A a reference customer + for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n + + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + Please formulate your answer as a list of subquestions: + + Answer: + """ + +REWRITE_PROMPT_SINGLE = """ \n + Please convert an initial user question into a more appropriate search query for retrievel from a + document store. \n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + Formulate the query: """ + +MODIFIED_RAG_PROMPT = """You are an assistant for question-answering tasks. Use the context provided below + - and only this context - to answer the question. If you don't know the answer, just say "I don't know". + Use three sentences maximum and keep the answer concise. + Pay also particular attention to the sub-questions and their answers, at least it may enrich the answer. + Again, only use the provided context and do not use your internal knowledge! If you cannot answer the + question based on the context, say "I don't know". It is a matter of life and death that you do NOT + use your internal knowledge, just the provided information! + + \nQuestion: {question} + \nContext: {combined_context} \n + + Answer:""" + +ORIG_DEEP_DECOMPOSE_PROMPT = """ \n + An initial user question needs to be answered. An initial answer has been provided but it wasn't quite + good enough. Also, some sub-questions had been answered and this information has been used to provide + the initial answer. Some other subquestions may have been suggested based on little knowledge, but they + were not directly answerable. Also, some entities, relationships and terms are givenm to you so that + you have an idea of how the avaiolable data looks like. + + Your role is to generate 3-5 new sub-questions that would help to answer the initial question, + considering: + + 1) The initial question + 2) The initial answer that was found to be unsatisfactory + 3) The sub-questions that were answered + 4) The sub-questions that were suggested but not answered + 5) The entities, relationships and terms that were extracted from the context + + The individual questions should be answerable by a good RAG system. + So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the + question for different entities that may be involved in the original question, but in a way that does + not duplicate questions that were already tried. + + Additional Guidelines: + - The sub-questions should be specific to the question and provide richer context for the question, + resolve ambiguities, or address shortcoming of the initial answer + - Each sub-question - when answered - should be relevant for the answer to the original question + - The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any + other complications that may require extra context. + - The sub-questions MUST have the full context of the original question so that it can be executed by + a RAG system independently without the original question available + (Example: + - initial question: "What is the capital of France?" + - bad sub-question: "What is the name of the river there?" + - good sub-question: "What is the name of the river that flows through Paris?" + - For each sub-question, please provide a short explanation for why it is a good sub-question. So + generate a list of dictionaries with the following format: + [{{"sub_question": , "explanation": , "search_term": }}, ...] + + \n\n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + Here is the initial sub-optimal answer: + \n ------- \n + {base_answer} + \n ------- \n + + Here are the sub-questions that were answered: + \n ------- \n + {answered_sub_questions} + \n ------- \n + + Here are the sub-questions that were suggested but not answered: + \n ------- \n + {failed_sub_questions} + \n ------- \n + + And here are the entities, relationships and terms extracted from the context: + \n ------- \n + {entity_term_extraction_str} + \n ------- \n + + Please generate the list of good, fully contextualized sub-questions that would help to address the + main question. Again, please find questions that are NOT overlapping too much with the already answered + sub-questions or those that already were suggested and failed. + In other words - what can we try in addition to what has been tried so far? + + Please think through it step by step and then generate the list of json dictionaries with the following + format: + + {{"sub_questions": [{{"sub_question": , + "explanation": , + "search_term": }}, + ...]}} """ + +DEEP_DECOMPOSE_PROMPT = """ \n + An initial user question needs to be answered. An initial answer has been provided but it wasn't quite + good enough. Also, some sub-questions had been answered and this information has been used to provide + the initial answer. Some other subquestions may have been suggested based on little knowledge, but they + were not directly answerable. Also, some entities, relationships and terms are givenm to you so that + you have an idea of how the avaiolable data looks like. + + Your role is to generate 4-6 new sub-questions that would help to answer the initial question, + considering: + + 1) The initial question + 2) The initial answer that was found to be unsatisfactory + 3) The sub-questions that were answered + 4) The sub-questions that were suggested but not answered + 5) The entities, relationships and terms that were extracted from the context + + The individual questions should be answerable by a good RAG system. + So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the + question for different entities that may be involved in the original question, but in a way that does + not duplicate questions that were already tried. + + Additional Guidelines: + - The sub-questions should be specific to the question and provide richer context for the question, + resolve ambiguities, or address shortcoming of the initial answer + - Each sub-question - when answered - should be relevant for the answer to the original question + - The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any + other complications that may require extra context. + - The sub-questions MUST have the full context of the original question so that it can be executed by + a RAG system independently without the original question available + (Example: + - initial question: "What is the capital of France?" + - bad sub-question: "What is the name of the river there?" + - good sub-question: "What is the name of the river that flows through Paris?" + - For each sub-question, please also provide a search term that can be used to retrieve relevant + documents from a document store. + \n\n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + Here is the initial sub-optimal answer: + \n ------- \n + {base_answer} + \n ------- \n + + Here are the sub-questions that were answered: + \n ------- \n + {answered_sub_questions} + \n ------- \n + + Here are the sub-questions that were suggested but not answered: + \n ------- \n + {failed_sub_questions} + \n ------- \n + + And here are the entities, relationships and terms extracted from the context: + \n ------- \n + {entity_term_extraction_str} + \n ------- \n + + Please generate the list of good, fully contextualized sub-questions that would help to address the + main question. Again, please find questions that are NOT overlapping too much with the already answered + sub-questions or those that already were suggested and failed. + In other words - what can we try in addition to what has been tried so far? + + Generate the list of json dictionaries with the following format: + + {{"sub_questions": [{{"sub_question": , + "search_term": }}, + ...]}} """ + +DECOMPOSE_PROMPT = """ \n + For an initial user question, please generate at 5-10 individual sub-questions whose answers would help + \n to answer the initial question. The individual questions should be answerable by a good RAG system. + So a good idea would be to \n use the sub-questions to resolve ambiguities and/or to separate the + question for different entities that may be involved in the original question. + + In order to arrive at meaningful sub-questions, please also consider the context retrieved from the + document store, expressed as entities, relationships and terms. You can also think about the types + mentioned in brackets + + Guidelines: + - The sub-questions should be specific to the question and provide richer context for the question, + and or resolve ambiguities + - Each sub-question - when answered - should be relevant for the answer to the original question + - The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any + other complications that may require extra context. + - The sub-questions MUST have the full context of the original question so that it can be executed by + a RAG system independently without the original question available + (Example: + - initial question: "What is the capital of France?" + - bad sub-question: "What is the name of the river there?" + - good sub-question: "What is the name of the river that flows through Paris?" + - For each sub-question, please provide a short explanation for why it is a good sub-question. So + generate a list of dictionaries with the following format: + [{{"sub_question": , "explanation": , "search_term": }}, ...] + + \n\n + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + And here are the entities, relationships and terms extracted from the context: + \n ------- \n + {entity_term_extraction_str} + \n ------- \n + + Please generate the list of good, fully contextualized sub-questions that would help to address the + main question. Don't be too specific unless the original question is specific. + Please think through it step by step and then generate the list of json dictionaries with the following + format: + {{"sub_questions": [{{"sub_question": , + "explanation": , + "search_term": }}, + ...]}} """ + +#### Consolidations +COMBINED_CONTEXT = """------- + Below you will find useful information to answer the original question. First, you see a number of + sub-questions with their answers. This information should be considered to be more focussed and + somewhat more specific to the original question as it tries to contextualized facts. + After that will see the documents that were considered to be relevant to answer the original question. + + Here are the sub-questions and their answers: + \n\n {deep_answer_context} \n\n + \n\n Here are the documents that were considered to be relevant to answer the original question: + \n\n {formated_docs} \n\n + ---------------- + """ + +SUB_QUESTION_EXPLANATION_RANKER_PROMPT = """------- + Below you will find a question that we ultimately want to answer (the original question) and a list of + motivations in arbitrary order for generated sub-questions that are supposed to help us answering the + original question. The motivations are formatted as : . + (Again, the numbering is arbitrary and does not necessarily mean that 1 is the most relevant + motivation and 2 is less relevant.) + + Please rank the motivations in order of relevance for answering the original question. Also, try to + ensure that the top questions do not duplicate too much, i.e. that they are not too similar. + Ultimately, create a list with the motivation numbers where the number of the most relevant + motivations comes first. + + Here is the original question: + \n\n {original_question} \n\n + \n\n Here is the list of sub-question motivations: + \n\n {sub_question_explanations} \n\n + ---------------- + + Please think step by step and then generate the ranked list of motivations. + + Please format your answer as a json object in the following format: + {{"reasonning": , + "ranked_motivations": }} + """ + + +INITIAL_DECOMPOSITION_PROMPT = """ \n + Please decompose an initial user question into 2 or 3 appropriate sub-questions that help to + answer the original question. The purpose for this decomposition is to isolate individulal entities + (i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales + for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our + sales with company A' + 'what is our market share with company A' + 'is company A a reference customer + for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n + + For each sub-question, please also create one search term that can be used to retrieve relevant + documents from a document store. + + Here is the initial question: + \n ------- \n + {question} + \n ------- \n + + Please formulate your answer as a list of json objects with the following format: + + [{{"sub_question": , "search_term": }}, ...] + + Answer: + """ + +INITIAL_RAG_PROMPT = """ \n + You are an assistant for question-answering tasks. Use the information provided below - and only the + provided information - to answer the provided question. + + The information provided below consists of: + 1) a number of answered sub-questions - these are very important(!) and definitely should be + considered to answer the question. + 2) a number of documents that were also deemed relevant for the question. + + If you don't know the answer or if the provided information is empty or insufficient, just say + "I don't know". Do not use your internal knowledge! + + Again, only use the provided informationand do not use your internal knowledge! It is a matter of life + and death that you do NOT use your internal knowledge, just the provided information! + + Try to keep your answer concise. + + And here is the question and the provided information: + \n + \nQuestion:\n {question} + + \nAnswered Sub-questions:\n {answered_sub_questions} + + \nContext:\n {context} \n\n + \n\n + + Answer:""" + +ENTITY_TERM_PROMPT = """ \n + Based on the original question and the context retieved from a dataset, please generate a list of + entities (e.g. companies, organizations, industries, products, locations, etc.), terms and concepts + (e.g. sales, revenue, etc.) that are relevant for the question, plus their relations to each other. + + \n\n + Here is the original question: + \n ------- \n + {question} + \n ------- \n + And here is the context retrieved: + \n ------- \n + {context} + \n ------- \n + + Please format your answer as a json object in the following format: + + {{"retrieved_entities_relationships": {{ + "entities": [{{ + "entity_name": , + "entity_type": + }}], + "relationships": [{{ + "name": , + "type": , + "entities": [, ] + }}], + "terms": [{{ + "term_name": , + "term_type": , + "similar_to": + }}] + }} + }} + """ diff --git a/backend/onyx/agent_search/shared_graph_utils/utils.py b/backend/onyx/agent_search/shared_graph_utils/utils.py new file mode 100644 index 00000000000..a435860320d --- /dev/null +++ b/backend/onyx/agent_search/shared_graph_utils/utils.py @@ -0,0 +1,101 @@ +import ast +import json +import re +from collections.abc import Sequence +from datetime import datetime +from datetime import timedelta +from typing import Any + +from onyx.context.search.models import InferenceSection + + +def normalize_whitespace(text: str) -> str: + """Normalize whitespace in text to single spaces and strip leading/trailing whitespace.""" + import re + + return re.sub(r"\s+", " ", text.strip()) + + +# Post-processing +def format_docs(docs: Sequence[InferenceSection]) -> str: + return "\n\n".join(doc.combined_content for doc in docs) + + +def clean_and_parse_list_string(json_string: str) -> list[dict]: + # Remove any prefixes/labels before the actual JSON content + json_string = re.sub(r"^.*?(?=\[)", "", json_string, flags=re.DOTALL) + + # Remove markdown code block markers and any newline prefixes + cleaned_string = re.sub(r"```json\n|\n```", "", json_string) + cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ") + cleaned_string = " ".join(cleaned_string.split()) + + # Try parsing with json.loads first, fall back to ast.literal_eval + try: + return json.loads(cleaned_string) + except json.JSONDecodeError: + try: + return ast.literal_eval(cleaned_string) + except (ValueError, SyntaxError) as e: + raise ValueError(f"Failed to parse JSON string: {cleaned_string}") from e + + +def clean_and_parse_json_string(json_string: str) -> dict[str, Any]: + # Remove markdown code block markers and any newline prefixes + cleaned_string = re.sub(r"```json\n|\n```", "", json_string) + cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ") + cleaned_string = " ".join(cleaned_string.split()) + # Parse the cleaned string into a Python dictionary + return json.loads(cleaned_string) + + +def format_entity_term_extraction(entity_term_extraction_dict: dict[str, Any]) -> str: + entities = entity_term_extraction_dict["entities"] + terms = entity_term_extraction_dict["terms"] + relationships = entity_term_extraction_dict["relationships"] + + entity_strs = ["\nEntities:\n"] + for entity in entities: + entity_str = f"{entity['entity_name']} ({entity['entity_type']})" + entity_strs.append(entity_str) + + entity_str = "\n - ".join(entity_strs) + + relationship_strs = ["\n\nRelationships:\n"] + for relationship in relationships: + relationship_str = f"{relationship['name']} ({relationship['type']}): {relationship['entities']}" + relationship_strs.append(relationship_str) + + relationship_str = "\n - ".join(relationship_strs) + + term_strs = ["\n\nTerms:\n"] + for term in terms: + term_str = f"{term['term_name']} ({term['term_type']}): similar to {term['similar_to']}" + term_strs.append(term_str) + + term_str = "\n - ".join(term_strs) + + return "\n".join(entity_strs + relationship_strs + term_strs) + + +def _format_time_delta(time: timedelta) -> str: + seconds_from_start = f"{((time).seconds):03d}" + microseconds_from_start = f"{((time).microseconds):06d}" + return f"{seconds_from_start}.{microseconds_from_start}" + + +def generate_log_message( + message: str, + node_start_time: datetime, + graph_start_time: datetime | None = None, +) -> str: + current_time = datetime.now() + + if graph_start_time is not None: + graph_time_str = _format_time_delta(current_time - graph_start_time) + else: + graph_time_str = "N/A" + + node_time_str = _format_time_delta(current_time - node_start_time) + + return f"{graph_time_str} ({node_time_str} s): {message}" diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 3a4996d9014..01a99c975fd 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -26,10 +26,15 @@ huggingface-hub==0.20.1 jira==3.5.1 jsonref==1.1.0 trafilatura==1.12.2 -langchain==0.1.17 -langchain-core==0.1.50 -langchain-text-splitters==0.0.1 -litellm==1.54.1 +langchain==0.3.7 +langchain-core==0.3.24 +langchain-openai==0.2.9 +langchain-text-splitters==0.3.2 +langchainhub==0.1.21 +langgraph==0.2.59 +langgraph-checkpoint==2.0.5 +langgraph-sdk==0.1.44 +litellm==1.53.1 lxml==5.3.0 lxml_html_clean==0.2.2 llama-index==0.9.45 From 11ce2a62abf893ec1857e7a9bf4830b4a4b23fef Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Mon, 16 Dec 2024 12:24:17 -0800 Subject: [PATCH 02/19] fix: update staged changes --- backend/onyx/agent_search/run_graph.py | 6 +++--- backend/onyx/tools/message.py | 5 +++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/backend/onyx/agent_search/run_graph.py b/backend/onyx/agent_search/run_graph.py index 98ed0ff8e62..9a93dbba646 100644 --- a/backend/onyx/agent_search/run_graph.py +++ b/backend/onyx/agent_search/run_graph.py @@ -1,5 +1,5 @@ -from onyx.agent_search.primary_graph.graph_builder import build_core_graph -from onyx.llm.answering.answer import AnswerStream +from onyx.agent_search.main.graph_builder import main_graph_builder +from onyx.chat.answer import AnswerStream from onyx.llm.interfaces import LLM from onyx.tools.tool import Tool @@ -9,7 +9,7 @@ def run_graph( llm: LLM, tools: list[Tool], ) -> AnswerStream: - graph = build_core_graph() + graph = main_graph_builder() inputs = { "original_query": query, diff --git a/backend/onyx/tools/message.py b/backend/onyx/tools/message.py index d5590111623..659f38731e3 100644 --- a/backend/onyx/tools/message.py +++ b/backend/onyx/tools/message.py @@ -25,6 +25,11 @@ class ToolCallSummary(BaseModel__v1): tool_call_request: AIMessage tool_call_result: ToolMessage + # This is a workaround to allow arbitrary types in the model + # TODO: Remove this once we have a better solution + class Config: + arbitrary_types_allowed = True + def tool_call_tokens( tool_call_summary: ToolCallSummary, llm_tokenizer: BaseTokenizer From 82914ad365cc94ec5adfae517301dce3d9987ae6 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Mon, 16 Dec 2024 13:26:09 -0800 Subject: [PATCH 03/19] fixed key issue --- .../onyx/agent_search/answer_query/nodes/answer_generation.py | 2 +- backend/onyx/agent_search/answer_query/nodes/format_answer.py | 2 +- backend/onyx/agent_search/answer_query/states.py | 2 +- backend/onyx/agent_search/expanded_retrieval/states.py | 2 +- backend/onyx/agent_search/main/graph_builder.py | 1 - 5 files changed, 4 insertions(+), 5 deletions(-) diff --git a/backend/onyx/agent_search/answer_query/nodes/answer_generation.py b/backend/onyx/agent_search/answer_query/nodes/answer_generation.py index c23f77ee706..18c0862e23e 100644 --- a/backend/onyx/agent_search/answer_query/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_query/nodes/answer_generation.py @@ -9,7 +9,7 @@ def answer_generation(state: AnswerQueryState) -> QAGenerationOutput: query = state["query_to_answer"] - docs = state["reordered_documents"] + docs = state["reranked_documents"] print(f"Number of verified retrieval docs: {len(docs)}") diff --git a/backend/onyx/agent_search/answer_query/nodes/format_answer.py b/backend/onyx/agent_search/answer_query/nodes/format_answer.py index 8359baec9b4..51f7dbad5b2 100644 --- a/backend/onyx/agent_search/answer_query/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_query/nodes/format_answer.py @@ -10,7 +10,7 @@ def format_answer(state: AnswerQueryState) -> AnswerQueryOutput: query=state["query_to_answer"], quality=state["answer_quality"], answer=state["answer"], - documents=state["reordered_documents"], + documents=state["reranked_documents"], ) ], ) diff --git a/backend/onyx/agent_search/answer_query/states.py b/backend/onyx/agent_search/answer_query/states.py index 9f8fe12ab61..d2dd1f12c65 100644 --- a/backend/onyx/agent_search/answer_query/states.py +++ b/backend/onyx/agent_search/answer_query/states.py @@ -24,7 +24,7 @@ class QAGenerationOutput(TypedDict, total=False): class ExpandedRetrievalOutput(TypedDict): - reordered_documents: Annotated[list[InferenceSection], dedup_inference_sections] + reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] class AnswerQueryState( diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index a0f726b7f8b..54fa6023cc1 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -33,4 +33,4 @@ class ExpandedRetrievalInput(PrimaryState, total=True): class ExpandedRetrievalOutput(TypedDict): - reordered_documents: Annotated[list[InferenceSection], dedup_inference_sections] + reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index 449ffb89dff..930d7a745f7 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -95,4 +95,3 @@ def main_graph_builder() -> StateGraph: ): # print(thing) print() - print() From ff03d717f37f4dc3f3027b338b5e365f223cb112 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Tue, 17 Dec 2024 12:36:28 -0800 Subject: [PATCH 04/19] brough over joachim changes --- .../answer_query/nodes/answer_check.py | 6 ++-- .../agent_search/expanded_retrieval/edges.py | 6 ++-- .../expanded_retrieval/nodes/doc_retrieval.py | 28 ++++++++----------- .../shared_graph_utils/prompts.py | 18 +++++++++++- 4 files changed, 35 insertions(+), 23 deletions(-) diff --git a/backend/onyx/agent_search/answer_query/nodes/answer_check.py b/backend/onyx/agent_search/answer_query/nodes/answer_check.py index 8b58129c47b..95059673061 100644 --- a/backend/onyx/agent_search/answer_query/nodes/answer_check.py +++ b/backend/onyx/agent_search/answer_query/nodes/answer_check.py @@ -3,14 +3,14 @@ from onyx.agent_search.answer_query.states import AnswerQueryState from onyx.agent_search.answer_query.states import QACheckOutput -from onyx.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT +from onyx.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT def answer_check(state: AnswerQueryState) -> QACheckOutput: msg = [ HumanMessage( - content=BASE_CHECK_PROMPT.format( - question=state["search_request"].query, + content=SUB_CHECK_PROMPT.format( + question=state["query_to_answer"], base_answer=state["answer"], ) ) diff --git a/backend/onyx/agent_search/expanded_retrieval/edges.py b/backend/onyx/agent_search/expanded_retrieval/edges.py index 2c63125bb9c..063befe85ae 100644 --- a/backend/onyx/agent_search/expanded_retrieval/edges.py +++ b/backend/onyx/agent_search/expanded_retrieval/edges.py @@ -6,7 +6,7 @@ from onyx.agent_search.expanded_retrieval.nodes.doc_retrieval import RetrieveInput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput -from onyx.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI +from onyx.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI_ORIGINAL from onyx.llm.interfaces import LLM @@ -19,7 +19,7 @@ def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashab msg = [ HumanMessage( - content=REWRITE_PROMPT_MULTI.format(question=question), + content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question), ) ] llm_response_list = list( @@ -31,7 +31,7 @@ def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashab print(f"llm_response: {llm_response}") - rewritten_queries = llm_response.split("\n") + rewritten_queries = llm_response.split("--") print(f"rewritten_queries: {rewritten_queries}") diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py index 8d612499483..af38e5f4909 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py @@ -3,7 +3,6 @@ from onyx.context.search.models import InferenceSection from onyx.context.search.models import SearchRequest from onyx.context.search.pipeline import SearchPipeline -from onyx.db.engine import get_session_context_manager class RetrieveInput(ExpandedRetrievalState): @@ -23,25 +22,22 @@ def doc_retrieval(state: RetrieveInput) -> DocRetrievalOutput: """ print(f"doc_retrieval state: {state.keys()}") - state["query_to_retrieve"] - documents: list[InferenceSection] = [] llm = state["primary_llm"] fast_llm = state["fast_llm"] - # db_session = state["db_session"] - query_to_retrieve = state["search_request"].query - with get_session_context_manager() as db_session1: - documents = SearchPipeline( - search_request=SearchRequest( - query=query_to_retrieve, - ), - user=None, - llm=llm, - fast_llm=fast_llm, - db_session=db_session1, - ).reranked_sections + query_to_retrieve = state["query_to_retrieve"] + + documents = SearchPipeline( + search_request=SearchRequest( + query=query_to_retrieve, + ), + user=None, + llm=llm, + fast_llm=fast_llm, + db_session=state["db_session"], + ).reranked_sections print(f"retrieved documents: {len(documents)}") return DocRetrievalOutput( - retrieved_documents=documents, + retrieved_documents=documents[:4], ) diff --git a/backend/onyx/agent_search/shared_graph_utils/prompts.py b/backend/onyx/agent_search/shared_graph_utils/prompts.py index a3eeba29fb9..229a980762c 100644 --- a/backend/onyx/agent_search/shared_graph_utils/prompts.py +++ b/backend/onyx/agent_search/shared_graph_utils/prompts.py @@ -32,6 +32,21 @@ \n\n Answer:""" +SUB_CHECK_PROMPT = """ \n + Your task is to see whether a given answer addresses a given question. + Please do not use any internal knowledge you may have - just focus on whether the answer + as given seems to address the question as given. + Here is the question: + \n ------- \n + {question} + \n ------- \n + Here is the suggested answer: + \n ------- \n + {base_answer} + \n ------- \n + Please answer with yes or no:""" + + BASE_CHECK_PROMPT = """ \n Please check whether 1) the suggested answer seems to fully address the original question AND 2)the original question requests a simple, factual answer, and there are no ambiguities, judgements, @@ -50,7 +65,8 @@ Please answer with yes or no:""" VERIFIER_PROMPT = """ \n - Please check whether the document seems to be relevant for the answer of the question. Please + Please check whether the document provided below seems to be relevant + to get an answer to the provided question. Please only answer with 'yes' or 'no' \n Here is the initial question: \n ------- \n From 1f88b60abd9b5d42864a0333da303eb30a20bd4b Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Tue, 17 Dec 2024 14:05:51 -0800 Subject: [PATCH 05/19] Now using result objects --- .../onyx/agent_search/answer_query/edges.py | 16 ++++++++++ .../answer_query/graph_builder.py | 14 +++++---- .../answer_query/nodes/answer_check.py | 2 +- .../answer_query/nodes/answer_generation.py | 2 +- .../answer_query/nodes/format_answer.py | 2 +- .../onyx/agent_search/answer_query/states.py | 15 ++++------ .../agent_search/expanded_retrieval/edges.py | 2 +- .../expanded_retrieval/graph_builder.py | 11 ++++++- .../nodes/format_results.py | 15 ++++++++++ .../agent_search/expanded_retrieval/states.py | 29 +++++++++++++++---- backend/onyx/agent_search/main/edges.py | 2 +- .../main/nodes/generate_initial_answer.py | 2 +- 12 files changed, 84 insertions(+), 28 deletions(-) create mode 100644 backend/onyx/agent_search/answer_query/edges.py create mode 100644 backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py diff --git a/backend/onyx/agent_search/answer_query/edges.py b/backend/onyx/agent_search/answer_query/edges.py new file mode 100644 index 00000000000..15f60f2bdf3 --- /dev/null +++ b/backend/onyx/agent_search/answer_query/edges.py @@ -0,0 +1,16 @@ +from collections.abc import Hashable + +from langgraph.types import Send + +from onyx.agent_search.answer_query.states import AnswerQueryInput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput + + +def send_to_expanded_retrieval(state: AnswerQueryInput) -> Send | Hashable: + return Send( + "expanded_retrieval", + ExpandedRetrievalInput( + **state, + starting_query=state["starting_query"], + ), + ) diff --git a/backend/onyx/agent_search/answer_query/graph_builder.py b/backend/onyx/agent_search/answer_query/graph_builder.py index e52bfe28d69..53f647eac87 100644 --- a/backend/onyx/agent_search/answer_query/graph_builder.py +++ b/backend/onyx/agent_search/answer_query/graph_builder.py @@ -2,6 +2,7 @@ from langgraph.graph import START from langgraph.graph import StateGraph +from onyx.agent_search.answer_query.edges import send_to_expanded_retrieval from onyx.agent_search.answer_query.nodes.answer_check import answer_check from onyx.agent_search.answer_query.nodes.answer_generation import answer_generation from onyx.agent_search.answer_query.nodes.format_answer import format_answer @@ -24,7 +25,7 @@ def answer_query_graph_builder() -> StateGraph: expanded_retrieval = expanded_retrieval_graph_builder().compile() graph.add_node( - node="expanded_retrieval_for_initial_decomp", + node="decomped_expanded_retrieval", action=expanded_retrieval, ) graph.add_node( @@ -42,12 +43,13 @@ def answer_query_graph_builder() -> StateGraph: ### Add edges ### - graph.add_edge( - start_key=START, - end_key="expanded_retrieval_for_initial_decomp", + graph.add_conditional_edges( + source=START, + path=send_to_expanded_retrieval, + path_map=["decomped_expanded_retrieval"], ) graph.add_edge( - start_key="expanded_retrieval_for_initial_decomp", + start_key="decomped_expanded_retrieval", end_key="answer_generation", ) graph.add_edge( @@ -83,7 +85,7 @@ def answer_query_graph_builder() -> StateGraph: primary_llm=primary_llm, fast_llm=fast_llm, db_session=db_session, - query_to_answer="Who made Excel?", + question_to_answer="Who made Excel?", ) output = compiled_graph.invoke( input=inputs, diff --git a/backend/onyx/agent_search/answer_query/nodes/answer_check.py b/backend/onyx/agent_search/answer_query/nodes/answer_check.py index 95059673061..f06b2071f96 100644 --- a/backend/onyx/agent_search/answer_query/nodes/answer_check.py +++ b/backend/onyx/agent_search/answer_query/nodes/answer_check.py @@ -10,7 +10,7 @@ def answer_check(state: AnswerQueryState) -> QACheckOutput: msg = [ HumanMessage( content=SUB_CHECK_PROMPT.format( - question=state["query_to_answer"], + question=state["question_to_answer"], base_answer=state["answer"], ) ) diff --git a/backend/onyx/agent_search/answer_query/nodes/answer_generation.py b/backend/onyx/agent_search/answer_query/nodes/answer_generation.py index 18c0862e23e..3de9c403d2b 100644 --- a/backend/onyx/agent_search/answer_query/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_query/nodes/answer_generation.py @@ -8,7 +8,7 @@ def answer_generation(state: AnswerQueryState) -> QAGenerationOutput: - query = state["query_to_answer"] + query = state["question_to_answer"] docs = state["reranked_documents"] print(f"Number of verified retrieval docs: {len(docs)}") diff --git a/backend/onyx/agent_search/answer_query/nodes/format_answer.py b/backend/onyx/agent_search/answer_query/nodes/format_answer.py index 51f7dbad5b2..061000701b9 100644 --- a/backend/onyx/agent_search/answer_query/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_query/nodes/format_answer.py @@ -7,7 +7,7 @@ def format_answer(state: AnswerQueryState) -> AnswerQueryOutput: return AnswerQueryOutput( decomp_answer_results=[ SearchAnswerResults( - query=state["query_to_answer"], + question=state["question_to_answer"], quality=state["answer_quality"], answer=state["answer"], documents=state["reranked_documents"], diff --git a/backend/onyx/agent_search/answer_query/states.py b/backend/onyx/agent_search/answer_query/states.py index d2dd1f12c65..8c24623ee05 100644 --- a/backend/onyx/agent_search/answer_query/states.py +++ b/backend/onyx/agent_search/answer_query/states.py @@ -4,14 +4,16 @@ from pydantic import BaseModel from onyx.agent_search.core_state import PrimaryState +from onyx.agent_search.expanded_retrieval.states import RetrievalResult from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection class SearchAnswerResults(BaseModel): - query: str + question: str answer: str quality: str + retrieval_results: list[RetrievalResult] documents: Annotated[list[InferenceSection], dedup_inference_sections] @@ -23,23 +25,18 @@ class QAGenerationOutput(TypedDict, total=False): answer: str -class ExpandedRetrievalOutput(TypedDict): - reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] - - class AnswerQueryState( PrimaryState, QACheckOutput, QAGenerationOutput, - ExpandedRetrievalOutput, total=True, ): - query_to_answer: str + question: str class AnswerQueryInput(PrimaryState, total=True): - query_to_answer: str + question: str class AnswerQueryOutput(TypedDict): - decomp_answer_results: list[SearchAnswerResults] + answer_results: list[SearchAnswerResults] diff --git a/backend/onyx/agent_search/expanded_retrieval/edges.py b/backend/onyx/agent_search/expanded_retrieval/edges.py index 063befe85ae..085479a4e43 100644 --- a/backend/onyx/agent_search/expanded_retrieval/edges.py +++ b/backend/onyx/agent_search/expanded_retrieval/edges.py @@ -14,7 +14,7 @@ def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashab print(f"parallel_retrieval_edge state: {state.keys()}") # This should be better... - question = state.get("query_to_answer") or state["search_request"].query + question = state.get("question_to_answer") or state["search_request"].query llm: LLM = state["fast_llm"] msg = [ diff --git a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py index 1928e93450c..7f94c1ef78a 100644 --- a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py +++ b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py @@ -8,6 +8,7 @@ from onyx.agent_search.expanded_retrieval.nodes.doc_verification import ( doc_verification, ) +from onyx.agent_search.expanded_retrieval.nodes.format_results import format_results from onyx.agent_search.expanded_retrieval.nodes.verification_kickoff import ( verification_kickoff, ) @@ -41,6 +42,10 @@ def expanded_retrieval_graph_builder() -> StateGraph: node="doc_reranking", action=doc_reranking, ) + graph.add_node( + node="format_results", + action=format_results, + ) ### Add edges ### @@ -59,6 +64,10 @@ def expanded_retrieval_graph_builder() -> StateGraph: ) graph.add_edge( start_key="doc_reranking", + end_key="format_results", + ) + graph.add_edge( + start_key="format_results", end_key=END, ) @@ -82,7 +91,7 @@ def expanded_retrieval_graph_builder() -> StateGraph: primary_llm=primary_llm, fast_llm=fast_llm, db_session=db_session, - query_to_answer="Who made Excel?", + question_to_answer="Who made Excel?", ) for thing in compiled_graph.stream(inputs, debug=True): print(thing) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py new file mode 100644 index 00000000000..36883eb6bd7 --- /dev/null +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py @@ -0,0 +1,15 @@ +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState +from onyx.agent_search.expanded_retrieval.states import RetrievalResult + + +def format_results(state: ExpandedRetrievalState) -> ExpandedRetrievalOutput: + return ExpandedRetrievalOutput( + retrieval_results=[ + RetrievalResult( + starting_query=state["starting_query"], + expanded_retrieval_results=state["expanded_retrieval_results"], + documents=state["reranked_documents"], + ) + ], + ) diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index 54fa6023cc1..697639cc4ae 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -1,12 +1,29 @@ +from operator import add from typing import Annotated from typing import TypedDict +from pydantic import BaseModel + from onyx.agent_search.core_state import PrimaryState from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection +class ExpandedRetrievalResult(BaseModel): + expanded_query: str + expanded_retrieval_documents: Annotated[ + list[InferenceSection], dedup_inference_sections + ] + + +class RetrievalResult(BaseModel): + starting_query: str + expanded_retrieval_results: list[ExpandedRetrievalResult] + documents: Annotated[list[InferenceSection], dedup_inference_sections] + + class DocRetrievalOutput(TypedDict, total=False): + expanded_retrieval_results: Annotated[list[ExpandedRetrievalResult], add] retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections] @@ -18,6 +35,10 @@ class DocRerankingOutput(TypedDict, total=False): reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] +class ExpandedRetrievalOutput(TypedDict): + retrieval_results: Annotated[list[RetrievalResult], add] + + class ExpandedRetrievalState( PrimaryState, DocRetrievalOutput, @@ -25,12 +46,8 @@ class ExpandedRetrievalState( DocRerankingOutput, total=True, ): - query_to_answer: str + starting_query: str class ExpandedRetrievalInput(PrimaryState, total=True): - query_to_answer: str - - -class ExpandedRetrievalOutput(TypedDict): - reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] + starting_query: str diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index 953b0a96275..0ec4c0f4a67 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -12,7 +12,7 @@ def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hasha "answer_query", AnswerQueryInput( **state, - query_to_answer=query, + question_to_answer=query, ), ) for query in state["initial_decomp_queries"] diff --git a/backend/onyx/agent_search/main/nodes/generate_initial_answer.py b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py index 5671b2352fa..a6476477ae4 100644 --- a/backend/onyx/agent_search/main/nodes/generate_initial_answer.py +++ b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py @@ -27,7 +27,7 @@ def generate_initial_answer(state: MainState) -> InitialAnswerOutput: ): good_qa_list.append( _SUB_QUESTION_ANSWER_TEMPLATE.format( - sub_question=decomp_answer_result.query, + sub_question=decomp_answer_result.question, sub_answer=decomp_answer_result.answer, ) ) From 2f2b9a862ace2d637192fadbb1bb30deba1c29cb Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Tue, 17 Dec 2024 15:11:54 -0800 Subject: [PATCH 06/19] fixed expanded retrieval subgraph --- backend/onyx/agent_search/answer_query/edges.py | 4 ++-- .../onyx/agent_search/answer_query/graph_builder.py | 4 ++-- .../agent_search/answer_query/nodes/answer_check.py | 2 +- .../answer_query/nodes/answer_generation.py | 4 ++-- .../agent_search/answer_query/nodes/format_answer.py | 5 +++-- backend/onyx/agent_search/answer_query/states.py | 6 ++++-- backend/onyx/agent_search/expanded_retrieval/edges.py | 2 +- .../agent_search/expanded_retrieval/graph_builder.py | 10 +++++++--- .../expanded_retrieval/nodes/doc_retrieval.py | 6 ++++++ .../expanded_retrieval/nodes/format_results.py | 10 ++-------- .../onyx/agent_search/expanded_retrieval/states.py | 11 ++++++----- backend/onyx/agent_search/main/edges.py | 2 +- 12 files changed, 37 insertions(+), 29 deletions(-) diff --git a/backend/onyx/agent_search/answer_query/edges.py b/backend/onyx/agent_search/answer_query/edges.py index 15f60f2bdf3..c538ef8958b 100644 --- a/backend/onyx/agent_search/answer_query/edges.py +++ b/backend/onyx/agent_search/answer_query/edges.py @@ -8,9 +8,9 @@ def send_to_expanded_retrieval(state: AnswerQueryInput) -> Send | Hashable: return Send( - "expanded_retrieval", + "decomped_expanded_retrieval", ExpandedRetrievalInput( **state, - starting_query=state["starting_query"], + starting_query=state["question"], ), ) diff --git a/backend/onyx/agent_search/answer_query/graph_builder.py b/backend/onyx/agent_search/answer_query/graph_builder.py index 53f647eac87..27d89af0845 100644 --- a/backend/onyx/agent_search/answer_query/graph_builder.py +++ b/backend/onyx/agent_search/answer_query/graph_builder.py @@ -77,7 +77,7 @@ def answer_query_graph_builder() -> StateGraph: compiled_graph = graph.compile() primary_llm, fast_llm = get_default_llms() search_request = SearchRequest( - query="Who made Excel and what other products did they make?", + query="what can you do with onyx or danswer?", ) with get_session_context_manager() as db_session: inputs = AnswerQueryInput( @@ -85,7 +85,7 @@ def answer_query_graph_builder() -> StateGraph: primary_llm=primary_llm, fast_llm=fast_llm, db_session=db_session, - question_to_answer="Who made Excel?", + question="what can you do with onyx?", ) output = compiled_graph.invoke( input=inputs, diff --git a/backend/onyx/agent_search/answer_query/nodes/answer_check.py b/backend/onyx/agent_search/answer_query/nodes/answer_check.py index f06b2071f96..c035f309feb 100644 --- a/backend/onyx/agent_search/answer_query/nodes/answer_check.py +++ b/backend/onyx/agent_search/answer_query/nodes/answer_check.py @@ -10,7 +10,7 @@ def answer_check(state: AnswerQueryState) -> QACheckOutput: msg = [ HumanMessage( content=SUB_CHECK_PROMPT.format( - question=state["question_to_answer"], + question=state["question"], base_answer=state["answer"], ) ) diff --git a/backend/onyx/agent_search/answer_query/nodes/answer_generation.py b/backend/onyx/agent_search/answer_query/nodes/answer_generation.py index 3de9c403d2b..d35d55673ac 100644 --- a/backend/onyx/agent_search/answer_query/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_query/nodes/answer_generation.py @@ -8,8 +8,8 @@ def answer_generation(state: AnswerQueryState) -> QAGenerationOutput: - query = state["question_to_answer"] - docs = state["reranked_documents"] + query = state["question"] + docs = state["documents"] print(f"Number of verified retrieval docs: {len(docs)}") diff --git a/backend/onyx/agent_search/answer_query/nodes/format_answer.py b/backend/onyx/agent_search/answer_query/nodes/format_answer.py index 061000701b9..5a7fffddaf1 100644 --- a/backend/onyx/agent_search/answer_query/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_query/nodes/format_answer.py @@ -7,10 +7,11 @@ def format_answer(state: AnswerQueryState) -> AnswerQueryOutput: return AnswerQueryOutput( decomp_answer_results=[ SearchAnswerResults( - question=state["question_to_answer"], + question=state["question"], quality=state["answer_quality"], answer=state["answer"], - documents=state["reranked_documents"], + expanded_retrieval_results=state["expanded_retrieval_results"], + documents=state["documents"], ) ], ) diff --git a/backend/onyx/agent_search/answer_query/states.py b/backend/onyx/agent_search/answer_query/states.py index 8c24623ee05..f0249b4fe72 100644 --- a/backend/onyx/agent_search/answer_query/states.py +++ b/backend/onyx/agent_search/answer_query/states.py @@ -4,7 +4,8 @@ from pydantic import BaseModel from onyx.agent_search.core_state import PrimaryState -from onyx.agent_search.expanded_retrieval.states import RetrievalResult +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalResult from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection @@ -13,7 +14,7 @@ class SearchAnswerResults(BaseModel): question: str answer: str quality: str - retrieval_results: list[RetrievalResult] + expanded_retrieval_results: list[ExpandedRetrievalResult] documents: Annotated[list[InferenceSection], dedup_inference_sections] @@ -29,6 +30,7 @@ class AnswerQueryState( PrimaryState, QACheckOutput, QAGenerationOutput, + ExpandedRetrievalOutput, total=True, ): question: str diff --git a/backend/onyx/agent_search/expanded_retrieval/edges.py b/backend/onyx/agent_search/expanded_retrieval/edges.py index 085479a4e43..19a321bd727 100644 --- a/backend/onyx/agent_search/expanded_retrieval/edges.py +++ b/backend/onyx/agent_search/expanded_retrieval/edges.py @@ -14,7 +14,7 @@ def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashab print(f"parallel_retrieval_edge state: {state.keys()}") # This should be better... - question = state.get("question_to_answer") or state["search_request"].query + question = state.get("question") or state["search_request"].query llm: LLM = state["fast_llm"] msg = [ diff --git a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py index 7f94c1ef78a..c2bfd1e346c 100644 --- a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py +++ b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py @@ -83,7 +83,7 @@ def expanded_retrieval_graph_builder() -> StateGraph: compiled_graph = graph.compile() primary_llm, fast_llm = get_default_llms() search_request = SearchRequest( - query="Who made Excel and what other products did they make?", + query="what can you do with onyx or danswer?", ) with get_session_context_manager() as db_session: inputs = ExpandedRetrievalInput( @@ -91,7 +91,11 @@ def expanded_retrieval_graph_builder() -> StateGraph: primary_llm=primary_llm, fast_llm=fast_llm, db_session=db_session, - question_to_answer="Who made Excel?", + question="what can you do with onyx?", ) - for thing in compiled_graph.stream(inputs, debug=True): + for thing in compiled_graph.stream( + input=inputs, + # debug=True, + subgraphs=True, + ): print(thing) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py index af38e5f4909..118aaa776c6 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py @@ -1,4 +1,5 @@ from onyx.agent_search.expanded_retrieval.states import DocRetrievalOutput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalResult from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState from onyx.context.search.models import InferenceSection from onyx.context.search.models import SearchRequest @@ -38,6 +39,11 @@ def doc_retrieval(state: RetrieveInput) -> DocRetrievalOutput: ).reranked_sections print(f"retrieved documents: {len(documents)}") + expanded_retrieval_result = ExpandedRetrievalResult( + expanded_query=query_to_retrieve, + expanded_retrieval_documents=documents[:4], + ) return DocRetrievalOutput( + expanded_retrieval_results=[expanded_retrieval_result], retrieved_documents=documents[:4], ) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py index 36883eb6bd7..2a9620a0a9a 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py @@ -1,15 +1,9 @@ from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState -from onyx.agent_search.expanded_retrieval.states import RetrievalResult def format_results(state: ExpandedRetrievalState) -> ExpandedRetrievalOutput: return ExpandedRetrievalOutput( - retrieval_results=[ - RetrievalResult( - starting_query=state["starting_query"], - expanded_retrieval_results=state["expanded_retrieval_results"], - documents=state["reranked_documents"], - ) - ], + expanded_retrieval_results=state["expanded_retrieval_results"], + documents=state["reranked_documents"], ) diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index 697639cc4ae..e6c3e8945e0 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -16,10 +16,10 @@ class ExpandedRetrievalResult(BaseModel): ] -class RetrievalResult(BaseModel): - starting_query: str - expanded_retrieval_results: list[ExpandedRetrievalResult] - documents: Annotated[list[InferenceSection], dedup_inference_sections] +# class RetrievalResult(BaseModel): +# starting_query: str +# expanded_retrieval_results: list[ExpandedRetrievalResult] +# documents: Annotated[list[InferenceSection], dedup_inference_sections] class DocRetrievalOutput(TypedDict, total=False): @@ -36,7 +36,8 @@ class DocRerankingOutput(TypedDict, total=False): class ExpandedRetrievalOutput(TypedDict): - retrieval_results: Annotated[list[RetrievalResult], add] + expanded_retrieval_results: list[ExpandedRetrievalResult] + documents: Annotated[list[InferenceSection], dedup_inference_sections] class ExpandedRetrievalState( diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index 0ec4c0f4a67..caaf9fe412e 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -12,7 +12,7 @@ def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hasha "answer_query", AnswerQueryInput( **state, - question_to_answer=query, + question=query, ), ) for query in state["initial_decomp_queries"] From 442c94727e92dfd39d6d0e1d90bc2d624c671f83 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Tue, 17 Dec 2024 15:16:36 -0800 Subject: [PATCH 07/19] got answer subgraph working --- .../agent_search/answer_query/graph_builder.py | 15 +++++---------- .../answer_query/nodes/format_answer.py | 2 +- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/backend/onyx/agent_search/answer_query/graph_builder.py b/backend/onyx/agent_search/answer_query/graph_builder.py index 27d89af0845..1036f425b76 100644 --- a/backend/onyx/agent_search/answer_query/graph_builder.py +++ b/backend/onyx/agent_search/answer_query/graph_builder.py @@ -87,16 +87,11 @@ def answer_query_graph_builder() -> StateGraph: db_session=db_session, question="what can you do with onyx?", ) - output = compiled_graph.invoke( + for thing in compiled_graph.stream( input=inputs, # debug=True, # subgraphs=True, - ) - print(output) - # for namespace, chunk in compiled_graph.stream( - # input=inputs, - # # debug=True, - # subgraphs=True, - # ): - # print(namespace) - # print(chunk) + ): + print(thing) + # output = compiled_graph.invoke(inputs) + # print(output) diff --git a/backend/onyx/agent_search/answer_query/nodes/format_answer.py b/backend/onyx/agent_search/answer_query/nodes/format_answer.py index 5a7fffddaf1..4220c2cc1ee 100644 --- a/backend/onyx/agent_search/answer_query/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_query/nodes/format_answer.py @@ -5,7 +5,7 @@ def format_answer(state: AnswerQueryState) -> AnswerQueryOutput: return AnswerQueryOutput( - decomp_answer_results=[ + answer_results=[ SearchAnswerResults( question=state["question"], quality=state["answer_quality"], From d66180fe13ad48995cb32ec472b0da69353743ba Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 18 Dec 2024 07:33:40 -0800 Subject: [PATCH 08/19] Cleanup --- .../answer_query/nodes/answer_generation.py | 4 ++-- .../answer_query/nodes/format_answer.py | 16 +++++++--------- backend/onyx/agent_search/answer_query/states.py | 10 ++++------ .../agent_search/expanded_retrieval/states.py | 6 ------ backend/onyx/agent_search/main/edges.py | 4 ++-- .../onyx/agent_search/main/nodes/base_decomp.py | 2 +- backend/onyx/agent_search/main/states.py | 2 +- 7 files changed, 17 insertions(+), 27 deletions(-) diff --git a/backend/onyx/agent_search/answer_query/nodes/answer_generation.py b/backend/onyx/agent_search/answer_query/nodes/answer_generation.py index d35d55673ac..f036267c3b7 100644 --- a/backend/onyx/agent_search/answer_query/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_query/nodes/answer_generation.py @@ -8,14 +8,14 @@ def answer_generation(state: AnswerQueryState) -> QAGenerationOutput: - query = state["question"] + question = state["question"] docs = state["documents"] print(f"Number of verified retrieval docs: {len(docs)}") msg = [ HumanMessage( - content=BASE_RAG_PROMPT.format(question=query, context=format_docs(docs)) + content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs)) ) ] diff --git a/backend/onyx/agent_search/answer_query/nodes/format_answer.py b/backend/onyx/agent_search/answer_query/nodes/format_answer.py index 4220c2cc1ee..2bf618c571f 100644 --- a/backend/onyx/agent_search/answer_query/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_query/nodes/format_answer.py @@ -5,13 +5,11 @@ def format_answer(state: AnswerQueryState) -> AnswerQueryOutput: return AnswerQueryOutput( - answer_results=[ - SearchAnswerResults( - question=state["question"], - quality=state["answer_quality"], - answer=state["answer"], - expanded_retrieval_results=state["expanded_retrieval_results"], - documents=state["documents"], - ) - ], + answer_result=SearchAnswerResults( + question=state["question"], + quality=state["answer_quality"], + answer=state["answer"], + expanded_retrieval_results=state["expanded_retrieval_results"], + documents=state["documents"], + ), ) diff --git a/backend/onyx/agent_search/answer_query/states.py b/backend/onyx/agent_search/answer_query/states.py index f0249b4fe72..7e6cf160a99 100644 --- a/backend/onyx/agent_search/answer_query/states.py +++ b/backend/onyx/agent_search/answer_query/states.py @@ -1,4 +1,3 @@ -from typing import Annotated from typing import TypedDict from pydantic import BaseModel @@ -6,7 +5,6 @@ from onyx.agent_search.core_state import PrimaryState from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalResult -from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection @@ -15,7 +13,7 @@ class SearchAnswerResults(BaseModel): answer: str quality: str expanded_retrieval_results: list[ExpandedRetrievalResult] - documents: Annotated[list[InferenceSection], dedup_inference_sections] + documents: list[InferenceSection] class QACheckOutput(TypedDict, total=False): @@ -28,9 +26,9 @@ class QAGenerationOutput(TypedDict, total=False): class AnswerQueryState( PrimaryState, - QACheckOutput, - QAGenerationOutput, ExpandedRetrievalOutput, + QAGenerationOutput, + QACheckOutput, total=True, ): question: str @@ -41,4 +39,4 @@ class AnswerQueryInput(PrimaryState, total=True): class AnswerQueryOutput(TypedDict): - answer_results: list[SearchAnswerResults] + answer_result: SearchAnswerResults diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index e6c3e8945e0..238b294a79f 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -16,12 +16,6 @@ class ExpandedRetrievalResult(BaseModel): ] -# class RetrievalResult(BaseModel): -# starting_query: str -# expanded_retrieval_results: list[ExpandedRetrievalResult] -# documents: Annotated[list[InferenceSection], dedup_inference_sections] - - class DocRetrievalOutput(TypedDict, total=False): expanded_retrieval_results: Annotated[list[ExpandedRetrievalResult], add] retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections] diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index caaf9fe412e..7d58dbdacdc 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -12,10 +12,10 @@ def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hasha "answer_query", AnswerQueryInput( **state, - question=query, + question=question, ), ) - for query in state["initial_decomp_queries"] + for question in state["initial_decomp_questions"] ] diff --git a/backend/onyx/agent_search/main/nodes/base_decomp.py b/backend/onyx/agent_search/main/nodes/base_decomp.py index 28e93c6cbcc..e8af64a9fba 100644 --- a/backend/onyx/agent_search/main/nodes/base_decomp.py +++ b/backend/onyx/agent_search/main/nodes/base_decomp.py @@ -27,5 +27,5 @@ def main_decomp_base(state: MainState) -> BaseDecompOutput: ] return BaseDecompOutput( - initial_decomp_queries=decomp_list, + initial_decomp_questions=decomp_list, ) diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index 3b753ff8476..1c3fcc41d17 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -9,7 +9,7 @@ class BaseDecompOutput(TypedDict, total=False): - initial_decomp_queries: list[str] + initial_decomp_questions: list[str] class InitialAnswerOutput(TypedDict, total=False): From e76cbec53cd78fac6c5719e1a82690de5516eaed Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 18 Dec 2024 08:43:54 -0800 Subject: [PATCH 09/19] main graph works --- .../answer_query/nodes/format_answer.py | 16 +++++++++------- backend/onyx/agent_search/answer_query/states.py | 10 +++++++++- backend/onyx/agent_search/main/graph_builder.py | 11 ++++++++++- .../agent_search/main/nodes/ingest_answers.py | 15 +++++++++++++++ backend/onyx/agent_search/main/states.py | 16 ++++++++++------ 5 files changed, 53 insertions(+), 15 deletions(-) create mode 100644 backend/onyx/agent_search/main/nodes/ingest_answers.py diff --git a/backend/onyx/agent_search/answer_query/nodes/format_answer.py b/backend/onyx/agent_search/answer_query/nodes/format_answer.py index 2bf618c571f..4220c2cc1ee 100644 --- a/backend/onyx/agent_search/answer_query/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_query/nodes/format_answer.py @@ -5,11 +5,13 @@ def format_answer(state: AnswerQueryState) -> AnswerQueryOutput: return AnswerQueryOutput( - answer_result=SearchAnswerResults( - question=state["question"], - quality=state["answer_quality"], - answer=state["answer"], - expanded_retrieval_results=state["expanded_retrieval_results"], - documents=state["documents"], - ), + answer_results=[ + SearchAnswerResults( + question=state["question"], + quality=state["answer_quality"], + answer=state["answer"], + expanded_retrieval_results=state["expanded_retrieval_results"], + documents=state["documents"], + ) + ], ) diff --git a/backend/onyx/agent_search/answer_query/states.py b/backend/onyx/agent_search/answer_query/states.py index 7e6cf160a99..f622db82150 100644 --- a/backend/onyx/agent_search/answer_query/states.py +++ b/backend/onyx/agent_search/answer_query/states.py @@ -1,3 +1,5 @@ +from operator import add +from typing import Annotated from typing import TypedDict from pydantic import BaseModel @@ -39,4 +41,10 @@ class AnswerQueryInput(PrimaryState, total=True): class AnswerQueryOutput(TypedDict): - answer_result: SearchAnswerResults + """ + This is a list of results even though each call of this subgraph only returns one result. + This is because if we parallelize the answer query subgraph, there will be multiple + results in a list so the add operator is used to add them together. + """ + + answer_results: Annotated[list[SearchAnswerResults], add] diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index 930d7a745f7..dbe02194a12 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -11,6 +11,7 @@ from onyx.agent_search.main.nodes.generate_initial_answer import ( generate_initial_answer, ) +from onyx.agent_search.main.nodes.ingest_answers import ingest_answers from onyx.agent_search.main.states import MainInput from onyx.agent_search.main.states import MainState @@ -41,6 +42,10 @@ def main_graph_builder() -> StateGraph: node="generate_initial_answer", action=generate_initial_answer, ) + graph.add_node( + node="ingest_answers", + action=ingest_answers, + ) ### Add edges ### graph.add_edge( @@ -59,6 +64,10 @@ def main_graph_builder() -> StateGraph: ) graph.add_edge( start_key=["answer_query", "expanded_retrieval"], + end_key="ingest_answers", + ) + graph.add_edge( + start_key="ingest_answers", end_key="generate_initial_answer", ) graph.add_edge( @@ -78,7 +87,7 @@ def main_graph_builder() -> StateGraph: compiled_graph = graph.compile() primary_llm, fast_llm = get_default_llms() search_request = SearchRequest( - query="If i am familiar with the function that I need, how can I type it into a cell?", + query="what can you do with onyx or danswer?", ) with get_session_context_manager() as db_session: inputs = MainInput( diff --git a/backend/onyx/agent_search/main/nodes/ingest_answers.py b/backend/onyx/agent_search/main/nodes/ingest_answers.py new file mode 100644 index 00000000000..f761a85b1ac --- /dev/null +++ b/backend/onyx/agent_search/main/nodes/ingest_answers.py @@ -0,0 +1,15 @@ +from onyx.agent_search.answer_query.states import AnswerQueryOutput +from onyx.agent_search.main.states import DecompAnswersOutput +from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections + + +def ingest_answers(state: AnswerQueryOutput) -> DecompAnswersOutput: + documents = [] + for answer_result in state["answer_results"]: + documents.extend(answer_result.documents) + return DecompAnswersOutput( + # Deduping is done by the documents operator for the main graph + # so we might not need to dedup here + documents=dedup_inference_sections(documents, []), + decomp_answer_results=state["answer_results"].answer_results, + ) diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index 1c3fcc41d17..c28220a9677 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -8,25 +8,29 @@ from onyx.context.search.models import InferenceSection -class BaseDecompOutput(TypedDict, total=False): +class BaseDecompOutput(TypedDict): initial_decomp_questions: list[str] -class InitialAnswerOutput(TypedDict, total=False): +class InitialAnswerOutput(TypedDict): initial_answer: str +class DecompAnswersOutput(TypedDict): + documents: Annotated[list[InferenceSection], dedup_inference_sections] + decomp_answer_results: Annotated[list[SearchAnswerResults], add] + + class MainState( PrimaryState, BaseDecompOutput, InitialAnswerOutput, - total=True, + DecompAnswersOutput, ): - documents: Annotated[list[InferenceSection], dedup_inference_sections] - decomp_answer_results: Annotated[list[SearchAnswerResults], add] + pass -class MainInput(PrimaryState, total=True): +class MainInput(PrimaryState): pass From fd694bea8faab8b2693908f483cbdec2a8eb546f Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 18 Dec 2024 08:47:43 -0800 Subject: [PATCH 10/19] query->question --- .../{answer_query => answer_question}/edges.py | 2 +- .../graph_builder.py | 14 +++++++------- .../nodes/answer_check.py | 4 ++-- .../nodes/answer_generation.py | 4 ++-- .../nodes/format_answer.py | 6 +++--- .../{answer_query => answer_question}/states.py | 0 backend/onyx/agent_search/main/edges.py | 2 +- backend/onyx/agent_search/main/graph_builder.py | 2 +- .../onyx/agent_search/main/nodes/ingest_answers.py | 2 +- backend/onyx/agent_search/main/states.py | 2 +- 10 files changed, 19 insertions(+), 19 deletions(-) rename backend/onyx/agent_search/{answer_query => answer_question}/edges.py (85%) rename backend/onyx/agent_search/{answer_query => answer_question}/graph_builder.py (81%) rename backend/onyx/agent_search/{answer_query => answer_question}/nodes/answer_check.py (83%) rename backend/onyx/agent_search/{answer_query => answer_question}/nodes/answer_generation.py (85%) rename backend/onyx/agent_search/{answer_query => answer_question}/nodes/format_answer.py (67%) rename backend/onyx/agent_search/{answer_query => answer_question}/states.py (100%) diff --git a/backend/onyx/agent_search/answer_query/edges.py b/backend/onyx/agent_search/answer_question/edges.py similarity index 85% rename from backend/onyx/agent_search/answer_query/edges.py rename to backend/onyx/agent_search/answer_question/edges.py index c538ef8958b..45c24137d3d 100644 --- a/backend/onyx/agent_search/answer_query/edges.py +++ b/backend/onyx/agent_search/answer_question/edges.py @@ -2,7 +2,7 @@ from langgraph.types import Send -from onyx.agent_search.answer_query.states import AnswerQueryInput +from onyx.agent_search.answer_question.states import AnswerQueryInput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput diff --git a/backend/onyx/agent_search/answer_query/graph_builder.py b/backend/onyx/agent_search/answer_question/graph_builder.py similarity index 81% rename from backend/onyx/agent_search/answer_query/graph_builder.py rename to backend/onyx/agent_search/answer_question/graph_builder.py index 1036f425b76..75563e6abf7 100644 --- a/backend/onyx/agent_search/answer_query/graph_builder.py +++ b/backend/onyx/agent_search/answer_question/graph_builder.py @@ -2,13 +2,13 @@ from langgraph.graph import START from langgraph.graph import StateGraph -from onyx.agent_search.answer_query.edges import send_to_expanded_retrieval -from onyx.agent_search.answer_query.nodes.answer_check import answer_check -from onyx.agent_search.answer_query.nodes.answer_generation import answer_generation -from onyx.agent_search.answer_query.nodes.format_answer import format_answer -from onyx.agent_search.answer_query.states import AnswerQueryInput -from onyx.agent_search.answer_query.states import AnswerQueryOutput -from onyx.agent_search.answer_query.states import AnswerQueryState +from onyx.agent_search.answer_question.edges import send_to_expanded_retrieval +from onyx.agent_search.answer_question.nodes.answer_check import answer_check +from onyx.agent_search.answer_question.nodes.answer_generation import answer_generation +from onyx.agent_search.answer_question.nodes.format_answer import format_answer +from onyx.agent_search.answer_question.states import AnswerQueryInput +from onyx.agent_search.answer_question.states import AnswerQueryOutput +from onyx.agent_search.answer_question.states import AnswerQueryState from onyx.agent_search.expanded_retrieval.graph_builder import ( expanded_retrieval_graph_builder, ) diff --git a/backend/onyx/agent_search/answer_query/nodes/answer_check.py b/backend/onyx/agent_search/answer_question/nodes/answer_check.py similarity index 83% rename from backend/onyx/agent_search/answer_query/nodes/answer_check.py rename to backend/onyx/agent_search/answer_question/nodes/answer_check.py index c035f309feb..008001b6201 100644 --- a/backend/onyx/agent_search/answer_query/nodes/answer_check.py +++ b/backend/onyx/agent_search/answer_question/nodes/answer_check.py @@ -1,8 +1,8 @@ from langchain_core.messages import HumanMessage from langchain_core.messages import merge_message_runs -from onyx.agent_search.answer_query.states import AnswerQueryState -from onyx.agent_search.answer_query.states import QACheckOutput +from onyx.agent_search.answer_question.states import AnswerQueryState +from onyx.agent_search.answer_question.states import QACheckOutput from onyx.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT diff --git a/backend/onyx/agent_search/answer_query/nodes/answer_generation.py b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py similarity index 85% rename from backend/onyx/agent_search/answer_query/nodes/answer_generation.py rename to backend/onyx/agent_search/answer_question/nodes/answer_generation.py index f036267c3b7..f01f5baeac7 100644 --- a/backend/onyx/agent_search/answer_query/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py @@ -1,8 +1,8 @@ from langchain_core.messages import HumanMessage from langchain_core.messages import merge_message_runs -from onyx.agent_search.answer_query.states import AnswerQueryState -from onyx.agent_search.answer_query.states import QAGenerationOutput +from onyx.agent_search.answer_question.states import AnswerQueryState +from onyx.agent_search.answer_question.states import QAGenerationOutput from onyx.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT from onyx.agent_search.shared_graph_utils.utils import format_docs diff --git a/backend/onyx/agent_search/answer_query/nodes/format_answer.py b/backend/onyx/agent_search/answer_question/nodes/format_answer.py similarity index 67% rename from backend/onyx/agent_search/answer_query/nodes/format_answer.py rename to backend/onyx/agent_search/answer_question/nodes/format_answer.py index 4220c2cc1ee..fff4f940dba 100644 --- a/backend/onyx/agent_search/answer_query/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_question/nodes/format_answer.py @@ -1,6 +1,6 @@ -from onyx.agent_search.answer_query.states import AnswerQueryOutput -from onyx.agent_search.answer_query.states import AnswerQueryState -from onyx.agent_search.answer_query.states import SearchAnswerResults +from onyx.agent_search.answer_question.states import AnswerQueryOutput +from onyx.agent_search.answer_question.states import AnswerQueryState +from onyx.agent_search.answer_question.states import SearchAnswerResults def format_answer(state: AnswerQueryState) -> AnswerQueryOutput: diff --git a/backend/onyx/agent_search/answer_query/states.py b/backend/onyx/agent_search/answer_question/states.py similarity index 100% rename from backend/onyx/agent_search/answer_query/states.py rename to backend/onyx/agent_search/answer_question/states.py diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index 7d58dbdacdc..4f2468b8c4a 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -2,7 +2,7 @@ from langgraph.types import Send -from onyx.agent_search.answer_query.states import AnswerQueryInput +from onyx.agent_search.answer_question.states import AnswerQueryInput from onyx.agent_search.main.states import MainState diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index dbe02194a12..d91aecc5988 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -2,7 +2,7 @@ from langgraph.graph import START from langgraph.graph import StateGraph -from onyx.agent_search.answer_query.graph_builder import answer_query_graph_builder +from onyx.agent_search.answer_question.graph_builder import answer_query_graph_builder from onyx.agent_search.expanded_retrieval.graph_builder import ( expanded_retrieval_graph_builder, ) diff --git a/backend/onyx/agent_search/main/nodes/ingest_answers.py b/backend/onyx/agent_search/main/nodes/ingest_answers.py index f761a85b1ac..8a59afdbafd 100644 --- a/backend/onyx/agent_search/main/nodes/ingest_answers.py +++ b/backend/onyx/agent_search/main/nodes/ingest_answers.py @@ -1,4 +1,4 @@ -from onyx.agent_search.answer_query.states import AnswerQueryOutput +from onyx.agent_search.answer_question.states import AnswerQueryOutput from onyx.agent_search.main.states import DecompAnswersOutput from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index c28220a9677..679230cd326 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -2,7 +2,7 @@ from typing import Annotated from typing import TypedDict -from onyx.agent_search.answer_query.states import SearchAnswerResults +from onyx.agent_search.answer_question.states import SearchAnswerResults from onyx.agent_search.core_state import PrimaryState from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection From 8399d2ee0aab2bd65a7fd40ab47cda3ea18717bf Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 18 Dec 2024 09:27:47 -0800 Subject: [PATCH 11/19] mypy fixed --- .../agent_search/answer_question/edges.py | 3 +- backend/onyx/agent_search/core_state.py | 12 ++++++++ .../expanded_retrieval/graph_builder.py | 2 +- backend/onyx/agent_search/main/edges.py | 3 +- .../onyx/agent_search/main/graph_builder.py | 28 +++++++++++++------ .../agent_search/main/nodes/base_decomp.py | 6 ++-- .../main/nodes/generate_initial_answer.py | 10 ++++--- .../agent_search/main/nodes/ingest_answers.py | 8 +++--- .../main/nodes/ingest_initial_retrieval.py | 9 ++++++ backend/onyx/agent_search/main/states.py | 21 ++++++++++---- 10 files changed, 74 insertions(+), 28 deletions(-) create mode 100644 backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py diff --git a/backend/onyx/agent_search/answer_question/edges.py b/backend/onyx/agent_search/answer_question/edges.py index 45c24137d3d..05de5899002 100644 --- a/backend/onyx/agent_search/answer_question/edges.py +++ b/backend/onyx/agent_search/answer_question/edges.py @@ -3,6 +3,7 @@ from langgraph.types import Send from onyx.agent_search.answer_question.states import AnswerQueryInput +from onyx.agent_search.core_state import extract_primary_fields from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput @@ -10,7 +11,7 @@ def send_to_expanded_retrieval(state: AnswerQueryInput) -> Send | Hashable: return Send( "decomped_expanded_retrieval", ExpandedRetrievalInput( - **state, + **extract_primary_fields(state), starting_query=state["question"], ), ) diff --git a/backend/onyx/agent_search/core_state.py b/backend/onyx/agent_search/core_state.py index fcd8bddf3ec..ee490e0a337 100644 --- a/backend/onyx/agent_search/core_state.py +++ b/backend/onyx/agent_search/core_state.py @@ -1,4 +1,5 @@ from typing import TypedDict +from typing import TypeVar from sqlalchemy.orm import Session @@ -13,3 +14,14 @@ class PrimaryState(TypedDict, total=False): # a single session for the entire agent search # is fine if we are only reading db_session: Session + + +# This ensures that the state passed in extends the PrimaryState +T = TypeVar("T", bound=PrimaryState) + + +def extract_primary_fields(state: T) -> PrimaryState: + filtered_dict = { + k: v for k, v in state.items() if k in PrimaryState.__annotations__ + } + return PrimaryState(**dict(filtered_dict)) # type: ignore diff --git a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py index c2bfd1e346c..a1606781968 100644 --- a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py +++ b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py @@ -91,7 +91,7 @@ def expanded_retrieval_graph_builder() -> StateGraph: primary_llm=primary_llm, fast_llm=fast_llm, db_session=db_session, - question="what can you do with onyx?", + starting_query="what can you do with onyx?", ) for thing in compiled_graph.stream( input=inputs, diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index 4f2468b8c4a..c0730c1537b 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -3,6 +3,7 @@ from langgraph.types import Send from onyx.agent_search.answer_question.states import AnswerQueryInput +from onyx.agent_search.core_state import extract_primary_fields from onyx.agent_search.main.states import MainState @@ -11,7 +12,7 @@ def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hasha Send( "answer_query", AnswerQueryInput( - **state, + **extract_primary_fields(state), question=question, ), ) diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index d91aecc5988..971398f9c90 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -12,6 +12,9 @@ generate_initial_answer, ) from onyx.agent_search.main.nodes.ingest_answers import ingest_answers +from onyx.agent_search.main.nodes.ingest_initial_retrieval import ( + ingest_initial_retrieval, +) from onyx.agent_search.main.states import MainInput from onyx.agent_search.main.states import MainState @@ -35,22 +38,30 @@ def main_graph_builder() -> StateGraph: ) expanded_retrieval_subgraph = expanded_retrieval_graph_builder().compile() graph.add_node( - node="expanded_retrieval", + node="initial_retrieval", action=expanded_retrieval_subgraph, ) - graph.add_node( - node="generate_initial_answer", - action=generate_initial_answer, - ) graph.add_node( node="ingest_answers", action=ingest_answers, ) + graph.add_node( + node="ingest_initial_retrieval", + action=ingest_initial_retrieval, + ) + graph.add_node( + node="generate_initial_answer", + action=generate_initial_answer, + ) ### Add edges ### graph.add_edge( start_key=START, - end_key="expanded_retrieval", + end_key="initial_retrieval", + ) + graph.add_edge( + start_key="initial_retrieval", + end_key="ingest_initial_retrieval", ) graph.add_edge( @@ -63,11 +74,12 @@ def main_graph_builder() -> StateGraph: path_map=["answer_query"], ) graph.add_edge( - start_key=["answer_query", "expanded_retrieval"], + start_key="answer_query", end_key="ingest_answers", ) + graph.add_edge( - start_key="ingest_answers", + start_key=["ingest_answers", "ingest_initial_retrieval"], end_key="generate_initial_answer", ) graph.add_edge( diff --git a/backend/onyx/agent_search/main/nodes/base_decomp.py b/backend/onyx/agent_search/main/nodes/base_decomp.py index e8af64a9fba..05b095794be 100644 --- a/backend/onyx/agent_search/main/nodes/base_decomp.py +++ b/backend/onyx/agent_search/main/nodes/base_decomp.py @@ -1,12 +1,12 @@ from langchain_core.messages import HumanMessage -from onyx.agent_search.main.states import BaseDecompOutput +from onyx.agent_search.main.states import BaseDecompUpdate from onyx.agent_search.main.states import MainState from onyx.agent_search.shared_graph_utils.prompts import INITIAL_DECOMPOSITION_PROMPT from onyx.agent_search.shared_graph_utils.utils import clean_and_parse_list_string -def main_decomp_base(state: MainState) -> BaseDecompOutput: +def main_decomp_base(state: MainState) -> BaseDecompUpdate: question = state["search_request"].query msg = [ @@ -26,6 +26,6 @@ def main_decomp_base(state: MainState) -> BaseDecompOutput: sub_question["sub_question"].strip() for sub_question in list_of_subquestions ] - return BaseDecompOutput( + return BaseDecompUpdate( initial_decomp_questions=decomp_list, ) diff --git a/backend/onyx/agent_search/main/nodes/generate_initial_answer.py b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py index a6476477ae4..828472d6eae 100644 --- a/backend/onyx/agent_search/main/nodes/generate_initial_answer.py +++ b/backend/onyx/agent_search/main/nodes/generate_initial_answer.py @@ -1,16 +1,18 @@ from langchain_core.messages import HumanMessage -from onyx.agent_search.main.states import InitialAnswerOutput +from onyx.agent_search.main.states import InitialAnswerUpdate from onyx.agent_search.main.states import MainState from onyx.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT from onyx.agent_search.shared_graph_utils.utils import format_docs -def generate_initial_answer(state: MainState) -> InitialAnswerOutput: +def generate_initial_answer(state: MainState) -> InitialAnswerUpdate: print("---GENERATE INITIAL---") question = state["search_request"].query docs = state["documents"] + all_original_question_documents = state["all_original_question_documents"] + combined_docs = docs + all_original_question_documents decomp_answer_results = state["decomp_answer_results"] @@ -38,7 +40,7 @@ def generate_initial_answer(state: MainState) -> InitialAnswerOutput: HumanMessage( content=INITIAL_RAG_PROMPT.format( question=question, - context=format_docs(docs), + context=format_docs(combined_docs), answered_sub_questions=sub_question_answer_str, ) ) @@ -50,4 +52,4 @@ def generate_initial_answer(state: MainState) -> InitialAnswerOutput: answer = response.pretty_repr() print(answer) - return InitialAnswerOutput(initial_answer=answer) + return InitialAnswerUpdate(initial_answer=answer) diff --git a/backend/onyx/agent_search/main/nodes/ingest_answers.py b/backend/onyx/agent_search/main/nodes/ingest_answers.py index 8a59afdbafd..2662951cce9 100644 --- a/backend/onyx/agent_search/main/nodes/ingest_answers.py +++ b/backend/onyx/agent_search/main/nodes/ingest_answers.py @@ -1,15 +1,15 @@ from onyx.agent_search.answer_question.states import AnswerQueryOutput -from onyx.agent_search.main.states import DecompAnswersOutput +from onyx.agent_search.main.states import DecompAnswersUpdate from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections -def ingest_answers(state: AnswerQueryOutput) -> DecompAnswersOutput: +def ingest_answers(state: AnswerQueryOutput) -> DecompAnswersUpdate: documents = [] for answer_result in state["answer_results"]: documents.extend(answer_result.documents) - return DecompAnswersOutput( + return DecompAnswersUpdate( # Deduping is done by the documents operator for the main graph # so we might not need to dedup here documents=dedup_inference_sections(documents, []), - decomp_answer_results=state["answer_results"].answer_results, + decomp_answer_results=state["answer_results"], ) diff --git a/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py b/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py new file mode 100644 index 00000000000..e3a96e0b8e4 --- /dev/null +++ b/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py @@ -0,0 +1,9 @@ +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput +from onyx.agent_search.main.states import ExpandedRetrievalUpdate + + +def ingest_initial_retrieval(state: ExpandedRetrievalOutput) -> ExpandedRetrievalUpdate: + return ExpandedRetrievalUpdate( + all_original_question_documents=state["documents"], + original_question_retrieval_results=state["expanded_retrieval_results"], + ) diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index 679230cd326..7e2c14d2bcd 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -4,28 +4,37 @@ from onyx.agent_search.answer_question.states import SearchAnswerResults from onyx.agent_search.core_state import PrimaryState +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalResult from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection -class BaseDecompOutput(TypedDict): +class BaseDecompUpdate(TypedDict): initial_decomp_questions: list[str] -class InitialAnswerOutput(TypedDict): +class InitialAnswerUpdate(TypedDict): initial_answer: str -class DecompAnswersOutput(TypedDict): +class DecompAnswersUpdate(TypedDict): documents: Annotated[list[InferenceSection], dedup_inference_sections] decomp_answer_results: Annotated[list[SearchAnswerResults], add] +class ExpandedRetrievalUpdate(TypedDict): + all_original_question_documents: Annotated[ + list[InferenceSection], dedup_inference_sections + ] + original_question_retrieval_results: list[ExpandedRetrievalResult] + + class MainState( PrimaryState, - BaseDecompOutput, - InitialAnswerOutput, - DecompAnswersOutput, + BaseDecompUpdate, + InitialAnswerUpdate, + DecompAnswersUpdate, + ExpandedRetrievalUpdate, ): pass From 50a216f554459d9c67c267b908d92018e1e66742 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 18 Dec 2024 09:56:34 -0800 Subject: [PATCH 12/19] naming and comments --- .../agent_search/answer_question/edges.py | 4 +- .../answer_question/graph_builder.py | 14 +++--- .../answer_question/nodes/answer_check.py | 8 +-- .../nodes/answer_generation.py | 8 +-- .../answer_question/nodes/format_answer.py | 8 +-- .../agent_search/answer_question/states.py | 31 +++++++++--- .../agent_search/expanded_retrieval/edges.py | 8 ++- .../expanded_retrieval/nodes/doc_reranking.py | 6 +-- .../expanded_retrieval/nodes/doc_retrieval.py | 23 ++++----- .../nodes/doc_verification.py | 6 +-- .../agent_search/expanded_retrieval/states.py | 49 +++++++++++++------ backend/onyx/agent_search/main/edges.py | 4 +- .../agent_search/main/nodes/ingest_answers.py | 4 +- backend/onyx/agent_search/main/states.py | 13 +++++ 14 files changed, 116 insertions(+), 70 deletions(-) diff --git a/backend/onyx/agent_search/answer_question/edges.py b/backend/onyx/agent_search/answer_question/edges.py index 05de5899002..ec32f1c8523 100644 --- a/backend/onyx/agent_search/answer_question/edges.py +++ b/backend/onyx/agent_search/answer_question/edges.py @@ -2,12 +2,12 @@ from langgraph.types import Send -from onyx.agent_search.answer_question.states import AnswerQueryInput +from onyx.agent_search.answer_question.states import AnswerQuestionInput from onyx.agent_search.core_state import extract_primary_fields from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput -def send_to_expanded_retrieval(state: AnswerQueryInput) -> Send | Hashable: +def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable: return Send( "decomped_expanded_retrieval", ExpandedRetrievalInput( diff --git a/backend/onyx/agent_search/answer_question/graph_builder.py b/backend/onyx/agent_search/answer_question/graph_builder.py index 75563e6abf7..291d9deb6aa 100644 --- a/backend/onyx/agent_search/answer_question/graph_builder.py +++ b/backend/onyx/agent_search/answer_question/graph_builder.py @@ -6,9 +6,9 @@ from onyx.agent_search.answer_question.nodes.answer_check import answer_check from onyx.agent_search.answer_question.nodes.answer_generation import answer_generation from onyx.agent_search.answer_question.nodes.format_answer import format_answer -from onyx.agent_search.answer_question.states import AnswerQueryInput -from onyx.agent_search.answer_question.states import AnswerQueryOutput -from onyx.agent_search.answer_question.states import AnswerQueryState +from onyx.agent_search.answer_question.states import AnswerQuestionInput +from onyx.agent_search.answer_question.states import AnswerQuestionOutput +from onyx.agent_search.answer_question.states import AnswerQuestionState from onyx.agent_search.expanded_retrieval.graph_builder import ( expanded_retrieval_graph_builder, ) @@ -16,9 +16,9 @@ def answer_query_graph_builder() -> StateGraph: graph = StateGraph( - state_schema=AnswerQueryState, - input=AnswerQueryInput, - output=AnswerQueryOutput, + state_schema=AnswerQuestionState, + input=AnswerQuestionInput, + output=AnswerQuestionOutput, ) ### Add nodes ### @@ -80,7 +80,7 @@ def answer_query_graph_builder() -> StateGraph: query="what can you do with onyx or danswer?", ) with get_session_context_manager() as db_session: - inputs = AnswerQueryInput( + inputs = AnswerQuestionInput( search_request=search_request, primary_llm=primary_llm, fast_llm=fast_llm, diff --git a/backend/onyx/agent_search/answer_question/nodes/answer_check.py b/backend/onyx/agent_search/answer_question/nodes/answer_check.py index 008001b6201..b04953dd0b9 100644 --- a/backend/onyx/agent_search/answer_question/nodes/answer_check.py +++ b/backend/onyx/agent_search/answer_question/nodes/answer_check.py @@ -1,12 +1,12 @@ from langchain_core.messages import HumanMessage from langchain_core.messages import merge_message_runs -from onyx.agent_search.answer_question.states import AnswerQueryState -from onyx.agent_search.answer_question.states import QACheckOutput +from onyx.agent_search.answer_question.states import AnswerQuestionState +from onyx.agent_search.answer_question.states import QACheckUpdate from onyx.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT -def answer_check(state: AnswerQueryState) -> QACheckOutput: +def answer_check(state: AnswerQuestionState) -> QACheckUpdate: msg = [ HumanMessage( content=SUB_CHECK_PROMPT.format( @@ -25,6 +25,6 @@ def answer_check(state: AnswerQueryState) -> QACheckOutput: response_str = merge_message_runs(response, chunk_separator="")[0].content - return QACheckOutput( + return QACheckUpdate( answer_quality=response_str, ) diff --git a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py index f01f5baeac7..d47d1aaf77b 100644 --- a/backend/onyx/agent_search/answer_question/nodes/answer_generation.py +++ b/backend/onyx/agent_search/answer_question/nodes/answer_generation.py @@ -1,13 +1,13 @@ from langchain_core.messages import HumanMessage from langchain_core.messages import merge_message_runs -from onyx.agent_search.answer_question.states import AnswerQueryState -from onyx.agent_search.answer_question.states import QAGenerationOutput +from onyx.agent_search.answer_question.states import AnswerQuestionState +from onyx.agent_search.answer_question.states import QAGenerationUpdate from onyx.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT from onyx.agent_search.shared_graph_utils.utils import format_docs -def answer_generation(state: AnswerQueryState) -> QAGenerationOutput: +def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate: question = state["question"] docs = state["documents"] @@ -27,6 +27,6 @@ def answer_generation(state: AnswerQueryState) -> QAGenerationOutput: ) answer_str = merge_message_runs(response, chunk_separator="")[0].content - return QAGenerationOutput( + return QAGenerationUpdate( answer=answer_str, ) diff --git a/backend/onyx/agent_search/answer_question/nodes/format_answer.py b/backend/onyx/agent_search/answer_question/nodes/format_answer.py index fff4f940dba..216100a94cc 100644 --- a/backend/onyx/agent_search/answer_question/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_question/nodes/format_answer.py @@ -1,10 +1,10 @@ -from onyx.agent_search.answer_question.states import AnswerQueryOutput -from onyx.agent_search.answer_question.states import AnswerQueryState +from onyx.agent_search.answer_question.states import AnswerQuestionOutput +from onyx.agent_search.answer_question.states import AnswerQuestionState from onyx.agent_search.answer_question.states import SearchAnswerResults -def format_answer(state: AnswerQueryState) -> AnswerQueryOutput: - return AnswerQueryOutput( +def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput: + return AnswerQuestionOutput( answer_results=[ SearchAnswerResults( question=state["question"], diff --git a/backend/onyx/agent_search/answer_question/states.py b/backend/onyx/agent_search/answer_question/states.py index f622db82150..06cbe3ba839 100644 --- a/backend/onyx/agent_search/answer_question/states.py +++ b/backend/onyx/agent_search/answer_question/states.py @@ -10,6 +10,9 @@ from onyx.context.search.models import InferenceSection +### Models ### + + class SearchAnswerResults(BaseModel): question: str answer: str @@ -18,29 +21,43 @@ class SearchAnswerResults(BaseModel): documents: list[InferenceSection] -class QACheckOutput(TypedDict, total=False): +### States ### + +## Update States + + +class QACheckUpdate(TypedDict): answer_quality: str -class QAGenerationOutput(TypedDict, total=False): +class QAGenerationUpdate(TypedDict): answer: str -class AnswerQueryState( +## Graph State + + +class AnswerQuestionState( PrimaryState, ExpandedRetrievalOutput, - QAGenerationOutput, - QACheckOutput, + QAGenerationUpdate, + QACheckUpdate, total=True, ): question: str -class AnswerQueryInput(PrimaryState, total=True): +## Input State + + +class AnswerQuestionInput(PrimaryState): question: str -class AnswerQueryOutput(TypedDict): +## Graph Output State + + +class AnswerQuestionOutput(TypedDict): """ This is a list of results even though each call of this subgraph only returns one result. This is because if we parallelize the answer query subgraph, there will be multiple diff --git a/backend/onyx/agent_search/expanded_retrieval/edges.py b/backend/onyx/agent_search/expanded_retrieval/edges.py index 19a321bd727..1c62ba7dd91 100644 --- a/backend/onyx/agent_search/expanded_retrieval/edges.py +++ b/backend/onyx/agent_search/expanded_retrieval/edges.py @@ -4,7 +4,8 @@ from langchain_core.messages import merge_message_runs from langgraph.types import Send -from onyx.agent_search.expanded_retrieval.nodes.doc_retrieval import RetrieveInput +from onyx.agent_search.core_state import extract_primary_fields +from onyx.agent_search.expanded_retrieval.nodes.doc_retrieval import RetrievalInput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput from onyx.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI_ORIGINAL from onyx.llm.interfaces import LLM @@ -38,7 +39,10 @@ def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashab return [ Send( "doc_retrieval", - RetrieveInput(query_to_retrieve=query, **state), + RetrievalInput( + query_to_retrieve=query, + **extract_primary_fields(state), + ), ) for query in rewritten_queries ] diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py index 1ac36203518..925b7c7f444 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py @@ -1,11 +1,11 @@ -from onyx.agent_search.expanded_retrieval.states import DocRerankingOutput +from onyx.agent_search.expanded_retrieval.states import DocRerankingUpdate from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState -def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingOutput: +def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingUpdate: print(f"doc_reranking state: {state.keys()}") verified_documents = state["verified_documents"] reranked_documents = verified_documents - return DocRerankingOutput(reranked_documents=reranked_documents) + return DocRerankingUpdate(reranked_documents=reranked_documents) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py index 118aaa776c6..a141bfcaac3 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py @@ -1,34 +1,29 @@ -from onyx.agent_search.expanded_retrieval.states import DocRetrievalOutput +from onyx.agent_search.expanded_retrieval.states import DocRetrievalUpdate from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalResult -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState +from onyx.agent_search.expanded_retrieval.states import RetrievalInput from onyx.context.search.models import InferenceSection from onyx.context.search.models import SearchRequest from onyx.context.search.pipeline import SearchPipeline -class RetrieveInput(ExpandedRetrievalState): - query_to_retrieve: str - - -def doc_retrieval(state: RetrieveInput) -> DocRetrievalOutput: +def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: # def doc_retrieval(state: RetrieveInput) -> Command[Literal["doc_verification"]]: """ Retrieve documents Args: - state (dict): The current graph state + state (RetrievalInput): Primary state + the query to retrieve - Returns: - state (dict): New key added to state, documents, that contains retrieved documents + Updates: + expanded_retrieval_results: list[ExpandedRetrievalResult] + retrieved_documents: list[InferenceSection] """ - print(f"doc_retrieval state: {state.keys()}") - documents: list[InferenceSection] = [] llm = state["primary_llm"] fast_llm = state["fast_llm"] query_to_retrieve = state["query_to_retrieve"] - documents = SearchPipeline( + documents: list[InferenceSection] = SearchPipeline( search_request=SearchRequest( query=query_to_retrieve, ), @@ -43,7 +38,7 @@ def doc_retrieval(state: RetrieveInput) -> DocRetrievalOutput: expanded_query=query_to_retrieve, expanded_retrieval_documents=documents[:4], ) - return DocRetrievalOutput( + return DocRetrievalUpdate( expanded_retrieval_results=[expanded_retrieval_result], retrieved_documents=documents[:4], ) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py index f3f993e87b7..741c445e2c8 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py @@ -1,7 +1,7 @@ from langchain_core.messages import HumanMessage from langchain_core.messages import merge_message_runs -from onyx.agent_search.expanded_retrieval.states import DocVerificationOutput +from onyx.agent_search.expanded_retrieval.states import DocVerificationUpdate from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState from onyx.agent_search.shared_graph_utils.models import BinaryDecision from onyx.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT @@ -12,7 +12,7 @@ class DocVerificationInput(ExpandedRetrievalState, total=True): doc_to_verify: InferenceSection -def doc_verification(state: DocVerificationInput) -> DocVerificationOutput: +def doc_verification(state: DocVerificationInput) -> DocVerificationUpdate: """ Check whether the document is relevant for the original user question @@ -55,6 +55,6 @@ def doc_verification(state: DocVerificationInput) -> DocVerificationOutput: if formatted_response.decision == "yes": verified_documents.append(doc_to_verify) - return DocVerificationOutput( + return DocVerificationUpdate( verified_documents=verified_documents, ) diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index 238b294a79f..68c0b5889ba 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -9,40 +9,57 @@ from onyx.context.search.models import InferenceSection +### Models ### + + class ExpandedRetrievalResult(BaseModel): expanded_query: str - expanded_retrieval_documents: Annotated[ - list[InferenceSection], dedup_inference_sections - ] + expanded_retrieval_documents: list[InferenceSection] -class DocRetrievalOutput(TypedDict, total=False): - expanded_retrieval_results: Annotated[list[ExpandedRetrievalResult], add] - retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections] +### States ### +## Update States -class DocVerificationOutput(TypedDict, total=False): +class DocVerificationUpdate(TypedDict): verified_documents: Annotated[list[InferenceSection], dedup_inference_sections] -class DocRerankingOutput(TypedDict, total=False): +class DocRerankingUpdate(TypedDict): reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] -class ExpandedRetrievalOutput(TypedDict): - expanded_retrieval_results: list[ExpandedRetrievalResult] - documents: Annotated[list[InferenceSection], dedup_inference_sections] +class DocRetrievalUpdate(TypedDict): + expanded_retrieval_results: Annotated[list[ExpandedRetrievalResult], add] + retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections] + + +## Graph State class ExpandedRetrievalState( PrimaryState, - DocRetrievalOutput, - DocVerificationOutput, - DocRerankingOutput, - total=True, + DocRetrievalUpdate, + DocVerificationUpdate, + DocRerankingUpdate, ): starting_query: str -class ExpandedRetrievalInput(PrimaryState, total=True): +## Graph Output State + + +class ExpandedRetrievalOutput(TypedDict): + expanded_retrieval_results: list[ExpandedRetrievalResult] + documents: Annotated[list[InferenceSection], dedup_inference_sections] + + +## Input States + + +class ExpandedRetrievalInput(PrimaryState): starting_query: str + + +class RetrievalInput(PrimaryState): + query_to_retrieve: str diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index c0730c1537b..484c0c354a7 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -2,7 +2,7 @@ from langgraph.types import Send -from onyx.agent_search.answer_question.states import AnswerQueryInput +from onyx.agent_search.answer_question.states import AnswerQuestionInput from onyx.agent_search.core_state import extract_primary_fields from onyx.agent_search.main.states import MainState @@ -11,7 +11,7 @@ def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hasha return [ Send( "answer_query", - AnswerQueryInput( + AnswerQuestionInput( **extract_primary_fields(state), question=question, ), diff --git a/backend/onyx/agent_search/main/nodes/ingest_answers.py b/backend/onyx/agent_search/main/nodes/ingest_answers.py index 2662951cce9..c86f3f3104e 100644 --- a/backend/onyx/agent_search/main/nodes/ingest_answers.py +++ b/backend/onyx/agent_search/main/nodes/ingest_answers.py @@ -1,9 +1,9 @@ -from onyx.agent_search.answer_question.states import AnswerQueryOutput +from onyx.agent_search.answer_question.states import AnswerQuestionOutput from onyx.agent_search.main.states import DecompAnswersUpdate from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections -def ingest_answers(state: AnswerQueryOutput) -> DecompAnswersUpdate: +def ingest_answers(state: AnswerQuestionOutput) -> DecompAnswersUpdate: documents = [] for answer_result in state["answer_results"]: documents.extend(answer_result.documents) diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index 7e2c14d2bcd..afed369e593 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -8,6 +8,10 @@ from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection +### States ### + +## Update States + class BaseDecompUpdate(TypedDict): initial_decomp_questions: list[str] @@ -29,6 +33,9 @@ class ExpandedRetrievalUpdate(TypedDict): original_question_retrieval_results: list[ExpandedRetrievalResult] +## Graph State + + class MainState( PrimaryState, BaseDecompUpdate, @@ -39,10 +46,16 @@ class MainState( pass +## Input States + + class MainInput(PrimaryState): pass +## Graph Output State + + class MainOutput(TypedDict): """ This is not used because defining the output only matters for filtering the output of From 9d3220fcfc3355531c51a7b92477357995c951f3 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 18 Dec 2024 10:17:07 -0800 Subject: [PATCH 13/19] explicitly ingest state from retrieval --- .../onyx/agent_search/answer_question/graph_builder.py | 9 +++++++++ .../answer_question/nodes/ingest_retrieval.py | 9 +++++++++ backend/onyx/agent_search/answer_question/states.py | 9 +++++++-- 3 files changed, 25 insertions(+), 2 deletions(-) create mode 100644 backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py diff --git a/backend/onyx/agent_search/answer_question/graph_builder.py b/backend/onyx/agent_search/answer_question/graph_builder.py index 291d9deb6aa..0aebb045de0 100644 --- a/backend/onyx/agent_search/answer_question/graph_builder.py +++ b/backend/onyx/agent_search/answer_question/graph_builder.py @@ -6,6 +6,7 @@ from onyx.agent_search.answer_question.nodes.answer_check import answer_check from onyx.agent_search.answer_question.nodes.answer_generation import answer_generation from onyx.agent_search.answer_question.nodes.format_answer import format_answer +from onyx.agent_search.answer_question.nodes.ingest_retrieval import ingest_retrieval from onyx.agent_search.answer_question.states import AnswerQuestionInput from onyx.agent_search.answer_question.states import AnswerQuestionOutput from onyx.agent_search.answer_question.states import AnswerQuestionState @@ -40,6 +41,10 @@ def answer_query_graph_builder() -> StateGraph: node="format_answer", action=format_answer, ) + graph.add_node( + node="ingest_retrieval", + action=ingest_retrieval, + ) ### Add edges ### @@ -50,6 +55,10 @@ def answer_query_graph_builder() -> StateGraph: ) graph.add_edge( start_key="decomped_expanded_retrieval", + end_key="ingest_retrieval", + ) + graph.add_edge( + start_key="ingest_retrieval", end_key="answer_generation", ) graph.add_edge( diff --git a/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py b/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py new file mode 100644 index 00000000000..7ee1ae75efb --- /dev/null +++ b/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py @@ -0,0 +1,9 @@ +from onyx.agent_search.answer_question.states import RetrievalIngestionUpdate +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput + + +def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate: + return RetrievalIngestionUpdate( + documents=state["documents"], + expanded_retrieval_results=state["expanded_retrieval_results"], + ) diff --git a/backend/onyx/agent_search/answer_question/states.py b/backend/onyx/agent_search/answer_question/states.py index 06cbe3ba839..a0a4295da00 100644 --- a/backend/onyx/agent_search/answer_question/states.py +++ b/backend/onyx/agent_search/answer_question/states.py @@ -5,8 +5,8 @@ from pydantic import BaseModel from onyx.agent_search.core_state import PrimaryState -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalResult +from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection @@ -34,14 +34,19 @@ class QAGenerationUpdate(TypedDict): answer: str +class RetrievalIngestionUpdate(TypedDict): + documents: Annotated[list[InferenceSection], dedup_inference_sections] + expanded_retrieval_results: list[ExpandedRetrievalResult] + + ## Graph State class AnswerQuestionState( PrimaryState, - ExpandedRetrievalOutput, QAGenerationUpdate, QACheckUpdate, + RetrievalIngestionUpdate, total=True, ): question: str From 0c75ca05799a77c179cf7398744c5610c97dcb7e Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 18 Dec 2024 11:08:43 -0800 Subject: [PATCH 14/19] renames --- .../onyx/agent_search/answer_question/nodes/answer_check.py | 4 ++-- .../onyx/agent_search/answer_question/nodes/format_answer.py | 4 ++-- backend/onyx/agent_search/answer_question/states.py | 4 ++-- backend/onyx/agent_search/expanded_retrieval/states.py | 4 ++-- backend/onyx/agent_search/main/states.py | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/backend/onyx/agent_search/answer_question/nodes/answer_check.py b/backend/onyx/agent_search/answer_question/nodes/answer_check.py index b04953dd0b9..83cc46280f7 100644 --- a/backend/onyx/agent_search/answer_question/nodes/answer_check.py +++ b/backend/onyx/agent_search/answer_question/nodes/answer_check.py @@ -23,8 +23,8 @@ def answer_check(state: AnswerQuestionState) -> QACheckUpdate: ) ) - response_str = merge_message_runs(response, chunk_separator="")[0].content + quality_str = merge_message_runs(response, chunk_separator="")[0].content return QACheckUpdate( - answer_quality=response_str, + answer_quality=quality_str, ) diff --git a/backend/onyx/agent_search/answer_question/nodes/format_answer.py b/backend/onyx/agent_search/answer_question/nodes/format_answer.py index 216100a94cc..c7897294726 100644 --- a/backend/onyx/agent_search/answer_question/nodes/format_answer.py +++ b/backend/onyx/agent_search/answer_question/nodes/format_answer.py @@ -1,12 +1,12 @@ from onyx.agent_search.answer_question.states import AnswerQuestionOutput from onyx.agent_search.answer_question.states import AnswerQuestionState -from onyx.agent_search.answer_question.states import SearchAnswerResults +from onyx.agent_search.answer_question.states import QuestionAnswerResults def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput: return AnswerQuestionOutput( answer_results=[ - SearchAnswerResults( + QuestionAnswerResults( question=state["question"], quality=state["answer_quality"], answer=state["answer"], diff --git a/backend/onyx/agent_search/answer_question/states.py b/backend/onyx/agent_search/answer_question/states.py index a0a4295da00..e216ac20499 100644 --- a/backend/onyx/agent_search/answer_question/states.py +++ b/backend/onyx/agent_search/answer_question/states.py @@ -13,7 +13,7 @@ ### Models ### -class SearchAnswerResults(BaseModel): +class QuestionAnswerResults(BaseModel): question: str answer: str quality: str @@ -69,4 +69,4 @@ class AnswerQuestionOutput(TypedDict): results in a list so the add operator is used to add them together. """ - answer_results: Annotated[list[SearchAnswerResults], add] + answer_results: Annotated[list[QuestionAnswerResults], add] diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index 68c0b5889ba..81b96d95f10 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -43,7 +43,7 @@ class ExpandedRetrievalState( DocVerificationUpdate, DocRerankingUpdate, ): - starting_query: str + question: str ## Graph Output State @@ -58,7 +58,7 @@ class ExpandedRetrievalOutput(TypedDict): class ExpandedRetrievalInput(PrimaryState): - starting_query: str + question: str class RetrievalInput(PrimaryState): diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index afed369e593..a6c296109f2 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -2,7 +2,7 @@ from typing import Annotated from typing import TypedDict -from onyx.agent_search.answer_question.states import SearchAnswerResults +from onyx.agent_search.answer_question.states import QuestionAnswerResults from onyx.agent_search.core_state import PrimaryState from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalResult from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections @@ -23,7 +23,7 @@ class InitialAnswerUpdate(TypedDict): class DecompAnswersUpdate(TypedDict): documents: Annotated[list[InferenceSection], dedup_inference_sections] - decomp_answer_results: Annotated[list[SearchAnswerResults], add] + decomp_answer_results: Annotated[list[QuestionAnswerResults], add] class ExpandedRetrievalUpdate(TypedDict): From bca02ebec6a0c24d97d39a509c530865a0998433 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 18 Dec 2024 12:44:28 -0800 Subject: [PATCH 15/19] figured it out --- .../onyx/agent_search/answer_question/edges.py | 2 +- .../answer_question/nodes/ingest_retrieval.py | 6 ++++-- .../agent_search/answer_question/states.py | 7 +++---- .../expanded_retrieval/graph_builder.py | 2 +- .../expanded_retrieval/nodes/doc_reranking.py | 2 -- .../expanded_retrieval/nodes/doc_retrieval.py | 10 ++++------ .../nodes/doc_verification.py | 17 ++++------------- .../expanded_retrieval/nodes/format_results.py | 7 +++++-- .../nodes/verification_kickoff.py | 6 +++++- .../agent_search/expanded_retrieval/states.py | 18 +++++++++++++----- .../onyx/agent_search/main/graph_builder.py | 8 ++++---- .../main/nodes/ingest_initial_retrieval.py | 8 ++++++-- backend/onyx/agent_search/main/states.py | 4 ++-- 13 files changed, 52 insertions(+), 45 deletions(-) diff --git a/backend/onyx/agent_search/answer_question/edges.py b/backend/onyx/agent_search/answer_question/edges.py index ec32f1c8523..261821f8cc0 100644 --- a/backend/onyx/agent_search/answer_question/edges.py +++ b/backend/onyx/agent_search/answer_question/edges.py @@ -12,6 +12,6 @@ def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable: "decomped_expanded_retrieval", ExpandedRetrievalInput( **extract_primary_fields(state), - starting_query=state["question"], + question=state["question"], ), ) diff --git a/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py b/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py index 7ee1ae75efb..f20ec7d86d5 100644 --- a/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py +++ b/backend/onyx/agent_search/answer_question/nodes/ingest_retrieval.py @@ -4,6 +4,8 @@ def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate: return RetrievalIngestionUpdate( - documents=state["documents"], - expanded_retrieval_results=state["expanded_retrieval_results"], + expanded_retrieval_results=state[ + "expanded_retrieval_result" + ].expanded_queries_results, + documents=state["expanded_retrieval_result"].all_documents, ) diff --git a/backend/onyx/agent_search/answer_question/states.py b/backend/onyx/agent_search/answer_question/states.py index e216ac20499..898a035b7b3 100644 --- a/backend/onyx/agent_search/answer_question/states.py +++ b/backend/onyx/agent_search/answer_question/states.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from onyx.agent_search.core_state import PrimaryState -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalResult +from onyx.agent_search.expanded_retrieval.states import QueryResult from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection @@ -17,7 +17,7 @@ class QuestionAnswerResults(BaseModel): question: str answer: str quality: str - expanded_retrieval_results: list[ExpandedRetrievalResult] + expanded_retrieval_results: list[QueryResult] documents: list[InferenceSection] @@ -35,8 +35,8 @@ class QAGenerationUpdate(TypedDict): class RetrievalIngestionUpdate(TypedDict): + expanded_retrieval_results: list[QueryResult] documents: Annotated[list[InferenceSection], dedup_inference_sections] - expanded_retrieval_results: list[ExpandedRetrievalResult] ## Graph State @@ -47,7 +47,6 @@ class AnswerQuestionState( QAGenerationUpdate, QACheckUpdate, RetrievalIngestionUpdate, - total=True, ): question: str diff --git a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py index a1606781968..c2bfd1e346c 100644 --- a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py +++ b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py @@ -91,7 +91,7 @@ def expanded_retrieval_graph_builder() -> StateGraph: primary_llm=primary_llm, fast_llm=fast_llm, db_session=db_session, - starting_query="what can you do with onyx?", + question="what can you do with onyx?", ) for thing in compiled_graph.stream( input=inputs, diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py index 925b7c7f444..6f8e3df0634 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_reranking.py @@ -3,8 +3,6 @@ def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingUpdate: - print(f"doc_reranking state: {state.keys()}") - verified_documents = state["verified_documents"] reranked_documents = verified_documents diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py index a141bfcaac3..c0b60ef38d3 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py @@ -1,5 +1,5 @@ from onyx.agent_search.expanded_retrieval.states import DocRetrievalUpdate -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalResult +from onyx.agent_search.expanded_retrieval.states import QueryResult from onyx.agent_search.expanded_retrieval.states import RetrievalInput from onyx.context.search.models import InferenceSection from onyx.context.search.models import SearchRequest @@ -7,7 +7,6 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: - # def doc_retrieval(state: RetrieveInput) -> Command[Literal["doc_verification"]]: """ Retrieve documents @@ -33,10 +32,9 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: db_session=state["db_session"], ).reranked_sections - print(f"retrieved documents: {len(documents)}") - expanded_retrieval_result = ExpandedRetrievalResult( - expanded_query=query_to_retrieve, - expanded_retrieval_documents=documents[:4], + expanded_retrieval_result = QueryResult( + query=query_to_retrieve, + documents_for_query=documents[:4], ) return DocRetrievalUpdate( expanded_retrieval_results=[expanded_retrieval_result], diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py index 741c445e2c8..3abebfcf2e1 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_verification.py @@ -1,15 +1,10 @@ from langchain_core.messages import HumanMessage from langchain_core.messages import merge_message_runs +from onyx.agent_search.expanded_retrieval.states import DocVerificationInput from onyx.agent_search.expanded_retrieval.states import DocVerificationUpdate -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState from onyx.agent_search.shared_graph_utils.models import BinaryDecision from onyx.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT -from onyx.context.search.models import InferenceSection - - -class DocVerificationInput(ExpandedRetrievalState, total=True): - doc_to_verify: InferenceSection def doc_verification(state: DocVerificationInput) -> DocVerificationUpdate: @@ -17,14 +12,12 @@ def doc_verification(state: DocVerificationInput) -> DocVerificationUpdate: Check whether the document is relevant for the original user question Args: - state (VerifierState): The current state + state (DocVerificationInput): The current state - Returns: - dict: ict: The updated state with the final decision + Updates: + verified_documents: list[InferenceSection] """ - print(f"doc_verification state: {state.keys()}") - original_query = state["search_request"].query doc_to_verify = state["doc_to_verify"] document_content = doc_to_verify.combined_content @@ -49,8 +42,6 @@ def doc_verification(state: DocVerificationInput) -> DocVerificationUpdate: decision_dict = {"decision": response_string.lower()} formatted_response = BinaryDecision.model_validate(decision_dict) - print(f"Verdict: {formatted_response.decision}") - verified_documents = [] if formatted_response.decision == "yes": verified_documents.append(doc_to_verify) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py index 2a9620a0a9a..50da6e9a640 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/format_results.py @@ -1,9 +1,12 @@ from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalResult from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState def format_results(state: ExpandedRetrievalState) -> ExpandedRetrievalOutput: return ExpandedRetrievalOutput( - expanded_retrieval_results=state["expanded_retrieval_results"], - documents=state["reranked_documents"], + expanded_retrieval_result=ExpandedRetrievalResult( + expanded_queries_results=state["expanded_retrieval_results"], + all_documents=state["reranked_documents"], + ), ) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py b/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py index d40bf6f0dae..08940889952 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py @@ -3,6 +3,7 @@ from langgraph.types import Command from langgraph.types import Send +from onyx.agent_search.core_state import extract_primary_fields from onyx.agent_search.expanded_retrieval.nodes.doc_verification import ( DocVerificationInput, ) @@ -20,7 +21,10 @@ def verification_kickoff( goto=[ Send( node="doc_verification", - arg=DocVerificationInput(doc_to_verify=doc, **state), + arg=DocVerificationInput( + doc_to_verify=doc, + **extract_primary_fields(state), + ), ) for doc in documents ], diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index 81b96d95f10..5408e75e893 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -12,9 +12,14 @@ ### Models ### +class QueryResult(BaseModel): + query: str + documents_for_query: list[InferenceSection] + + class ExpandedRetrievalResult(BaseModel): - expanded_query: str - expanded_retrieval_documents: list[InferenceSection] + expanded_queries_results: list[QueryResult] + all_documents: list[InferenceSection] ### States ### @@ -30,7 +35,7 @@ class DocRerankingUpdate(TypedDict): class DocRetrievalUpdate(TypedDict): - expanded_retrieval_results: Annotated[list[ExpandedRetrievalResult], add] + expanded_retrieval_results: Annotated[list[QueryResult], add] retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections] @@ -50,8 +55,7 @@ class ExpandedRetrievalState( class ExpandedRetrievalOutput(TypedDict): - expanded_retrieval_results: list[ExpandedRetrievalResult] - documents: Annotated[list[InferenceSection], dedup_inference_sections] + expanded_retrieval_result: ExpandedRetrievalResult ## Input States @@ -61,5 +65,9 @@ class ExpandedRetrievalInput(PrimaryState): question: str +class DocVerificationInput(PrimaryState): + doc_to_verify: InferenceSection + + class RetrievalInput(PrimaryState): query_to_retrieve: str diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index 971398f9c90..f628ebf78c1 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -41,14 +41,14 @@ def main_graph_builder() -> StateGraph: node="initial_retrieval", action=expanded_retrieval_subgraph, ) - graph.add_node( - node="ingest_answers", - action=ingest_answers, - ) graph.add_node( node="ingest_initial_retrieval", action=ingest_initial_retrieval, ) + graph.add_node( + node="ingest_answers", + action=ingest_answers, + ) graph.add_node( node="generate_initial_answer", action=generate_initial_answer, diff --git a/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py b/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py index e3a96e0b8e4..3cd75860cec 100644 --- a/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py +++ b/backend/onyx/agent_search/main/nodes/ingest_initial_retrieval.py @@ -4,6 +4,10 @@ def ingest_initial_retrieval(state: ExpandedRetrievalOutput) -> ExpandedRetrievalUpdate: return ExpandedRetrievalUpdate( - all_original_question_documents=state["documents"], - original_question_retrieval_results=state["expanded_retrieval_results"], + original_question_retrieval_results=state[ + "expanded_retrieval_result" + ].expanded_queries_results, + all_original_question_documents=state[ + "expanded_retrieval_result" + ].all_documents, ) diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index a6c296109f2..081440344bc 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -4,7 +4,7 @@ from onyx.agent_search.answer_question.states import QuestionAnswerResults from onyx.agent_search.core_state import PrimaryState -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalResult +from onyx.agent_search.expanded_retrieval.states import QueryResult from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection @@ -30,7 +30,7 @@ class ExpandedRetrievalUpdate(TypedDict): all_original_question_documents: Annotated[ list[InferenceSection], dedup_inference_sections ] - original_question_retrieval_results: list[ExpandedRetrievalResult] + original_question_retrieval_results: list[QueryResult] ## Graph State From 2d6f7462592f0a4b0fea89578965a9a102edde4d Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 18 Dec 2024 13:03:28 -0800 Subject: [PATCH 16/19] made query expansion explicit --- .../agent_search/expanded_retrieval/edges.py | 34 ++----------------- .../expanded_retrieval/graph_builder.py | 12 ++++++- .../nodes/expand_queries.py | 30 ++++++++++++++++ .../agent_search/expanded_retrieval/states.py | 5 +++ backend/onyx/agent_search/main/edges.py | 18 ++++++++-- .../onyx/agent_search/main/graph_builder.py | 14 +++----- 6 files changed, 70 insertions(+), 43 deletions(-) create mode 100644 backend/onyx/agent_search/expanded_retrieval/nodes/expand_queries.py diff --git a/backend/onyx/agent_search/expanded_retrieval/edges.py b/backend/onyx/agent_search/expanded_retrieval/edges.py index 1c62ba7dd91..d426ed36031 100644 --- a/backend/onyx/agent_search/expanded_retrieval/edges.py +++ b/backend/onyx/agent_search/expanded_retrieval/edges.py @@ -1,41 +1,13 @@ from collections.abc import Hashable -from langchain_core.messages import HumanMessage -from langchain_core.messages import merge_message_runs from langgraph.types import Send from onyx.agent_search.core_state import extract_primary_fields from onyx.agent_search.expanded_retrieval.nodes.doc_retrieval import RetrievalInput -from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput -from onyx.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI_ORIGINAL -from onyx.llm.interfaces import LLM +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState -def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashable]: - print(f"parallel_retrieval_edge state: {state.keys()}") - - # This should be better... - question = state.get("question") or state["search_request"].query - llm: LLM = state["fast_llm"] - - msg = [ - HumanMessage( - content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question), - ) - ] - llm_response_list = list( - llm.stream( - prompt=msg, - ) - ) - llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content - - print(f"llm_response: {llm_response}") - - rewritten_queries = llm_response.split("--") - - print(f"rewritten_queries: {rewritten_queries}") - +def parallel_retrieval_edge(state: ExpandedRetrievalState) -> list[Send | Hashable]: return [ Send( "doc_retrieval", @@ -44,5 +16,5 @@ def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashab **extract_primary_fields(state), ), ) - for query in rewritten_queries + for query in state["expanded_queries"] ] diff --git a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py index c2bfd1e346c..8da14eea434 100644 --- a/backend/onyx/agent_search/expanded_retrieval/graph_builder.py +++ b/backend/onyx/agent_search/expanded_retrieval/graph_builder.py @@ -8,6 +8,7 @@ from onyx.agent_search.expanded_retrieval.nodes.doc_verification import ( doc_verification, ) +from onyx.agent_search.expanded_retrieval.nodes.expand_queries import expand_queries from onyx.agent_search.expanded_retrieval.nodes.format_results import format_results from onyx.agent_search.expanded_retrieval.nodes.verification_kickoff import ( verification_kickoff, @@ -26,6 +27,11 @@ def expanded_retrieval_graph_builder() -> StateGraph: ### Add nodes ### + graph.add_node( + node="expand_queries", + action=expand_queries, + ) + graph.add_node( node="doc_retrieval", action=doc_retrieval, @@ -48,9 +54,13 @@ def expanded_retrieval_graph_builder() -> StateGraph: ) ### Add edges ### + graph.add_edge( + start_key=START, + end_key="expand_queries", + ) graph.add_conditional_edges( - source=START, + source="expand_queries", path=parallel_retrieval_edge, path_map=["doc_retrieval"], ) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/expand_queries.py b/backend/onyx/agent_search/expanded_retrieval/nodes/expand_queries.py new file mode 100644 index 00000000000..193d9b648fb --- /dev/null +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/expand_queries.py @@ -0,0 +1,30 @@ +from langchain_core.messages import HumanMessage +from langchain_core.messages import merge_message_runs + +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput +from onyx.agent_search.expanded_retrieval.states import QueryExpansionUpdate +from onyx.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI_ORIGINAL +from onyx.llm.interfaces import LLM + + +def expand_queries(state: ExpandedRetrievalInput) -> QueryExpansionUpdate: + question = state.get("question") + llm: LLM = state["fast_llm"] + + msg = [ + HumanMessage( + content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question), + ) + ] + llm_response_list = list( + llm.stream( + prompt=msg, + ) + ) + llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content + + rewritten_queries = llm_response.split("--") + + return QueryExpansionUpdate( + expanded_queries=rewritten_queries, + ) diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index 5408e75e893..25160073e99 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -34,6 +34,10 @@ class DocRerankingUpdate(TypedDict): reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] +class QueryExpansionUpdate(TypedDict): + expanded_queries: list[str] + + class DocRetrievalUpdate(TypedDict): expanded_retrieval_results: Annotated[list[QueryResult], add] retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections] @@ -47,6 +51,7 @@ class ExpandedRetrievalState( DocRetrievalUpdate, DocVerificationUpdate, DocRerankingUpdate, + QueryExpansionUpdate, ): question: str diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index 484c0c354a7..454b245ee8a 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -1,14 +1,18 @@ from collections.abc import Hashable +from collections.abc import Sequence from langgraph.types import Send from onyx.agent_search.answer_question.states import AnswerQuestionInput from onyx.agent_search.core_state import extract_primary_fields +from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput from onyx.agent_search.main.states import MainState -def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hashable]: - return [ +def parallelize_decompozed_answer_queries( + state: MainState, +) -> Sequence[Send | Hashable]: + answer_query_edges = [ Send( "answer_query", AnswerQuestionInput( @@ -18,6 +22,16 @@ def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hasha ) for question in state["initial_decomp_questions"] ] + initial_retrieval_edges = [ + Send( + "initial_retrieval", + ExpandedRetrievalInput( + **extract_primary_fields(state), + question=state["search_request"].query, + ), + ) + ] + return answer_query_edges + initial_retrieval_edges # def continue_to_answer_sub_questions(state: QAState) -> Union[Hashable, list[Hashable]]: diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index f628ebf78c1..0c85bac7c18 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -55,14 +55,6 @@ def main_graph_builder() -> StateGraph: ) ### Add edges ### - graph.add_edge( - start_key=START, - end_key="initial_retrieval", - ) - graph.add_edge( - start_key="initial_retrieval", - end_key="ingest_initial_retrieval", - ) graph.add_edge( start_key=START, @@ -71,7 +63,11 @@ def main_graph_builder() -> StateGraph: graph.add_conditional_edges( source="base_decomp", path=parallelize_decompozed_answer_queries, - path_map=["answer_query"], + path_map=["answer_query", "initial_retrieval"], + ) + graph.add_edge( + start_key="initial_retrieval", + end_key="ingest_initial_retrieval", ) graph.add_edge( start_key="answer_query", From ffc81f6e45b15c7bd2bfbdd476247a70831c20d9 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 18 Dec 2024 13:10:46 -0800 Subject: [PATCH 17/19] seperate edge for initial retrieval --- backend/onyx/agent_search/main/edges.py | 14 +++++++------- backend/onyx/agent_search/main/graph_builder.py | 17 ++++++++++++----- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index 454b245ee8a..0836882ce6f 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -1,18 +1,16 @@ from collections.abc import Hashable -from collections.abc import Sequence from langgraph.types import Send from onyx.agent_search.answer_question.states import AnswerQuestionInput from onyx.agent_search.core_state import extract_primary_fields from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput +from onyx.agent_search.main.states import MainInput from onyx.agent_search.main.states import MainState -def parallelize_decompozed_answer_queries( - state: MainState, -) -> Sequence[Send | Hashable]: - answer_query_edges = [ +def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hashable]: + return [ Send( "answer_query", AnswerQuestionInput( @@ -22,7 +20,10 @@ def parallelize_decompozed_answer_queries( ) for question in state["initial_decomp_questions"] ] - initial_retrieval_edges = [ + + +def send_to_initial_retrieval(state: MainInput) -> list[Send | Hashable]: + return [ Send( "initial_retrieval", ExpandedRetrievalInput( @@ -31,7 +32,6 @@ def parallelize_decompozed_answer_queries( ), ) ] - return answer_query_edges + initial_retrieval_edges # def continue_to_answer_sub_questions(state: QAState) -> Union[Hashable, list[Hashable]]: diff --git a/backend/onyx/agent_search/main/graph_builder.py b/backend/onyx/agent_search/main/graph_builder.py index 0c85bac7c18..dc09000435c 100644 --- a/backend/onyx/agent_search/main/graph_builder.py +++ b/backend/onyx/agent_search/main/graph_builder.py @@ -7,6 +7,7 @@ expanded_retrieval_graph_builder, ) from onyx.agent_search.main.edges import parallelize_decompozed_answer_queries +from onyx.agent_search.main.edges import send_to_initial_retrieval from onyx.agent_search.main.nodes.base_decomp import main_decomp_base from onyx.agent_search.main.nodes.generate_initial_answer import ( generate_initial_answer, @@ -56,6 +57,16 @@ def main_graph_builder() -> StateGraph: ### Add edges ### + graph.add_conditional_edges( + source=START, + path=send_to_initial_retrieval, + path_map=["initial_retrieval"], + ) + graph.add_edge( + start_key="initial_retrieval", + end_key="ingest_initial_retrieval", + ) + graph.add_edge( start_key=START, end_key="base_decomp", @@ -63,11 +74,7 @@ def main_graph_builder() -> StateGraph: graph.add_conditional_edges( source="base_decomp", path=parallelize_decompozed_answer_queries, - path_map=["answer_query", "initial_retrieval"], - ) - graph.add_edge( - start_key="initial_retrieval", - end_key="ingest_initial_retrieval", + path_map=["answer_query"], ) graph.add_edge( start_key="answer_query", From cebe237705121c66d490f3431a434b8c565f7f03 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Thu, 19 Dec 2024 08:47:39 -0800 Subject: [PATCH 18/19] renamed PrimaryState to CoreState --- .../agent_search/answer_question/edges.py | 4 ++-- .../agent_search/answer_question/states.py | 20 +++++++++---------- backend/onyx/agent_search/core_state.py | 18 +++++++++-------- .../agent_search/expanded_retrieval/edges.py | 4 ++-- .../nodes/verification_kickoff.py | 4 ++-- backend/onyx/agent_search/main/edges.py | 6 +++--- backend/onyx/agent_search/main/states.py | 19 +++++++++--------- 7 files changed, 39 insertions(+), 36 deletions(-) diff --git a/backend/onyx/agent_search/answer_question/edges.py b/backend/onyx/agent_search/answer_question/edges.py index 261821f8cc0..bdd9864e6e0 100644 --- a/backend/onyx/agent_search/answer_question/edges.py +++ b/backend/onyx/agent_search/answer_question/edges.py @@ -3,7 +3,7 @@ from langgraph.types import Send from onyx.agent_search.answer_question.states import AnswerQuestionInput -from onyx.agent_search.core_state import extract_primary_fields +from onyx.agent_search.core_state import extract_core_fields from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput @@ -11,7 +11,7 @@ def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable: return Send( "decomped_expanded_retrieval", ExpandedRetrievalInput( - **extract_primary_fields(state), + **extract_core_fields(state), question=state["question"], ), ) diff --git a/backend/onyx/agent_search/answer_question/states.py b/backend/onyx/agent_search/answer_question/states.py index 898a035b7b3..2964df0ab54 100644 --- a/backend/onyx/agent_search/answer_question/states.py +++ b/backend/onyx/agent_search/answer_question/states.py @@ -4,7 +4,7 @@ from pydantic import BaseModel -from onyx.agent_search.core_state import PrimaryState +from onyx.agent_search.core_state import CoreState from onyx.agent_search.expanded_retrieval.states import QueryResult from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection @@ -39,23 +39,23 @@ class RetrievalIngestionUpdate(TypedDict): documents: Annotated[list[InferenceSection], dedup_inference_sections] +## Graph Input State + + +class AnswerQuestionInput(CoreState): + question: str + + ## Graph State class AnswerQuestionState( - PrimaryState, + AnswerQuestionInput, QAGenerationUpdate, QACheckUpdate, RetrievalIngestionUpdate, ): - question: str - - -## Input State - - -class AnswerQuestionInput(PrimaryState): - question: str + pass ## Graph Output State diff --git a/backend/onyx/agent_search/core_state.py b/backend/onyx/agent_search/core_state.py index ee490e0a337..cbc8f3d5c4d 100644 --- a/backend/onyx/agent_search/core_state.py +++ b/backend/onyx/agent_search/core_state.py @@ -7,7 +7,11 @@ from onyx.llm.interfaces import LLM -class PrimaryState(TypedDict, total=False): +class CoreState(TypedDict, total=False): + """ + This is the core state that is shared across all subgraphs. + """ + search_request: SearchRequest primary_llm: LLM fast_llm: LLM @@ -16,12 +20,10 @@ class PrimaryState(TypedDict, total=False): db_session: Session -# This ensures that the state passed in extends the PrimaryState -T = TypeVar("T", bound=PrimaryState) +# This ensures that the state passed in extends the CoreState +T = TypeVar("T", bound=CoreState) -def extract_primary_fields(state: T) -> PrimaryState: - filtered_dict = { - k: v for k, v in state.items() if k in PrimaryState.__annotations__ - } - return PrimaryState(**dict(filtered_dict)) # type: ignore +def extract_core_fields(state: T) -> CoreState: + filtered_dict = {k: v for k, v in state.items() if k in CoreState.__annotations__} + return CoreState(**dict(filtered_dict)) # type: ignore diff --git a/backend/onyx/agent_search/expanded_retrieval/edges.py b/backend/onyx/agent_search/expanded_retrieval/edges.py index d426ed36031..61d994a6871 100644 --- a/backend/onyx/agent_search/expanded_retrieval/edges.py +++ b/backend/onyx/agent_search/expanded_retrieval/edges.py @@ -2,7 +2,7 @@ from langgraph.types import Send -from onyx.agent_search.core_state import extract_primary_fields +from onyx.agent_search.core_state import extract_core_fields from onyx.agent_search.expanded_retrieval.nodes.doc_retrieval import RetrievalInput from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState @@ -13,7 +13,7 @@ def parallel_retrieval_edge(state: ExpandedRetrievalState) -> list[Send | Hashab "doc_retrieval", RetrievalInput( query_to_retrieve=query, - **extract_primary_fields(state), + **extract_core_fields(state), ), ) for query in state["expanded_queries"] diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py b/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py index 08940889952..81fe7f9229c 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/verification_kickoff.py @@ -3,7 +3,7 @@ from langgraph.types import Command from langgraph.types import Send -from onyx.agent_search.core_state import extract_primary_fields +from onyx.agent_search.core_state import extract_core_fields from onyx.agent_search.expanded_retrieval.nodes.doc_verification import ( DocVerificationInput, ) @@ -23,7 +23,7 @@ def verification_kickoff( node="doc_verification", arg=DocVerificationInput( doc_to_verify=doc, - **extract_primary_fields(state), + **extract_core_fields(state), ), ) for doc in documents diff --git a/backend/onyx/agent_search/main/edges.py b/backend/onyx/agent_search/main/edges.py index 0836882ce6f..7791498d3e8 100644 --- a/backend/onyx/agent_search/main/edges.py +++ b/backend/onyx/agent_search/main/edges.py @@ -3,7 +3,7 @@ from langgraph.types import Send from onyx.agent_search.answer_question.states import AnswerQuestionInput -from onyx.agent_search.core_state import extract_primary_fields +from onyx.agent_search.core_state import extract_core_fields from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput from onyx.agent_search.main.states import MainInput from onyx.agent_search.main.states import MainState @@ -14,7 +14,7 @@ def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hasha Send( "answer_query", AnswerQuestionInput( - **extract_primary_fields(state), + **extract_core_fields(state), question=question, ), ) @@ -27,7 +27,7 @@ def send_to_initial_retrieval(state: MainInput) -> list[Send | Hashable]: Send( "initial_retrieval", ExpandedRetrievalInput( - **extract_primary_fields(state), + **extract_core_fields(state), question=state["search_request"].query, ), ) diff --git a/backend/onyx/agent_search/main/states.py b/backend/onyx/agent_search/main/states.py index 081440344bc..6fb44da4903 100644 --- a/backend/onyx/agent_search/main/states.py +++ b/backend/onyx/agent_search/main/states.py @@ -3,7 +3,7 @@ from typing import TypedDict from onyx.agent_search.answer_question.states import QuestionAnswerResults -from onyx.agent_search.core_state import PrimaryState +from onyx.agent_search.core_state import CoreState from onyx.agent_search.expanded_retrieval.states import QueryResult from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection @@ -33,11 +33,19 @@ class ExpandedRetrievalUpdate(TypedDict): original_question_retrieval_results: list[QueryResult] +## Graph Input State + + +class MainInput(CoreState): + pass + + ## Graph State class MainState( - PrimaryState, + # This includes the core state + MainInput, BaseDecompUpdate, InitialAnswerUpdate, DecompAnswersUpdate, @@ -46,13 +54,6 @@ class MainState( pass -## Input States - - -class MainInput(PrimaryState): - pass - - ## Graph Output State From 34aa054c5d1c1e51e88b651edb9a6c3ee2ec22f7 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Thu, 19 Dec 2024 08:48:05 -0800 Subject: [PATCH 19/19] added chunk_ids and stats to QueryResult --- .../expanded_retrieval/nodes/doc_retrieval.py | 2 ++ .../agent_search/expanded_retrieval/states.py | 27 ++++++++++++------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py index c0b60ef38d3..54d54211198 100644 --- a/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py +++ b/backend/onyx/agent_search/expanded_retrieval/nodes/doc_retrieval.py @@ -35,6 +35,8 @@ def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate: expanded_retrieval_result = QueryResult( query=query_to_retrieve, documents_for_query=documents[:4], + chunk_ids=[], + stats={}, ) return DocRetrievalUpdate( expanded_retrieval_results=[expanded_retrieval_result], diff --git a/backend/onyx/agent_search/expanded_retrieval/states.py b/backend/onyx/agent_search/expanded_retrieval/states.py index 25160073e99..71c845cd6fd 100644 --- a/backend/onyx/agent_search/expanded_retrieval/states.py +++ b/backend/onyx/agent_search/expanded_retrieval/states.py @@ -1,10 +1,11 @@ from operator import add from typing import Annotated +from typing import Any from typing import TypedDict from pydantic import BaseModel -from onyx.agent_search.core_state import PrimaryState +from onyx.agent_search.core_state import CoreState from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections from onyx.context.search.models import InferenceSection @@ -15,6 +16,8 @@ class QueryResult(BaseModel): query: str documents_for_query: list[InferenceSection] + chunk_ids: list[str] + stats: dict[str, Any] class ExpandedRetrievalResult(BaseModel): @@ -43,17 +46,25 @@ class DocRetrievalUpdate(TypedDict): retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections] +## Graph Input State + + +class ExpandedRetrievalInput(CoreState): + question: str + + ## Graph State class ExpandedRetrievalState( - PrimaryState, + # This includes the core state + ExpandedRetrievalInput, DocRetrievalUpdate, DocVerificationUpdate, DocRerankingUpdate, QueryExpansionUpdate, ): - question: str + pass ## Graph Output State @@ -63,16 +74,12 @@ class ExpandedRetrievalOutput(TypedDict): expanded_retrieval_result: ExpandedRetrievalResult -## Input States - - -class ExpandedRetrievalInput(PrimaryState): - question: str +## Conditional Input States -class DocVerificationInput(PrimaryState): +class DocVerificationInput(CoreState): doc_to_verify: InferenceSection -class RetrievalInput(PrimaryState): +class RetrievalInput(CoreState): query_to_retrieve: str