diff --git a/app/api/api_v1/routers/family.py b/app/api/api_v1/routers/family.py index 3cb5d017..e4911baf 100644 --- a/app/api/api_v1/routers/family.py +++ b/app/api/api_v1/routers/family.py @@ -92,7 +92,7 @@ async def search_family(request: Request) -> list[FamilyReadDTO]: query_params = set_default_query_params(query_params) - VALID_PARAMS = ["q", "title", "summary", "geography", "status", "max_results"] + VALID_PARAMS = ["q", "title", "summary", "geographies", "status", "max_results"] validate_query_params(query_params, VALID_PARAMS) try: diff --git a/app/model/family.py b/app/model/family.py index 0ca016a7..708b4291 100644 --- a/app/model/family.py +++ b/app/model/family.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Optional, Union +from typing import Optional from pydantic import BaseModel @@ -12,7 +12,7 @@ class FamilyReadDTO(BaseModel): import_id: str title: str summary: str - geography: str + geographies: list[str] category: str status: str metadata: Json @@ -41,7 +41,7 @@ class FamilyWriteDTO(BaseModel): title: str summary: str - geography: str + geographies: list[str] category: str metadata: Json collections: list[str] @@ -60,7 +60,7 @@ class FamilyCreateDTO(BaseModel): import_id: Optional[str] = None title: str summary: str - geography: Union[str, list[str]] + geographies: list[str] category: str metadata: Json collections: list[str] diff --git a/app/repository/family.py b/app/repository/family.py index ffe42476..25f0533f 100644 --- a/app/repository/family.py +++ b/app/repository/family.py @@ -22,7 +22,7 @@ from db_client.models.organisation.users import Organisation from sqlalchemy import Column, and_ from sqlalchemy import delete as db_delete -from sqlalchemy import desc, func, or_ +from sqlalchemy import desc, or_ from sqlalchemy import update as db_update from sqlalchemy.exc import NoResultFound, OperationalError from sqlalchemy.orm import Query, Session @@ -34,30 +34,24 @@ _LOGGER = logging.getLogger(__name__) -FamilyGeoMetaOrg = Tuple[Family, str, FamilyMetadata, Corpus, Organisation] +FamilyGeoMetaOrg = Tuple[Family, Geography, FamilyMetadata, Corpus, Organisation] def _get_query(db: Session) -> Query: # NOTE: SqlAlchemy will make a complete hash of query generation # if columns are used in the query() call. Therefore, entire # objects are returned. - geo_subquery = ( - db.query( - func.min(Geography.value).label("value"), - FamilyGeography.family_import_id, - ) - .join(FamilyGeography, FamilyGeography.geography_id == Geography.id) - .filter(FamilyGeography.family_import_id == Family.import_id) - .group_by(Geography.value, FamilyGeography.family_import_id) - ).subquery("geo_subquery") - return ( - db.query(Family, geo_subquery.c.value, FamilyMetadata, Corpus, Organisation) # type: ignore + db.query(Family, Geography, FamilyMetadata, Corpus, Organisation) # type: ignore + .join(FamilyGeography, FamilyGeography.family_import_id == Family.import_id) + .join( + Geography, + Geography.id == FamilyGeography.geography_id, + ) .join(FamilyMetadata, FamilyMetadata.family_import_id == Family.import_id) .join(FamilyCorpus, FamilyCorpus.family_import_id == Family.import_id) .join(Corpus, Corpus.import_id == FamilyCorpus.corpus_import_id) .join(Organisation, Corpus.organisation_id == Organisation.id) - .filter(geo_subquery.c.family_import_id == Family.import_id) # type: ignore ) @@ -72,7 +66,7 @@ def _family_to_dto( import_id=str(fam.import_id), title=str(fam.title), summary=str(fam.description), - geography=geo_value, + geographies=[str(geo_value.display_value)], category=str(fam.family_category), status=str(fam.family_status), metadata=metadata, @@ -100,7 +94,6 @@ def _update_intention( db: Session, import_id: str, family: FamilyWriteDTO, - geo_id: int, original_family: Family, ): original_collections = [ @@ -111,17 +104,17 @@ def _update_intention( ] update_collections = set(original_collections) != set(family.collections) update_title = cast(str, original_family.title) != family.title - # TODO: PDCT-1406: Properly implement multi-geography support - update_geo = ( - db.query(FamilyGeography) - .filter(FamilyGeography.family_import_id == import_id) - .one() - .geography_id - != geo_id - ) + original_geographies = [ + geography.collection_import_id + for geography in db.query(Geography).filter( + original_family.import_id == Geography.family_import_id + ) + ] + update_geographies = set(original_geographies) != set(family.geographies) + update_basics = ( update_title - or update_geo + or update_geographies or original_family.description != family.summary or original_family.family_category != family.category ) @@ -131,7 +124,13 @@ def _update_intention( .one() ) update_metadata = existing_metadata.value != family.metadata - return update_title, update_basics, update_metadata, update_collections + return ( + update_title, + update_basics, + update_metadata, + update_collections, + update_geographies, + ) def all(db: Session, org_id: Optional[int]) -> list[FamilyReadDTO]: @@ -201,13 +200,9 @@ def search( term = f"%{escape_like(search_params['summary'])}%" search.append(Family.description.ilike(term)) - if "geography" in search_params.keys(): - term = cast(str, search_params["geography"]) - search.append( - or_( - Geography.display_value == term.title(), Geography.value == term.upper() - ) - ) + if "geographies" in search_params.keys(): + term = cast(str, search_params["geographies"]) + search.append(Geography.display_value == term.title()) if "status" in search_params.keys(): term = cast(str, search_params["status"]) @@ -231,7 +226,7 @@ def search( return [_family_to_dto(db, f) for f in found] -def update(db: Session, import_id: str, family: FamilyWriteDTO, geo_id: int) -> bool: +def update(db: Session, import_id: str, family: FamilyWriteDTO) -> bool: """ Updates a single entry with the new values passed. @@ -257,7 +252,8 @@ def update(db: Session, import_id: str, family: FamilyWriteDTO, geo_id: int) -> update_basics, update_metadata, update_collections, - ) = _update_intention(db, import_id, family, geo_id, original_family) + update_geographies, + ) = _update_intention(db, import_id, family, original_family) # Return if nothing to do if not (update_title or update_basics or update_metadata or update_collections): @@ -276,12 +272,6 @@ def update(db: Session, import_id: str, family: FamilyWriteDTO, geo_id: int) -> ) ) updates = result.rowcount # type: ignore - # TODO: PDCT-1406: Properly implement multi-geography support - result = db.execute( - db_update(FamilyGeography) - .where(FamilyGeography.family_import_id == import_id) - .values(geography_id=geo_id) - ) updates += result.rowcount # type: ignore if updates == 0: # type: ignore @@ -316,6 +306,40 @@ def update(db: Session, import_id: str, family: FamilyWriteDTO, geo_id: int) -> db.add(new_slug) _LOGGER.info(f"Added a new slug for {import_id} of {new_slug.name}") + # update geographies if geographies changed + if update_geographies: + original_geographies = set( + [ + geography.geography_id + for geography in db.query(FamilyGeography).filter( + original_family.import_id == FamilyGeography.family_import_id + ) + ] + ) + + # Remove any collections that were originally associated with the family but + # now aren't. + geographies_to_remove = set(original_geographies) - set(family.geographies) + for geography in geographies_to_remove: + result = db.execute( + db_delete(Geography).where(FamilyGeography.geography_id == geography) + ) + + if result.rowcount == 0: # type: ignore + msg = f"Could not remove family {import_id} from collection {geography}" + _LOGGER.error(msg) + raise RepositoryError(msg) + + # Add any collections that weren't originally associated with the family. + geographies_to_add = set(family.geographies) - set(original_geographies) + for geography in geographies_to_add: + db.flush() + new_geography = FamilyGeography( + family_import_id=import_id, + geography_id=geography, + ) + db.add(new_geography) + # Update collections if collections changed. if update_collections: original_collections = set( diff --git a/app/repository/geography.py b/app/repository/geography.py index 5ab606ba..7a8234d3 100644 --- a/app/repository/geography.py +++ b/app/repository/geography.py @@ -6,3 +6,7 @@ def get_id_from_value(db: Session, geo_string: str) -> Optional[int]: return db.query(Geography.id).filter_by(value=geo_string).scalar() + + +def get_ids_from_values(db: Session, geo_strings: list[str]) -> Optional[list[int]]: + return db.query(Geography.id).filter(Geography.value.in_(geo_strings)).scalar() diff --git a/app/repository/protocols.py b/app/repository/protocols.py index 7cfbf092..004209b1 100644 --- a/app/repository/protocols.py +++ b/app/repository/protocols.py @@ -34,9 +34,7 @@ def search( ... @staticmethod - def update( - db: Session, import_id: str, family: FamilyWriteDTO, geo_id: int - ) -> bool: + def update(db: Session, import_id: str, family: FamilyWriteDTO) -> bool: """Updates a family""" ... diff --git a/app/service/family.py b/app/service/family.py index 0eb85cfc..137dc779 100644 --- a/app/service/family.py +++ b/app/service/family.py @@ -127,7 +127,7 @@ def update( db = db_session.get_db() # Validate geography - geo_id = geography.get_id(db, family_dto.geography) + # geo_ids = geography_repo.get_ids_from_values(db, family_dto.geographies) # Validate family belongs to same org as current user. entity_org_id: int = corpus.get_corpus_org_id(family.corpus_import_id, db) @@ -153,7 +153,7 @@ def update( raise ValidationError(msg) try: - if family_repo.update(db, import_id, family_dto, geo_id): + if family_repo.update(db, import_id, family_dto): db.commit() else: db.rollback() @@ -185,10 +185,10 @@ def create( # Validate geographies geo_ids = [] - if isinstance(family.geography, str): - geo_ids.append(geography.get_id(db, family.geography)) - elif isinstance(family.geography, list): - for geo_id in family.geography: + if isinstance(family.geographies, str): + geo_ids.append(geography.get_id(db, family.geographies)) + elif isinstance(family.geographies, list): + for geo_id in family.geographies: geo_ids.append(geography.get_id(db, geo_id)) # Validate category diff --git a/app/service/ingest.py b/app/service/ingest.py index 84288ca2..d3ce6652 100644 --- a/app/service/ingest.py +++ b/app/service/ingest.py @@ -219,7 +219,7 @@ def save_families( **fam, corpus_import_id=corpus_import_id ).to_family_create_dto(corpus_import_id) geo_ids = [] - for geo in dto.geography: + for geo in dto.geographies: geo_ids.append(geography.get_id(db, geo)) import_id = family_repository.create(db, dto, geo_ids, org_id) family_import_ids.append(import_id) diff --git a/tests/integration_tests/family/test_search.py b/tests/integration_tests/family/test_search.py index c3e7d361..6b85524e 100644 --- a/tests/integration_tests/family/test_search.py +++ b/tests/integration_tests/family/test_search.py @@ -7,6 +7,25 @@ from tests.integration_tests.setup_db import setup_db +def test_search_geographies( + client: TestClient, data_db: Session, superuser_header_token +): + setup_db(data_db) + response = client.get( + "/api/v1/families/?geographies=zimbabwe", + headers=superuser_header_token, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + + ids_found = set([f["import_id"] for f in data]) + assert len(ids_found) == 1 + + # expected_ids = set(["A.0.0.2", "A.0.0.3"]) + # assert ids_found.symmetric_difference(expected_ids) == set([]) + + def test_search_family_super( client: TestClient, data_db: Session, superuser_header_token ): diff --git a/tests/integration_tests/setup_db.py b/tests/integration_tests/setup_db.py index ad112ee2..076231b7 100644 --- a/tests/integration_tests/setup_db.py +++ b/tests/integration_tests/setup_db.py @@ -31,7 +31,7 @@ "import_id": "A.0.0.1", "title": "apple", "summary": "", - "geography": "Other", + "geographies": ["AFG"], "category": "UNFCCC", "status": "Created", "metadata": { @@ -57,7 +57,7 @@ "import_id": "A.0.0.2", "title": "apple orange banana", "summary": "apple", - "geography": "Other", + "geographies": ["ZWE"], "category": "UNFCCC", "status": "Created", "metadata": { @@ -83,7 +83,7 @@ "import_id": "A.0.0.3", "title": "title", "summary": "orange peas", - "geography": "Other", + "geographies": ["AFG"], "category": "UNFCCC", "status": "Created", "metadata": {"author": "CPR", "author_type": "Party"}, @@ -490,7 +490,7 @@ def _setup_family_data( geo_id = ( test_db.query(Geography.id) - .filter(Geography.value == data["geography"]) + .filter(Geography.value == data["geographies"][0]) .scalar() )