diff --git a/src/cpr_data_access/models/search.py b/src/cpr_data_access/models/search.py index c0f6bdb..ac151d3 100644 --- a/src/cpr_data_access/models/search.py +++ b/src/cpr_data_access/models/search.py @@ -1,8 +1,8 @@ from datetime import datetime import re -from typing import List, Mapping, Optional, Sequence, Union +from typing import List, Optional, Sequence -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, ConfigDict, field_validator from cpr_data_access.exceptions import QueryError @@ -24,6 +24,19 @@ ID_PATTERN = re.compile(rf"{_ID_ELEMENT}\.{_ID_ELEMENT}\.{_ID_ELEMENT}\.{_ID_ELEMENT}") +class KeywordFilters(BaseModel): + """Filterable fields in a search request""" + + family_geography: Sequence[str] = [] + family_category: Sequence[str] = [] + document_languages: Sequence[str] = [] + family_source: Sequence[str] = [] + + model_config: ConfigDict = { + "extra": "forbid", + } + + class SearchParameters(BaseModel): """Parameters for a search request""" @@ -35,7 +48,7 @@ class SearchParameters(BaseModel): family_ids: Optional[Sequence[str]] = None document_ids: Optional[Sequence[str]] = None - keyword_filters: Optional[Mapping[str, Union[str, Sequence[str]]]] = None + keyword_filters: Optional[KeywordFilters] = None year_range: Optional[tuple[Optional[int], Optional[int]]] = None sort_by: Optional[str] = None @@ -100,31 +113,6 @@ def sort_order_must_be_valid(cls, sort_order): ) return sort_order - @field_validator("keyword_filters") - def keyword_filters_must_be_valid(cls, keyword_filters): - """Validate that the keyword filters are valid.""" - if keyword_filters is not None: - for field_key, values in keyword_filters.items(): - if field_key not in filter_fields.values(): - raise QueryError( - f"Invalid keyword filter: {field_key}. keyword_filters must be " - f"a subset of: {list(filter_fields.values())}" - ) - - # convert single values to lists to make things easier later on - if not isinstance(values, list): - keyword_filters[field_key] = [values] - - for value in keyword_filters[field_key]: - if not isinstance(value, str): - raise QueryError( - "Invalid keyword filter value: " - f"{{{field_key}: {value}}}. " - "Keyword filter values must be strings." - ) - - return keyword_filters - class Hit(BaseModel): """Common model for all search result hits.""" diff --git a/src/cpr_data_access/yql_builder.py b/src/cpr_data_access/yql_builder.py index 54f1fea..51916f5 100644 --- a/src/cpr_data_access/yql_builder.py +++ b/src/cpr_data_access/yql_builder.py @@ -1,7 +1,7 @@ from string import Template from typing import Optional -from cpr_data_access.models.search import SearchParameters +from cpr_data_access.models.search import KeywordFilters, SearchParameters def sanitize(user_input: str) -> str: @@ -123,17 +123,15 @@ def build_document_filter(self) -> Optional[str]: return f"(document_import_id in({documents}))" return None - def build_keyword_filter(self) -> Optional[str]: - """Create the part of the query that adds keyword filters""" - keyword_filters = self.params.keyword_filters - if keyword_filters: - filters = [] - for field_name, values in keyword_filters.items(): - for value in values: - filters.append(f'({field_name} contains "{sanitize(value)}")') - if filters: - return f"({' or '.join(filters)})" - return None + def _inclusive_keyword_filters( + self, keyword_filters: KeywordFilters, field_name: str + ): + values = getattr(keyword_filters, field_name) + filters = [] + for value in values: + filters.append(f'({field_name} contains "{sanitize(value)}")') + if filters: + return f"({' or '.join(filters)})" def build_year_start_filter(self) -> Optional[str]: """Create the part of the query that filters on a year range""" @@ -157,7 +155,11 @@ def build_where_clause(self) -> str: filters.append(self.build_search_term()) filters.append(self.build_family_filter()) filters.append(self.build_document_filter()) - filters.append(self.build_keyword_filter()) + if kf := self.params.keyword_filters: + filters.append(self._inclusive_keyword_filters(kf, "family_geography")) + filters.append(self._inclusive_keyword_filters(kf, "family_category")) + filters.append(self._inclusive_keyword_filters(kf, "document_languages")) + filters.append(self._inclusive_keyword_filters(kf, "family_source")) filters.append(self.build_year_start_filter()) filters.append(self.build_year_end_filter()) return " and ".join([f for f in filters if f]) # Remove empty @@ -196,7 +198,9 @@ def to_str(self) -> str: exact_match=False, limit=10, max_hits_per_family=10, - keyword_filters={"document_languages": "value", "family_source": "value"}, + keyword_filters=KeywordFilters( + **{"document_languages": "value", "family_source": "value"} + ), year_range=(2000, 2020), continuation_token=None, ) diff --git a/tests/test_search_requests.py b/tests/test_search_requests.py index 5f3aa20..88259a5 100644 --- a/tests/test_search_requests.py +++ b/tests/test_search_requests.py @@ -3,8 +3,9 @@ import pytest from vespa.exceptions import VespaError +from pydantic import ValidationError -from cpr_data_access.models.search import SearchParameters +from cpr_data_access.models.search import KeywordFilters, SearchParameters from cpr_data_access.vespa import build_vespa_request_body, VespaErrorDetails from cpr_data_access.yql_builder import YQLBuilder, sanitize from cpr_data_access.exceptions import QueryError @@ -137,16 +138,18 @@ def test_whether_an_invalid_sort_order_raises_a_queryerror(): ["family_geography", "family_category", "document_languages", "family_source"], ) def test_whether_valid_filter_fields_are_accepted(field): - params = SearchParameters(query_string="test", keyword_filters={field: "value"}) + keyword_filters = KeywordFilters(**{field: ["value"]}) + params = SearchParameters(query_string="test", keyword_filters=keyword_filters) assert isinstance(params, SearchParameters) def test_whether_an_invalid_filter_fields_raises_a_valueerror(): - with pytest.raises(QueryError) as excinfo: + with pytest.raises(ValidationError) as excinfo: SearchParameters( - query_string="test", keyword_filters={"invalid_field": "value"} + query_string="test", + keyword_filters=KeywordFilters(**{"invalid_field": ["value"]}), ) - assert "keyword_filters must be a subset of" in str(excinfo.value) + assert "Extra inputs are not permitted" in str(excinfo.value) @pytest.mark.parametrize( @@ -169,18 +172,20 @@ def test_whether_malicious_query_strings_are_sanitized(input_string, expected): def test_whether_single_filter_values_and_lists_of_filter_values_appear_in_yql(): + keyword_filters = { + "family_geography": ["SWE"], + "family_category": ["Executive"], + "document_languages": ["English", "Swedish"], + "family_source": ["CCLW"], + } params = SearchParameters( query_string="test", - keyword_filters={ - "family_geography": "SWE", - "family_category": "Executive", - "document_languages": ["English", "Swedish"], - "family_source": "CCLW", - }, + keyword_filters=KeywordFilters(**keyword_filters), ) yql = YQLBuilder(params).to_str() - assert isinstance(params.keyword_filters, dict) - for key, values in params.keyword_filters.items(): + assert isinstance(params.keyword_filters, KeywordFilters) + + for key, values in keyword_filters.items(): for value in values: assert key in yql assert value in yql @@ -266,7 +271,7 @@ def test_yql_builder_build_where_clause(): assert "climate" in where_clause params = SearchParameters( - query_string="climate", keyword_filters={"family_geography": "SWE"} + query_string="climate", keyword_filters={"family_geography": ["SWE"]} ) where_clause = YQLBuilder(params).build_where_clause() assert "SWE" in where_clause