Skip to content
This repository has been archived by the owner on Sep 11, 2024. It is now read-only.

Commit

Permalink
Make filters exclusive to each other
Browse files Browse the repository at this point in the history
Previously, keyword filters where joined by an OR. This made sense for
searching within each keyword filter, which might have multiple geographies
or categories to include, but was not what we wanted when combining with
other filters. The result being that searching for a country and a category
would return a full list of everything from both, rather than just the
intersection between the two.
  • Loading branch information
olaughter committed Mar 13, 2024
1 parent ae0a4a2 commit 165748e
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 56 deletions.
44 changes: 16 additions & 28 deletions src/cpr_data_access/models/search.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from datetime import datetime
import re
from typing import List, Mapping, Optional, Sequence, Union
from typing import List, Optional, Sequence

from pydantic import BaseModel, field_validator
from pydantic import BaseModel, ConfigDict, field_validator

from cpr_data_access.exceptions import QueryError

Expand All @@ -24,6 +24,19 @@
ID_PATTERN = re.compile(rf"{_ID_ELEMENT}\.{_ID_ELEMENT}\.{_ID_ELEMENT}\.{_ID_ELEMENT}")


class KeywordFilters(BaseModel):
"""Filterable fields in a search request"""

family_geography: Sequence[str] = []
family_category: Sequence[str] = []
document_languages: Sequence[str] = []
family_source: Sequence[str] = []

model_config: ConfigDict = {
"extra": "forbid",
}


class SearchParameters(BaseModel):
"""Parameters for a search request"""

Expand All @@ -35,7 +48,7 @@ class SearchParameters(BaseModel):
family_ids: Optional[Sequence[str]] = None
document_ids: Optional[Sequence[str]] = None

keyword_filters: Optional[Mapping[str, Union[str, Sequence[str]]]] = None
keyword_filters: Optional[KeywordFilters] = None
year_range: Optional[tuple[Optional[int], Optional[int]]] = None

sort_by: Optional[str] = None
Expand Down Expand Up @@ -100,31 +113,6 @@ def sort_order_must_be_valid(cls, sort_order):
)
return sort_order

@field_validator("keyword_filters")
def keyword_filters_must_be_valid(cls, keyword_filters):
"""Validate that the keyword filters are valid."""
if keyword_filters is not None:
for field_key, values in keyword_filters.items():
if field_key not in filter_fields.values():
raise QueryError(
f"Invalid keyword filter: {field_key}. keyword_filters must be "
f"a subset of: {list(filter_fields.values())}"
)

# convert single values to lists to make things easier later on
if not isinstance(values, list):
keyword_filters[field_key] = [values]

for value in keyword_filters[field_key]:
if not isinstance(value, str):
raise QueryError(
"Invalid keyword filter value: "
f"{{{field_key}: {value}}}. "
"Keyword filter values must be strings."
)

return keyword_filters


class Hit(BaseModel):
"""Common model for all search result hits."""
Expand Down
32 changes: 18 additions & 14 deletions src/cpr_data_access/yql_builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from string import Template
from typing import Optional

from cpr_data_access.models.search import SearchParameters
from cpr_data_access.models.search import KeywordFilters, SearchParameters


def sanitize(user_input: str) -> str:
Expand Down Expand Up @@ -123,17 +123,15 @@ def build_document_filter(self) -> Optional[str]:
return f"(document_import_id in({documents}))"
return None

def build_keyword_filter(self) -> Optional[str]:
"""Create the part of the query that adds keyword filters"""
keyword_filters = self.params.keyword_filters
if keyword_filters:
filters = []
for field_name, values in keyword_filters.items():
for value in values:
filters.append(f'({field_name} contains "{sanitize(value)}")')
if filters:
return f"({' or '.join(filters)})"
return None
def _inclusive_keyword_filters(
self, keyword_filters: KeywordFilters, field_name: str
):
values = getattr(keyword_filters, field_name)
filters = []
for value in values:
filters.append(f'({field_name} contains "{sanitize(value)}")')
if filters:
return f"({' or '.join(filters)})"

def build_year_start_filter(self) -> Optional[str]:
"""Create the part of the query that filters on a year range"""
Expand All @@ -157,7 +155,11 @@ def build_where_clause(self) -> str:
filters.append(self.build_search_term())
filters.append(self.build_family_filter())
filters.append(self.build_document_filter())
filters.append(self.build_keyword_filter())
if kf := self.params.keyword_filters:
filters.append(self._inclusive_keyword_filters(kf, "family_geography"))
filters.append(self._inclusive_keyword_filters(kf, "family_category"))
filters.append(self._inclusive_keyword_filters(kf, "document_languages"))
filters.append(self._inclusive_keyword_filters(kf, "family_source"))
filters.append(self.build_year_start_filter())
filters.append(self.build_year_end_filter())
return " and ".join([f for f in filters if f]) # Remove empty
Expand Down Expand Up @@ -196,7 +198,9 @@ def to_str(self) -> str:
exact_match=False,
limit=10,
max_hits_per_family=10,
keyword_filters={"document_languages": "value", "family_source": "value"},
keyword_filters=KeywordFilters(
**{"document_languages": "value", "family_source": "value"}
),
year_range=(2000, 2020),
continuation_token=None,
)
Expand Down
33 changes: 19 additions & 14 deletions tests/test_search_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import pytest

from vespa.exceptions import VespaError
from pydantic import ValidationError

from cpr_data_access.models.search import SearchParameters
from cpr_data_access.models.search import KeywordFilters, SearchParameters
from cpr_data_access.vespa import build_vespa_request_body, VespaErrorDetails
from cpr_data_access.yql_builder import YQLBuilder, sanitize
from cpr_data_access.exceptions import QueryError
Expand Down Expand Up @@ -137,16 +138,18 @@ def test_whether_an_invalid_sort_order_raises_a_queryerror():
["family_geography", "family_category", "document_languages", "family_source"],
)
def test_whether_valid_filter_fields_are_accepted(field):
params = SearchParameters(query_string="test", keyword_filters={field: "value"})
keyword_filters = KeywordFilters(**{field: ["value"]})
params = SearchParameters(query_string="test", keyword_filters=keyword_filters)
assert isinstance(params, SearchParameters)


def test_whether_an_invalid_filter_fields_raises_a_valueerror():
with pytest.raises(QueryError) as excinfo:
with pytest.raises(ValidationError) as excinfo:
SearchParameters(
query_string="test", keyword_filters={"invalid_field": "value"}
query_string="test",
keyword_filters=KeywordFilters(**{"invalid_field": ["value"]}),
)
assert "keyword_filters must be a subset of" in str(excinfo.value)
assert "Extra inputs are not permitted" in str(excinfo.value)


@pytest.mark.parametrize(
Expand All @@ -169,18 +172,20 @@ def test_whether_malicious_query_strings_are_sanitized(input_string, expected):


def test_whether_single_filter_values_and_lists_of_filter_values_appear_in_yql():
keyword_filters = {
"family_geography": ["SWE"],
"family_category": ["Executive"],
"document_languages": ["English", "Swedish"],
"family_source": ["CCLW"],
}
params = SearchParameters(
query_string="test",
keyword_filters={
"family_geography": "SWE",
"family_category": "Executive",
"document_languages": ["English", "Swedish"],
"family_source": "CCLW",
},
keyword_filters=KeywordFilters(**keyword_filters),
)
yql = YQLBuilder(params).to_str()
assert isinstance(params.keyword_filters, dict)
for key, values in params.keyword_filters.items():
assert isinstance(params.keyword_filters, KeywordFilters)

for key, values in keyword_filters.items():
for value in values:
assert key in yql
assert value in yql
Expand Down Expand Up @@ -266,7 +271,7 @@ def test_yql_builder_build_where_clause():
assert "climate" in where_clause

params = SearchParameters(
query_string="climate", keyword_filters={"family_geography": "SWE"}
query_string="climate", keyword_filters={"family_geography": ["SWE"]}
)
where_clause = YQLBuilder(params).build_where_clause()
assert "SWE" in where_clause
Expand Down

0 comments on commit 165748e

Please sign in to comment.