diff --git a/src/cpr_data_access/models/search.py b/src/cpr_data_access/models/search.py index 4ed1fab..80bb7ba 100644 --- a/src/cpr_data_access/models/search.py +++ b/src/cpr_data_access/models/search.py @@ -81,10 +81,10 @@ def keyword_filters_must_be_valid(cls, keyword_filters): """Validate that the keyword filters are valid.""" if keyword_filters is not None: for field_key, values in keyword_filters.items(): - if field_key not in filter_fields: + if field_key not in filter_fields.values(): raise QueryError( f"Invalid keyword filter: {field_key}. keyword_filters must be " - f"a subset of: {list(filter_fields.keys())}" + f"a subset of: {list(filter_fields.values())}" ) # convert single values to lists to make things easier later on diff --git a/src/cpr_data_access/vespa.py b/src/cpr_data_access/vespa.py index 091d7b3..346fc5b 100644 --- a/src/cpr_data_access/vespa.py +++ b/src/cpr_data_access/vespa.py @@ -10,7 +10,6 @@ Hit, SearchParameters, SearchResponse, - filter_fields, sort_fields, ) from cpr_data_access.exceptions import FetchError @@ -130,8 +129,7 @@ def build_yql(request: SearchParameters) -> str: rendered_filters = "" if request.keyword_filters: filters = [] - for field_key, values in request.keyword_filters.items(): - field_name = filter_fields[field_key] + for field_name, values in request.keyword_filters.items(): for value in values: filters.append(f'({field_name} contains "{sanitize(value)}")') rendered_filters = " and " + " and ".join(filters) diff --git a/tests/test_search_requests.py b/tests/test_search_requests.py index b76629a..1a6a53d 100644 --- a/tests/test_search_requests.py +++ b/tests/test_search_requests.py @@ -52,7 +52,10 @@ def test_whether_an_invalid_sort_order_raises_a_queryerror(): assert "sort_order must be one of" in str(excinfo.value) -@pytest.mark.parametrize("field", ["geography", "category", "language", "source"]) +@pytest.mark.parametrize( + "field", + ["family_geography", "family_category", "document_languages", "family_source"], +) def test_whether_valid_filter_fields_are_accepted(field): request = SearchParameters(query_string="test", keyword_filters={field: "value"}) assert isinstance(request, SearchParameters) @@ -89,13 +92,14 @@ def test_whether_single_filter_values_and_lists_of_filter_values_appear_in_yql() request = SearchParameters( query_string="test", keyword_filters={ - "geography": "SWE", - "category": "Executive", - "language": ["English", "Swedish"], - "source": "CCLW", + "family_geography": "SWE", + "family_category": "Executive", + "document_languages": ["English", "Swedish"], + "family_source": "CCLW", }, ) yql = build_yql(request) + assert isinstance(request.keyword_filters, dict) for key, values in request.keyword_filters.items(): for value in values: assert key in yql