Skip to content
This repository has been archived by the owner on Sep 11, 2024. It is now read-only.

Commit

Permalink
Compile sensitive terms up front for performance
Browse files Browse the repository at this point in the history
The sensitive query change led to performance benchmark failures in the
backend. Compiling the regex match up front reduces the impact.
  • Loading branch information
olaughter committed Apr 3, 2024
1 parent a78d0ae commit 8871380
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 25 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "cpr-data-access"
version = "0.5.6"
version = "0.5.8"
description = ""
authors = ["CPR Tech <[email protected]>"]
readme = "README.md"
Expand Down
26 changes: 11 additions & 15 deletions src/cpr_data_access/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,15 @@ def is_sensitive_query(text: str, sensitive_terms: set) -> bool:
"""
sensitive_terms_in_query = [
term
for term in sensitive_terms
if re.findall(r"\b" + re.escape(term) + r"\b", text.lower())
term for term in sensitive_terms if re.findall(term, text.lower())
]

if sensitive_terms_in_query:
shortest_sensitive_term = min(sensitive_terms_in_query, key=len)
terms = [term.pattern.strip("\\b") for term in sensitive_terms_in_query]
shortest_sensitive_term = min(terms, key=len)
shortest_sensitive_word_count = len(shortest_sensitive_term.split(" "))

remaining_sensitive_word_count = sum(
[
len(term.split())
for term in sensitive_terms_in_query
if term != shortest_sensitive_term
]
[len(term.split()) for term in terms if term != shortest_sensitive_term]
)

query_word_count = len(text.split())
Expand All @@ -54,7 +48,7 @@ def is_sensitive_query(text: str, sensitive_terms: set) -> bool:
return False


def load_sensitive_query_terms() -> set[str]:
def load_sensitive_query_terms() -> set[re.Pattern]:
"""
Return sensitive query terms from the first column of a TSV file.
Expand All @@ -65,10 +59,12 @@ def load_sensitive_query_terms() -> set[str]:
tsv_path = Path(__file__).parent / "resources" / "sensitive_query_terms.tsv"
with open(tsv_path, "r") as tsv_file:
reader = csv.DictReader(tsv_file, delimiter="\t")

sensitive_terms = set([row["keyword"].lower().strip() for row in reader])

return sensitive_terms
sensitive_terms = []
for row in reader:
keyword = row["keyword"].lower().strip()
keyword_regex = re.compile(r"\b" + re.escape(keyword) + r"\b")
sensitive_terms.append(keyword_regex)
return set(sensitive_terms)


def dig(obj: Union[list, dict], *fields: Any, default: Any = None) -> Any:
Expand Down
32 changes: 32 additions & 0 deletions tests/test_search_adaptors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from unittest.mock import patch
from timeit import timeit
from typing import Mapping

import pytest

Expand All @@ -25,6 +27,17 @@ def vespa_search(cert_directory: str, request: SearchParameters) -> SearchRespon
return response


def profile_search(
fake_vespa_credentials, params: Mapping[str, str], n: int = 25
) -> float:
t = timeit(
lambda: vespa_search(fake_vespa_credentials, SearchParameters(**params)),
number=n,
)
avg_ms = (t / n) * 1000
return avg_ms


@pytest.mark.vespa
def test_vespa_search_adaptor__works(fake_vespa_credentials):
request = SearchParameters(query_string="the")
Expand All @@ -36,6 +49,25 @@ def test_vespa_search_adaptor__works(fake_vespa_credentials):
assert total_passage_count == response.total_hits


@pytest.mark.parametrize(
"params",
(
{"query_string": "the"},
{"query_string": "climate change"},
{"query_string": "fuel", "exact_search": True},
{"all_results": True, "documents_only": True},
{"query_string": "fuel", "sort_by": "date", "sort_order": "asc"},
{"query_string": "forest", "filter": {"family_category": "CCLW"}},
),
)
@pytest.mark.vespa
def test_vespa_search_adaptor__is_fast_enough(fake_vespa_credentials, params):
MAX_SPEED_MS = 750

avg_ms = profile_search(fake_vespa_credentials, params=params)
assert avg_ms <= MAX_SPEED_MS


@pytest.mark.vespa
@pytest.mark.parametrize(
"family_ids",
Expand Down
6 changes: 5 additions & 1 deletion tests/test_search_requests.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from unittest.mock import patch
import re

import pytest

Expand All @@ -15,7 +16,10 @@
from cpr_data_access.embedding import Embedder


@patch("cpr_data_access.vespa.SENSITIVE_QUERY_TERMS", {"sensitive"})
@patch(
"cpr_data_access.vespa.SENSITIVE_QUERY_TERMS",
{re.compile(r"\b" + re.escape("sensitive") + r"\b")},
)
@pytest.mark.parametrize(
"query_type, params",
[
Expand Down
18 changes: 10 additions & 8 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import patch, mock_open

import pytest
from cpr_data_access.utils import (
dig,
Expand All @@ -7,11 +9,11 @@
unflatten_json,
)

TEST_SENSITIVE_QUERY_TERMS = (
"word",
"test term",
"another phrase example",
)
TEST_SENSITIVE_QUERY_TERMS = """group_name\tkeyword
type\tWord
type\tTest Term
type\tAnother Phrase Example
"""


@pytest.mark.parametrize(
Expand All @@ -33,9 +35,9 @@
),
)
def test_is_sensitive_query(expected, text):
assert (
is_sensitive_query(text, sensitive_terms=TEST_SENSITIVE_QUERY_TERMS) == expected
)
with patch("builtins.open", mock_open(read_data=TEST_SENSITIVE_QUERY_TERMS)):
sensitive_terms = load_sensitive_query_terms()
assert is_sensitive_query(text, sensitive_terms=sensitive_terms) == expected


def test_load_sensitive_query_terms():
Expand Down

0 comments on commit 8871380

Please sign in to comment.