diff --git a/backend/onyx/chat/answer.py b/backend/onyx/chat/answer.py index 51836c228d1..14b9c227ac4 100644 --- a/backend/onyx/chat/answer.py +++ b/backend/onyx/chat/answer.py @@ -22,7 +22,9 @@ from onyx.chat.stream_processing.answer_response_handler import ( DummyAnswerResponseHandler, ) -from onyx.chat.stream_processing.utils import map_document_id_order +from onyx.chat.stream_processing.utils import ( + map_document_id_order, +) from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler from onyx.file_store.utils import InMemoryChatFile from onyx.llm.interfaces import LLM @@ -206,9 +208,9 @@ def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream: # + figure out what the next LLM call should be tool_call_handler = ToolResponseHandler(current_llm_call.tools) - search_result, displayed_search_results_map = SearchTool.get_search_result( + final_search_results, displayed_search_results = SearchTool.get_search_result( current_llm_call - ) or ([], {}) + ) or ([], []) # Quotes are no longer supported # answer_handler: AnswerResponseHandler @@ -224,9 +226,9 @@ def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream: # else: # raise ValueError("No answer style config provided") answer_handler = CitationResponseHandler( - context_docs=search_result, - doc_id_to_rank_map=map_document_id_order(search_result), - display_doc_order_dict=displayed_search_results_map, + context_docs=final_search_results, + final_doc_id_to_rank_map=map_document_id_order(final_search_results), + display_doc_id_to_rank_map=map_document_id_order(displayed_search_results), ) response_handler_manager = LLMResponseHandlerManager( diff --git a/backend/onyx/chat/stream_processing/answer_response_handler.py b/backend/onyx/chat/stream_processing/answer_response_handler.py index 6d1031e95d0..87098c3f177 100644 --- a/backend/onyx/chat/stream_processing/answer_response_handler.py +++ b/backend/onyx/chat/stream_processing/answer_response_handler.py @@ -37,22 +37,22 @@ class CitationResponseHandler(AnswerResponseHandler): def __init__( self, context_docs: list[LlmDoc], - doc_id_to_rank_map: DocumentIdOrderMapping, - display_doc_order_dict: dict[str, int], + final_doc_id_to_rank_map: DocumentIdOrderMapping, + display_doc_id_to_rank_map: DocumentIdOrderMapping, ): self.context_docs = context_docs - self.doc_id_to_rank_map = doc_id_to_rank_map - self.display_doc_order_dict = display_doc_order_dict + self.final_doc_id_to_rank_map = final_doc_id_to_rank_map + self.display_doc_id_to_rank_map = display_doc_id_to_rank_map self.citation_processor = CitationProcessor( context_docs=self.context_docs, - doc_id_to_rank_map=self.doc_id_to_rank_map, - display_doc_order_dict=self.display_doc_order_dict, + final_doc_id_to_rank_map=self.final_doc_id_to_rank_map, + display_doc_id_to_rank_map=self.display_doc_id_to_rank_map, ) self.processed_text = "" self.citations: list[CitationInfo] = [] # TODO remove this after citation issue is resolved - logger.debug(f"Document to ranking map {self.doc_id_to_rank_map}") + logger.debug(f"Document to ranking map {self.final_doc_id_to_rank_map}") def handle_response_part( self, diff --git a/backend/onyx/chat/stream_processing/citation_processing.py b/backend/onyx/chat/stream_processing/citation_processing.py index cd159dd1542..071b28c3457 100644 --- a/backend/onyx/chat/stream_processing/citation_processing.py +++ b/backend/onyx/chat/stream_processing/citation_processing.py @@ -21,20 +21,19 @@ class CitationProcessor: def __init__( self, context_docs: list[LlmDoc], - doc_id_to_rank_map: DocumentIdOrderMapping, - display_doc_order_dict: dict[str, int], + final_doc_id_to_rank_map: DocumentIdOrderMapping, + display_doc_id_to_rank_map: DocumentIdOrderMapping, stop_stream: str | None = STOP_STREAM_PAT, ): self.context_docs = context_docs - self.doc_id_to_rank_map = doc_id_to_rank_map + self.final_doc_id_to_rank_map = final_doc_id_to_rank_map + self.display_doc_id_to_rank_map = display_doc_id_to_rank_map self.stop_stream = stop_stream - self.order_mapping = doc_id_to_rank_map.order_mapping - self.display_doc_order_dict = ( - display_doc_order_dict # original order of docs to displayed to user - ) + self.final_order_mapping = final_doc_id_to_rank_map.order_mapping + self.display_order_mapping = display_doc_id_to_rank_map.order_mapping self.llm_out = "" self.max_citation_num = len(context_docs) - self.citation_order: list[int] = [] + self.citation_order: list[int] = [] # order of citations in the LLM output self.curr_segment = "" self.cited_inds: set[int] = set() self.hold = "" @@ -93,29 +92,31 @@ def process_token( if 1 <= numerical_value <= self.max_citation_num: context_llm_doc = self.context_docs[numerical_value - 1] - real_citation_num = self.order_mapping[context_llm_doc.document_id] + final_citation_num = self.final_order_mapping[ + context_llm_doc.document_id + ] - if real_citation_num not in self.citation_order: - self.citation_order.append(real_citation_num) + if final_citation_num not in self.citation_order: + self.citation_order.append(final_citation_num) - target_citation_num = ( - self.citation_order.index(real_citation_num) + 1 + citation_order_idx = ( + self.citation_order.index(final_citation_num) + 1 ) # get the value that was displayed to user, should always # be in the display_doc_order_dict. But check anyways - if context_llm_doc.document_id in self.display_doc_order_dict: - displayed_citation_num = self.display_doc_order_dict[ + if context_llm_doc.document_id in self.display_order_mapping: + displayed_citation_num = self.display_order_mapping[ context_llm_doc.document_id ] else: - displayed_citation_num = real_citation_num + displayed_citation_num = final_citation_num logger.warning( f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead." ) # Skip consecutive citations of the same work - if target_citation_num in self.current_citations: + if final_citation_num in self.current_citations: start, end = citation.span() real_start = length_to_add + start diff = end - start @@ -134,8 +135,8 @@ def process_token( doc_id = int(match.group(1)) context_llm_doc = self.context_docs[doc_id - 1] yield CitationInfo( - # stay with the original for now (order of LLM cites) - citation_num=target_citation_num, + # citation_num is now the number post initial ranking, i.e. as displayed to user + citation_num=displayed_citation_num, document_id=context_llm_doc.document_id, ) except Exception as e: @@ -151,13 +152,13 @@ def process_token( link = context_llm_doc.link self.past_cite_count = len(self.llm_out) - self.current_citations.append(target_citation_num) + self.current_citations.append(final_citation_num) - if target_citation_num not in self.cited_inds: - self.cited_inds.add(target_citation_num) + if citation_order_idx not in self.cited_inds: + self.cited_inds.add(citation_order_idx) yield CitationInfo( - # stay with the original for now (order of LLM cites) - citation_num=target_citation_num, + # citation number is now the one that was displayed to user + citation_num=displayed_citation_num, document_id=context_llm_doc.document_id, ) @@ -167,7 +168,6 @@ def process_token( self.curr_segment = ( self.curr_segment[: start + length_to_add] + f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user - # + f"[[{target_citation_num}]]({link})" + self.curr_segment[end + length_to_add :] ) length_to_add += len(self.curr_segment) - prev_length @@ -176,7 +176,6 @@ def process_token( self.curr_segment = ( self.curr_segment[: start + length_to_add] + f"[[{displayed_citation_num}]]()" # use the value that was displayed to user - # + f"[[{target_citation_num}]]()" + self.curr_segment[end + length_to_add :] ) length_to_add += len(self.curr_segment) - prev_length diff --git a/backend/onyx/tools/tool_implementations/search/search_tool.py b/backend/onyx/tools/tool_implementations/search/search_tool.py index 368111ca46f..da8bc6a1a1c 100644 --- a/backend/onyx/tools/tool_implementations/search/search_tool.py +++ b/backend/onyx/tools/tool_implementations/search/search_tool.py @@ -396,7 +396,7 @@ def build_next_prompt( @classmethod def get_search_result( cls, llm_call: LLMCall - ) -> tuple[list[LlmDoc], dict[str, int]] | None: + ) -> tuple[list[LlmDoc], list[LlmDoc]] | None: """ Returns the final search results and a map of docs to their original search rank (which is what is displayed to user) """ @@ -404,7 +404,7 @@ def get_search_result( return None final_search_results = [] - doc_id_to_original_search_rank_map = {} + initial_search_results = [] for yield_item in llm_call.tool_call_info: if ( @@ -417,12 +417,11 @@ def get_search_result( and yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID ): search_contexts = yield_item.response.contexts - original_doc_search_rank = 1 - for idx, doc in enumerate(search_contexts): - if doc.document_id not in doc_id_to_original_search_rank_map: - doc_id_to_original_search_rank_map[ - doc.document_id - ] = original_doc_search_rank - original_doc_search_rank += 1 - - return final_search_results, doc_id_to_original_search_rank_map + # original_doc_search_rank = 1 + for doc in search_contexts: + if doc.document_id not in initial_search_results: + initial_search_results.append(doc) + + initial_search_results = cast(list[LlmDoc], initial_search_results) + + return final_search_results, initial_search_results diff --git a/backend/tests/unit/onyx/chat/stream_processing/test_citation_processing.py b/backend/tests/unit/onyx/chat/stream_processing/test_citation_processing.py index dcc960790db..85da97be299 100644 --- a/backend/tests/unit/onyx/chat/stream_processing/test_citation_processing.py +++ b/backend/tests/unit/onyx/chat/stream_processing/test_citation_processing.py @@ -68,11 +68,12 @@ def process_text( tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int]] ) -> tuple[str, list[CitationInfo]]: mock_docs, mock_doc_id_to_rank_map = mock_data - mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map) + final_mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map) + display_mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map) processor = CitationProcessor( context_docs=mock_docs, - doc_id_to_rank_map=mapping, - display_doc_order_dict=mock_doc_id_to_rank_map, + final_doc_id_to_rank_map=final_mapping, + display_doc_id_to_rank_map=display_mapping, stop_stream=None, ) diff --git a/backend/tests/unit/onyx/chat/stream_processing/test_citation_substitution.py b/backend/tests/unit/onyx/chat/stream_processing/test_citation_substitution.py index 78ecad1a3ec..3e14d54b097 100644 --- a/backend/tests/unit/onyx/chat/stream_processing/test_citation_substitution.py +++ b/backend/tests/unit/onyx/chat/stream_processing/test_citation_substitution.py @@ -71,19 +71,22 @@ @pytest.fixture -def mock_data() -> tuple[list[LlmDoc], dict[str, int]]: - return mock_docs, mock_doc_mapping +def mock_data() -> tuple[list[LlmDoc], dict[str, int], dict[str, int]]: + return mock_docs, mock_doc_mapping, mock_doc_mapping_rerank def process_text( - tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int]] + tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int], dict[str, int]] ) -> tuple[str, list[CitationInfo]]: - mock_docs, mock_doc_id_to_rank_map = mock_data - mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map) + mock_docs, mock_doc_id_to_rank_map, mock_doc_id_to_rank_map_rerank = mock_data + final_mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map) + display_mapping = DocumentIdOrderMapping( + order_mapping=mock_doc_id_to_rank_map_rerank + ) processor = CitationProcessor( context_docs=mock_docs, - doc_id_to_rank_map=mapping, - display_doc_order_dict=mock_doc_mapping_rerank, + final_doc_id_to_rank_map=final_mapping, + display_doc_id_to_rank_map=display_mapping, stop_stream=None, ) @@ -115,7 +118,7 @@ def process_text( ], ) def test_citation_substitution( - mock_data: tuple[list[LlmDoc], dict[str, int]], + mock_data: tuple[list[LlmDoc], dict[str, int], dict[str, int]], test_name: str, input_tokens: list[str], expected_text: str,