diff --git a/README.md b/README.md index 7ccf55a..22f5668 100644 --- a/README.md +++ b/README.md @@ -116,6 +116,44 @@ request = SearchParameters( ) ``` +### Search within families or documents + +A subset of families or documents can be retrieved for search using their ids +```python +request = SearchParameters( + query_string="forest fires", + family_ids=["CCLW.family.10121.0", "CCLW.family.4980.0"], +) +``` + +```python +request = SearchParameters( + query_string="forest fires", + document_ids=["CCLW.executive.10121.4637", "CCLW.legislative.4980.1745"], +) +``` + +### Types of query +The default search approach uses a nearest neighbour search ranking. + +Its also possible to search for exact matches instead: + +```python +request = SearchParameters( + query_string="forest fires", + exact_match=True, +) +``` + +Or to ignore the query string and search the whole database instead: +```python +request = SearchParameters( + year_range=(2020, 2024), + sort_by="date", + sort_order="descending", +) +``` + ### Continuing results The response objects include continuation tokens, which can be used to get more results. diff --git a/src/cpr_data_access/models/search.py b/src/cpr_data_access/models/search.py index 744a6db..e20db87 100644 --- a/src/cpr_data_access/models/search.py +++ b/src/cpr_data_access/models/search.py @@ -2,11 +2,17 @@ import re from typing import List, Optional, Sequence -from pydantic import BaseModel, ConfigDict, field_validator +from pydantic import ( + BaseModel, + computed_field, + ConfigDict, + field_validator, + model_validator, +) from cpr_data_access.exceptions import QueryError -sort_orders = ["ascending", "descending"] +sort_orders = {"ascending": "+", "descending": "-"} sort_fields = { "date": "family_publication_ts", @@ -53,8 +59,9 @@ def sanitise_filter_inputs(cls, field): class SearchParameters(BaseModel): """Parameters for a search request""" - query_string: str + query_string: Optional[str] = None exact_match: bool = False + all_results: bool = False limit: int = 100 max_hits_per_family: int = 10 @@ -69,6 +76,13 @@ class SearchParameters(BaseModel): continuation_tokens: Optional[Sequence[str]] = None + @model_validator(mode="after") + def validate(self): + """Validate against mutually exclusive fields""" + if self.exact_match and self.all_results: + raise QueryError("`exact_match` and `all_results` are mutually exclusive") + return self + @field_validator("continuation_tokens") def continuation_tokens_must_be_upper_strings(cls, continuation_tokens): """Validate continuation_tokens match the expected format""" @@ -86,12 +100,12 @@ def continuation_tokens_must_be_upper_strings(cls, continuation_tokens): ) return continuation_tokens - @field_validator("query_string") - def query_string_must_not_be_empty(cls, query_string): + @model_validator(mode="after") + def query_string_must_not_be_empty(self): """Validate that the query string is not empty.""" - if query_string == "": - raise QueryError("query_string must not be empty") - return query_string + if not self.query_string: + self.all_results = True + return self @field_validator("family_ids", "document_ids") def ids_must_fit_pattern(cls, ids): @@ -136,13 +150,26 @@ def sort_by_must_be_valid(cls, sort_by): @field_validator("sort_order") def sort_order_must_be_valid(cls, sort_order): """Validate that the sort order is valid.""" - if sort_order not in ["ascending", "descending"]: + if sort_order not in sort_orders: raise QueryError( f"Invalid sort order: {sort_order}. sort_order must be one of: " f"{sort_orders}" ) return sort_order + @computed_field + def vespa_sort_by(self) -> Optional[str]: + """Translates sort by into the format acceptable by vespa""" + if self.sort_by: + return sort_fields.get(self.sort_by) + else: + return None + + @computed_field + def vespa_sort_order(self) -> Optional[str]: + """Translates sort order into the format acceptable by vespa""" + return sort_orders.get(self.sort_order) + class Hit(BaseModel): """Common model for all search result hits.""" @@ -277,6 +304,7 @@ class Family(BaseModel): hits: Sequence[Hit] total_passage_hits: int = 0 continuation_token: Optional[str] = None + prev_continuation_token: Optional[str] = None class SearchResponse(BaseModel): @@ -289,3 +317,4 @@ class SearchResponse(BaseModel): families: Sequence[Family] continuation_token: Optional[str] = None this_continuation_token: Optional[str] = None + prev_continuation_token: Optional[str] = None diff --git a/src/cpr_data_access/vespa.py b/src/cpr_data_access/vespa.py index 99d5557..fa3b2e0 100644 --- a/src/cpr_data_access/vespa.py +++ b/src/cpr_data_access/vespa.py @@ -10,7 +10,6 @@ Hit, SearchParameters, SearchResponse, - sort_fields, ) from cpr_data_access.embedding import Embedder from cpr_data_access.exceptions import FetchError @@ -83,7 +82,9 @@ def build_vespa_request_body( "query_string": parameters.query_string, } - if parameters.exact_match: + if parameters.all_results: + pass + elif parameters.exact_match: vespa_request_body["ranking.profile"] = "exact" elif sensitive: vespa_request_body["ranking.profile"] = "hybrid_no_closeness" @@ -122,6 +123,7 @@ def parse_vespa_response( total_passage_hits = dig(family, "fields", "count()") family_hits: List[Hit] = [] passages_continuation = dig(family, "children", 0, "continuation", "next") + prev_passages_continuation = dig(family, "children", 0, "continuation", "prev") for hit in dig(family, "children", 0, "children", default=[]): family_hits.append(Hit.from_vespa_response(response_hit=hit)) families.append( @@ -130,21 +132,16 @@ def parse_vespa_response( hits=family_hits, total_passage_hits=total_passage_hits, continuation_token=passages_continuation, + prev_continuation_token=prev_passages_continuation, ) ) - # For now, we can't sort our results natively in vespa because sort orders are - # applied _before_ grouping. We're sorting here instead. - if request.sort_by is not None: - sort_field = sort_fields[request.sort_by] - families.sort( - key=lambda f: getattr(f.hits[0], sort_field), - reverse=request.sort_order == "descending", - ) - next_family_continuation = dig( root, "children", 0, "children", 0, "continuation", "next" ) + prev_family_continuation = dig( + root, "children", 0, "children", 0, "continuation", "prev" + ) this_family_continuation = dig(root, "children", 0, "continuation", "this") total_hits = dig(root, "fields", "totalCount", default=0) total_family_hits = dig(root, "children", 0, "fields", "count()", default=0) @@ -154,6 +151,7 @@ def parse_vespa_response( families=families, continuation_token=next_family_continuation, this_continuation_token=this_family_continuation, + prev_continuation_token=prev_family_continuation, query_time_ms=None, total_time_ms=None, ) diff --git a/src/cpr_data_access/yql_builder.py b/src/cpr_data_access/yql_builder.py index 7f83797..71ca3f8 100644 --- a/src/cpr_data_access/yql_builder.py +++ b/src/cpr_data_access/yql_builder.py @@ -18,6 +18,7 @@ class YQLBuilder: group(family_import_id) output(count()) max($LIMIT) + $SORT each( output(count()) max($MAX_HITS_PER_FAMILY) @@ -37,6 +38,8 @@ def __init__(self, params: SearchParameters, sensitive: bool = False) -> None: def build_search_term(self) -> str: """Create the part of the query that matches a users search text""" + if self.params.all_results: + return "( true )" if self.params.exact_match: return """ ( @@ -141,6 +144,15 @@ def build_limit(self) -> int: """Create the part of the query limiting the number of families returned""" return self.params.limit + def build_sort(self) -> str: + """Creates the part of the query used for sorting by different fields""" + sort_by = self.params.vespa_sort_by + sort_order = self.params.vespa_sort_order + + if not sort_by or not sort_order: + return "" + return f"order({sort_order}max({sort_by}))" + def build_max_hits_per_family(self) -> int: """Create the part of the query limiting passages within a family returned""" return self.params.max_hits_per_family @@ -151,6 +163,7 @@ def to_str(self) -> str: WHERE_CLAUSE=self.build_where_clause(), CONTINUATION=self.build_continuation(), LIMIT=self.build_limit(), + SORT=self.build_sort(), MAX_HITS_PER_FAMILY=self.build_max_hits_per_family(), ) return " ".join(yql.split()) diff --git a/tests/test_data/search_responses/search_response.json b/tests/test_data/search_responses/search_response.json index 3f6f0d0..08b57ad 100644 --- a/tests/test_data/search_responses/search_response.json +++ b/tests/test_data/search_responses/search_response.json @@ -29,7 +29,9 @@ "relevance": 1.0, "label": "family_import_id", "continuation": { - "next": "BGAAABECBEBC" + "next": "BGAAABECBEBC", + "prev": "BGAAAAAABEBC" + }, "children": [ { @@ -45,7 +47,8 @@ "relevance": 1.0, "label": "hits", "continuation": { - "next": "BKAAAAABGCBEBC" + "next": "BKAAAAABGCBEBC", + "prev": "BKAAAAAAAABEBC" }, "children": [ { @@ -680,7 +683,8 @@ "relevance": 1.0, "label": "hits", "continuation": { - "next": "BKAAABEABGCBEBC" + "next": "BKAAABEABGCBEBC", + "prev": "BKAAABEAAACBEBC" }, "children": [ { @@ -1499,7 +1503,8 @@ "relevance": 1.0, "label": "hits", "continuation": { - "next": "BKAAABIABGCBEBC" + "next": "BKAAABIABGCBEBC", + "prev": "BKAAABIAAAABEBC" }, "children": [ { @@ -1979,7 +1984,8 @@ "relevance": 1.0, "label": "hits", "continuation": { - "next": "BKAAABKABGCBEBC" + "next": "BKAAABKABGCBEBC", + "prev": "BKAAABKAAACBEBC" }, "children": [ { @@ -2459,7 +2465,8 @@ "relevance": 1.0, "label": "hits", "continuation": { - "next": "BKAAABMABGCBEBC" + "next": "BKAAABMABGCBEBC", + "prev": "BKAAABMABGCAAAA" }, "children": [ { @@ -2939,7 +2946,8 @@ "relevance": 1.0, "label": "hits", "continuation": { - "next": "BKAAABOABGCBEBC" + "next": "BKAAABOABGCBEBC", + "prev": "BKAAAAAABGCBEBC" }, "children": [ { @@ -3419,7 +3427,8 @@ "relevance": 1.0, "label": "hits", "continuation": { - "next": "BKAAACBAABGCBEBC" + "next": "BKAAACBAABGCBEBC", + "prev": "BKAAACBAABGAAEBC" }, "children": [ { @@ -3899,7 +3908,8 @@ "relevance": 1.0, "label": "hits", "continuation": { - "next": "BKAAACBCABGCBEBC" + "next": "BKAAACBCABGCBEBC", + "prev": "BKAAACBCABGCBEAA" }, "children": [ { diff --git a/tests/test_search_adaptors.py b/tests/test_search_adaptors.py index d36227b..70c4fcb 100644 --- a/tests/test_search_adaptors.py +++ b/tests/test_search_adaptors.py @@ -3,7 +3,11 @@ import pytest from cpr_data_access.search_adaptors import VespaSearchAdapter -from cpr_data_access.models.search import SearchParameters, SearchResponse +from cpr_data_access.models.search import ( + SearchParameters, + SearchResponse, + sort_fields, +) from conftest import VESPA_TEST_SEARCH_URL @@ -96,6 +100,22 @@ def test_vespa_search_adaptor__hybrid(fake_vespa_credentials): assert family_name in got_family_names +@pytest.mark.vespa +def test_vespa_search_adaptor__all(fake_vespa_credentials): + request = SearchParameters(query_string="", all_results=True) + response = vespa_search(fake_vespa_credentials, request) + assert len(response.families) == response.total_family_hits + + # Filtering should still work + family_id = "CCLW.family.i00000003.n0000" + request = SearchParameters( + query_string="", all_results=True, family_ids=[family_id] + ) + response = vespa_search(fake_vespa_credentials, request) + assert len(response.families) == 1 + assert response.families[0].id == family_id + + @pytest.mark.vespa def test_vespa_search_adaptor__exact(fake_vespa_credentials): query_string = "Environmental Strategy for 2014-2023" @@ -180,6 +200,7 @@ def test_vespa_search_adaptor__continuation_tokens__families(fake_vespa_credenti continuation_tokens=[family_continuation], ) response = vespa_search(fake_vespa_credentials, request) + prev_family_continuation = response.prev_continuation_token assert len(response.families) == 1 assert response.total_family_hits == 3 @@ -189,6 +210,17 @@ def test_vespa_search_adaptor__continuation_tokens__families(fake_vespa_credenti # As this is the end of the results we also expect no more tokens assert response.continuation_token is None + # Using prev_continuation_token give initial results + request = SearchParameters( + query_string=query_string, + limit=limit, + max_hits_per_family=max_hits_per_family, + continuation_tokens=[prev_family_continuation], + ) + response = vespa_search(fake_vespa_credentials, request) + prev_family_ids = [f.id for f in response.families] + assert prev_family_ids == first_family_ids + @pytest.mark.vespa def test_vespa_search_adaptor__continuation_tokens__passages(fake_vespa_credentials): @@ -219,6 +251,7 @@ def test_vespa_search_adaptor__continuation_tokens__passages(fake_vespa_credenti continuation_tokens=[this_continuation, passage_continuation], ) response = vespa_search(fake_vespa_credentials, request) + prev_passage_continuation = response.families[0].prev_continuation_token # Family should not have changed assert response.families[0].id == initial_family_id @@ -227,6 +260,19 @@ def test_vespa_search_adaptor__continuation_tokens__passages(fake_vespa_credenti new_passages = sorted([h.text_block_id for h in response.families[0].hits]) assert sorted(new_passages) != sorted(initial_passages) + # Previous passage continuation gives initial results + request = SearchParameters( + query_string=query_string, + limit=limit, + max_hits_per_family=max_hits_per_family, + continuation_tokens=[this_continuation, prev_passage_continuation], + ) + response = vespa_search(fake_vespa_credentials, request) + assert response.families[0].id == initial_family_id + prev_passages = sorted([h.text_block_id for h in response.families[0].hits]) + assert sorted(prev_passages) != sorted(new_passages) + assert sorted(prev_passages) == sorted(initial_passages) + @pytest.mark.vespa def test_vespa_search_adaptor__continuation_tokens__families_and_passages( @@ -284,3 +330,18 @@ def test_vespa_search_adaptor__continuation_tokens__families_and_passages( != sorted([h.text_block_id for h in response_three.families[0].hits]) != sorted([h.text_block_id for h in response_four.families[0].hits]) ) + + +@pytest.mark.parametrize("sort_by", sort_fields.keys()) +@pytest.mark.vespa +def test_vespa_search_adapter_sorting(fake_vespa_credentials, sort_by): + ascend = vespa_search( + fake_vespa_credentials, + SearchParameters(query_string="the", sort_by=sort_by, sort_order="ascending"), + ) + descend = vespa_search( + fake_vespa_credentials, + SearchParameters(query_string="the", sort_by=sort_by, sort_order="descending"), + ) + + assert ascend != descend diff --git a/tests/test_search_requests.py b/tests/test_search_requests.py index f8e3645..a9d1a26 100644 --- a/tests/test_search_requests.py +++ b/tests/test_search_requests.py @@ -2,12 +2,15 @@ import pytest -from vespa.exceptions import VespaError from pydantic import ValidationError -from cpr_data_access.models.search import KeywordFilters, SearchParameters -from cpr_data_access.vespa import build_vespa_request_body, VespaErrorDetails -from cpr_data_access.yql_builder import YQLBuilder +from cpr_data_access.models.search import ( + KeywordFilters, + SearchParameters, + sort_orders, + sort_fields, +) +from cpr_data_access.vespa import build_vespa_request_body from cpr_data_access.exceptions import QueryError from cpr_data_access.embedding import Embedder @@ -31,10 +34,37 @@ def test_build_vespa_request_body(query_type, params): ), f"Query type: {query_type} has an empty value for {key}: {value}" -def test_whether_an_empty_query_string_raises_a_queryerror(): +def test_build_vespa_request_body__all(): + params = SearchParameters(query_string="", all_results=True) + embedder = Embedder() + body = build_vespa_request_body(parameters=params, embedder=embedder) + + assert not body.get("ranking.profile") + + +def test_whether_an_empty_query_string_does_all_result_search(): + params = SearchParameters(query_string="") + assert params.all_results + + # This rule does not apply to `all_result` requests: + try: + SearchParameters(query_string="", all_results=True) + except Exception as e: + pytest.fail(f"{e.__class__.__name__}: {e}") + + +def test_wether_combining_all_results_and_exact_match_raises_error(): + q = "Search" with pytest.raises(QueryError) as excinfo: - SearchParameters(query_string="") - assert "query_string must not be empty" in str(excinfo.value) + SearchParameters(query_string=q, exact_match=True, all_results=True) + assert "" in str(excinfo.value) + + # They should be fine independently: + try: + SearchParameters(query_string=q, all_results=True) + SearchParameters(query_string=q, exact_match=True) + except Exception as e: + pytest.fail(f"{e.__class__.__name__}: {e}") @pytest.mark.parametrize("year_range", [(2000, 2020), (2000, None), (None, 2020)]) @@ -133,6 +163,15 @@ def test_whether_an_invalid_sort_order_raises_a_queryerror(): assert "sort_order must be one of" in str(excinfo.value) +@pytest.mark.parametrize("sort_by", sort_fields.keys()) +@pytest.mark.parametrize("sort_order", sort_orders.keys()) +def test_computed_vespa_sort_fields(sort_by, sort_order): + params = SearchParameters( + query_string="test", sort_by=sort_by, sort_order=sort_order + ) + assert params.vespa_sort_by and params.vespa_sort_order + + @pytest.mark.parametrize( "field", ["family_geography", "family_category", "document_languages", "family_source"], @@ -208,139 +247,3 @@ def test_continuation_tokens__good(tokens): SearchParameters(query_string="test", continuation_tokens=tokens) except Exception as e: pytest.fail(f"{e.__class__.__name__}: {e}") - - -def test_whether_single_filter_values_and_lists_of_filter_values_appear_in_yql(): - keyword_filters = { - "family_geography": ["SWE"], - "family_category": ["Executive"], - "document_languages": ["English", "Swedish"], - "family_source": ["CCLW"], - } - params = SearchParameters( - query_string="test", - keyword_filters=KeywordFilters(**keyword_filters), - ) - yql = YQLBuilder(params).to_str() - assert isinstance(params.keyword_filters, KeywordFilters) - - for key, values in keyword_filters.items(): - for value in values: - assert key in yql - assert value in yql - - -@pytest.mark.parametrize( - "year_range, expected_include, expected_exclude", - [ - ((2000, 2020), [">= 2000", "<= 2020"], []), - ((2000, None), [">= 2000"], ["<="]), - ((None, 2020), ["<= 2020"], [">="]), - ], -) -def test_whether_year_ranges_appear_in_yql( - year_range, expected_include, expected_exclude -): - params = SearchParameters(query_string="test", year_range=year_range) - yql = YQLBuilder(params).to_str() - for include in expected_include: - assert include in yql - for exclude in expected_exclude: - assert exclude not in yql - - -def test_vespa_error_details(): - # With invalid query parameter code - err_object = [ - { - "code": 4, - "summary": "test_summary", - "message": "test_message", - "stackTrace": None, - } - ] - err = VespaError(err_object) - details = VespaErrorDetails(err) - - assert details.code == err_object[0]["code"] - assert details.summary == err_object[0]["summary"] - assert details.message == err_object[0]["message"] - assert details.is_invalid_query_parameter - - # With other code - err_object = [{"code": 1}] - err = VespaError(err_object) - details = VespaErrorDetails(err) - assert not details.is_invalid_query_parameter - - -def test_filter_profiles_return_different_queries(): - exact_yql = YQLBuilder( - params=SearchParameters( - query_string="test", year_range=(2000, 2023), exact_match=True - ), - sensitive=False, - ).to_str() - assert "stem: false" in exact_yql - assert "nearestNeighbor" not in exact_yql - - hybrid_yql = YQLBuilder( - params=SearchParameters( - query_string="test", year_range=(2000, 2023), exact_match=False - ), - sensitive=False, - ).to_str() - assert "nearestNeighbor" in hybrid_yql - - sensitive_yql = YQLBuilder( - params=SearchParameters( - query_string="test", year_range=(2000, 2023), exact_match=False - ), - sensitive=True, - ).to_str() - assert "nearestNeighbor" not in sensitive_yql - - queries = [exact_yql, hybrid_yql, sensitive_yql] - assert len(queries) == len(set(queries)) - - -def test_yql_builder_build_where_clause(): - query_string = "climate" - params = SearchParameters(query_string=query_string) - where_clause = YQLBuilder(params).build_where_clause() - # raw user input should NOT be in the where clause - # We send this in the body so its cleaned by vespa - assert query_string not in where_clause - - params = SearchParameters( - query_string="climate", keyword_filters={"family_geography": ["SWE"]} - ) - where_clause = YQLBuilder(params).build_where_clause() - assert "SWE" in where_clause - assert "family_geography" in where_clause - - params = SearchParameters( - query_string="test", - family_ids=("CCLW.family.i00000003.n0000", "CCLW.family.10014.0"), - ) - where_clause = YQLBuilder(params).build_where_clause() - assert "CCLW.family.i00000003.n0000" in where_clause - assert "CCLW.family.10014.0" in where_clause - - params = SearchParameters( - query_string="test", - document_ids=("CCLW.document.i00000004.n0000", "CCLW.executive.10014.4470"), - ) - where_clause = YQLBuilder(params).build_where_clause() - assert "CCLW.document.i00000004.n0000" in where_clause - assert "CCLW.executive.10014.4470" in where_clause - - params = SearchParameters(query_string="climate", year_range=(2000, None)) - where_clause = YQLBuilder(params).build_where_clause() - assert "2000" in where_clause - assert "family_publication_year" in where_clause - - params = SearchParameters(query_string="climate", year_range=(None, 2020)) - where_clause = YQLBuilder(params).build_where_clause() - assert "2020" in where_clause - assert "family_publication_year" in where_clause diff --git a/tests/test_search_responses.py b/tests/test_search_responses.py index 0d88d83..cc15cd4 100644 --- a/tests/test_search_responses.py +++ b/tests/test_search_responses.py @@ -65,58 +65,6 @@ def test_whether_an_invalid_vespa_response_raises_a_valueerror( assert "Received status code 500" in str(excinfo.value) -def test_whether_sorting_by_ascending_date_works(valid_vespa_search_response): - request = SearchParameters( - query_string="test", sort_by="date", sort_order="ascending" - ) - response = parse_vespa_response( - request=request, vespa_response=valid_vespa_search_response - ) - for family_i, family_j in zip(response.families[:-1], response.families[1:]): - date_i = family_i.hits[0].family_publication_ts - date_j = family_j.hits[0].family_publication_ts - assert date_i <= date_j - - -def test_whether_sorting_by_descending_date_works(valid_vespa_search_response): - request = SearchParameters( - query_string="test", sort_by="date", sort_order="descending" - ) - response = parse_vespa_response( - request=request, vespa_response=valid_vespa_search_response - ) - for family_i, family_j in zip(response.families[:-1], response.families[1:]): - date_i = family_i.hits[0].family_publication_ts - date_j = family_j.hits[0].family_publication_ts - assert date_i >= date_j - - -def test_whether_sorting_by_ascending_name_works(valid_vespa_search_response): - request = SearchParameters( - query_string="test", sort_by="name", sort_order="ascending" - ) - response = parse_vespa_response( - request=request, vespa_response=valid_vespa_search_response - ) - for family_i, family_j in zip(response.families[:-1], response.families[1:]): - name_i = family_i.hits[0].family_name - name_j = family_j.hits[0].family_name - assert name_i <= name_j - - -def test_whether_sorting_by_descending_name_works(valid_vespa_search_response): - request = SearchParameters( - query_string="test", sort_by="name", sort_order="descending" - ) - response = parse_vespa_response( - request=request, vespa_response=valid_vespa_search_response - ) - for family_i, family_j in zip(response.families[:-1], response.families[1:]): - name_i = family_i.hits[0].family_name - name_j = family_j.hits[0].family_name - assert name_i >= name_j - - def test_whether_continuation_token_is_returned_when_present( valid_vespa_search_response, ): @@ -125,6 +73,7 @@ def test_whether_continuation_token_is_returned_when_present( request=request, vespa_response=valid_vespa_search_response ) assert response.continuation_token + assert response.prev_continuation_token def test_whether_valid_get_document_response_is_parsed(valid_get_document_response): diff --git a/tests/test_yql_builder.py b/tests/test_yql_builder.py new file mode 100644 index 0000000..08d8500 --- /dev/null +++ b/tests/test_yql_builder.py @@ -0,0 +1,172 @@ +import pytest +from vespa.exceptions import VespaError + +from cpr_data_access.models.search import ( + KeywordFilters, + SearchParameters, + sort_fields, + sort_orders, +) +from cpr_data_access.vespa import VespaErrorDetails +from cpr_data_access.yql_builder import YQLBuilder + + +def test_whether_single_filter_values_and_lists_of_filter_values_appear_in_yql(): + keyword_filters = { + "family_geography": ["SWE"], + "family_category": ["Executive"], + "document_languages": ["English", "Swedish"], + "family_source": ["CCLW"], + } + params = SearchParameters( + query_string="test", + keyword_filters=KeywordFilters(**keyword_filters), + ) + yql = YQLBuilder(params).to_str() + assert isinstance(params.keyword_filters, KeywordFilters) + + for key, values in keyword_filters.items(): + for value in values: + assert key in yql + assert value in yql + + +@pytest.mark.parametrize( + "year_range, expected_include, expected_exclude", + [ + ((2000, 2020), [">= 2000", "<= 2020"], []), + ((2000, None), [">= 2000"], ["<="]), + ((None, 2020), ["<= 2020"], [">="]), + ], +) +def test_whether_year_ranges_appear_in_yql( + year_range, expected_include, expected_exclude +): + params = SearchParameters(query_string="test", year_range=year_range) + yql = YQLBuilder(params).to_str() + for include in expected_include: + assert include in yql + for exclude in expected_exclude: + assert exclude not in yql + + +@pytest.mark.parametrize("sort_by", sort_fields.keys()) +@pytest.mark.parametrize("sort_order", sort_orders.keys()) +def test_sorting_appears_in_yql(sort_by, sort_order): + params = SearchParameters( + query_string="test", sort_by=sort_by, sort_order=sort_order + ) + assert "order" in YQLBuilder(params).to_str() + + +def test_sorting_does_not_appear_in_yql(): + params = SearchParameters(query_string="test", sort_order="ascending") + assert "order" not in YQLBuilder(params).to_str() + params = SearchParameters(query_string="test") + assert "order" not in YQLBuilder(params).to_str() + + +def test_vespa_error_details(): + # With invalid query parameter code + err_object = [ + { + "code": 4, + "summary": "test_summary", + "message": "test_message", + "stackTrace": None, + } + ] + err = VespaError(err_object) + details = VespaErrorDetails(err) + + assert details.code == err_object[0]["code"] + assert details.summary == err_object[0]["summary"] + assert details.message == err_object[0]["message"] + assert details.is_invalid_query_parameter + + # With other code + err_object = [{"code": 1}] + err = VespaError(err_object) + details = VespaErrorDetails(err) + assert not details.is_invalid_query_parameter + + +def test_filter_profiles_return_different_queries(): + exact_yql = YQLBuilder( + params=SearchParameters( + query_string="test", year_range=(2000, 2023), exact_match=True + ), + sensitive=False, + ).to_str() + assert "stem: false" in exact_yql + assert "nearestNeighbor" not in exact_yql + + hybrid_yql = YQLBuilder( + params=SearchParameters( + query_string="test", year_range=(2000, 2023), exact_match=False + ), + sensitive=False, + ).to_str() + assert "nearestNeighbor" in hybrid_yql + + sensitive_yql = YQLBuilder( + params=SearchParameters( + query_string="test", year_range=(2000, 2023), exact_match=False + ), + sensitive=True, + ).to_str() + assert "nearestNeighbor" not in sensitive_yql + + all_yql = YQLBuilder( + params=SearchParameters( + query_string="test query string", year_range=(2000, 2024), all_results=True + ) + ).to_str() + assert "true" in all_yql + assert "2024" in all_yql + assert "test query string" not in all_yql + + queries = [exact_yql, hybrid_yql, sensitive_yql, all_yql] + assert len(queries) == len(set(queries)) + + +def test_yql_builder_build_where_clause(): + query_string = "climate" + params = SearchParameters(query_string=query_string) + where_clause = YQLBuilder(params).build_where_clause() + # raw user input should NOT be in the where clause + # We send this in the body so its cleaned by vespa + assert query_string not in where_clause + + params = SearchParameters( + query_string="climate", keyword_filters={"family_geography": ["SWE"]} + ) + where_clause = YQLBuilder(params).build_where_clause() + assert "SWE" in where_clause + assert "family_geography" in where_clause + + params = SearchParameters( + query_string="test", + family_ids=("CCLW.family.i00000003.n0000", "CCLW.family.10014.0"), + ) + where_clause = YQLBuilder(params).build_where_clause() + assert "CCLW.family.i00000003.n0000" in where_clause + assert "CCLW.family.10014.0" in where_clause + + params = SearchParameters( + query_string="test", + document_ids=("CCLW.document.i00000004.n0000", "CCLW.executive.10014.4470"), + ) + where_clause = YQLBuilder(params).build_where_clause() + assert "CCLW.document.i00000004.n0000" in where_clause + assert "CCLW.executive.10014.4470" in where_clause + + params = SearchParameters(query_string="climate", year_range=(2000, None)) + where_clause = YQLBuilder(params).build_where_clause() + assert "2000" in where_clause + assert "family_publication_year" in where_clause + + params = SearchParameters(query_string="climate", year_range=(None, 2020)) + where_clause = YQLBuilder(params).build_where_clause() + assert "2020" in where_clause + assert "family_publication_year" in where_clause