Skip to content

Commit

Permalink
refactor: update subquery to aggregate geography values
Browse files Browse the repository at this point in the history
- remove commented out code
  • Loading branch information
Osneil Drakes authored and Osneil Drakes committed Dec 23, 2024
1 parent edc17a7 commit 09090b4
Showing 1 changed file with 17 additions and 30 deletions.
47 changes: 17 additions & 30 deletions app/repository/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,31 +34,23 @@

_LOGGER = logging.getLogger(__name__)

FamilyGeoMetaOrg = Tuple[Family, Geography, FamilyMetadata, Corpus, Organisation]
FamilyGeoMetaOrg = Tuple[
Family,
list[str],
FamilyMetadata,
Corpus,
Organisation,
]


def _get_query(db: Session) -> Query:
# NOTE: SqlAlchemy will make a complete hash of query generation
# if columns are used in the query() call. Therefore, entire
# objects are returned.
# return (
# db.query(Family, Geography, FamilyMetadata, Corpus, Organisation)
# .join(FamilyGeography, FamilyGeography.family_import_id == Family.import_id)
# .join(
# Geography,
# Geography.id == FamilyGeography.geography_id,
# )
# .join(FamilyMetadata, FamilyMetadata.family_import_id == Family.import_id)
# .join(FamilyCorpus, FamilyCorpus.family_import_id == Family.import_id)
# .join(Corpus, Corpus.import_id == FamilyCorpus.corpus_import_id)
# .join(Organisation, Corpus.organisation_id == Organisation.id)
# .distinct(Family.import_id)
# )

geography_subquery = (
db.query(
FamilyGeography.family_import_id,
func.array_agg(Geography.id).label("geography_ids"),
func.array_agg(Geography.value).label("geography_values"),
)
.join(Geography, Geography.id == FamilyGeography.geography_id)
Expand All @@ -69,7 +61,6 @@ def _get_query(db: Session) -> Query:
query = (
db.query(
Family,
geography_subquery.c.geography_ids,
geography_subquery.c.geography_values,
FamilyMetadata,
Corpus,
Expand All @@ -86,7 +77,6 @@ def _get_query(db: Session) -> Query:
.group_by(
Family.import_id,
Family.title,
geography_subquery.c.geography_ids, # Include aggregated fields
geography_subquery.c.geography_values,
FamilyMetadata.family_import_id,
Corpus.import_id,
Expand All @@ -106,26 +96,23 @@ def _get_query(db: Session) -> Query:
def _family_to_dto(
db: Session, fam_geo_meta_corp_org: FamilyGeoMetaOrg
) -> FamilyReadDTO:
fam, geo_value, meta, corpus, org = fam_geo_meta_corp_org
(
fam,
geo_values,
meta,
corpus,
org,
) = fam_geo_meta_corp_org

metadata = cast(dict, meta.value)
org = cast(str, org.name)

geographies = [
str(g.value)
for g in db.query(Geography)
.join(FamilyGeography, FamilyGeography.geography_id == Geography.id)
.filter(FamilyGeography.family_import_id == fam.import_id)
.all()
]

# If you want to ensure only one geography is returned, you can handle it here

return FamilyReadDTO(
import_id=str(fam.import_id),
title=str(fam.title),
summary=str(fam.description),
geography=str(geo_value.value),
geographies=geographies,
geography=str(geo_values[0]),
geographies=[str(value) for value in geo_values],
category=str(fam.family_category),
status=str(fam.family_status),
metadata=metadata,
Expand Down

0 comments on commit 09090b4

Please sign in to comment.