Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesgorrie committed Nov 28, 2024
1 parent 5061ac3 commit 08be6bf
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 60 deletions.
2 changes: 1 addition & 1 deletion app/api/api_v1/routers/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ async def search_family(request: Request) -> list[FamilyReadDTO]:

query_params = set_default_query_params(query_params)

VALID_PARAMS = ["q", "title", "summary", "geography", "status", "max_results"]
VALID_PARAMS = ["q", "title", "summary", "geographies", "status", "max_results"]
validate_query_params(query_params, VALID_PARAMS)

try:
Expand Down
8 changes: 4 additions & 4 deletions app/model/family.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Optional, Union
from typing import Optional

from pydantic import BaseModel

Expand All @@ -12,7 +12,7 @@ class FamilyReadDTO(BaseModel):
import_id: str
title: str
summary: str
geography: str
geographies: list[str]
category: str
status: str
metadata: Json
Expand Down Expand Up @@ -41,7 +41,7 @@ class FamilyWriteDTO(BaseModel):

title: str
summary: str
geography: str
geographies: list[str]
category: str
metadata: Json
collections: list[str]
Expand All @@ -60,7 +60,7 @@ class FamilyCreateDTO(BaseModel):
import_id: Optional[str] = None
title: str
summary: str
geography: Union[str, list[str]]
geographies: list[str]
category: str
metadata: Json
collections: list[str]
Expand Down
106 changes: 65 additions & 41 deletions app/repository/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from db_client.models.organisation.users import Organisation
from sqlalchemy import Column, and_
from sqlalchemy import delete as db_delete
from sqlalchemy import desc, func, or_
from sqlalchemy import desc, or_
from sqlalchemy import update as db_update
from sqlalchemy.exc import NoResultFound, OperationalError
from sqlalchemy.orm import Query, Session
Expand All @@ -34,30 +34,24 @@

_LOGGER = logging.getLogger(__name__)

FamilyGeoMetaOrg = Tuple[Family, str, FamilyMetadata, Corpus, Organisation]
FamilyGeoMetaOrg = Tuple[Family, Geography, FamilyMetadata, Corpus, Organisation]


def _get_query(db: Session) -> Query:
# NOTE: SqlAlchemy will make a complete hash of query generation
# if columns are used in the query() call. Therefore, entire
# objects are returned.
geo_subquery = (
db.query(
func.min(Geography.value).label("value"),
FamilyGeography.family_import_id,
)
.join(FamilyGeography, FamilyGeography.geography_id == Geography.id)
.filter(FamilyGeography.family_import_id == Family.import_id)
.group_by(Geography.value, FamilyGeography.family_import_id)
).subquery("geo_subquery")

return (
db.query(Family, geo_subquery.c.value, FamilyMetadata, Corpus, Organisation) # type: ignore
db.query(Family, Geography, FamilyMetadata, Corpus, Organisation) # type: ignore
.join(FamilyGeography, FamilyGeography.family_import_id == Family.import_id)
.join(
Geography,
Geography.id == FamilyGeography.geography_id,
)
.join(FamilyMetadata, FamilyMetadata.family_import_id == Family.import_id)
.join(FamilyCorpus, FamilyCorpus.family_import_id == Family.import_id)
.join(Corpus, Corpus.import_id == FamilyCorpus.corpus_import_id)
.join(Organisation, Corpus.organisation_id == Organisation.id)
.filter(geo_subquery.c.family_import_id == Family.import_id) # type: ignore
)


Expand All @@ -72,7 +66,7 @@ def _family_to_dto(
import_id=str(fam.import_id),
title=str(fam.title),
summary=str(fam.description),
geography=geo_value,
geographies=[str(geo_value.display_value)],
category=str(fam.family_category),
status=str(fam.family_status),
metadata=metadata,
Expand Down Expand Up @@ -100,7 +94,6 @@ def _update_intention(
db: Session,
import_id: str,
family: FamilyWriteDTO,
geo_id: int,
original_family: Family,
):
original_collections = [
Expand All @@ -111,17 +104,17 @@ def _update_intention(
]
update_collections = set(original_collections) != set(family.collections)
update_title = cast(str, original_family.title) != family.title
# TODO: PDCT-1406: Properly implement multi-geography support
update_geo = (
db.query(FamilyGeography)
.filter(FamilyGeography.family_import_id == import_id)
.one()
.geography_id
!= geo_id
)
original_geographies = [
geography.collection_import_id
for geography in db.query(Geography).filter(
original_family.import_id == Geography.family_import_id
)
]
update_geographies = set(original_geographies) != set(family.geographies)

update_basics = (
update_title
or update_geo
or update_geographies
or original_family.description != family.summary
or original_family.family_category != family.category
)
Expand All @@ -131,7 +124,13 @@ def _update_intention(
.one()
)
update_metadata = existing_metadata.value != family.metadata
return update_title, update_basics, update_metadata, update_collections
return (
update_title,
update_basics,
update_metadata,
update_collections,
update_geographies,
)


def all(db: Session, org_id: Optional[int]) -> list[FamilyReadDTO]:
Expand Down Expand Up @@ -201,13 +200,9 @@ def search(
term = f"%{escape_like(search_params['summary'])}%"
search.append(Family.description.ilike(term))

if "geography" in search_params.keys():
term = cast(str, search_params["geography"])
search.append(
or_(
Geography.display_value == term.title(), Geography.value == term.upper()
)
)
if "geographies" in search_params.keys():
term = cast(str, search_params["geographies"])
search.append(Geography.display_value == term.title())

if "status" in search_params.keys():
term = cast(str, search_params["status"])
Expand All @@ -231,7 +226,7 @@ def search(
return [_family_to_dto(db, f) for f in found]


def update(db: Session, import_id: str, family: FamilyWriteDTO, geo_id: int) -> bool:
def update(db: Session, import_id: str, family: FamilyWriteDTO) -> bool:
"""
Updates a single entry with the new values passed.
Expand All @@ -257,7 +252,8 @@ def update(db: Session, import_id: str, family: FamilyWriteDTO, geo_id: int) ->
update_basics,
update_metadata,
update_collections,
) = _update_intention(db, import_id, family, geo_id, original_family)
update_geographies,
) = _update_intention(db, import_id, family, original_family)

# Return if nothing to do
if not (update_title or update_basics or update_metadata or update_collections):
Expand All @@ -276,12 +272,6 @@ def update(db: Session, import_id: str, family: FamilyWriteDTO, geo_id: int) ->
)
)
updates = result.rowcount # type: ignore
# TODO: PDCT-1406: Properly implement multi-geography support
result = db.execute(
db_update(FamilyGeography)
.where(FamilyGeography.family_import_id == import_id)
.values(geography_id=geo_id)
)

updates += result.rowcount # type: ignore
if updates == 0: # type: ignore
Expand Down Expand Up @@ -316,6 +306,40 @@ def update(db: Session, import_id: str, family: FamilyWriteDTO, geo_id: int) ->
db.add(new_slug)
_LOGGER.info(f"Added a new slug for {import_id} of {new_slug.name}")

# update geographies if geographies changed
if update_geographies:
original_geographies = set(
[
geography.geography_id
for geography in db.query(FamilyGeography).filter(
original_family.import_id == FamilyGeography.family_import_id
)
]
)

# Remove any collections that were originally associated with the family but
# now aren't.
geographies_to_remove = set(original_geographies) - set(family.geographies)
for geography in geographies_to_remove:
result = db.execute(
db_delete(Geography).where(FamilyGeography.geography_id == geography)
)

if result.rowcount == 0: # type: ignore
msg = f"Could not remove family {import_id} from collection {geography}"
_LOGGER.error(msg)
raise RepositoryError(msg)

# Add any collections that weren't originally associated with the family.
geographies_to_add = set(family.geographies) - set(original_geographies)
for geography in geographies_to_add:
db.flush()
new_geography = FamilyGeography(
family_import_id=import_id,
geography_id=geography,
)
db.add(new_geography)

# Update collections if collections changed.
if update_collections:
original_collections = set(
Expand Down
4 changes: 4 additions & 0 deletions app/repository/geography.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@

def get_id_from_value(db: Session, geo_string: str) -> Optional[int]:
return db.query(Geography.id).filter_by(value=geo_string).scalar()


def get_ids_from_values(db: Session, geo_strings: list[str]) -> Optional[list[int]]:
return db.query(Geography.id).filter(Geography.value.in_(geo_strings)).scalar()
4 changes: 1 addition & 3 deletions app/repository/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ def search(
...

@staticmethod
def update(
db: Session, import_id: str, family: FamilyWriteDTO, geo_id: int
) -> bool:
def update(db: Session, import_id: str, family: FamilyWriteDTO) -> bool:
"""Updates a family"""
...

Expand Down
12 changes: 6 additions & 6 deletions app/service/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def update(
db = db_session.get_db()

# Validate geography
geo_id = geography.get_id(db, family_dto.geography)
# geo_ids = geography_repo.get_ids_from_values(db, family_dto.geographies)

# Validate family belongs to same org as current user.
entity_org_id: int = corpus.get_corpus_org_id(family.corpus_import_id, db)
Expand All @@ -153,7 +153,7 @@ def update(
raise ValidationError(msg)

try:
if family_repo.update(db, import_id, family_dto, geo_id):
if family_repo.update(db, import_id, family_dto):
db.commit()
else:
db.rollback()
Expand Down Expand Up @@ -185,10 +185,10 @@ def create(

# Validate geographies
geo_ids = []
if isinstance(family.geography, str):
geo_ids.append(geography.get_id(db, family.geography))
elif isinstance(family.geography, list):
for geo_id in family.geography:
if isinstance(family.geographies, str):
geo_ids.append(geography.get_id(db, family.geographies))
elif isinstance(family.geographies, list):
for geo_id in family.geographies:
geo_ids.append(geography.get_id(db, geo_id))

# Validate category
Expand Down
2 changes: 1 addition & 1 deletion app/service/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def save_families(
**fam, corpus_import_id=corpus_import_id
).to_family_create_dto(corpus_import_id)
geo_ids = []
for geo in dto.geography:
for geo in dto.geographies:
geo_ids.append(geography.get_id(db, geo))
import_id = family_repository.create(db, dto, geo_ids, org_id)
family_import_ids.append(import_id)
Expand Down
19 changes: 19 additions & 0 deletions tests/integration_tests/family/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,25 @@
from tests.integration_tests.setup_db import setup_db


def test_search_geographies(
client: TestClient, data_db: Session, superuser_header_token
):
setup_db(data_db)
response = client.get(
"/api/v1/families/?geographies=zimbabwe",
headers=superuser_header_token,
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert isinstance(data, list)

ids_found = set([f["import_id"] for f in data])
assert len(ids_found) == 1

# expected_ids = set(["A.0.0.2", "A.0.0.3"])
# assert ids_found.symmetric_difference(expected_ids) == set([])


def test_search_family_super(
client: TestClient, data_db: Session, superuser_header_token
):
Expand Down
8 changes: 4 additions & 4 deletions tests/integration_tests/setup_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"import_id": "A.0.0.1",
"title": "apple",
"summary": "",
"geography": "Other",
"geographies": ["AFG"],
"category": "UNFCCC",
"status": "Created",
"metadata": {
Expand All @@ -57,7 +57,7 @@
"import_id": "A.0.0.2",
"title": "apple orange banana",
"summary": "apple",
"geography": "Other",
"geographies": ["ZWE"],
"category": "UNFCCC",
"status": "Created",
"metadata": {
Expand All @@ -83,7 +83,7 @@
"import_id": "A.0.0.3",
"title": "title",
"summary": "orange peas",
"geography": "Other",
"geographies": ["AFG"],
"category": "UNFCCC",
"status": "Created",
"metadata": {"author": "CPR", "author_type": "Party"},
Expand Down Expand Up @@ -490,7 +490,7 @@ def _setup_family_data(

geo_id = (
test_db.query(Geography.id)
.filter(Geography.value == data["geography"])
.filter(Geography.value == data["geographies"][0])
.scalar()
)

Expand Down

0 comments on commit 08be6bf

Please sign in to comment.