Skip to content

Commit

Permalink
Source Filter Extraction (#708)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuhongsun96 authored Nov 7, 2023
1 parent 4f64444 commit 0125d8a
Show file tree
Hide file tree
Showing 18 changed files with 400 additions and 80 deletions.
11 changes: 2 additions & 9 deletions backend/danswer/chat/personas.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from datetime import datetime
from typing import Any

import yaml
Expand All @@ -11,19 +10,13 @@
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.db.models import Persona
from danswer.db.models import ToolInfo
from danswer.prompts.prompt_utils import get_current_llm_day_time


def build_system_text_from_persona(persona: Persona) -> str | None:
text = (persona.system_text or "").strip()
if persona.datetime_aware:
current_datetime = datetime.now()
# Format looks like: "October 16, 2023 14:30"
formatted_datetime = current_datetime.strftime("%B %d, %Y %H:%M")

text += (
"\n\nAdditional Information:\n"
f"\t- The current date and time is {formatted_datetime}."
)
text += "\n\nAdditional Information:\n" f"\t- {get_current_llm_day_time()}."

return text or None

Expand Down
4 changes: 2 additions & 2 deletions backend/danswer/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@
os.environ.get("DOC_TIME_DECAY") or 0.5 # Hits limit at 2 years by default
)
FAVOR_RECENT_DECAY_MULTIPLIER = 2
DISABLE_TIME_FILTER_EXTRACTION = (
os.environ.get("DISABLE_TIME_FILTER_EXTRACTION", "").lower() == "true"
DISABLE_LLM_FILTER_EXTRACTION = (
os.environ.get("DISABLE_LLM_FILTER_EXTRACTION", "").lower() == "true"
)
# 1 edit per 2 characters, currently unused due to fuzzy match being too slow
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
Expand Down
9 changes: 8 additions & 1 deletion backend/danswer/danswerbot/slack/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,22 @@ def build_qa_response_blocks(
query_event_id: int,
answer: str | None,
quotes: list[DanswerQuote] | None,
source_filters: list[DocumentSource] | None,
time_cutoff: datetime | None,
favor_recent: bool,
) -> list[Block]:
quotes_blocks: list[Block] = []

ai_answer_header = HeaderBlock(text="AI Answer")

filter_block: Block | None = None
if time_cutoff or favor_recent:
if time_cutoff or favor_recent or source_filters:
filter_text = "Filters: "
if source_filters:
sources_str = ", ".join([s.value for s in source_filters])
filter_text += f"`Sources in [{sources_str}]`"
if time_cutoff or favor_recent:
filter_text += " and "
if time_cutoff is not None:
time_str = time_cutoff.strftime("%b %d, %Y")
filter_text += f"`Docs Updated >= {time_str}` "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def _get_answer(question: QuestionRequest) -> QAResponse:
query_event_id=answer.query_event_id,
answer=answer.answer,
quotes=answer.quotes,
source_filters=answer.source_type,
time_cutoff=answer.time_cutoff,
favor_recent=answer.favor_recent,
)
Expand Down
8 changes: 8 additions & 0 deletions backend/danswer/db/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,11 @@ def fetch_latest_index_attempts_by_status(
),
)
return cast(list[IndexAttempt], query.all())


def fetch_unique_document_sources(db_session: Session) -> list[DocumentSource]:
distinct_sources = db_session.query(Connector.source).distinct().all()

sources = [source[0] for source in distinct_sources]

return sources
45 changes: 36 additions & 9 deletions backend/danswer/direct_qa/answer_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
from danswer.search.search_runner import chunks_to_search_docs
from danswer.search.search_runner import danswer_search
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters
from danswer.secondary_llm_flows.source_filter import extract_question_source_filters
from danswer.secondary_llm_flows.time_filter import extract_question_time_filters
from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest
from danswer.server.models import RerankedRetrievalDocs
from danswer.server.utils import get_json_line
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
from danswer.utils.timing import log_function_time
from danswer.utils.timing import log_generator_function_time

Expand All @@ -52,9 +54,22 @@ def answer_qa_query(
offset_count = question.offset if question.offset is not None else 0
logger.info(f"Received QA query: {query}")

time_cutoff, favor_recent = extract_question_time_filters(question)
functions_to_run: dict[Callable, tuple] = {
extract_question_time_filters: (question,),
extract_question_source_filters: (question, db_session),
query_intent: (query,),
}

parallel_results = run_functions_in_parallel(functions_to_run)

time_cutoff, favor_recent = parallel_results["extract_question_time_filters"]
source_filters = parallel_results["extract_question_source_filters"]
predicted_search, predicted_flow = parallel_results["query_intent"]

# Modifies the question object but nothing upstream uses it
question.filters.time_cutoff = time_cutoff
question.favor_recent = favor_recent
question.filters.source_type = source_filters

ranked_chunks, unranked_chunks, query_event_id = danswer_search(
question=question,
Expand All @@ -65,9 +80,6 @@ def answer_qa_query(
rerank_metrics_callback=rerank_metrics_callback,
)

# TODO retire this
predicted_search, predicted_flow = query_intent(query)

if not ranked_chunks:
return QAResponse(
answer=None,
Expand All @@ -77,6 +89,7 @@ def answer_qa_query(
predicted_flow=predicted_flow,
predicted_search=predicted_search,
query_event_id=query_event_id,
source_type=source_filters,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
)
Expand All @@ -96,6 +109,7 @@ def answer_qa_query(
predicted_flow=QueryFlow.SEARCH,
predicted_search=predicted_search,
query_event_id=query_event_id,
source_type=source_filters,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
)
Expand All @@ -113,6 +127,7 @@ def answer_qa_query(
predicted_flow=predicted_flow,
predicted_search=predicted_search,
query_event_id=query_event_id,
source_type=source_filters,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
error_msg=str(e),
Expand Down Expand Up @@ -159,6 +174,7 @@ def answer_qa_query(
predicted_search=predicted_search,
eval_res_valid=True if valid else False,
query_event_id=query_event_id,
source_type=source_filters,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
error_msg=error_msg,
Expand All @@ -172,6 +188,7 @@ def answer_qa_query(
predicted_flow=predicted_flow,
predicted_search=predicted_search,
query_event_id=query_event_id,
source_type=source_filters,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
error_msg=error_msg,
Expand All @@ -194,9 +211,22 @@ def answer_qa_query_stream(
query = question.query
offset_count = question.offset if question.offset is not None else 0

time_cutoff, favor_recent = extract_question_time_filters(question)
functions_to_run: dict[Callable, tuple] = {
extract_question_time_filters: (question,),
extract_question_source_filters: (question, db_session),
query_intent: (query,),
}

parallel_results = run_functions_in_parallel(functions_to_run)

time_cutoff, favor_recent = parallel_results["extract_question_time_filters"]
source_filters = parallel_results["extract_question_source_filters"]
predicted_search, predicted_flow = parallel_results["query_intent"]

# Modifies the question object but nothing upstream uses it
question.filters.time_cutoff = time_cutoff
question.favor_recent = favor_recent
question.filters.source_type = source_filters

ranked_chunks, unranked_chunks, query_event_id = danswer_search(
question=question,
Expand All @@ -205,9 +235,6 @@ def answer_qa_query_stream(
document_index=get_default_document_index(),
)

# TODO retire this
predicted_search, predicted_flow = query_intent(query)

top_docs = chunks_to_search_docs(ranked_chunks)
unranked_top_docs = chunks_to_search_docs(unranked_chunks)

Expand Down
5 changes: 4 additions & 1 deletion backend/danswer/document_index/vespa/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,10 @@ def _build_time_filter(
# CAREFUL touching this one, currently there is no second ACL double-check post retrieval
filter_str += _build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list)

filter_str += _build_or_filters(SOURCE_TYPE, filters.source_type)
source_strs = (
[s.value for s in filters.source_type] if filters.source_type else None
)
filter_str += _build_or_filters(SOURCE_TYPE, source_strs)

filter_str += _build_or_filters(DOCUMENT_SETS, filters.document_set)

Expand Down
1 change: 1 addition & 0 deletions backend/danswer/prompts/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
QUOTE_PAT = "Quote:"
QUOTES_PAT_PLURAL = "Quotes:"
INVALID_PAT = "Invalid:"
SOURCES_KEY = "sources"
9 changes: 9 additions & 0 deletions backend/danswer/prompts/prompt_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from datetime import datetime


def get_current_llm_day_time() -> str:
current_datetime = datetime.now()
# Format looks like: "October 16, 2023 14:30"
formatted_datetime = current_datetime.strftime("%B %d, %Y %H:%M")
day_of_week = current_datetime.strftime("%A")
return f"The current day and time is {day_of_week} {formatted_datetime}"
71 changes: 56 additions & 15 deletions backend/danswer/prompts/secondary_llm_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from danswer.prompts.constants import ANSWERABLE_PAT
from danswer.prompts.constants import GENERAL_SEP_PAT
from danswer.prompts.constants import QUESTION_PAT
from danswer.prompts.constants import SOURCES_KEY
from danswer.prompts.constants import THOUGHT_PAT


Expand Down Expand Up @@ -31,21 +32,6 @@
""".strip()


TIME_FILTER_PROMPT = """
You are a tool to identify time filters to apply to a user query for a downstream search \
application. The downstream application is able to use a recency bias or apply a hard cutoff to \
remove all documents before the cutoff. Identify the correct filters to apply for the user query.
Always answer with ONLY a json which contains the keys "filter_type", "filter_value", \
"value_multiple" and "date".
The valid values for "filter_type" are "hard cutoff", "favors recent", or "not time sensitive".
The valid values for "filter_value" are "day", "week", "month", "quarter", "half", or "year".
The valid values for "value_multiple" is any number.
The valid values for "date" is a date in format MM/DD/YYYY.
""".strip()


ANSWERABLE_PROMPT = f"""
You are a helper tool to determine if a query is answerable using retrieval augmented generation.
The main system will try to answer the user query based on ONLY the top 5 most relevant \
Expand Down Expand Up @@ -91,6 +77,61 @@
""".strip()


# Smaller followup prompts in time_filter.py
TIME_FILTER_PROMPT = """
You are a tool to identify time filters to apply to a user query for a downstream search \
application. The downstream application is able to use a recency bias or apply a hard cutoff to \
remove all documents before the cutoff. Identify the correct filters to apply for the user query.
The current day and time is {current_day_time_str}.
Always answer with ONLY a json which contains the keys "filter_type", "filter_value", \
"value_multiple" and "date".
The valid values for "filter_type" are "hard cutoff", "favors recent", or "not time sensitive".
The valid values for "filter_value" are "day", "week", "month", "quarter", "half", or "year".
The valid values for "value_multiple" is any number.
The valid values for "date" is a date in format MM/DD/YYYY, ALWAYS follow this format.
""".strip()


# Smaller followup prompts in source_filter.py
# Known issue: LLMs like GPT-3.5 try to generalize. If the valid sources contains "web" but not
# "confluence" and the user asks for confluence related things, the LLM will select "web" since
# confluence is accessed as a website. This cannot be fixed without also reducing the capability
# to match things like repository->github, website->web, etc.
# This is generally not a big issue though as if the company has confluence, hopefully they add
# a connector for it or the user is aware that confluence has not been added.
SOURCE_FILTER_PROMPT = f"""
Given a user query, extract relevant source filters for use in a downstream search tool.
Respond with a json containing the source filters or null if no specific sources are referenced.
ONLY extract sources when the user is explicitly limiting the scope of where information is \
coming from.
The user may provide invalid source filters, ignore those.
The valid sources are:
{{valid_sources}}
{{web_source_warning}}
{{file_source_warning}}
ALWAYS answer with ONLY a json with the key "{SOURCES_KEY}". \
The value for "{SOURCES_KEY}" must be null or a list of valid sources.
Sample Response:
{{sample_response}}
""".strip()

WEB_SOURCE_WARNING = """
Note: The "web" source only applies to when the user specifies "website" in the query. \
It does not apply to tools such as Confluence, GitHub, etc. which have a website.
""".strip()

FILE_SOURCE_WARNING = """
Note: The "file" source only applies to when the user refers to uploaded files in the query.
""".strip()


# User the following for easy viewing of prompts
if __name__ == "__main__":
print(ANSWERABLE_PROMPT)
3 changes: 2 additions & 1 deletion backend/danswer/search/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.constants import DocumentSource
from danswer.configs.model_configs import SKIP_RERANKING
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
Expand All @@ -31,7 +32,7 @@ def embed(self, chunks: list[DocAwareChunk]) -> list[IndexChunk]:


class BaseFilters(BaseModel):
source_type: list[str] | None = None
source_type: list[DocumentSource] | None = None
document_set: list[str] | None = None
time_cutoff: datetime | None = None

Expand Down
Loading

1 comment on commit 0125d8a

@vercel
Copy link

@vercel vercel bot commented on 0125d8a Nov 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.