From 0125d8a0f69b2535ab0d1307050bdd3d862c70c4 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Tue, 7 Nov 2023 14:21:04 -0800 Subject: [PATCH] Source Filter Extraction (#708) --- backend/danswer/chat/personas.py | 11 +- backend/danswer/configs/app_configs.py | 4 +- backend/danswer/danswerbot/slack/blocks.py | 9 +- .../slack/handlers/handle_message.py | 1 + backend/danswer/db/connector.py | 8 + backend/danswer/direct_qa/answer_question.py | 45 ++++- backend/danswer/document_index/vespa/index.py | 5 +- backend/danswer/prompts/constants.py | 1 + backend/danswer/prompts/prompt_utils.py | 9 + .../danswer/prompts/secondary_llm_flows.py | 71 +++++-- backend/danswer/search/models.py | 3 +- .../secondary_llm_flows/source_filter.py | 185 ++++++++++++++++++ .../{extract_filters.py => time_filter.py} | 9 +- backend/danswer/server/models.py | 1 + backend/danswer/server/search_backend.py | 20 +- .../danswer/utils/threadpool_concurrency.py | 38 ++++ backend/danswer/utils/timing.py | 58 +++--- .../docker_compose/docker-compose.dev.yml | 2 +- 18 files changed, 400 insertions(+), 80 deletions(-) create mode 100644 backend/danswer/prompts/prompt_utils.py create mode 100644 backend/danswer/secondary_llm_flows/source_filter.py rename backend/danswer/secondary_llm_flows/{extract_filters.py => time_filter.py} (95%) create mode 100644 backend/danswer/utils/threadpool_concurrency.py diff --git a/backend/danswer/chat/personas.py b/backend/danswer/chat/personas.py index bb8c6822a4a..9bc927cbb2e 100644 --- a/backend/danswer/chat/personas.py +++ b/backend/danswer/chat/personas.py @@ -1,4 +1,3 @@ -from datetime import datetime from typing import Any import yaml @@ -11,19 +10,13 @@ from danswer.db.models import DocumentSet as DocumentSetDBModel from danswer.db.models import Persona from danswer.db.models import ToolInfo +from danswer.prompts.prompt_utils import get_current_llm_day_time def build_system_text_from_persona(persona: Persona) -> str | None: text = (persona.system_text or "").strip() if persona.datetime_aware: - current_datetime = datetime.now() - # Format looks like: "October 16, 2023 14:30" - formatted_datetime = current_datetime.strftime("%B %d, %Y %H:%M") - - text += ( - "\n\nAdditional Information:\n" - f"\t- The current date and time is {formatted_datetime}." - ) + text += "\n\nAdditional Information:\n" f"\t- {get_current_llm_day_time()}." return text or None diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 7ceeab46cb5..8890522bb26 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -169,8 +169,8 @@ os.environ.get("DOC_TIME_DECAY") or 0.5 # Hits limit at 2 years by default ) FAVOR_RECENT_DECAY_MULTIPLIER = 2 -DISABLE_TIME_FILTER_EXTRACTION = ( - os.environ.get("DISABLE_TIME_FILTER_EXTRACTION", "").lower() == "true" +DISABLE_LLM_FILTER_EXTRACTION = ( + os.environ.get("DISABLE_LLM_FILTER_EXTRACTION", "").lower() == "true" ) # 1 edit per 2 characters, currently unused due to fuzzy match being too slow QUOTE_ALLOWED_ERROR_PERCENT = 0.05 diff --git a/backend/danswer/danswerbot/slack/blocks.py b/backend/danswer/danswerbot/slack/blocks.py index 435ec79c0c3..479a6e65f39 100644 --- a/backend/danswer/danswerbot/slack/blocks.py +++ b/backend/danswer/danswerbot/slack/blocks.py @@ -182,15 +182,22 @@ def build_qa_response_blocks( query_event_id: int, answer: str | None, quotes: list[DanswerQuote] | None, + source_filters: list[DocumentSource] | None, time_cutoff: datetime | None, favor_recent: bool, ) -> list[Block]: quotes_blocks: list[Block] = [] ai_answer_header = HeaderBlock(text="AI Answer") + filter_block: Block | None = None - if time_cutoff or favor_recent: + if time_cutoff or favor_recent or source_filters: filter_text = "Filters: " + if source_filters: + sources_str = ", ".join([s.value for s in source_filters]) + filter_text += f"`Sources in [{sources_str}]`" + if time_cutoff or favor_recent: + filter_text += " and " if time_cutoff is not None: time_str = time_cutoff.strftime("%b %d, %Y") filter_text += f"`Docs Updated >= {time_str}` " diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index b2b8ff59ef5..a8c0c6f0680 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -260,6 +260,7 @@ def _get_answer(question: QuestionRequest) -> QAResponse: query_event_id=answer.query_event_id, answer=answer.answer, quotes=answer.quotes, + source_filters=answer.source_type, time_cutoff=answer.time_cutoff, favor_recent=answer.favor_recent, ) diff --git a/backend/danswer/db/connector.py b/backend/danswer/db/connector.py index ffdcc22c617..49ba370a2fc 100644 --- a/backend/danswer/db/connector.py +++ b/backend/danswer/db/connector.py @@ -202,3 +202,11 @@ def fetch_latest_index_attempts_by_status( ), ) return cast(list[IndexAttempt], query.all()) + + +def fetch_unique_document_sources(db_session: Session) -> list[DocumentSource]: + distinct_sources = db_session.query(Connector.source).distinct().all() + + sources = [source[0] for source in distinct_sources] + + return sources diff --git a/backend/danswer/direct_qa/answer_question.py b/backend/danswer/direct_qa/answer_question.py index b4e64c713f5..489cb2eb65b 100644 --- a/backend/danswer/direct_qa/answer_question.py +++ b/backend/danswer/direct_qa/answer_question.py @@ -22,12 +22,14 @@ from danswer.search.search_runner import chunks_to_search_docs from danswer.search.search_runner import danswer_search from danswer.secondary_llm_flows.answer_validation import get_answer_validity -from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters +from danswer.secondary_llm_flows.source_filter import extract_question_source_filters +from danswer.secondary_llm_flows.time_filter import extract_question_time_filters from danswer.server.models import QAResponse from danswer.server.models import QuestionRequest from danswer.server.models import RerankedRetrievalDocs from danswer.server.utils import get_json_line from danswer.utils.logger import setup_logger +from danswer.utils.threadpool_concurrency import run_functions_in_parallel from danswer.utils.timing import log_function_time from danswer.utils.timing import log_generator_function_time @@ -52,9 +54,22 @@ def answer_qa_query( offset_count = question.offset if question.offset is not None else 0 logger.info(f"Received QA query: {query}") - time_cutoff, favor_recent = extract_question_time_filters(question) + functions_to_run: dict[Callable, tuple] = { + extract_question_time_filters: (question,), + extract_question_source_filters: (question, db_session), + query_intent: (query,), + } + + parallel_results = run_functions_in_parallel(functions_to_run) + + time_cutoff, favor_recent = parallel_results["extract_question_time_filters"] + source_filters = parallel_results["extract_question_source_filters"] + predicted_search, predicted_flow = parallel_results["query_intent"] + + # Modifies the question object but nothing upstream uses it question.filters.time_cutoff = time_cutoff question.favor_recent = favor_recent + question.filters.source_type = source_filters ranked_chunks, unranked_chunks, query_event_id = danswer_search( question=question, @@ -65,9 +80,6 @@ def answer_qa_query( rerank_metrics_callback=rerank_metrics_callback, ) - # TODO retire this - predicted_search, predicted_flow = query_intent(query) - if not ranked_chunks: return QAResponse( answer=None, @@ -77,6 +89,7 @@ def answer_qa_query( predicted_flow=predicted_flow, predicted_search=predicted_search, query_event_id=query_event_id, + source_type=source_filters, time_cutoff=time_cutoff, favor_recent=favor_recent, ) @@ -96,6 +109,7 @@ def answer_qa_query( predicted_flow=QueryFlow.SEARCH, predicted_search=predicted_search, query_event_id=query_event_id, + source_type=source_filters, time_cutoff=time_cutoff, favor_recent=favor_recent, ) @@ -113,6 +127,7 @@ def answer_qa_query( predicted_flow=predicted_flow, predicted_search=predicted_search, query_event_id=query_event_id, + source_type=source_filters, time_cutoff=time_cutoff, favor_recent=favor_recent, error_msg=str(e), @@ -159,6 +174,7 @@ def answer_qa_query( predicted_search=predicted_search, eval_res_valid=True if valid else False, query_event_id=query_event_id, + source_type=source_filters, time_cutoff=time_cutoff, favor_recent=favor_recent, error_msg=error_msg, @@ -172,6 +188,7 @@ def answer_qa_query( predicted_flow=predicted_flow, predicted_search=predicted_search, query_event_id=query_event_id, + source_type=source_filters, time_cutoff=time_cutoff, favor_recent=favor_recent, error_msg=error_msg, @@ -194,9 +211,22 @@ def answer_qa_query_stream( query = question.query offset_count = question.offset if question.offset is not None else 0 - time_cutoff, favor_recent = extract_question_time_filters(question) + functions_to_run: dict[Callable, tuple] = { + extract_question_time_filters: (question,), + extract_question_source_filters: (question, db_session), + query_intent: (query,), + } + + parallel_results = run_functions_in_parallel(functions_to_run) + + time_cutoff, favor_recent = parallel_results["extract_question_time_filters"] + source_filters = parallel_results["extract_question_source_filters"] + predicted_search, predicted_flow = parallel_results["query_intent"] + + # Modifies the question object but nothing upstream uses it question.filters.time_cutoff = time_cutoff question.favor_recent = favor_recent + question.filters.source_type = source_filters ranked_chunks, unranked_chunks, query_event_id = danswer_search( question=question, @@ -205,9 +235,6 @@ def answer_qa_query_stream( document_index=get_default_document_index(), ) - # TODO retire this - predicted_search, predicted_flow = query_intent(query) - top_docs = chunks_to_search_docs(ranked_chunks) unranked_top_docs = chunks_to_search_docs(unranked_chunks) diff --git a/backend/danswer/document_index/vespa/index.py b/backend/danswer/document_index/vespa/index.py index 3a810b7c296..6c9cd5e6d0c 100644 --- a/backend/danswer/document_index/vespa/index.py +++ b/backend/danswer/document_index/vespa/index.py @@ -335,7 +335,10 @@ def _build_time_filter( # CAREFUL touching this one, currently there is no second ACL double-check post retrieval filter_str += _build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list) - filter_str += _build_or_filters(SOURCE_TYPE, filters.source_type) + source_strs = ( + [s.value for s in filters.source_type] if filters.source_type else None + ) + filter_str += _build_or_filters(SOURCE_TYPE, source_strs) filter_str += _build_or_filters(DOCUMENT_SETS, filters.document_set) diff --git a/backend/danswer/prompts/constants.py b/backend/danswer/prompts/constants.py index e1ba5c47b1e..688896ab851 100644 --- a/backend/danswer/prompts/constants.py +++ b/backend/danswer/prompts/constants.py @@ -9,3 +9,4 @@ QUOTE_PAT = "Quote:" QUOTES_PAT_PLURAL = "Quotes:" INVALID_PAT = "Invalid:" +SOURCES_KEY = "sources" diff --git a/backend/danswer/prompts/prompt_utils.py b/backend/danswer/prompts/prompt_utils.py new file mode 100644 index 00000000000..4c0de783f85 --- /dev/null +++ b/backend/danswer/prompts/prompt_utils.py @@ -0,0 +1,9 @@ +from datetime import datetime + + +def get_current_llm_day_time() -> str: + current_datetime = datetime.now() + # Format looks like: "October 16, 2023 14:30" + formatted_datetime = current_datetime.strftime("%B %d, %Y %H:%M") + day_of_week = current_datetime.strftime("%A") + return f"The current day and time is {day_of_week} {formatted_datetime}" diff --git a/backend/danswer/prompts/secondary_llm_flows.py b/backend/danswer/prompts/secondary_llm_flows.py index d0abec9d0dd..5bce628b14a 100644 --- a/backend/danswer/prompts/secondary_llm_flows.py +++ b/backend/danswer/prompts/secondary_llm_flows.py @@ -2,6 +2,7 @@ from danswer.prompts.constants import ANSWERABLE_PAT from danswer.prompts.constants import GENERAL_SEP_PAT from danswer.prompts.constants import QUESTION_PAT +from danswer.prompts.constants import SOURCES_KEY from danswer.prompts.constants import THOUGHT_PAT @@ -31,21 +32,6 @@ """.strip() -TIME_FILTER_PROMPT = """ -You are a tool to identify time filters to apply to a user query for a downstream search \ -application. The downstream application is able to use a recency bias or apply a hard cutoff to \ -remove all documents before the cutoff. Identify the correct filters to apply for the user query. - -Always answer with ONLY a json which contains the keys "filter_type", "filter_value", \ -"value_multiple" and "date". - -The valid values for "filter_type" are "hard cutoff", "favors recent", or "not time sensitive". -The valid values for "filter_value" are "day", "week", "month", "quarter", "half", or "year". -The valid values for "value_multiple" is any number. -The valid values for "date" is a date in format MM/DD/YYYY. -""".strip() - - ANSWERABLE_PROMPT = f""" You are a helper tool to determine if a query is answerable using retrieval augmented generation. The main system will try to answer the user query based on ONLY the top 5 most relevant \ @@ -91,6 +77,61 @@ """.strip() +# Smaller followup prompts in time_filter.py +TIME_FILTER_PROMPT = """ +You are a tool to identify time filters to apply to a user query for a downstream search \ +application. The downstream application is able to use a recency bias or apply a hard cutoff to \ +remove all documents before the cutoff. Identify the correct filters to apply for the user query. + +The current day and time is {current_day_time_str}. + +Always answer with ONLY a json which contains the keys "filter_type", "filter_value", \ +"value_multiple" and "date". + +The valid values for "filter_type" are "hard cutoff", "favors recent", or "not time sensitive". +The valid values for "filter_value" are "day", "week", "month", "quarter", "half", or "year". +The valid values for "value_multiple" is any number. +The valid values for "date" is a date in format MM/DD/YYYY, ALWAYS follow this format. +""".strip() + + +# Smaller followup prompts in source_filter.py +# Known issue: LLMs like GPT-3.5 try to generalize. If the valid sources contains "web" but not +# "confluence" and the user asks for confluence related things, the LLM will select "web" since +# confluence is accessed as a website. This cannot be fixed without also reducing the capability +# to match things like repository->github, website->web, etc. +# This is generally not a big issue though as if the company has confluence, hopefully they add +# a connector for it or the user is aware that confluence has not been added. +SOURCE_FILTER_PROMPT = f""" +Given a user query, extract relevant source filters for use in a downstream search tool. +Respond with a json containing the source filters or null if no specific sources are referenced. +ONLY extract sources when the user is explicitly limiting the scope of where information is \ +coming from. +The user may provide invalid source filters, ignore those. + +The valid sources are: +{{valid_sources}} +{{web_source_warning}} +{{file_source_warning}} + + +ALWAYS answer with ONLY a json with the key "{SOURCES_KEY}". \ +The value for "{SOURCES_KEY}" must be null or a list of valid sources. + +Sample Response: +{{sample_response}} +""".strip() + +WEB_SOURCE_WARNING = """ +Note: The "web" source only applies to when the user specifies "website" in the query. \ +It does not apply to tools such as Confluence, GitHub, etc. which have a website. +""".strip() + +FILE_SOURCE_WARNING = """ +Note: The "file" source only applies to when the user refers to uploaded files in the query. +""".strip() + + # User the following for easy viewing of prompts if __name__ == "__main__": print(ANSWERABLE_PROMPT) diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index c53f50f171d..97c4db6f67e 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -5,6 +5,7 @@ from danswer.configs.app_configs import NUM_RERANKED_RESULTS from danswer.configs.app_configs import NUM_RETURNED_HITS +from danswer.configs.constants import DocumentSource from danswer.configs.model_configs import SKIP_RERANKING from danswer.indexing.models import DocAwareChunk from danswer.indexing.models import IndexChunk @@ -31,7 +32,7 @@ def embed(self, chunks: list[DocAwareChunk]) -> list[IndexChunk]: class BaseFilters(BaseModel): - source_type: list[str] | None = None + source_type: list[DocumentSource] | None = None document_set: list[str] | None = None time_cutoff: datetime | None = None diff --git a/backend/danswer/secondary_llm_flows/source_filter.py b/backend/danswer/secondary_llm_flows/source_filter.py new file mode 100644 index 00000000000..ed4bdbdf6a0 --- /dev/null +++ b/backend/danswer/secondary_llm_flows/source_filter.py @@ -0,0 +1,185 @@ +import json +import random + +from sqlalchemy.orm import Session + +from danswer.configs.app_configs import DISABLE_LLM_FILTER_EXTRACTION +from danswer.configs.constants import DocumentSource +from danswer.db.connector import fetch_unique_document_sources +from danswer.db.engine import get_sqlalchemy_engine +from danswer.llm.factory import get_default_llm +from danswer.llm.utils import dict_based_prompt_to_langchain_prompt +from danswer.prompts.constants import SOURCES_KEY +from danswer.prompts.secondary_llm_flows import FILE_SOURCE_WARNING +from danswer.prompts.secondary_llm_flows import SOURCE_FILTER_PROMPT +from danswer.prompts.secondary_llm_flows import WEB_SOURCE_WARNING +from danswer.server.models import QuestionRequest +from danswer.utils.logger import setup_logger +from danswer.utils.text_processing import extract_embedded_json +from danswer.utils.timing import log_function_time + +logger = setup_logger() + + +def strings_to_document_sources(source_strs: list[str]) -> list[DocumentSource]: + sources = [] + for s in source_strs: + try: + sources.append(DocumentSource(s)) + except ValueError: + logger.warning(f"Failed to translate {s} to a DocumentSource") + return sources + + +def _sample_document_sources( + valid_sources: list[DocumentSource], + num_sample: int, + allow_less: bool = True, +) -> list[DocumentSource]: + if len(valid_sources) < num_sample: + if not allow_less: + raise RuntimeError("Not enough sample Document Sources") + return random.sample(valid_sources, len(valid_sources)) + else: + return random.sample(valid_sources, num_sample) + + +@log_function_time() +def extract_source_filter( + query: str, db_session: Session +) -> list[DocumentSource] | None: + """Returns a list of valid sources for search or None if no specific sources were detected""" + + def _get_source_filter_messages( + query: str, + valid_sources: list[DocumentSource], + # Seems the LLM performs similarly without examples + show_samples: bool = False, + ) -> list[dict[str, str]]: + sample_json = { + SOURCES_KEY: [ + s.value + for s in _sample_document_sources( + valid_sources=valid_sources, num_sample=2 + ) + ] + } + + web_warning = WEB_SOURCE_WARNING if DocumentSource.WEB in valid_sources else "" + file_warning = ( + FILE_SOURCE_WARNING if DocumentSource.FILE in valid_sources else "" + ) + + msg_1_sources = _sample_document_sources( + valid_sources=valid_sources, num_sample=2 + ) + msg_1_source_str = " and ".join([s.capitalize() for s in msg_1_sources]) + + msg_2_sources = _sample_document_sources( + valid_sources=valid_sources, num_sample=2 + ) + + msg_2_real_source = msg_2_sources[0] + msg_2_fake_source_str = ( + msg_2_sources[1].value.capitalize() + if len(msg_2_sources) > 1 + else "Confluence" + ) + + messages = [ + { + "role": "system", + "content": SOURCE_FILTER_PROMPT.format( + valid_sources=[s.value for s in valid_sources], + web_source_warning=web_warning, + file_source_warning=file_warning, + sample_response=json.dumps(sample_json), + ), + }, + { + "role": "user", + "content": f"What documents in {msg_1_source_str} cover engineer onboarding", + }, + { + "role": "assistant", + "content": json.dumps({SOURCES_KEY: msg_1_sources}), + }, + {"role": "user", "content": "What's the latest on project Corgies?"}, + { + "role": "assistant", + "content": json.dumps({SOURCES_KEY: None}), + }, + { + "role": "user", + "content": f"What information from {msg_2_real_source.value.capitalize()} " + f"mentions {msg_2_fake_source_str}?", + }, + { + "role": "assistant", + "content": json.dumps({SOURCES_KEY: [msg_2_real_source]}), + }, + { + "role": "user", + "content": "What page from Danswer contains debugging instruction on segfault", + }, + { + "role": "assistant", + "content": json.dumps({SOURCES_KEY: None}), + }, + {"role": "user", "content": query}, + ] + + if show_samples: + return messages + + # Only system prompt and latest user query + return [messages[0], messages[-1]] + + def _extract_source_filters_from_llm_out( + model_out: str, + ) -> list[DocumentSource] | None: + try: + sources_dict = extract_embedded_json(model_out) + sources_list = sources_dict.get(SOURCES_KEY) + if not sources_list: + return None + + return strings_to_document_sources(sources_list) + except ValueError: + logger.warning("LLM failed to provide a valid Source Filter output") + return None + + valid_sources = fetch_unique_document_sources(db_session) + if not valid_sources: + return None + + messages = _get_source_filter_messages(query=query, valid_sources=valid_sources) + filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) + model_output = get_default_llm().invoke(filled_llm_prompt) + logger.debug(model_output) + + return _extract_source_filters_from_llm_out(model_output) + + +def extract_question_source_filters( + question: QuestionRequest, + db_session: Session, + disable_llm_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION, +) -> list[DocumentSource] | None: + # If specified in the question, don't update + if question.filters.source_type: + return question.filters.source_type + + if not question.enable_auto_detect_filters or disable_llm_extraction: + return None + + return extract_source_filter(question.query, db_session) + + +if __name__ == "__main__": + # Just for testing purposes + with Session(get_sqlalchemy_engine()) as db_session: + while True: + user_input = input("Query to Extract Sources: ") + sources = extract_source_filter(user_input, db_session) + print(sources) diff --git a/backend/danswer/secondary_llm_flows/extract_filters.py b/backend/danswer/secondary_llm_flows/time_filter.py similarity index 95% rename from backend/danswer/secondary_llm_flows/extract_filters.py rename to backend/danswer/secondary_llm_flows/time_filter.py index 00cd024df02..be06d23cc64 100644 --- a/backend/danswer/secondary_llm_flows/extract_filters.py +++ b/backend/danswer/secondary_llm_flows/time_filter.py @@ -5,9 +5,10 @@ from dateutil.parser import parse -from danswer.configs.app_configs import DISABLE_TIME_FILTER_EXTRACTION +from danswer.configs.app_configs import DISABLE_LLM_FILTER_EXTRACTION from danswer.llm.factory import get_default_llm from danswer.llm.utils import dict_based_prompt_to_langchain_prompt +from danswer.prompts.prompt_utils import get_current_llm_day_time from danswer.prompts.secondary_llm_flows import TIME_FILTER_PROMPT from danswer.server.models import QuestionRequest from danswer.utils.logger import setup_logger @@ -51,7 +52,9 @@ def _get_time_filter_messages(query: str) -> list[dict[str, str]]: messages = [ { "role": "system", - "content": TIME_FILTER_PROMPT, + "content": TIME_FILTER_PROMPT.format( + current_day_time_str=get_current_llm_day_time() + ), }, { "role": "user", @@ -152,7 +155,7 @@ def _extract_time_filter_from_llm_out( def extract_question_time_filters( question: QuestionRequest, - disable_llm_extraction: bool = DISABLE_TIME_FILTER_EXTRACTION, + disable_llm_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION, ) -> tuple[datetime | None, bool]: time_cutoff = question.filters.time_cutoff favor_recent = question.favor_recent diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index 2b145873acd..d9b81c2507d 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -290,6 +290,7 @@ class SearchResponse(BaseModel): top_ranked_docs: list[SearchDoc] | None lower_ranked_docs: list[SearchDoc] | None query_event_id: int + source_type: list[DocumentSource] | None time_cutoff: datetime | None favor_recent: bool diff --git a/backend/danswer/server/search_backend.py b/backend/danswer/server/search_backend.py index 38f15fe00f5..a49bf0eae32 100644 --- a/backend/danswer/server/search_backend.py +++ b/backend/danswer/server/search_backend.py @@ -1,3 +1,5 @@ +from collections.abc import Callable + from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException @@ -20,9 +22,10 @@ from danswer.search.models import IndexFilters from danswer.search.search_runner import chunks_to_search_docs from danswer.search.search_runner import danswer_search -from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters from danswer.secondary_llm_flows.query_validation import get_query_answerability from danswer.secondary_llm_flows.query_validation import stream_query_answerability +from danswer.secondary_llm_flows.source_filter import extract_question_source_filters +from danswer.secondary_llm_flows.time_filter import extract_question_time_filters from danswer.server.models import HelperResponse from danswer.server.models import QAFeedbackRequest from danswer.server.models import QAResponse @@ -32,6 +35,7 @@ from danswer.server.models import SearchFeedbackRequest from danswer.server.models import SearchResponse from danswer.utils.logger import setup_logger +from danswer.utils.threadpool_concurrency import run_functions_in_parallel logger = setup_logger() @@ -125,9 +129,19 @@ def handle_search_request( query = question.query logger.info(f"Received {question.search_type.value} " f"search query: {query}") - time_cutoff, favor_recent = extract_question_time_filters(question) + functions_to_run: dict[Callable, tuple] = { + extract_question_time_filters: (question,), + extract_question_source_filters: (question, db_session), + } + + parallel_results = run_functions_in_parallel(functions_to_run) + + time_cutoff, favor_recent = parallel_results["extract_question_time_filters"] + source_filters = parallel_results["extract_question_source_filters"] + question.filters.time_cutoff = time_cutoff question.favor_recent = favor_recent + question.filters.source_type = source_filters ranked_chunks, unranked_chunks, query_event_id = danswer_search( question=question, @@ -141,6 +155,7 @@ def handle_search_request( top_ranked_docs=None, lower_ranked_docs=None, query_event_id=query_event_id, + source_type=source_filters, time_cutoff=time_cutoff, favor_recent=favor_recent, ) @@ -152,6 +167,7 @@ def handle_search_request( top_ranked_docs=top_docs, lower_ranked_docs=lower_top_docs or None, query_event_id=query_event_id, + source_type=source_filters, time_cutoff=time_cutoff, favor_recent=favor_recent, ) diff --git a/backend/danswer/utils/threadpool_concurrency.py b/backend/danswer/utils/threadpool_concurrency.py new file mode 100644 index 00000000000..6927148014f --- /dev/null +++ b/backend/danswer/utils/threadpool_concurrency.py @@ -0,0 +1,38 @@ +from collections.abc import Callable +from concurrent.futures import as_completed +from concurrent.futures import ThreadPoolExecutor +from typing import Any + +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +def run_functions_in_parallel( + functions_with_args: dict[Callable, tuple] +) -> dict[str, Any]: + """ + Executes multiple functions in parallel and returns a dictionary with the results. + + Args: + functions_with_args (dict): A dictionary mapping functions to a tuple of arguments. + + Returns: + dict: A dictionary mapping function names to their results or error messages. + """ + results = {} + with ThreadPoolExecutor(max_workers=len(functions_with_args)) as executor: + future_to_function = { + executor.submit(func, *args): func.__name__ + for func, args in functions_with_args.items() + } + + for future in as_completed(future_to_function): + function_name = future_to_function[future] + try: + results[function_name] = future.result() + except Exception as e: + logger.exception(f"Function {function_name} failed due to {e}") + raise + + return results diff --git a/backend/danswer/utils/timing.py b/backend/danswer/utils/timing.py index c92d91c0160..f01ec33c895 100644 --- a/backend/danswer/utils/timing.py +++ b/backend/danswer/utils/timing.py @@ -2,6 +2,7 @@ from collections.abc import Callable from collections.abc import Generator from collections.abc import Iterator +from functools import wraps from typing import Any from typing import cast from typing import TypeVar @@ -14,53 +15,38 @@ FG = TypeVar("FG", bound=Callable[..., Generator | Iterator]) -def log_function_time( - func_name: str | None = None, -) -> Callable[[F], F]: - """Build a timing wrapper for a function. Logs how long the function took to run. - Use like: - - @log_function_time() - def my_func(): - ... - """ - - def timing_wrapper(func: F) -> F: +def log_function_time(func_name: str | None = None) -> Callable[[F], F]: + def decorator(func: F) -> F: + @wraps(func) def wrapped_func(*args: Any, **kwargs: Any) -> Any: start_time = time.time() result = func(*args, **kwargs) - logger.info( - f"{func_name or func.__name__} took {time.time() - start_time} seconds" - ) + elapsed_time = time.time() - start_time + logger.info(f"{func_name or func.__name__} took {elapsed_time} seconds") return result return cast(F, wrapped_func) - return timing_wrapper - - -def log_generator_function_time( - func_name: str | None = None, -) -> Callable[[FG], FG]: - """Build a timing wrapper for a function which returns a generator. - Logs how long the function took to run. - Use like: + return decorator - @log_generator_function_time() - def my_func(): - ... - yield X - ... - """ - def timing_wrapper(func: FG) -> FG: +def log_generator_function_time(func_name: str | None = None) -> Callable[[FG], FG]: + def decorator(func: FG) -> FG: + @wraps(func) def wrapped_func(*args: Any, **kwargs: Any) -> Any: start_time = time.time() - yield from func(*args, **kwargs) - logger.info( - f"{func_name or func.__name__} took {time.time() - start_time} seconds" - ) + gen = func(*args, **kwargs) + try: + value = next(gen) + while True: + yield value + value = next(gen) + except StopIteration: + pass + finally: + elapsed_time = time.time() - start_time + logger.info(f"{func_name or func.__name__} took {elapsed_time} seconds") return cast(FG, wrapped_func) - return timing_wrapper + return decorator diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 3de57ec6133..9591e26b05b 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -31,7 +31,7 @@ services: - GOOGLE_OAUTH_CLIENT_SECRET=${GOOGLE_OAUTH_CLIENT_SECRET:-} - DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-} - NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-} - - DISABLE_TIME_FILTER_EXTRACTION=${DISABLE_TIME_FILTER_EXTRACTION:-} + - DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-} # Don't change the NLP model configs unless you know what you're doing - DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-} - NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}