From 165748ef612345e497447b6db3216f9bff730737 Mon Sep 17 00:00:00 2001 From: olaughter Date: Wed, 13 Mar 2024 08:20:47 +0000 Subject: [PATCH] Make filters exclusive to each other Previously, keyword filters where joined by an OR. This made sense for searching within each keyword filter, which might have multiple geographies or categories to include, but was not what we wanted when combining with other filters. The result being that searching for a country and a category would return a full list of everything from both, rather than just the intersection between the two. --- src/cpr_data_access/models/search.py | 44 ++++++++++------------------ src/cpr_data_access/yql_builder.py | 32 +++++++++++--------- tests/test_search_requests.py | 33 ++++++++++++--------- 3 files changed, 53 insertions(+), 56 deletions(-) diff --git a/src/cpr_data_access/models/search.py b/src/cpr_data_access/models/search.py index c0f6bdb1..ac151d32 100644 --- a/src/cpr_data_access/models/search.py +++ b/src/cpr_data_access/models/search.py @@ -1,8 +1,8 @@ from datetime import datetime import re -from typing import List, Mapping, Optional, Sequence, Union +from typing import List, Optional, Sequence -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, ConfigDict, field_validator from cpr_data_access.exceptions import QueryError @@ -24,6 +24,19 @@ ID_PATTERN = re.compile(rf"{_ID_ELEMENT}\.{_ID_ELEMENT}\.{_ID_ELEMENT}\.{_ID_ELEMENT}") +class KeywordFilters(BaseModel): + """Filterable fields in a search request""" + + family_geography: Sequence[str] = [] + family_category: Sequence[str] = [] + document_languages: Sequence[str] = [] + family_source: Sequence[str] = [] + + model_config: ConfigDict = { + "extra": "forbid", + } + + class SearchParameters(BaseModel): """Parameters for a search request""" @@ -35,7 +48,7 @@ class SearchParameters(BaseModel): family_ids: Optional[Sequence[str]] = None document_ids: Optional[Sequence[str]] = None - keyword_filters: Optional[Mapping[str, Union[str, Sequence[str]]]] = None + keyword_filters: Optional[KeywordFilters] = None year_range: Optional[tuple[Optional[int], Optional[int]]] = None sort_by: Optional[str] = None @@ -100,31 +113,6 @@ def sort_order_must_be_valid(cls, sort_order): ) return sort_order - @field_validator("keyword_filters") - def keyword_filters_must_be_valid(cls, keyword_filters): - """Validate that the keyword filters are valid.""" - if keyword_filters is not None: - for field_key, values in keyword_filters.items(): - if field_key not in filter_fields.values(): - raise QueryError( - f"Invalid keyword filter: {field_key}. keyword_filters must be " - f"a subset of: {list(filter_fields.values())}" - ) - - # convert single values to lists to make things easier later on - if not isinstance(values, list): - keyword_filters[field_key] = [values] - - for value in keyword_filters[field_key]: - if not isinstance(value, str): - raise QueryError( - "Invalid keyword filter value: " - f"{{{field_key}: {value}}}. " - "Keyword filter values must be strings." - ) - - return keyword_filters - class Hit(BaseModel): """Common model for all search result hits.""" diff --git a/src/cpr_data_access/yql_builder.py b/src/cpr_data_access/yql_builder.py index 54f1fea7..51916f51 100644 --- a/src/cpr_data_access/yql_builder.py +++ b/src/cpr_data_access/yql_builder.py @@ -1,7 +1,7 @@ from string import Template from typing import Optional -from cpr_data_access.models.search import SearchParameters +from cpr_data_access.models.search import KeywordFilters, SearchParameters def sanitize(user_input: str) -> str: @@ -123,17 +123,15 @@ def build_document_filter(self) -> Optional[str]: return f"(document_import_id in({documents}))" return None - def build_keyword_filter(self) -> Optional[str]: - """Create the part of the query that adds keyword filters""" - keyword_filters = self.params.keyword_filters - if keyword_filters: - filters = [] - for field_name, values in keyword_filters.items(): - for value in values: - filters.append(f'({field_name} contains "{sanitize(value)}")') - if filters: - return f"({' or '.join(filters)})" - return None + def _inclusive_keyword_filters( + self, keyword_filters: KeywordFilters, field_name: str + ): + values = getattr(keyword_filters, field_name) + filters = [] + for value in values: + filters.append(f'({field_name} contains "{sanitize(value)}")') + if filters: + return f"({' or '.join(filters)})" def build_year_start_filter(self) -> Optional[str]: """Create the part of the query that filters on a year range""" @@ -157,7 +155,11 @@ def build_where_clause(self) -> str: filters.append(self.build_search_term()) filters.append(self.build_family_filter()) filters.append(self.build_document_filter()) - filters.append(self.build_keyword_filter()) + if kf := self.params.keyword_filters: + filters.append(self._inclusive_keyword_filters(kf, "family_geography")) + filters.append(self._inclusive_keyword_filters(kf, "family_category")) + filters.append(self._inclusive_keyword_filters(kf, "document_languages")) + filters.append(self._inclusive_keyword_filters(kf, "family_source")) filters.append(self.build_year_start_filter()) filters.append(self.build_year_end_filter()) return " and ".join([f for f in filters if f]) # Remove empty @@ -196,7 +198,9 @@ def to_str(self) -> str: exact_match=False, limit=10, max_hits_per_family=10, - keyword_filters={"document_languages": "value", "family_source": "value"}, + keyword_filters=KeywordFilters( + **{"document_languages": "value", "family_source": "value"} + ), year_range=(2000, 2020), continuation_token=None, ) diff --git a/tests/test_search_requests.py b/tests/test_search_requests.py index 5f3aa20b..88259a56 100644 --- a/tests/test_search_requests.py +++ b/tests/test_search_requests.py @@ -3,8 +3,9 @@ import pytest from vespa.exceptions import VespaError +from pydantic import ValidationError -from cpr_data_access.models.search import SearchParameters +from cpr_data_access.models.search import KeywordFilters, SearchParameters from cpr_data_access.vespa import build_vespa_request_body, VespaErrorDetails from cpr_data_access.yql_builder import YQLBuilder, sanitize from cpr_data_access.exceptions import QueryError @@ -137,16 +138,18 @@ def test_whether_an_invalid_sort_order_raises_a_queryerror(): ["family_geography", "family_category", "document_languages", "family_source"], ) def test_whether_valid_filter_fields_are_accepted(field): - params = SearchParameters(query_string="test", keyword_filters={field: "value"}) + keyword_filters = KeywordFilters(**{field: ["value"]}) + params = SearchParameters(query_string="test", keyword_filters=keyword_filters) assert isinstance(params, SearchParameters) def test_whether_an_invalid_filter_fields_raises_a_valueerror(): - with pytest.raises(QueryError) as excinfo: + with pytest.raises(ValidationError) as excinfo: SearchParameters( - query_string="test", keyword_filters={"invalid_field": "value"} + query_string="test", + keyword_filters=KeywordFilters(**{"invalid_field": ["value"]}), ) - assert "keyword_filters must be a subset of" in str(excinfo.value) + assert "Extra inputs are not permitted" in str(excinfo.value) @pytest.mark.parametrize( @@ -169,18 +172,20 @@ def test_whether_malicious_query_strings_are_sanitized(input_string, expected): def test_whether_single_filter_values_and_lists_of_filter_values_appear_in_yql(): + keyword_filters = { + "family_geography": ["SWE"], + "family_category": ["Executive"], + "document_languages": ["English", "Swedish"], + "family_source": ["CCLW"], + } params = SearchParameters( query_string="test", - keyword_filters={ - "family_geography": "SWE", - "family_category": "Executive", - "document_languages": ["English", "Swedish"], - "family_source": "CCLW", - }, + keyword_filters=KeywordFilters(**keyword_filters), ) yql = YQLBuilder(params).to_str() - assert isinstance(params.keyword_filters, dict) - for key, values in params.keyword_filters.items(): + assert isinstance(params.keyword_filters, KeywordFilters) + + for key, values in keyword_filters.items(): for value in values: assert key in yql assert value in yql @@ -266,7 +271,7 @@ def test_yql_builder_build_where_clause(): assert "climate" in where_clause params = SearchParameters( - query_string="climate", keyword_filters={"family_geography": "SWE"} + query_string="climate", keyword_filters={"family_geography": ["SWE"]} ) where_clause = YQLBuilder(params).build_where_clause() assert "SWE" in where_clause