Skip to content

Commit

Permalink
Feature/pdct 1274 update the api search method to apply corpus filter…
Browse files Browse the repository at this point in the history
…ing (#350)

* Use host header for audience comparison

* Update custom app documentation to use hostname

* Make audience a string instead of URL

* Use host for audience & add additional corpora ID validation

* Bump to 1.16.4

* Two step validation of corpora IDs in search

* Bump to 1.16.5

* Fix order of patches

* Add app token corpora as default corpora for search

* Mock existence of corpus IDs so search validation passes

* Move corpus filtering tests into separate file

* Update corpora IDs to match search fixture

* Bump to 1.17.0

* Remove redundant check
  • Loading branch information
katybaulch authored Sep 24, 2024
1 parent 9f844ca commit 4072d70
Show file tree
Hide file tree
Showing 15 changed files with 454 additions and 176 deletions.
45 changes: 33 additions & 12 deletions app/api/api_v1/routers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import logging
from io import BytesIO
from typing import Annotated
from typing import Annotated, Sequence, cast

from cpr_sdk.exceptions import QueryError
from cpr_sdk.search_adaptors import VespaSearchAdapter
Expand Down Expand Up @@ -37,7 +37,10 @@
process_result_into_csv,
process_vespa_search_response,
)
from app.db.crud.helpers.validate import verify_any_corpora_ids_in_db
from app.db.crud.helpers.validate import (
validate_corpora_ids,
verify_any_corpora_ids_in_db,
)
from app.db.session import get_db

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -137,6 +140,17 @@ def search_documents(
the search database. The continuation token can be used to get the next set of
results from the search database. See the request schema for more details.
"""
_LOGGER.info(
"Search request",
extra={
"props": {
"search_request": search_body.model_dump(),
"host": str(host),
"app_token": str(app_token),
}
},
)

try:
allowed_corpora_ids = decode_config_token(app_token, host)
except PyJWTError as e:
Expand All @@ -147,16 +161,6 @@ def search_documents(
headers={"WWW-Authenticate": "Bearer"},
)

_LOGGER.info(
"Search request",
extra={
"props": {
"search_request": search_body.model_dump(),
"allowed_corpora_ids": str(allowed_corpora_ids),
}
},
)

# First corpora validation is app token against DB. At least one of the app token
# corpora IDs must be present in the DB to continue the search request.
if not verify_any_corpora_ids_in_db(db, allowed_corpora_ids):
Expand All @@ -167,6 +171,23 @@ def search_documents(
headers={"WWW-Authenticate": "Bearer"},
)

# If the search request IDs are null, we want to search using the app token corpora.
if search_body.corpus_import_ids is None:
search_body.corpus_import_ids = cast(Sequence, allowed_corpora_ids)

# For the second validation, search request corpora Ids are validated against the
# app token corpora IDs if the search request param 'corpus_import_ids' is not None.
# corpus_import_ids must be a subset of app token IDs.
if not validate_corpora_ids(
set(search_body.corpus_import_ids), cast(set, allowed_corpora_ids)
):
msg = "Error validating corpora IDs."
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=msg,
headers={"WWW-Authenticate": "Bearer"},
)

_LOGGER.info(
"Starting search...",
extra={"props": {"search_request": search_body.model_dump()}},
Expand Down
19 changes: 19 additions & 0 deletions app/db/crud/helpers/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,22 @@ def verify_any_corpora_ids_in_db(db: Session, corpora_ids: list[str]) -> bool:
)

return validate_success


def validate_corpora_ids(corpora_ids: set[str], valid_corpora_ids: set[str]) -> bool:
"""Validate all given corpus IDs against a list of allowed corpora.
:param set[str] corpora_ids: The corpus import IDs we want to
validate.
:param set[str] valid_corpora_ids: The corpus import IDs
we want to validate against.
:return bool: Return whether or not all the corpora are valid.
"""
validate_success = corpora_ids.issubset(valid_corpora_ids)
if not validate_success:
invalid_corpora = set(corpora_ids).difference(valid_corpora_ids)
_LOGGER.warning(
f"Some corpora in search request params {invalid_corpora}"
"forbidden to search against."
)
return validate_success
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "navigator_backend"
version = "1.16.4"
version = "1.17.0"
description = ""
authors = ["CPR-dev-team <[email protected]>"]
packages = [{ include = "app" }, { include = "tests" }]
Expand Down
19 changes: 1 addition & 18 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,24 +113,7 @@ def valid_token():
might be different (e.g., the one for staging). This fixture works
around this.
"""
corpora_ids = "CCLW.corpus.i00000001.n0000,UNFCCC.corpus.i00000001.n0000"
subject = "CCLW"
audience = "localhost:8888"
input_str = f"{corpora_ids};{subject};{audience}"
return create_configuration_token(input_str)


@pytest.fixture
def unfccc_token():
"""Generate valid config token using TOKEN_SECRET_KEY.
Need to generate the config token using the token secret key from
your local env file. For tests in CI, this will be the secret key in
the .env.example file, but for local development this secret key
might be different (e.g., the one for staging). This fixture works
around this.
"""
corpora_ids = "CCLW.corpus.i00000001.n0000,UNFCCC.corpus.i00000001.n0000"
corpora_ids = "CCLW.corpus.1.0,CCLW.corpus.2.0"
subject = "CCLW"
audience = "localhost:8888"
input_str = f"{corpora_ids};{subject};{audience}"
Expand Down
100 changes: 91 additions & 9 deletions tests/search/vespa/test_range_and_keyword_filters_search.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import patch

import pytest
from db_client.models.dfce import Geography
from fastapi import status
Expand All @@ -12,9 +14,17 @@


@pytest.mark.search
@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True)
@pytest.mark.parametrize("label,query", [("search", "the"), ("browse", "")])
def test_keyword_country_filters__geography(
label, query, test_vespa, data_client, data_db, monkeypatch, valid_token
mock_corpora_exist_in_db,
label,
query,
test_vespa,
data_client,
data_db,
monkeypatch,
valid_token,
):
monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa)
_populate_db_families(data_db)
Expand Down Expand Up @@ -43,11 +53,21 @@ def test_keyword_country_filters__geography(
assert len(filtered_family_slugs) == 1
assert family["family_slug"] in filtered_family_slugs

assert mock_corpora_exist_in_db.assert_called


@pytest.mark.search
@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True)
@pytest.mark.parametrize("label,query", [("search", "the"), ("browse", "")])
def test_keyword_country_filters__geographies(
label, query, test_vespa, data_client, data_db, monkeypatch, valid_token
mock_corpora_exist_in_db,
label,
query,
test_vespa,
data_client,
data_db,
monkeypatch,
valid_token,
):
monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa)
_populate_db_families(data_db)
Expand Down Expand Up @@ -78,11 +98,21 @@ def test_keyword_country_filters__geographies(
assert len(filtered_family_slugs) == 1
assert family["family_slug"] in filtered_family_slugs

assert mock_corpora_exist_in_db.assert_called


@pytest.mark.search
@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True)
@pytest.mark.parametrize("label,query", [("search", "the"), ("browse", "")])
def test_keyword_region_filters(
label, query, test_vespa, data_client, data_db, monkeypatch, valid_token
mock_corpora_exist_in_db,
label,
query,
test_vespa,
data_client,
data_db,
monkeypatch,
valid_token,
):
monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa)
_populate_db_families(data_db)
Expand Down Expand Up @@ -118,11 +148,21 @@ def test_keyword_region_filters(
]
assert family["family_slug"] in filtered_family_slugs

assert mock_corpora_exist_in_db.assert_called


@pytest.mark.search
@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True)
@pytest.mark.parametrize("label,query", [("search", "the"), ("browse", "")])
def test_keyword_region_and_country_filters(
label, query, test_vespa, data_client, data_db, monkeypatch, valid_token
mock_corpora_exist_in_db,
label,
query,
test_vespa,
data_client,
data_db,
monkeypatch,
valid_token,
):
monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa)
_populate_db_families(data_db)
Expand All @@ -141,11 +181,21 @@ def test_keyword_region_and_country_filters(
assert len(body["families"]) == 1
assert body["families"][0]["family_name"] == "National Energy Strategy"

assert mock_corpora_exist_in_db.assert_called


@pytest.mark.search
@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True)
@pytest.mark.parametrize("label,query", [("search", "the"), ("browse", "")])
def test_invalid_keyword_filters(
label, query, test_vespa, data_db, monkeypatch, data_client, valid_token
mock_corpora_exist_in_db,
label,
query,
test_vespa,
data_db,
monkeypatch,
data_client,
valid_token,
):
monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa)
_populate_db_families(data_db)
Expand All @@ -164,13 +214,22 @@ def test_invalid_keyword_filters(
expected_status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
)

assert mock_corpora_exist_in_db.assert_called


@pytest.mark.search
@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True)
@pytest.mark.parametrize(
"year_range", [(None, None), (1900, None), (None, 2020), (1900, 2020)]
)
def test_year_range_filtered_in(
year_range, test_vespa, data_db, monkeypatch, data_client, valid_token
mock_corpora_exist_in_db,
year_range,
test_vespa,
data_db,
monkeypatch,
data_client,
valid_token,
):
monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa)
_populate_db_families(data_db)
Expand All @@ -185,11 +244,20 @@ def test_year_range_filtered_in(
body = _make_search_request(data_client, valid_token, params=params)
assert len(body["families"]) > 0

assert mock_corpora_exist_in_db.assert_called


@pytest.mark.search
@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True)
@pytest.mark.parametrize("year_range", [(None, 2010), (2024, None)])
def test_year_range_filtered_out(
year_range, test_vespa, data_db, monkeypatch, data_client, valid_token
mock_corpora_exist_in_db,
year_range,
test_vespa,
data_db,
monkeypatch,
data_client,
valid_token,
):
monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa)
_populate_db_families(data_db)
Expand All @@ -204,11 +272,21 @@ def test_year_range_filtered_out(
body = _make_search_request(data_client, valid_token, params=params)
assert len(body["families"]) == 0

assert mock_corpora_exist_in_db.assert_called


@pytest.mark.search
@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True)
@pytest.mark.parametrize("label, query", [("search", "the"), ("browse", "")])
def test_multiple_filters(
label, query, test_vespa, data_db, monkeypatch, data_client, valid_token
mock_corpora_exist_in_db,
label,
query,
test_vespa,
data_db,
monkeypatch,
data_client,
valid_token,
):
monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa)
_populate_db_families(data_db)
Expand All @@ -224,11 +302,13 @@ def test_multiple_filters(
}

_ = _make_search_request(data_client, valid_token, params)
assert mock_corpora_exist_in_db.assert_called


@pytest.mark.search
@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True)
def test_geo_filter_with_exact(
test_vespa, data_db, monkeypatch, data_client, valid_token
mock_corpora_exist_in_db, test_vespa, data_db, monkeypatch, data_client, valid_token
):
monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa)
_populate_db_families(data_db)
Expand All @@ -246,3 +326,5 @@ def test_geo_filter_with_exact(
assert len(response["families"]) > 0
for family in response["families"]:
assert "ITA" in family["family_geographies"]

assert mock_corpora_exist_in_db.assert_called
Loading

0 comments on commit 4072d70

Please sign in to comment.