diff --git a/app/repository/config.py b/app/repository/config.py index 56a2c6ea..f7f13bf3 100644 --- a/app/repository/config.py +++ b/app/repository/config.py @@ -1,5 +1,5 @@ import logging -from typing import Any +from typing import Any, Optional from sqlalchemy.orm import Session from app.clients.db.models.app.users import Organisation from app.clients.db.models.document.physical_document import Language @@ -41,14 +41,16 @@ def _tree_table_to_json( return json_out -def _get_organisation_taxonomy_by_name(db: Session, org_name: str) -> TaxonomyData: +def _get_organisation_taxonomy_by_name( + db: Session, org_name: str +) -> Optional[TaxonomyData]: """ Returns the TaxonomyConfig for the named organisation :param Session db: connection to the database :return TaxonomyConfig: the TaxonomyConfig from the db """ - return ( + metadata = ( db.query(MetadataTaxonomy.valid_metadata) .join( MetadataOrganisation, @@ -56,8 +58,10 @@ def _get_organisation_taxonomy_by_name(db: Session, org_name: str) -> TaxonomyDa ) .join(Organisation, Organisation.id == MetadataOrganisation.organisation_id) .filter_by(name=org_name) - .one()[0] + .one_or_none() ) + if metadata is not None: + return metadata[0] def get(db: Session) -> ConfigReadDTO: @@ -70,10 +74,14 @@ def get(db: Session) -> ConfigReadDTO: # TODO: Return the event types too geographies = _tree_table_to_json(table=Geography, db=db) - taxonomies = { - org.name: _get_organisation_taxonomy_by_name(db=db, org_name=org.name) - for org in db.query(Organisation).all() - } + taxonomies = {} + + # Be resilient to an organisation not having a taxonomy + for org in db.query(Organisation).all(): + tax = _get_organisation_taxonomy_by_name(db=db, org_name=org.name) + if tax is not None: + taxonomies[org.name] = tax + languages = {lang.language_code: lang.name for lang in db.query(Language).all()} return ConfigReadDTO( geographies=geographies, taxonomies=taxonomies, languages=languages diff --git a/app/repository/family.py b/app/repository/family.py index d8ac1e01..0186c81e 100644 --- a/app/repository/family.py +++ b/app/repository/family.py @@ -321,12 +321,14 @@ def get_organisation(db: Session, family_import_id: str) -> Optional[Organisatio :param str family_import_id: The family import_id in question :return Optional[Organisation]: Any associated organisation """ - family_org = db.query(FamilyOrganisation).filter( - FamilyOrganisation.family_import_id == family_import_id - ) - # TODO - can this be improved - we get warnings on integration tests ? - return db.query(Organisation).select_from(family_org).one_or_none() + return ( + db.query(Organisation) + .join(FamilyOrganisation, FamilyOrganisation.organisation_id == Organisation.id) + .filter(FamilyOrganisation.family_import_id == family_import_id) + .group_by(Organisation.id) + .one() + ) def count(db: Session) -> Optional[int]: diff --git a/integration_tests/setup_db.py b/integration_tests/setup_db.py index acfea88d..92851afa 100644 --- a/integration_tests/setup_db.py +++ b/integration_tests/setup_db.py @@ -203,6 +203,13 @@ def _setup_organisation(test_db: Session) -> int: organisation_type="test organisation", ) test_db.add(org) + test_db.add( + Organisation( + name="Another org", + description="because we will have more than one org", + organisation_type="test", + ) + ) test_db.flush() return cast(int, org.id)