diff --git a/app/api/api_v1/routers/family.py b/app/api/api_v1/routers/family.py index 2c5f1004..888dec44 100644 --- a/app/api/api_v1/routers/family.py +++ b/app/api/api_v1/routers/family.py @@ -77,6 +77,8 @@ async def get_all_families(request: Request) -> list[FamilyReadDTO]: @r.get("/families/", response_model=list[FamilyReadDTO]) async def search_family( request: Request, + # We have used the built in parsers here for geography specifically + # so that we do not have to build our own geography: Annotated[list[str] | None, Query()] = None, ) -> list[FamilyReadDTO]: """ diff --git a/tests/integration_tests/family/test_search.py b/tests/integration_tests/family/test_search.py index 22be4486..d5d839c9 100644 --- a/tests/integration_tests/family/test_search.py +++ b/tests/integration_tests/family/test_search.py @@ -4,13 +4,56 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from tests.integration_tests.setup_db import setup_db +from tests.integration_tests.setup_db import add_data, setup_db def test_search_geographies( client: TestClient, data_db: Session, superuser_header_token ): setup_db(data_db) + add_data( + data_db, + [ + { + "import_id": "A.0.0.4", + "title": "title", + "summary": "gregarious magazine rub", + "geography": "ALB", + "category": "UNFCCC", + "status": "Created", + "metadata": {"author": "CPR", "author_type": "Party"}, + "organisation": "UNFCCC", + "corpus_import_id": "UNFCCC.corpus.i00000001.n0000", + "corpus_title": "UNFCCC Submissions", + "corpus_type": "Intl. agreements", + "slug": "Slug4", + "events": ["E.0.0.3"], + "published_date": "2018-12-24T04:59:33Z", + "last_updated_date": "2018-12-24T04:59:33Z", + "documents": ["D.0.0.1", "D.0.0.2"], + "collections": ["C.0.0.4"], + }, + { + "import_id": "A.0.0.5", + "title": "title", + "summary": "flour umbrella established", + "geography": "ZMB", + "category": "UNFCCC", + "status": "Created", + "metadata": {"author": "CPR", "author_type": "Party"}, + "organisation": "UNFCCC", + "corpus_import_id": "UNFCCC.corpus.i00000001.n0000", + "corpus_title": "UNFCCC Submissions", + "corpus_type": "Intl. agreements", + "slug": "Slug5", + "events": ["E.0.0.3"], + "published_date": "2018-12-24T04:59:33Z", + "last_updated_date": "2018-12-24T04:59:33Z", + "documents": ["D.0.0.1", "D.0.0.2"], + "collections": ["C.0.0.4"], + }, + ], + ) tests_cases = [ (["afghanistan"], ["A.0.0.1", "A.0.0.3"]), diff --git a/tests/integration_tests/setup_db.py b/tests/integration_tests/setup_db.py index 58c062a7..fc2b84e9 100644 --- a/tests/integration_tests/setup_db.py +++ b/tests/integration_tests/setup_db.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import TypedDict, cast from db_client.models.dfce.collection import ( Collection, @@ -26,7 +26,28 @@ from sqlalchemy import update from sqlalchemy.orm import Session -EXPECTED_FAMILIES = [ + +class DBEntry(TypedDict): + import_id: str + title: str + summary: str + geography: str + category: str + status: str + metadata: dict + organisation: str + corpus_import_id: str + corpus_title: str + corpus_type: str + slug: str + events: list[str] + published_date: str | None + last_updated_date: str | None + documents: list[str] + collections: list[str] + + +EXPECTED_FAMILIES: list[DBEntry] = [ { "import_id": "A.0.0.1", "title": "apple", @@ -98,44 +119,6 @@ "documents": ["D.0.0.1", "D.0.0.2"], "collections": ["C.0.0.4"], }, - { - "import_id": "A.0.0.4", - "title": "title", - "summary": "gregarious magazine rub", - "geography": "ALB", - "category": "UNFCCC", - "status": "Created", - "metadata": {"author": "CPR", "author_type": "Party"}, - "organisation": "UNFCCC", - "corpus_import_id": "UNFCCC.corpus.i00000001.n0000", - "corpus_title": "UNFCCC Submissions", - "corpus_type": "Intl. agreements", - "slug": "Slug3", - "events": ["E.0.0.3"], - "published_date": "2018-12-24T04:59:33Z", - "last_updated_date": "2018-12-24T04:59:33Z", - "documents": ["D.0.0.1", "D.0.0.2"], - "collections": ["C.0.0.4"], - }, - { - "import_id": "A.0.0.5", - "title": "title", - "summary": "flour umbrella established", - "geography": "ZMB", - "category": "UNFCCC", - "status": "Created", - "metadata": {"author": "CPR", "author_type": "Party"}, - "organisation": "UNFCCC", - "corpus_import_id": "UNFCCC.corpus.i00000001.n0000", - "corpus_title": "UNFCCC Submissions", - "corpus_type": "Intl. agreements", - "slug": "Slug3", - "events": ["E.0.0.3"], - "published_date": "2018-12-24T04:59:33Z", - "last_updated_date": "2018-12-24T04:59:33Z", - "documents": ["D.0.0.1", "D.0.0.2"], - "collections": ["C.0.0.4"], - }, ] EXPECTED_NUM_FAMILIES = len(EXPECTED_FAMILIES) @@ -318,6 +301,15 @@ ] +def add_data(test_db: Session, data: list[DBEntry]): + org_id = test_db.query(Organisation).filter(Organisation.name == "CCLW").one().id + other_org_id = ( + test_db.query(Organisation).filter(Organisation.name == "UNFCCC").one().id + ) + _setup_family_data(test_db, org_id, other_org_id, data) + test_db.commit() + + def setup_db(test_db: Session, configure_empty: bool = False): setup_test_data(test_db, configure_empty) @@ -518,13 +510,15 @@ def _setup_family_data( test_db: Session, default_org_id: int, other_org_id: int, + initial_data: list[DBEntry] = EXPECTED_FAMILIES, configure_empty: bool = False, ) -> None: if configure_empty is True: return None - for index in range(EXPECTED_NUM_FAMILIES): - data = EXPECTED_FAMILIES[index] + num_families = len(initial_data) + for index in range(num_families): + data = initial_data[index] geo_id = ( test_db.query(Geography.id) @@ -562,8 +556,8 @@ def _setup_family_data( ) # Now add the metadata onto the families - for index in range(EXPECTED_NUM_FAMILIES): - data = EXPECTED_FAMILIES[index] + for index in range(num_families): + data = initial_data[index] test_db.add( FamilyMetadata( family_import_id=data["import_id"], @@ -572,7 +566,7 @@ def _setup_family_data( ) test_db.add( Slug( - name=f"Slug{index+1}", + name=data["slug"], family_import_id=data["import_id"], ) ) diff --git a/tests/mocks/repos/bad_family_repo.py b/tests/mocks/repos/bad_family_repo.py index 376e7f2c..59082199 100644 --- a/tests/mocks/repos/bad_family_repo.py +++ b/tests/mocks/repos/bad_family_repo.py @@ -13,7 +13,12 @@ def mock_get_all(_): def mock_get(_, import_id: str) -> Optional[FamilyReadDTO]: raise RepositoryError("Bad Repo") - def mock_search(_, q: str, org_id: Optional[int]) -> list[FamilyReadDTO]: + def mock_search( + _, + q: str, + org_id: Optional[int], + geography: list[str] | None, + ) -> list[FamilyReadDTO]: raise RepositoryError("Bad Repo") def mock_update(