diff --git a/app/api/api_v1/routers/search.py b/app/api/api_v1/routers/search.py index 87d5a28d..c621f006 100644 --- a/app/api/api_v1/routers/search.py +++ b/app/api/api_v1/routers/search.py @@ -157,7 +157,7 @@ def search_documents( ) if not validate_corpora_ids(db, allowed_corpora_ids): - msg = "One or more of the given corpora do not exist in the database." + msg = "Error validating corpora IDs." raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=msg, diff --git a/app/db/crud/helpers/validate.py b/app/db/crud/helpers/validate.py index 05a158c9..b017bf6b 100644 --- a/app/db/crud/helpers/validate.py +++ b/app/db/crud/helpers/validate.py @@ -1,4 +1,5 @@ import logging +from typing import Optional, cast from db_client.models.dfce.family import Corpus from sqlalchemy import distinct, select @@ -7,16 +8,23 @@ _LOGGER = logging.getLogger(__name__) -def validate_corpora_ids(db: Session, allowed_corpora_ids: list[str]) -> bool: - """Validate whether all given corpus IDs exist in the DB. +def validate_corpora_ids( + db: Session, corpora_ids: list[str], allowed_corpora_ids: Optional[list[str]] = None +) -> bool: + """Validate all given corpus IDs against a list of allowed corpora. :param Session db: The DB session to connect to. - :param list[str] allowed_corpora_ids: The corpus import IDs we want - to validate. - :return bool: Return whether or not all the corpora exist in the DB. + :param list[str] corpora_ids: The corpus import IDs we want to + validate. + :param Optional[list[str]] allowed_corpora_ids: The corpus import + IDs we want to validate against. + :return bool: Return whether or not all the corpora are valid. """ - existing_corpora_in_db = db.scalars(select(distinct(Corpus.import_id))).all() - validate_success = all( - corpus in allowed_corpora_ids for corpus in existing_corpora_in_db - ) + if allowed_corpora_ids is None: + allowed_corpora_ids = cast( + list, db.scalars(select(distinct(Corpus.import_id))).all() + ) + _LOGGER.info(allowed_corpora_ids) # TODO remove in part 2. + + validate_success = all(corpus in allowed_corpora_ids for corpus in corpora_ids) return validate_success diff --git a/pyproject.toml b/pyproject.toml index a3ffe3c9..cc74cc11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "navigator_backend" -version = "1.16.2" +version = "1.16.3" description = "" authors = ["CPR-dev-team "] packages = [{ include = "app" }, { include = "tests" }] diff --git a/tests/search/vespa/test_search_raises_on_token_errors.py b/tests/search/vespa/test_search_raises_on_token_errors.py index 7bee883d..6c0073d8 100644 --- a/tests/search/vespa/test_search_raises_on_token_errors.py +++ b/tests/search/vespa/test_search_raises_on_token_errors.py @@ -33,10 +33,7 @@ def test_search_with_invalid_corpus_id_in_token( expected_status_code=status.HTTP_400_BAD_REQUEST, ) - assert ( - response["detail"] - == "One or more of the given corpora do not exist in the database." - ) + assert response["detail"] == "Error validating corpora IDs." @pytest.mark.search