Skip to content
This repository has been archived by the owner on Sep 11, 2024. It is now read-only.

Commit

Permalink
Merge pull request #111 from climatepolicyradar/bugfix/minor-update-t…
Browse files Browse the repository at this point in the history
…o-validation

Updating to validate for different keys.
  • Loading branch information
THOR300 authored Dec 12, 2023
2 parents 8ac8b0f + 34785d6 commit b7e1585
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/cpr_data_access/models/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ def keyword_filters_must_be_valid(cls, keyword_filters):
"""Validate that the keyword filters are valid."""
if keyword_filters is not None:
for field_key, values in keyword_filters.items():
if field_key not in filter_fields:
if field_key not in filter_fields.values():
raise QueryError(
f"Invalid keyword filter: {field_key}. keyword_filters must be "
f"a subset of: {list(filter_fields.keys())}"
f"a subset of: {list(filter_fields.values())}"
)

# convert single values to lists to make things easier later on
Expand Down
4 changes: 1 addition & 3 deletions src/cpr_data_access/vespa.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
Hit,
SearchParameters,
SearchResponse,
filter_fields,
sort_fields,
)
from cpr_data_access.exceptions import FetchError
Expand Down Expand Up @@ -130,8 +129,7 @@ def build_yql(request: SearchParameters) -> str:
rendered_filters = ""
if request.keyword_filters:
filters = []
for field_key, values in request.keyword_filters.items():
field_name = filter_fields[field_key]
for field_name, values in request.keyword_filters.items():
for value in values:
filters.append(f'({field_name} contains "{sanitize(value)}")')
rendered_filters = " and " + " and ".join(filters)
Expand Down
14 changes: 9 additions & 5 deletions tests/test_search_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ def test_whether_an_invalid_sort_order_raises_a_queryerror():
assert "sort_order must be one of" in str(excinfo.value)


@pytest.mark.parametrize("field", ["geography", "category", "language", "source"])
@pytest.mark.parametrize(
"field",
["family_geography", "family_category", "document_languages", "family_source"],
)
def test_whether_valid_filter_fields_are_accepted(field):
request = SearchParameters(query_string="test", keyword_filters={field: "value"})
assert isinstance(request, SearchParameters)
Expand Down Expand Up @@ -89,13 +92,14 @@ def test_whether_single_filter_values_and_lists_of_filter_values_appear_in_yql()
request = SearchParameters(
query_string="test",
keyword_filters={
"geography": "SWE",
"category": "Executive",
"language": ["English", "Swedish"],
"source": "CCLW",
"family_geography": "SWE",
"family_category": "Executive",
"document_languages": ["English", "Swedish"],
"family_source": "CCLW",
},
)
yql = build_yql(request)
assert isinstance(request.keyword_filters, dict)
for key, values in request.keyword_filters.items():
for value in values:
assert key in yql
Expand Down

0 comments on commit b7e1585

Please sign in to comment.