From 4072d70a8e3167cab5a816cee0ca0e7a1067efb1 Mon Sep 17 00:00:00 2001 From: Katy Baulch <46493669+katybaulch@users.noreply.github.com> Date: Tue, 24 Sep 2024 13:03:39 +0100 Subject: [PATCH] Feature/pdct 1274 update the api search method to apply corpus filtering (#350) * Use host header for audience comparison * Update custom app documentation to use hostname * Make audience a string instead of URL * Use host for audience & add additional corpora ID validation * Bump to 1.16.4 * Two step validation of corpora IDs in search * Bump to 1.16.5 * Fix order of patches * Add app token corpora as default corpora for search * Mock existence of corpus IDs so search validation passes * Move corpus filtering tests into separate file * Update corpora IDs to match search fixture * Bump to 1.17.0 * Remove redundant check --- app/api/api_v1/routers/search.py | 45 ++++-- app/db/crud/helpers/validate.py | 19 +++ pyproject.toml | 2 +- tests/conftest.py | 19 +-- .../test_range_and_keyword_filters_search.py | 100 ++++++++++-- .../test_search_raises_on_token_errors.py | 74 --------- .../vespa/test_this_vespa_search_download.py | 22 ++- .../vespa/test_vespa_corpus_filtering.py | 152 ++++++++++++++++++ tests/search/vespa/test_vespa_ids_search.py | 30 +++- .../test_vespa_metadata_filters_search.py | 6 + .../search/vespa/test_vespa_query_ignores.py | 19 ++- tests/search/vespa/test_vespa_search.py | 94 +++++------ .../vespa/test_vespa_search_pagination.py | 17 +- .../vespa/test_vespa_search_result_order.py | 26 ++- .../vespa/test_whole_database_download.py | 5 +- 15 files changed, 454 insertions(+), 176 deletions(-) delete mode 100644 tests/search/vespa/test_search_raises_on_token_errors.py create mode 100644 tests/search/vespa/test_vespa_corpus_filtering.py diff --git a/app/api/api_v1/routers/search.py b/app/api/api_v1/routers/search.py index 4a6d4b3a..b1aff031 100644 --- a/app/api/api_v1/routers/search.py +++ b/app/api/api_v1/routers/search.py @@ -8,7 +8,7 @@ import logging from io import BytesIO -from typing import Annotated +from typing import Annotated, Sequence, cast from cpr_sdk.exceptions import QueryError from cpr_sdk.search_adaptors import VespaSearchAdapter @@ -37,7 +37,10 @@ process_result_into_csv, process_vespa_search_response, ) -from app.db.crud.helpers.validate import verify_any_corpora_ids_in_db +from app.db.crud.helpers.validate import ( + validate_corpora_ids, + verify_any_corpora_ids_in_db, +) from app.db.session import get_db _LOGGER = logging.getLogger(__name__) @@ -137,6 +140,17 @@ def search_documents( the search database. The continuation token can be used to get the next set of results from the search database. See the request schema for more details. """ + _LOGGER.info( + "Search request", + extra={ + "props": { + "search_request": search_body.model_dump(), + "host": str(host), + "app_token": str(app_token), + } + }, + ) + try: allowed_corpora_ids = decode_config_token(app_token, host) except PyJWTError as e: @@ -147,16 +161,6 @@ def search_documents( headers={"WWW-Authenticate": "Bearer"}, ) - _LOGGER.info( - "Search request", - extra={ - "props": { - "search_request": search_body.model_dump(), - "allowed_corpora_ids": str(allowed_corpora_ids), - } - }, - ) - # First corpora validation is app token against DB. At least one of the app token # corpora IDs must be present in the DB to continue the search request. if not verify_any_corpora_ids_in_db(db, allowed_corpora_ids): @@ -167,6 +171,23 @@ def search_documents( headers={"WWW-Authenticate": "Bearer"}, ) + # If the search request IDs are null, we want to search using the app token corpora. + if search_body.corpus_import_ids is None: + search_body.corpus_import_ids = cast(Sequence, allowed_corpora_ids) + + # For the second validation, search request corpora Ids are validated against the + # app token corpora IDs if the search request param 'corpus_import_ids' is not None. + # corpus_import_ids must be a subset of app token IDs. + if not validate_corpora_ids( + set(search_body.corpus_import_ids), cast(set, allowed_corpora_ids) + ): + msg = "Error validating corpora IDs." + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=msg, + headers={"WWW-Authenticate": "Bearer"}, + ) + _LOGGER.info( "Starting search...", extra={"props": {"search_request": search_body.model_dump()}}, diff --git a/app/db/crud/helpers/validate.py b/app/db/crud/helpers/validate.py index 2e9468c7..ee6f3961 100644 --- a/app/db/crud/helpers/validate.py +++ b/app/db/crud/helpers/validate.py @@ -30,3 +30,22 @@ def verify_any_corpora_ids_in_db(db: Session, corpora_ids: list[str]) -> bool: ) return validate_success + + +def validate_corpora_ids(corpora_ids: set[str], valid_corpora_ids: set[str]) -> bool: + """Validate all given corpus IDs against a list of allowed corpora. + + :param set[str] corpora_ids: The corpus import IDs we want to + validate. + :param set[str] valid_corpora_ids: The corpus import IDs + we want to validate against. + :return bool: Return whether or not all the corpora are valid. + """ + validate_success = corpora_ids.issubset(valid_corpora_ids) + if not validate_success: + invalid_corpora = set(corpora_ids).difference(valid_corpora_ids) + _LOGGER.warning( + f"Some corpora in search request params {invalid_corpora}" + "forbidden to search against." + ) + return validate_success diff --git a/pyproject.toml b/pyproject.toml index 94c3cd86..7fad6227 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "navigator_backend" -version = "1.16.4" +version = "1.17.0" description = "" authors = ["CPR-dev-team "] packages = [{ include = "app" }, { include = "tests" }] diff --git a/tests/conftest.py b/tests/conftest.py index 975b38ab..a7a99215 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -113,24 +113,7 @@ def valid_token(): might be different (e.g., the one for staging). This fixture works around this. """ - corpora_ids = "CCLW.corpus.i00000001.n0000,UNFCCC.corpus.i00000001.n0000" - subject = "CCLW" - audience = "localhost:8888" - input_str = f"{corpora_ids};{subject};{audience}" - return create_configuration_token(input_str) - - -@pytest.fixture -def unfccc_token(): - """Generate valid config token using TOKEN_SECRET_KEY. - - Need to generate the config token using the token secret key from - your local env file. For tests in CI, this will be the secret key in - the .env.example file, but for local development this secret key - might be different (e.g., the one for staging). This fixture works - around this. - """ - corpora_ids = "CCLW.corpus.i00000001.n0000,UNFCCC.corpus.i00000001.n0000" + corpora_ids = "CCLW.corpus.1.0,CCLW.corpus.2.0" subject = "CCLW" audience = "localhost:8888" input_str = f"{corpora_ids};{subject};{audience}" diff --git a/tests/search/vespa/test_range_and_keyword_filters_search.py b/tests/search/vespa/test_range_and_keyword_filters_search.py index 6a24f54a..ad090a7e 100644 --- a/tests/search/vespa/test_range_and_keyword_filters_search.py +++ b/tests/search/vespa/test_range_and_keyword_filters_search.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import pytest from db_client.models.dfce import Geography from fastapi import status @@ -12,9 +14,17 @@ @pytest.mark.search +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) @pytest.mark.parametrize("label,query", [("search", "the"), ("browse", "")]) def test_keyword_country_filters__geography( - label, query, test_vespa, data_client, data_db, monkeypatch, valid_token + mock_corpora_exist_in_db, + label, + query, + test_vespa, + data_client, + data_db, + monkeypatch, + valid_token, ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -43,11 +53,21 @@ def test_keyword_country_filters__geography( assert len(filtered_family_slugs) == 1 assert family["family_slug"] in filtered_family_slugs + assert mock_corpora_exist_in_db.assert_called + @pytest.mark.search +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) @pytest.mark.parametrize("label,query", [("search", "the"), ("browse", "")]) def test_keyword_country_filters__geographies( - label, query, test_vespa, data_client, data_db, monkeypatch, valid_token + mock_corpora_exist_in_db, + label, + query, + test_vespa, + data_client, + data_db, + monkeypatch, + valid_token, ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -78,11 +98,21 @@ def test_keyword_country_filters__geographies( assert len(filtered_family_slugs) == 1 assert family["family_slug"] in filtered_family_slugs + assert mock_corpora_exist_in_db.assert_called + @pytest.mark.search +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) @pytest.mark.parametrize("label,query", [("search", "the"), ("browse", "")]) def test_keyword_region_filters( - label, query, test_vespa, data_client, data_db, monkeypatch, valid_token + mock_corpora_exist_in_db, + label, + query, + test_vespa, + data_client, + data_db, + monkeypatch, + valid_token, ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -118,11 +148,21 @@ def test_keyword_region_filters( ] assert family["family_slug"] in filtered_family_slugs + assert mock_corpora_exist_in_db.assert_called + @pytest.mark.search +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) @pytest.mark.parametrize("label,query", [("search", "the"), ("browse", "")]) def test_keyword_region_and_country_filters( - label, query, test_vespa, data_client, data_db, monkeypatch, valid_token + mock_corpora_exist_in_db, + label, + query, + test_vespa, + data_client, + data_db, + monkeypatch, + valid_token, ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -141,11 +181,21 @@ def test_keyword_region_and_country_filters( assert len(body["families"]) == 1 assert body["families"][0]["family_name"] == "National Energy Strategy" + assert mock_corpora_exist_in_db.assert_called + @pytest.mark.search +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) @pytest.mark.parametrize("label,query", [("search", "the"), ("browse", "")]) def test_invalid_keyword_filters( - label, query, test_vespa, data_db, monkeypatch, data_client, valid_token + mock_corpora_exist_in_db, + label, + query, + test_vespa, + data_db, + monkeypatch, + data_client, + valid_token, ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -164,13 +214,22 @@ def test_invalid_keyword_filters( expected_status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, ) + assert mock_corpora_exist_in_db.assert_called + @pytest.mark.search +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) @pytest.mark.parametrize( "year_range", [(None, None), (1900, None), (None, 2020), (1900, 2020)] ) def test_year_range_filtered_in( - year_range, test_vespa, data_db, monkeypatch, data_client, valid_token + mock_corpora_exist_in_db, + year_range, + test_vespa, + data_db, + monkeypatch, + data_client, + valid_token, ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -185,11 +244,20 @@ def test_year_range_filtered_in( body = _make_search_request(data_client, valid_token, params=params) assert len(body["families"]) > 0 + assert mock_corpora_exist_in_db.assert_called + @pytest.mark.search +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) @pytest.mark.parametrize("year_range", [(None, 2010), (2024, None)]) def test_year_range_filtered_out( - year_range, test_vespa, data_db, monkeypatch, data_client, valid_token + mock_corpora_exist_in_db, + year_range, + test_vespa, + data_db, + monkeypatch, + data_client, + valid_token, ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -204,11 +272,21 @@ def test_year_range_filtered_out( body = _make_search_request(data_client, valid_token, params=params) assert len(body["families"]) == 0 + assert mock_corpora_exist_in_db.assert_called + @pytest.mark.search +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) @pytest.mark.parametrize("label, query", [("search", "the"), ("browse", "")]) def test_multiple_filters( - label, query, test_vespa, data_db, monkeypatch, data_client, valid_token + mock_corpora_exist_in_db, + label, + query, + test_vespa, + data_db, + monkeypatch, + data_client, + valid_token, ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -224,11 +302,13 @@ def test_multiple_filters( } _ = _make_search_request(data_client, valid_token, params) + assert mock_corpora_exist_in_db.assert_called @pytest.mark.search +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) def test_geo_filter_with_exact( - test_vespa, data_db, monkeypatch, data_client, valid_token + mock_corpora_exist_in_db, test_vespa, data_db, monkeypatch, data_client, valid_token ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -246,3 +326,5 @@ def test_geo_filter_with_exact( assert len(response["families"]) > 0 for family in response["families"]: assert "ITA" in family["family_geographies"] + + assert mock_corpora_exist_in_db.assert_called diff --git a/tests/search/vespa/test_search_raises_on_token_errors.py b/tests/search/vespa/test_search_raises_on_token_errors.py deleted file mode 100644 index db5b5a43..00000000 --- a/tests/search/vespa/test_search_raises_on_token_errors.py +++ /dev/null @@ -1,74 +0,0 @@ -from unittest.mock import patch - -import jwt -import pytest -from fastapi import status - -from app.api.api_v1.routers import search -from tests.search.vespa.setup_search_tests import ( - _make_search_request, - _populate_db_families, -) - - -@pytest.mark.search -def test_search_with_invalid_corpus_id_in_token( - data_client, data_db, valid_token, monkeypatch, test_vespa -): - """ - GIVEN a list of corpora IDs decoded from an app config token - WHEN one or more of those corpora IDs are not in our database - THEN raise a 400 HTTP error - """ - monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) - _populate_db_families(data_db) - - with patch( - "app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=False - ): - response = _make_search_request( - data_client, - valid_token, - params={"query_string": ""}, - expected_status_code=status.HTTP_400_BAD_REQUEST, - ) - - assert response["detail"] == "Error verifying corpora IDs." - - -@pytest.mark.search -@pytest.mark.parametrize( - "side_effect", - [ - jwt.exceptions.InvalidAudienceError, - jwt.exceptions.ExpiredSignatureError, - jwt.exceptions.InvalidTokenError, - ], -) -def test_search_decoding_token_raises_PyJWTError( - side_effect, - data_client, - data_db, - valid_token, - monkeypatch, - test_vespa, -): - """ - GIVEN a request to the search endpoint - WHEN the decode_config_token() function call raises a PyJWTError - THEN raise a 400 HTTP error - """ - monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) - _populate_db_families(data_db) - - with patch( - "app.api.api_v1.routers.search.decode_config_token", side_effect=side_effect - ): - response = _make_search_request( - data_client, - valid_token, - params={"query_string": ""}, - expected_status_code=status.HTTP_400_BAD_REQUEST, - ) - - assert response["detail"] == "Could not decode configuration token" diff --git a/tests/search/vespa/test_this_vespa_search_download.py b/tests/search/vespa/test_this_vespa_search_download.py index 51dfb954..e4447e01 100644 --- a/tests/search/vespa/test_this_vespa_search_download.py +++ b/tests/search/vespa/test_this_vespa_search_download.py @@ -15,9 +15,11 @@ @pytest.mark.search +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) @pytest.mark.parametrize("exact_match", [True, False]) @pytest.mark.parametrize("query_string", ["", "local"]) def test_csv_content( + mock_corpora_exist_in_db, exact_match, query_string, test_vespa, @@ -56,12 +58,23 @@ def test_csv_content( # TODO: Add collections to test db setup to provide document level coverage + assert mock_corpora_exist_in_db.assert_called + @pytest.mark.search +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) @pytest.mark.parametrize("label, query", [("search", "the"), ("browse", "")]) @pytest.mark.parametrize("limit", [100, 250, 500]) def test_csv_download_search_variable_limit( - label, query, limit, test_vespa, data_db, monkeypatch, data_client, mocker + mock_corpora_exist_in_db, + label, + query, + limit, + test_vespa, + data_db, + monkeypatch, + data_client, + mocker, ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -87,10 +100,13 @@ def test_csv_download_search_variable_limit( for key, value in params.items(): assert actual_params[key] == value + assert mock_corpora_exist_in_db.assert_called + @pytest.mark.search +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) def test_csv_download__ignore_extra_fields( - test_vespa, data_db, monkeypatch, data_client, mocker + mock_corpora_exist_in_db, test_vespa, data_db, monkeypatch, data_client, mocker ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -107,3 +123,5 @@ def test_csv_download__ignore_extra_fields( json=params, ) assert download_response.status_code == 200 + + assert mock_corpora_exist_in_db.assert_called diff --git a/tests/search/vespa/test_vespa_corpus_filtering.py b/tests/search/vespa/test_vespa_corpus_filtering.py new file mode 100644 index 00000000..62d2b581 --- /dev/null +++ b/tests/search/vespa/test_vespa_corpus_filtering.py @@ -0,0 +1,152 @@ +from typing import Any +from unittest.mock import patch + +import jwt +import pytest +from fastapi import status + +from app.api.api_v1.routers import search +from tests.search.vespa.setup_search_tests import ( + _make_search_request, + _populate_db_families, +) + + +@pytest.mark.search +@pytest.mark.parametrize( + ("corpus_import_id", "corpus_type_name", "expected_hits"), + [ + ("CCLW.corpus.1.0", "UNFCCC Submissions", 1), + ("CCLW.corpus.1.0", None, 1), + (None, "UNFCCC Submissions", 1), + (None, None, 5), + (None, "Laws and Policies", 4), + ("CCLW.corpus.2.0", None, 4), + ], +) +def test_corpus_filtering( + test_vespa, + monkeypatch, + data_client, + data_db, + corpus_import_id: str, + corpus_type_name: str, + expected_hits: int, +): + monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) + _populate_db_families(data_db) + + params: dict[str, Any] = {"query_string": "and"} + if corpus_import_id: + params["corpus_import_ids"] = [corpus_import_id] + if corpus_type_name: + params["corpus_type_names"] = [corpus_type_name] + + with patch( + "app.api.api_v1.routers.search.decode_config_token", + return_value=["CCLW.corpus.1.0", "CCLW.corpus.2.0"], + ), patch( + "app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True + ): + response = _make_search_request( + data_client, + token="foo", + params=params, + ) + + assert len(response["families"]) > 0 + assert len(response["families"]) == expected_hits + for family in response["families"]: + if corpus_import_id: + assert family["corpus_import_id"] == corpus_import_id + if corpus_type_name: + assert family["corpus_type_name"] == corpus_type_name + + +@pytest.mark.search +def test_search_with_corpus_ids_in_token_not_in_db( + data_client, data_db, valid_token, monkeypatch, test_vespa +): + """ + GIVEN a list of corpora IDs decoded from an app config token + WHEN one or more of those corpora IDs are not in our database + THEN raise a 400 HTTP error + """ + monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) + _populate_db_families(data_db) + + with patch( + "app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=False + ): + response = _make_search_request( + data_client, + valid_token, + params={"query_string": ""}, + expected_status_code=status.HTTP_400_BAD_REQUEST, + ) + + assert response["detail"] == "Error verifying corpora IDs." + + +@pytest.mark.search +@pytest.mark.parametrize( + "side_effect", + [ + jwt.exceptions.InvalidAudienceError, + jwt.exceptions.ExpiredSignatureError, + jwt.exceptions.InvalidTokenError, + ], +) +def test_search_decoding_token_raises_PyJWTError( + side_effect, + data_client, + data_db, + valid_token, + monkeypatch, + test_vespa, +): + """ + GIVEN a request to the search endpoint + WHEN the decode_config_token() function call raises a PyJWTError + THEN raise a 400 HTTP error + """ + monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) + _populate_db_families(data_db) + + with patch( + "app.api.api_v1.routers.search.decode_config_token", side_effect=side_effect + ): + response = _make_search_request( + data_client, + valid_token, + params={"query_string": ""}, + expected_status_code=status.HTTP_400_BAD_REQUEST, + ) + + assert response["detail"] == "Could not decode configuration token" + + +@pytest.mark.search +def test_search_with_invalid_corpus_id_in_search_request_params( + data_client, data_db, valid_token, monkeypatch, test_vespa +): + """ + GIVEN a list of corpora IDs from the search request body params + WHEN those corpora IDs are not a subset of the app token corpora IDs + THEN raise a 403 HTTP error + """ + monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) + _populate_db_families(data_db) + + with patch( + "app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", + return_value=True, + ), patch("app.api.api_v1.routers.search.validate_corpora_ids", return_value=False): + response = _make_search_request( + data_client, + valid_token, + params={"query_string": ""}, + expected_status_code=status.HTTP_403_FORBIDDEN, + ) + + assert response["detail"] == "Error validating corpora IDs." diff --git a/tests/search/vespa/test_vespa_ids_search.py b/tests/search/vespa/test_vespa_ids_search.py index 0808f99c..1be534e9 100644 --- a/tests/search/vespa/test_vespa_ids_search.py +++ b/tests/search/vespa/test_vespa_ids_search.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import pytest from db_client.models.dfce import Slug from db_client.models.dfce.family import FamilyDocument @@ -40,6 +42,7 @@ def _fam_ids_from_response(test_db, response) -> list[str]: return family_ids +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) @pytest.mark.parametrize( "family_ids", [ @@ -50,7 +53,13 @@ def _fam_ids_from_response(test_db, response) -> list[str]: ) @pytest.mark.search def test_family_ids_search( - test_vespa, data_db, monkeypatch, data_client, family_ids, valid_token + mock_corpora_exist_in_db, + test_vespa, + data_db, + monkeypatch, + data_client, + family_ids, + valid_token, ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -64,8 +73,10 @@ def test_family_ids_search( got_family_ids = _fam_ids_from_response(data_db, response) assert sorted(got_family_ids) == sorted(family_ids) + assert mock_corpora_exist_in_db.assert_called +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) @pytest.mark.parametrize( "document_ids", [ @@ -80,7 +91,13 @@ def test_family_ids_search( ) @pytest.mark.search def test_document_ids_search( - test_vespa, data_db, monkeypatch, data_client, document_ids, valid_token + mock_corpora_exist_in_db, + test_vespa, + data_db, + monkeypatch, + data_client, + document_ids, + valid_token, ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -93,11 +110,13 @@ def test_document_ids_search( got_document_ids = _doc_ids_from_response(data_db, response) assert sorted(got_document_ids) == sorted(document_ids) + assert mock_corpora_exist_in_db.assert_called +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) @pytest.mark.search def test_document_ids_and_family_ids_search( - test_vespa, data_db, monkeypatch, data_client, valid_token + mock_corpora_exist_in_db, test_vespa, data_db, monkeypatch, data_client, valid_token ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -113,11 +132,13 @@ def test_document_ids_and_family_ids_search( response = _make_search_request(data_client, valid_token, params) assert len(response["families"]) == 0 + assert mock_corpora_exist_in_db.assert_called +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) @pytest.mark.search def test_empty_ids_dont_limit_result( - test_vespa, data_db, monkeypatch, data_client, valid_token + mock_corpora_exist_in_db, test_vespa, data_db, monkeypatch, data_client, valid_token ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -136,3 +157,4 @@ def test_empty_ids_dont_limit_result( assert len(got_family_ids) > 1 assert len(got_document_ids) > 1 + assert mock_corpora_exist_in_db.assert_called diff --git a/tests/search/vespa/test_vespa_metadata_filters_search.py b/tests/search/vespa/test_vespa_metadata_filters_search.py index 849cd20f..e8b01901 100644 --- a/tests/search/vespa/test_vespa_metadata_filters_search.py +++ b/tests/search/vespa/test_vespa_metadata_filters_search.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import pytest from app.api.api_v1.routers import search @@ -8,6 +10,7 @@ @pytest.mark.search +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) @pytest.mark.parametrize( "label,query,metadata_filters", [ @@ -23,6 +26,7 @@ ], ) def test_metadata_filter( + mock_corpora_exist_in_db, label, query, metadata_filters, @@ -54,3 +58,5 @@ def test_metadata_filter( metadata_filter["value"] in f["family_metadata"][metadata_filter["name"]] ) + + assert mock_corpora_exist_in_db.assert_called diff --git a/tests/search/vespa/test_vespa_query_ignores.py b/tests/search/vespa/test_vespa_query_ignores.py index 709f3544..a05c8f71 100644 --- a/tests/search/vespa/test_vespa_query_ignores.py +++ b/tests/search/vespa/test_vespa_query_ignores.py @@ -1,4 +1,5 @@ import time +from unittest.mock import patch import pytest @@ -10,7 +11,10 @@ @pytest.mark.search -def test_case_insensitivity(test_vespa, data_db, monkeypatch, data_client, valid_token): +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +def test_case_insensitivity( + mock_corpora_exist_in_db, test_vespa, data_db, monkeypatch, data_client, valid_token +): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -18,11 +22,13 @@ def test_case_insensitivity(test_vespa, data_db, monkeypatch, data_client, valid upper_body = _make_search_request(data_client, valid_token, {"query_string": "THE"}) assert lower_body["families"] == upper_body["families"] + assert mock_corpora_exist_in_db.assert_called @pytest.mark.search +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) def test_punctuation_ignored( - test_vespa, data_db, monkeypatch, data_client, valid_token + mock_corpora_exist_in_db, test_vespa, data_db, monkeypatch, data_client, valid_token ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -43,9 +49,14 @@ def test_punctuation_ignored( == sorted([f["family_slug"] for f in accent_body["families"]]) ) + assert mock_corpora_exist_in_db.assert_called + @pytest.mark.search -def test_accents_ignored(test_vespa, data_db, monkeypatch, data_client, valid_token): +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +def test_accents_ignored( + mock_corpora_exist_in_db, test_vespa, data_db, monkeypatch, data_client, valid_token +): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -55,3 +66,5 @@ def test_accents_ignored(test_vespa, data_db, monkeypatch, data_client, valid_to request_time_ms = 1000 * (end - start) assert 0 < body["query_time_ms"] < body["total_time_ms"] < request_time_ms + + assert mock_corpora_exist_in_db.assert_called diff --git a/tests/search/vespa/test_vespa_search.py b/tests/search/vespa/test_vespa_search.py index bc2abe7a..0d6c812b 100644 --- a/tests/search/vespa/test_vespa_search.py +++ b/tests/search/vespa/test_vespa_search.py @@ -1,4 +1,4 @@ -from typing import Any +from unittest.mock import patch import pytest from db_client.models.dfce.family import FamilyDocument @@ -17,8 +17,15 @@ @pytest.mark.search +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) def test_empty_search_term_performs_browse( - test_vespa, data_client, data_db, mocker, monkeypatch, valid_token + mock_corpora_exist_in_db, + test_vespa, + data_client, + data_db, + mocker, + monkeypatch, + valid_token, ): """Make sure that empty search term returns results in browse mode.""" _populate_db_families(data_db) @@ -34,11 +41,20 @@ def test_empty_search_term_performs_browse( assert query_spy.call_args.kwargs["parameters"].all_results query_spy.assert_called_once() + assert mock_corpora_exist_in_db.assert_called + @pytest.mark.search +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) @pytest.mark.parametrize("exact_match", [True, False]) def test_search_body_valid( - exact_match, test_vespa, data_client, data_db, monkeypatch, valid_token + mock_corpora_exist_in_db, + exact_match, + test_vespa, + data_client, + data_db, + monkeypatch, + valid_token, ): """Test a simple known valid search responds with success.""" monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) @@ -66,10 +82,13 @@ def test_search_body_valid( ] assert isinstance(body["families"], list) + assert mock_corpora_exist_in_db.assert_called + @pytest.mark.search +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) def test_no_doc_if_in_postgres_but_not_vespa( - test_vespa, data_client, data_db, monkeypatch, valid_token + mock_corpora_exist_in_db, test_vespa, data_client, data_db, monkeypatch, valid_token ): """Test a simple known valid search responds with success.""" monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) @@ -120,11 +139,21 @@ def test_no_doc_if_in_postgres_but_not_vespa( assert len(body["families"]) == 0 + assert mock_corpora_exist_in_db.assert_called + @pytest.mark.search +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) @pytest.mark.parametrize("label,query", [("search", "the"), ("browse", "")]) def test_benchmark_families_search( - label, query, test_vespa, monkeypatch, data_client, data_db, valid_token + mock_corpora_exist_in_db, + label, + query, + test_vespa, + monkeypatch, + data_client, + data_db, + valid_token, ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -146,10 +175,13 @@ def test_benchmark_families_search( average = sum(times) / len(times) assert average < REASONABLE_LATENCY_MS + assert mock_corpora_exist_in_db.assert_called + +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) @pytest.mark.search def test_specific_doc_returned( - test_vespa, monkeypatch, data_client, data_db, valid_token + mock_corpora_exist_in_db, test_vespa, monkeypatch, data_client, data_db, valid_token ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -167,7 +199,10 @@ def test_specific_doc_returned( family_name = families[0]["family_name"] assert family_name == family_name_query + assert mock_corpora_exist_in_db.assert_called + +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) @pytest.mark.parametrize( ("extra_params", "invalid_field"), [ @@ -179,6 +214,7 @@ def test_specific_doc_returned( ) @pytest.mark.search def test_search_params_backend_limits( + mock_corpora_exist_in_db, test_vespa, monkeypatch, data_client, @@ -201,10 +237,13 @@ def test_search_params_backend_limits( assert "body" in error["loc"], error assert invalid_field in error["loc"], error + assert mock_corpora_exist_in_db.assert_called + @pytest.mark.search +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) def test_search_with_deleted_docs( - test_vespa, monkeypatch, data_client, data_db, valid_token + mock_corpora_exist_in_db, test_vespa, monkeypatch, data_client, data_db, valid_token ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -233,43 +272,4 @@ def test_search_with_deleted_docs( assert start_family_count > one_deleted_count > all_deleted_count assert len(all_deleted_body["families"]) == 0 - -@pytest.mark.search -@pytest.mark.parametrize( - ("corpus_import_id", "corpus_type_name"), - [ - ("CCLW.corpus.1.0", "UNFCCC Submissions"), - ("CCLW.corpus.1.0", None), - (None, "UNFCCC Submissions"), - ], -) -def test_corpus_filtering( - test_vespa, - monkeypatch, - data_client, - data_db, - valid_token, - corpus_import_id: str, - corpus_type_name: str, -): - monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) - _populate_db_families(data_db) - - params: dict[str, Any] = {"query_string": "and"} - if corpus_import_id: - params["corpus_import_ids"] = [corpus_import_id] - if corpus_type_name: - params["corpus_type_names"] = [corpus_type_name] - - response = _make_search_request( - data_client, - token=valid_token, - params=params, - ) - - assert len(response["families"]) > 0 - for family in response["families"]: - if corpus_import_id: - assert family["corpus_import_id"] == corpus_import_id - if corpus_type_name: - assert family["corpus_type_name"] == corpus_type_name + assert mock_corpora_exist_in_db.assert_called diff --git a/tests/search/vespa/test_vespa_search_pagination.py b/tests/search/vespa/test_vespa_search_pagination.py index 44868fda..708f1c9d 100644 --- a/tests/search/vespa/test_vespa_search_pagination.py +++ b/tests/search/vespa/test_vespa_search_pagination.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import pytest from app.api.api_v1.routers import search @@ -9,8 +11,9 @@ @pytest.mark.search +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) def test_simple_pagination_families( - test_vespa, data_client, data_db, monkeypatch, valid_token + mock_corpora_exist_in_db, test_vespa, data_client, data_db, monkeypatch, valid_token ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -53,10 +56,13 @@ def test_simple_pagination_families( == "submission-to-the-unfccc-ahead-of-the-first-technical-dialogue_e760" ) + assert mock_corpora_exist_in_db.assert_called + @pytest.mark.search +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) def test_continuation_token__families( - test_vespa, data_db, monkeypatch, data_client, valid_token + mock_corpora_exist_in_db, test_vespa, data_db, monkeypatch, data_client, valid_token ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) @@ -90,10 +96,13 @@ def test_continuation_token__families( assert sorted(first_family_ids) == sorted(prev_family_ids) + assert mock_corpora_exist_in_db.assert_called + @pytest.mark.search +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) def test_continuation_token__passages( - test_vespa, data_db, monkeypatch, data_client, valid_token + mock_corpora_exist_in_db, test_vespa, data_db, monkeypatch, data_client, valid_token ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) @@ -162,3 +171,5 @@ def test_continuation_token__passages( assert sorted(second_family_second_passages_ids) != sorted( second_family_prev_passages_ids ) + + assert mock_corpora_exist_in_db.assert_called diff --git a/tests/search/vespa/test_vespa_search_result_order.py b/tests/search/vespa/test_vespa_search_result_order.py index 3d8fc41a..ee738804 100644 --- a/tests/search/vespa/test_vespa_search_result_order.py +++ b/tests/search/vespa/test_vespa_search_result_order.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import pytest from app.api.api_v1.routers import search @@ -9,9 +11,17 @@ @pytest.mark.search +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) @pytest.mark.parametrize("label, query", [("search", "the"), ("browse", "")]) def test_result_order_score( - label, query, test_vespa, data_db, monkeypatch, data_client, valid_token + mock_corpora_exist_in_db, + label, + query, + test_vespa, + data_db, + monkeypatch, + data_client, + valid_token, ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -33,11 +43,21 @@ def test_result_order_score( assert asc_dates[0] < desc_dates[0] assert asc_dates[-1] > desc_dates[-1] + assert mock_corpora_exist_in_db.assert_called + @pytest.mark.search +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) @pytest.mark.parametrize("label, query", [("search", "the"), ("browse", "")]) def test_result_order_title( - label, query, test_vespa, data_db, monkeypatch, data_client, valid_token + mock_corpora_exist_in_db, + label, + query, + test_vespa, + data_db, + monkeypatch, + data_client, + valid_token, ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -50,3 +70,5 @@ def test_result_order_title( # Scope of test is to confirm this does not cause a failure _ = _make_search_request(data_client, valid_token, params) + + assert mock_corpora_exist_in_db.assert_called diff --git a/tests/search/vespa/test_whole_database_download.py b/tests/search/vespa/test_whole_database_download.py index dde10115..df2ca087 100644 --- a/tests/search/vespa/test_whole_database_download.py +++ b/tests/search/vespa/test_whole_database_download.py @@ -8,7 +8,8 @@ @pytest.mark.search -def test_all_data_download(data_db, data_client, valid_token): +@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +def test_all_data_download(mock_corpora_exist_in_db, data_db, data_client, valid_token): _populate_db_families(data_db) with ( @@ -25,3 +26,5 @@ def test_all_data_download(data_db, data_client, valid_token): "https://cdn.climatepolicyradar.org/" "navigator/dumps/whole_data_dump-2024-03-22.zip" ) + + assert mock_corpora_exist_in_db.assert_called