diff --git a/app/api/api_v1/routers/family.py b/app/api/api_v1/routers/family.py index 3cb5d017..888dec44 100644 --- a/app/api/api_v1/routers/family.py +++ b/app/api/api_v1/routers/family.py @@ -8,8 +8,9 @@ """ import logging +from typing import Annotated -from fastapi import APIRouter, HTTPException, Request, status +from fastapi import APIRouter, HTTPException, Query, Request, status import app.service.family as family_service from app.api.api_v1.query_params import ( @@ -74,7 +75,12 @@ async def get_all_families(request: Request) -> list[FamilyReadDTO]: @r.get("/families/", response_model=list[FamilyReadDTO]) -async def search_family(request: Request) -> list[FamilyReadDTO]: +async def search_family( + request: Request, + # We have used the built in parsers here for geography specifically + # so that we do not have to build our own + geography: Annotated[list[str] | None, Query()] = None, +) -> list[FamilyReadDTO]: """ Searches for families matching URL parameters ("q" by default). @@ -96,7 +102,7 @@ async def search_family(request: Request) -> list[FamilyReadDTO]: validate_query_params(query_params, VALID_PARAMS) try: - families = family_service.search(query_params, request.state.user) + families = family_service.search(query_params, request.state.user, geography) except ValidationError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.message) except RepositoryError as e: diff --git a/app/repository/family.py b/app/repository/family.py index cb36c279..698a2cd7 100644 --- a/app/repository/family.py +++ b/app/repository/family.py @@ -167,7 +167,10 @@ def get(db: Session, import_id: str) -> Optional[FamilyReadDTO]: def search( - db: Session, search_params: dict[str, Union[str, int]], org_id: Optional[int] + db: Session, + search_params: dict[str, Union[str, int]], + org_id: Optional[int], + geography: Optional[list[str]], ) -> list[FamilyReadDTO]: """ Gets a list of families from the repository searching given fields. @@ -195,13 +198,11 @@ def search( term = f"%{escape_like(search_params['summary'])}%" search.append(Family.description.ilike(term)) - if "geography" in search_params.keys(): - term = cast(str, search_params["geography"]) - search.append( - or_( - Geography.display_value == term.title(), Geography.value == term.upper() - ) + if geography is not None: + geography_filter = or_( + *[(Geography.display_value == g.title()) for g in geography] ) + search.append(geography_filter) if "status" in search_params.keys(): term = cast(str, search_params["status"]) diff --git a/app/repository/protocols.py b/app/repository/protocols.py index 7cfbf092..fc08ab7f 100644 --- a/app/repository/protocols.py +++ b/app/repository/protocols.py @@ -28,7 +28,10 @@ def get(db: Session, import_id: str) -> Optional[FamilyReadDTO]: @staticmethod def search( - db: Session, search_params: dict[str, Union[str, int]], org_id: Optional[int] + db: Session, + search_params: dict[str, Union[str, int]], + org_id: Optional[int], + geography: Optional[list[str]], ) -> list[FamilyReadDTO]: """Searches the families""" ... diff --git a/app/service/family.py b/app/service/family.py index 0eb85cfc..9940857a 100644 --- a/app/service/family.py +++ b/app/service/family.py @@ -64,7 +64,9 @@ def all(user: UserContext) -> list[FamilyReadDTO]: @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def search( - search_params: dict[str, Union[str, int]], user: UserContext + search_params: dict[str, Union[str, int]], + user: UserContext, + geography: Optional[list[str]] = None, ) -> list[FamilyReadDTO]: """ Searches for the search term against families on specified fields. @@ -76,12 +78,13 @@ def search( :param dict search_params: Search patterns to match against specified fields, given as key value pairs in a dictionary. :param UserContext user: The current user context. + :param Optional[list[str]] geography: geographies to filter on. :return list[FamilyDTO]: The list of families matching the given search terms. """ with db_session.get_db() as db: org_id = app_user.restrict_entities_to_user_org(user) - return family_repo.search(db, search_params, org_id) + return family_repo.search(db, search_params, org_id, geography) @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) diff --git a/pyproject.toml b/pyproject.toml index ecf3294f..942c39e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "admin_backend" -version = "2.17.21" +version = "2.17.22" description = "" authors = ["CPR-dev-team "] packages = [{ include = "app" }, { include = "tests" }] diff --git a/tests/integration_tests/family/test_search.py b/tests/integration_tests/family/test_search.py index 0eff7d47..d5d839c9 100644 --- a/tests/integration_tests/family/test_search.py +++ b/tests/integration_tests/family/test_search.py @@ -4,28 +4,74 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from tests.integration_tests.setup_db import setup_db +from tests.integration_tests.setup_db import add_data, setup_db def test_search_geographies( client: TestClient, data_db: Session, superuser_header_token ): setup_db(data_db) + add_data( + data_db, + [ + { + "import_id": "A.0.0.4", + "title": "title", + "summary": "gregarious magazine rub", + "geography": "ALB", + "category": "UNFCCC", + "status": "Created", + "metadata": {"author": "CPR", "author_type": "Party"}, + "organisation": "UNFCCC", + "corpus_import_id": "UNFCCC.corpus.i00000001.n0000", + "corpus_title": "UNFCCC Submissions", + "corpus_type": "Intl. agreements", + "slug": "Slug4", + "events": ["E.0.0.3"], + "published_date": "2018-12-24T04:59:33Z", + "last_updated_date": "2018-12-24T04:59:33Z", + "documents": ["D.0.0.1", "D.0.0.2"], + "collections": ["C.0.0.4"], + }, + { + "import_id": "A.0.0.5", + "title": "title", + "summary": "flour umbrella established", + "geography": "ZMB", + "category": "UNFCCC", + "status": "Created", + "metadata": {"author": "CPR", "author_type": "Party"}, + "organisation": "UNFCCC", + "corpus_import_id": "UNFCCC.corpus.i00000001.n0000", + "corpus_title": "UNFCCC Submissions", + "corpus_type": "Intl. agreements", + "slug": "Slug5", + "events": ["E.0.0.3"], + "published_date": "2018-12-24T04:59:33Z", + "last_updated_date": "2018-12-24T04:59:33Z", + "documents": ["D.0.0.1", "D.0.0.2"], + "collections": ["C.0.0.4"], + }, + ], + ) tests_cases = [ - ("afghanistan", 2), - ("zimbabwe", 1), + (["afghanistan"], ["A.0.0.1", "A.0.0.3"]), + (["zimbabwe"], ["A.0.0.2"]), + (["albania", "zambia"], ["A.0.0.4", "A.0.0.5"]), ] - for country, expected_count in tests_cases: + for countries, expected_ids in tests_cases: + geographies_query = "&".join([f"geography={country}" for country in countries]) response = client.get( - f"/api/v1/families/?geography={country}", + f"/api/v1/families/?{geographies_query}", headers=superuser_header_token, ) assert response.status_code == status.HTTP_200_OK data = response.json() + ids = [result["import_id"] for result in data] assert isinstance(data, list) - assert len(data) == expected_count + assert ids == expected_ids def test_search_family_super( diff --git a/tests/integration_tests/setup_db.py b/tests/integration_tests/setup_db.py index 3c730507..fc2b84e9 100644 --- a/tests/integration_tests/setup_db.py +++ b/tests/integration_tests/setup_db.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import TypedDict, cast from db_client.models.dfce.collection import ( Collection, @@ -26,7 +26,28 @@ from sqlalchemy import update from sqlalchemy.orm import Session -EXPECTED_FAMILIES = [ + +class DBEntry(TypedDict): + import_id: str + title: str + summary: str + geography: str + category: str + status: str + metadata: dict + organisation: str + corpus_import_id: str + corpus_title: str + corpus_type: str + slug: str + events: list[str] + published_date: str | None + last_updated_date: str | None + documents: list[str] + collections: list[str] + + +EXPECTED_FAMILIES: list[DBEntry] = [ { "import_id": "A.0.0.1", "title": "apple", @@ -280,6 +301,15 @@ ] +def add_data(test_db: Session, data: list[DBEntry]): + org_id = test_db.query(Organisation).filter(Organisation.name == "CCLW").one().id + other_org_id = ( + test_db.query(Organisation).filter(Organisation.name == "UNFCCC").one().id + ) + _setup_family_data(test_db, org_id, other_org_id, data) + test_db.commit() + + def setup_db(test_db: Session, configure_empty: bool = False): setup_test_data(test_db, configure_empty) @@ -480,13 +510,15 @@ def _setup_family_data( test_db: Session, default_org_id: int, other_org_id: int, + initial_data: list[DBEntry] = EXPECTED_FAMILIES, configure_empty: bool = False, ) -> None: if configure_empty is True: return None - for index in range(EXPECTED_NUM_FAMILIES): - data = EXPECTED_FAMILIES[index] + num_families = len(initial_data) + for index in range(num_families): + data = initial_data[index] geo_id = ( test_db.query(Geography.id) @@ -524,8 +556,8 @@ def _setup_family_data( ) # Now add the metadata onto the families - for index in range(EXPECTED_NUM_FAMILIES): - data = EXPECTED_FAMILIES[index] + for index in range(num_families): + data = initial_data[index] test_db.add( FamilyMetadata( family_import_id=data["import_id"], @@ -534,7 +566,7 @@ def _setup_family_data( ) test_db.add( Slug( - name=f"Slug{index+1}", + name=data["slug"], family_import_id=data["import_id"], ) ) diff --git a/tests/mocks/repos/bad_family_repo.py b/tests/mocks/repos/bad_family_repo.py index 376e7f2c..4696f3f1 100644 --- a/tests/mocks/repos/bad_family_repo.py +++ b/tests/mocks/repos/bad_family_repo.py @@ -13,7 +13,12 @@ def mock_get_all(_): def mock_get(_, import_id: str) -> Optional[FamilyReadDTO]: raise RepositoryError("Bad Repo") - def mock_search(_, q: str, org_id: Optional[int]) -> list[FamilyReadDTO]: + def mock_search( + _, + q: str, + org_id: Optional[int], + geography: Optional[list[str]], + ) -> list[FamilyReadDTO]: raise RepositoryError("Bad Repo") def mock_update( diff --git a/tests/mocks/repos/family_repo.py b/tests/mocks/repos/family_repo.py index 57eda0bd..2b9fcacc 100644 --- a/tests/mocks/repos/family_repo.py +++ b/tests/mocks/repos/family_repo.py @@ -36,7 +36,10 @@ def get(db: Session, import_id: str) -> Optional[FamilyReadDTO]: def search( - db: Session, search_params: dict[str, Union[str, int]], org_id: Optional[int] + db: Session, + search_params: dict[str, Union[str, int]], + org_id: Optional[int], + geography: Optional[list[str]], ) -> list[FamilyReadDTO]: _maybe_throw() _maybe_timeout() diff --git a/tests/mocks/services/family_service.py b/tests/mocks/services/family_service.py index 8964df04..a10bba3a 100644 --- a/tests/mocks/services/family_service.py +++ b/tests/mocks/services/family_service.py @@ -34,7 +34,9 @@ def mock_get_family(import_id: str) -> Optional[FamilyReadDTO]: if not family_service.missing: return create_family_read_dto(import_id, collections=["x.y.z.1", "x.y.z.2"]) - def mock_search_families(q_params: dict, user_email: str) -> list[FamilyReadDTO]: + def mock_search_families( + q_params: dict, user_email: str, geography: Optional[list[str]] + ) -> list[FamilyReadDTO]: if q_params["q"] == "empty": return []