Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Citation consistency in Citation Processing (initial ranking vs post validation/re-ranking) #3508

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions backend/onyx/chat/answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from onyx.chat.stream_processing.answer_response_handler import (
DummyAnswerResponseHandler,
)
from onyx.chat.stream_processing.utils import map_document_id_order
from onyx.chat.stream_processing.utils import (
map_document_id_order,
)
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
from onyx.file_store.utils import InMemoryChatFile
from onyx.llm.interfaces import LLM
Expand Down Expand Up @@ -206,9 +208,9 @@ def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
# + figure out what the next LLM call should be
tool_call_handler = ToolResponseHandler(current_llm_call.tools)

search_result, displayed_search_results_map = SearchTool.get_search_result(
final_search_results, displayed_search_results = SearchTool.get_search_result(
current_llm_call
) or ([], {})
) or ([], [])

# Quotes are no longer supported
# answer_handler: AnswerResponseHandler
Expand All @@ -224,9 +226,9 @@ def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
# else:
# raise ValueError("No answer style config provided")
answer_handler = CitationResponseHandler(
context_docs=search_result,
doc_id_to_rank_map=map_document_id_order(search_result),
display_doc_order_dict=displayed_search_results_map,
context_docs=final_search_results,
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
)

response_handler_manager = LLMResponseHandlerManager(
Expand Down
14 changes: 7 additions & 7 deletions backend/onyx/chat/stream_processing/answer_response_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,22 @@ class CitationResponseHandler(AnswerResponseHandler):
def __init__(
self,
context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping,
display_doc_order_dict: dict[str, int],
final_doc_id_to_rank_map: DocumentIdOrderMapping,
display_doc_id_to_rank_map: DocumentIdOrderMapping,
):
self.context_docs = context_docs
self.doc_id_to_rank_map = doc_id_to_rank_map
self.display_doc_order_dict = display_doc_order_dict
self.final_doc_id_to_rank_map = final_doc_id_to_rank_map
self.display_doc_id_to_rank_map = display_doc_id_to_rank_map
self.citation_processor = CitationProcessor(
context_docs=self.context_docs,
doc_id_to_rank_map=self.doc_id_to_rank_map,
display_doc_order_dict=self.display_doc_order_dict,
final_doc_id_to_rank_map=self.final_doc_id_to_rank_map,
display_doc_id_to_rank_map=self.display_doc_id_to_rank_map,
)
self.processed_text = ""
self.citations: list[CitationInfo] = []

# TODO remove this after citation issue is resolved
logger.debug(f"Document to ranking map {self.doc_id_to_rank_map}")
logger.debug(f"Document to ranking map {self.final_doc_id_to_rank_map}")

def handle_response_part(
self,
Expand Down
51 changes: 25 additions & 26 deletions backend/onyx/chat/stream_processing/citation_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,19 @@ class CitationProcessor:
def __init__(
self,
context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping,
display_doc_order_dict: dict[str, int],
final_doc_id_to_rank_map: DocumentIdOrderMapping,
display_doc_id_to_rank_map: DocumentIdOrderMapping,
stop_stream: str | None = STOP_STREAM_PAT,
):
self.context_docs = context_docs
self.doc_id_to_rank_map = doc_id_to_rank_map
self.final_doc_id_to_rank_map = final_doc_id_to_rank_map
self.display_doc_id_to_rank_map = display_doc_id_to_rank_map
self.stop_stream = stop_stream
self.order_mapping = doc_id_to_rank_map.order_mapping
self.display_doc_order_dict = (
display_doc_order_dict # original order of docs to displayed to user
)
self.final_order_mapping = final_doc_id_to_rank_map.order_mapping
self.display_order_mapping = display_doc_id_to_rank_map.order_mapping
self.llm_out = ""
self.max_citation_num = len(context_docs)
self.citation_order: list[int] = []
self.citation_order: list[int] = [] # order of citations in the LLM output
self.curr_segment = ""
self.cited_inds: set[int] = set()
self.hold = ""
Expand Down Expand Up @@ -93,29 +92,31 @@ def process_token(

if 1 <= numerical_value <= self.max_citation_num:
context_llm_doc = self.context_docs[numerical_value - 1]
real_citation_num = self.order_mapping[context_llm_doc.document_id]
final_citation_num = self.final_order_mapping[
context_llm_doc.document_id
]

if real_citation_num not in self.citation_order:
self.citation_order.append(real_citation_num)
if final_citation_num not in self.citation_order:
self.citation_order.append(final_citation_num)

target_citation_num = (
self.citation_order.index(real_citation_num) + 1
citation_order_idx = (
self.citation_order.index(final_citation_num) + 1
)

# get the value that was displayed to user, should always
# be in the display_doc_order_dict. But check anyways
if context_llm_doc.document_id in self.display_doc_order_dict:
displayed_citation_num = self.display_doc_order_dict[
if context_llm_doc.document_id in self.display_order_mapping:
displayed_citation_num = self.display_order_mapping[
context_llm_doc.document_id
]
else:
displayed_citation_num = real_citation_num
displayed_citation_num = final_citation_num
logger.warning(
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
)

# Skip consecutive citations of the same work
if target_citation_num in self.current_citations:
if final_citation_num in self.current_citations:
start, end = citation.span()
real_start = length_to_add + start
diff = end - start
Expand All @@ -134,8 +135,8 @@ def process_token(
doc_id = int(match.group(1))
context_llm_doc = self.context_docs[doc_id - 1]
yield CitationInfo(
# stay with the original for now (order of LLM cites)
citation_num=target_citation_num,
# citation_num is now the number post initial ranking, i.e. as displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
except Exception as e:
Expand All @@ -151,13 +152,13 @@ def process_token(
link = context_llm_doc.link

self.past_cite_count = len(self.llm_out)
self.current_citations.append(target_citation_num)
self.current_citations.append(final_citation_num)

if target_citation_num not in self.cited_inds:
self.cited_inds.add(target_citation_num)
if citation_order_idx not in self.cited_inds:
self.cited_inds.add(citation_order_idx)
yield CitationInfo(
# stay with the original for now (order of LLM cites)
citation_num=target_citation_num,
# citation number is now the one that was displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)

Expand All @@ -167,7 +168,6 @@ def process_token(
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
# + f"[[{target_citation_num}]]({link})"
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
Expand All @@ -176,7 +176,6 @@ def process_token(
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
# + f"[[{target_citation_num}]]()"
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
Expand Down
21 changes: 10 additions & 11 deletions backend/onyx/tools/tool_implementations/search/search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,15 +396,15 @@ def build_next_prompt(
@classmethod
def get_search_result(
cls, llm_call: LLMCall
) -> tuple[list[LlmDoc], dict[str, int]] | None:
) -> tuple[list[LlmDoc], list[LlmDoc]] | None:
"""
Returns the final search results and a map of docs to their original search rank (which is what is displayed to user)
"""
if not llm_call.tool_call_info:
return None

final_search_results = []
doc_id_to_original_search_rank_map = {}
initial_search_results = []

for yield_item in llm_call.tool_call_info:
if (
Expand All @@ -417,12 +417,11 @@ def get_search_result(
and yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID
):
search_contexts = yield_item.response.contexts
original_doc_search_rank = 1
for idx, doc in enumerate(search_contexts):
if doc.document_id not in doc_id_to_original_search_rank_map:
doc_id_to_original_search_rank_map[
doc.document_id
] = original_doc_search_rank
original_doc_search_rank += 1

return final_search_results, doc_id_to_original_search_rank_map
# original_doc_search_rank = 1
for doc in search_contexts:
if doc.document_id not in initial_search_results:
initial_search_results.append(doc)

initial_search_results = cast(list[LlmDoc], initial_search_results)

return final_search_results, initial_search_results
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,12 @@ def process_text(
tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int]]
) -> tuple[str, list[CitationInfo]]:
mock_docs, mock_doc_id_to_rank_map = mock_data
mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map)
final_mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map)
display_mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map)
processor = CitationProcessor(
context_docs=mock_docs,
doc_id_to_rank_map=mapping,
display_doc_order_dict=mock_doc_id_to_rank_map,
final_doc_id_to_rank_map=final_mapping,
display_doc_id_to_rank_map=display_mapping,
stop_stream=None,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,22 @@


@pytest.fixture
def mock_data() -> tuple[list[LlmDoc], dict[str, int]]:
return mock_docs, mock_doc_mapping
def mock_data() -> tuple[list[LlmDoc], dict[str, int], dict[str, int]]:
return mock_docs, mock_doc_mapping, mock_doc_mapping_rerank


def process_text(
tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int]]
tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int], dict[str, int]]
) -> tuple[str, list[CitationInfo]]:
mock_docs, mock_doc_id_to_rank_map = mock_data
mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map)
mock_docs, mock_doc_id_to_rank_map, mock_doc_id_to_rank_map_rerank = mock_data
final_mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map)
display_mapping = DocumentIdOrderMapping(
order_mapping=mock_doc_id_to_rank_map_rerank
)
processor = CitationProcessor(
context_docs=mock_docs,
doc_id_to_rank_map=mapping,
display_doc_order_dict=mock_doc_mapping_rerank,
final_doc_id_to_rank_map=final_mapping,
display_doc_id_to_rank_map=display_mapping,
stop_stream=None,
)

Expand Down Expand Up @@ -115,7 +118,7 @@ def process_text(
],
)
def test_citation_substitution(
mock_data: tuple[list[LlmDoc], dict[str, int]],
mock_data: tuple[list[LlmDoc], dict[str, int], dict[str, int]],
test_name: str,
input_tokens: list[str],
expected_text: str,
Expand Down
Loading