From daa458c496c709c5b2cc899bcd5262591f25a386 Mon Sep 17 00:00:00 2001 From: Peter Hooper Date: Thu, 19 Oct 2023 10:07:52 +0100 Subject: [PATCH 1/4] fix query and update tests to regression for it (#25) * fix query and update tests to regression for it * remove comment --- app/repository/config.py | 24 ++++++++++++++++-------- app/repository/family.py | 12 +++++++----- integration_tests/setup_db.py | 7 +++++++ 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/app/repository/config.py b/app/repository/config.py index 56a2c6ea..f7f13bf3 100644 --- a/app/repository/config.py +++ b/app/repository/config.py @@ -1,5 +1,5 @@ import logging -from typing import Any +from typing import Any, Optional from sqlalchemy.orm import Session from app.clients.db.models.app.users import Organisation from app.clients.db.models.document.physical_document import Language @@ -41,14 +41,16 @@ def _tree_table_to_json( return json_out -def _get_organisation_taxonomy_by_name(db: Session, org_name: str) -> TaxonomyData: +def _get_organisation_taxonomy_by_name( + db: Session, org_name: str +) -> Optional[TaxonomyData]: """ Returns the TaxonomyConfig for the named organisation :param Session db: connection to the database :return TaxonomyConfig: the TaxonomyConfig from the db """ - return ( + metadata = ( db.query(MetadataTaxonomy.valid_metadata) .join( MetadataOrganisation, @@ -56,8 +58,10 @@ def _get_organisation_taxonomy_by_name(db: Session, org_name: str) -> TaxonomyDa ) .join(Organisation, Organisation.id == MetadataOrganisation.organisation_id) .filter_by(name=org_name) - .one()[0] + .one_or_none() ) + if metadata is not None: + return metadata[0] def get(db: Session) -> ConfigReadDTO: @@ -70,10 +74,14 @@ def get(db: Session) -> ConfigReadDTO: # TODO: Return the event types too geographies = _tree_table_to_json(table=Geography, db=db) - taxonomies = { - org.name: _get_organisation_taxonomy_by_name(db=db, org_name=org.name) - for org in db.query(Organisation).all() - } + taxonomies = {} + + # Be resilient to an organisation not having a taxonomy + for org in db.query(Organisation).all(): + tax = _get_organisation_taxonomy_by_name(db=db, org_name=org.name) + if tax is not None: + taxonomies[org.name] = tax + languages = {lang.language_code: lang.name for lang in db.query(Language).all()} return ConfigReadDTO( geographies=geographies, taxonomies=taxonomies, languages=languages diff --git a/app/repository/family.py b/app/repository/family.py index d8ac1e01..0186c81e 100644 --- a/app/repository/family.py +++ b/app/repository/family.py @@ -321,12 +321,14 @@ def get_organisation(db: Session, family_import_id: str) -> Optional[Organisatio :param str family_import_id: The family import_id in question :return Optional[Organisation]: Any associated organisation """ - family_org = db.query(FamilyOrganisation).filter( - FamilyOrganisation.family_import_id == family_import_id - ) - # TODO - can this be improved - we get warnings on integration tests ? - return db.query(Organisation).select_from(family_org).one_or_none() + return ( + db.query(Organisation) + .join(FamilyOrganisation, FamilyOrganisation.organisation_id == Organisation.id) + .filter(FamilyOrganisation.family_import_id == family_import_id) + .group_by(Organisation.id) + .one() + ) def count(db: Session) -> Optional[int]: diff --git a/integration_tests/setup_db.py b/integration_tests/setup_db.py index acfea88d..92851afa 100644 --- a/integration_tests/setup_db.py +++ b/integration_tests/setup_db.py @@ -203,6 +203,13 @@ def _setup_organisation(test_db: Session) -> int: organisation_type="test organisation", ) test_db.add(org) + test_db.add( + Organisation( + name="Another org", + description="because we will have more than one org", + organisation_type="test", + ) + ) test_db.flush() return cast(int, org.id) From 8fbfe9efd90fb5fb86458811dca0b37d0f543349 Mon Sep 17 00:00:00 2001 From: Peter Hooper Date: Thu, 19 Oct 2023 11:22:34 +0100 Subject: [PATCH 2/4] Extend config endpoint with document and event settings (#26) * Extend config endpoint with document and event settings * update unit tests --- app/model/config.py | 16 ++++++++ app/repository/config.py | 45 +++++++++++++++++++-- integration_tests/test_config.py | 15 +++++++ unit_tests/mocks/services/config_service.py | 10 ++++- unit_tests/routers/test_config.py | 2 + 5 files changed, 83 insertions(+), 5 deletions(-) diff --git a/app/model/config.py b/app/model/config.py index 07227e40..68543043 100644 --- a/app/model/config.py +++ b/app/model/config.py @@ -6,9 +6,25 @@ TaxonomyData = Mapping[str, Mapping[str, Union[bool, str, Sequence[str]]]] +class EventConfig(BaseModel): + """Everything you need to know about events.""" + + types: Sequence[str] + + +class DocumentConfig(BaseModel): + """Everything you need to know about documents.""" + + roles: Sequence[str] + types: Sequence[str] + variants: Sequence[str] + + class ConfigReadDTO(BaseModel): """Definition of the new Config which just includes taxonomy.""" geographies: Sequence[dict] taxonomies: Mapping[str, TaxonomyData] languages: Mapping[str, str] + document: DocumentConfig + event: EventConfig diff --git a/app/repository/config.py b/app/repository/config.py index f7f13bf3..b5be4902 100644 --- a/app/repository/config.py +++ b/app/repository/config.py @@ -3,13 +3,19 @@ from sqlalchemy.orm import Session from app.clients.db.models.app.users import Organisation from app.clients.db.models.document.physical_document import Language +from app.clients.db.models.law_policy.family import ( + FamilyDocumentRole, + FamilyDocumentType, + FamilyEventType, + Variant, +) from app.clients.db.models.law_policy.geography import Geography from app.clients.db.models.law_policy.metadata import ( MetadataOrganisation, MetadataTaxonomy, ) from app.clients.db.session import AnyModel -from app.model.config import ConfigReadDTO, TaxonomyData +from app.model.config import ConfigReadDTO, DocumentConfig, EventConfig, TaxonomyData _LOGGER = logging.getLogger(__name__) @@ -72,7 +78,6 @@ def get(db: Session) -> ConfigReadDTO: :return ConfigReadDTO: The config data """ - # TODO: Return the event types too geographies = _tree_table_to_json(table=Geography, db=db) taxonomies = {} @@ -83,6 +88,40 @@ def get(db: Session) -> ConfigReadDTO: taxonomies[org.name] = tax languages = {lang.language_code: lang.name for lang in db.query(Language).all()} + + # Now Document config + doc_config = DocumentConfig( + roles=[ + doc_role.name + for doc_role in db.query(FamilyDocumentRole) + .order_by(FamilyDocumentRole.name) + .all() + ], + types=[ + doc_type.name + for doc_type in db.query(FamilyDocumentType) + .order_by(FamilyDocumentType.name) + .all() + ], + variants=[ + variant.variant_name + for variant in db.query(Variant).order_by(Variant.variant_name).all() + ], + ) + + # Now Event config + event_config = EventConfig( + types=[ + event_type.name + for event_type in db.query(FamilyEventType) + .order_by(FamilyEventType.name) + .all() + ] + ) return ConfigReadDTO( - geographies=geographies, taxonomies=taxonomies, languages=languages + geographies=geographies, + taxonomies=taxonomies, + languages=languages, + document=doc_config, + event=event_config, ) diff --git a/integration_tests/test_config.py b/integration_tests/test_config.py index 085d5b2d..8c97a3b3 100644 --- a/integration_tests/test_config.py +++ b/integration_tests/test_config.py @@ -17,3 +17,18 @@ def test_get_config(client: TestClient, test_db: Session, user_header_token): assert "geographies" in keys assert "taxonomies" in keys assert "languages" in keys + assert "document" in keys + assert "event" in keys + + # Now sanity check the data + assert data["geographies"][1]["node"]["slug"] == "europe-central-asia" + + assert "CCLW" in data["taxonomies"].keys() + + assert "aaa" in data["languages"].keys() + + assert "AMENDMENT" in data["document"]["roles"] + assert "Action Plan" in data["document"]["types"] + assert "Translation" in data["document"]["variants"] + + assert "Appealed" in data["event"]["types"] diff --git a/unit_tests/mocks/services/config_service.py b/unit_tests/mocks/services/config_service.py index bba2dcf2..9b3a3589 100644 --- a/unit_tests/mocks/services/config_service.py +++ b/unit_tests/mocks/services/config_service.py @@ -1,7 +1,7 @@ from pytest import MonkeyPatch from app.errors import RepositoryError -from app.model.config import ConfigReadDTO +from app.model.config import ConfigReadDTO, DocumentConfig, EventConfig def mock_config_service(config_service, monkeypatch: MonkeyPatch, mocker): @@ -13,7 +13,13 @@ def maybe_throw(): def mock_get_config() -> ConfigReadDTO: maybe_throw() - return ConfigReadDTO(geographies=[], taxonomies={}, languages={}) + return ConfigReadDTO( + geographies=[], + taxonomies={}, + languages={}, + document=DocumentConfig(roles=[], types=[], variants=[]), + event=EventConfig(types=[]), + ) monkeypatch.setattr(config_service, "get", mock_get_config) mocker.spy(config_service, "get") diff --git a/unit_tests/routers/test_config.py b/unit_tests/routers/test_config.py index a5209e00..6fe596ef 100644 --- a/unit_tests/routers/test_config.py +++ b/unit_tests/routers/test_config.py @@ -10,6 +10,8 @@ def test_get_when_ok(client: TestClient, user_header_token, config_service_mock) assert "geographies" in keys assert "taxonomies" in keys assert "languages" in keys + assert "event" in keys + assert "document" in keys assert config_service_mock.get.call_count == 1 From 5ba91c66651f8f6017753a288240fb2fa79fbed4 Mon Sep 17 00:00:00 2001 From: Katy Baulch <46493669+katybaulch@users.noreply.github.com> Date: Thu, 19 Oct 2023 18:03:53 +0100 Subject: [PATCH 3/4] PDCT-311 Added endpoint and tests for updating family events. (#27) --- app/api/api_v1/routers/event.py | 35 +++- app/model/event.py | 20 ++- app/repository/event.py | 44 ++++- app/service/event.py | 33 +++- integration_tests/event/test_update.py | 152 ++++++++++++++++++ integration_tests/mocks/bad_event_repo.py | 6 + .../mocks/rollback_event_repo.py | 14 +- unit_tests/helpers/event.py | 14 +- unit_tests/mocks/repos/event_repo.py | 11 +- unit_tests/mocks/services/event_service.py | 34 ++-- unit_tests/routers/test_event.py | 31 +++- unit_tests/service/test_event_service.py | 51 ++++++ 12 files changed, 415 insertions(+), 30 deletions(-) create mode 100644 integration_tests/event/test_update.py diff --git a/app/api/api_v1/routers/event.py b/app/api/api_v1/routers/event.py index 77fd2937..9e633004 100644 --- a/app/api/api_v1/routers/event.py +++ b/app/api/api_v1/routers/event.py @@ -5,7 +5,7 @@ import app.service.event as event_service from app.errors import RepositoryError, ValidationError -from app.model.event import EventCreateDTO, EventReadDTO +from app.model.event import EventCreateDTO, EventReadDTO, EventWriteDTO event_router = r = APIRouter() @@ -92,7 +92,7 @@ async def get_event(import_id: str) -> EventReadDTO: @r.post("/events", response_model=str, status_code=status.HTTP_201_CREATED) -async def create_document( +async def create_event( new_event: EventCreateDTO, ) -> str: """ @@ -116,3 +116,34 @@ async def create_document( raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=e.message ) + + +@r.put( + "/events/{import_id}", + response_model=EventReadDTO, +) +async def update_event( + import_id: str, + new_event: EventWriteDTO, +) -> EventReadDTO: + """ + Updates a specific event given the import id. + + :param str import_id: Specified import_id. + :raises HTTPException: If the event is not found a 404 is returned. + :return EventDTO: returns a EventDTO of the event updated. + """ + try: + event = event_service.update(import_id, new_event) + except ValidationError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.message) + except RepositoryError as e: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=e.message + ) + + if event is None: + detail = f"Event not updated: {import_id}" + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=detail) + + return event diff --git a/app/model/event.py b/app/model/event.py index dad1ea60..e15823c2 100644 --- a/app/model/event.py +++ b/app/model/event.py @@ -2,9 +2,7 @@ from typing import Optional from pydantic import BaseModel -from app.clients.db.models.law_policy.family import ( - EventStatus, -) +from app.clients.db.models.law_policy.family import EventStatus class EventReadDTO(BaseModel): @@ -44,3 +42,19 @@ class EventCreateDTO(BaseModel): # From FamilyDocument family_import_id: str family_document_import_id: Optional[str] = None + + +class EventWriteDTO(BaseModel): + """ + JSON Representation of the DTO for writing an event. + + The following fields are immutable: + - family_import_id + - import_id + - family_document_import_id + - event_status + """ + + event_title: str + date: datetime + event_type_value: str diff --git a/app/repository/event.py b/app/repository/event.py index daa2a5bf..cb490e97 100644 --- a/app/repository/event.py +++ b/app/repository/event.py @@ -4,7 +4,7 @@ from datetime import datetime from typing import Optional, Tuple, cast -from sqlalchemy import or_, Column +from sqlalchemy import or_, Column, update as db_update from sqlalchemy.orm import Query, Session from sqlalchemy.exc import NoResultFound from sqlalchemy_utils import escape_like @@ -16,8 +16,8 @@ Family, FamilyDocument, ) -from app.errors import ValidationError -from app.model.event import EventCreateDTO, EventReadDTO +from app.errors import ValidationError, RepositoryError +from app.model.event import EventCreateDTO, EventReadDTO, EventWriteDTO from app.repository import family as family_repo from app.repository.helpers import generate_import_id @@ -108,7 +108,6 @@ def get(db: Session, import_id: str) -> Optional[EventReadDTO]: except NoResultFound as e: _LOGGER.error(e) return - return _event_to_dto(family_event_meta) @@ -172,6 +171,43 @@ def create(db: Session, event: EventCreateDTO) -> str: return cast(str, new_family_event.import_id) +def update(db: Session, import_id: str, event: EventWriteDTO) -> bool: + """ + Updates a single entry with the new values passed. + + :param db Session: the database connection + :param str import_id: The event import id to change. + :param DocumentDTO event: The new values + :return bool: True if new values were set otherwise false. + """ + new_values = event.model_dump() + + original_fe = ( + db.query(FamilyEvent).filter(FamilyEvent.import_id == import_id).one_or_none() + ) + + if original_fe is None: # Not found the event to update + _LOGGER.error(f"Unable to find event for update {import_id}") + return False + + result = db.execute( + db_update(FamilyEvent) + .where(FamilyEvent.import_id == original_fe.import_id) + .values( + title=new_values["event_title"], + event_type_name=new_values["event_type_value"], + date=new_values["date"], + ) + ) + + if result.rowcount == 0: # type: ignore + msg = f"Could not update event fields: {event}" + _LOGGER.error(msg) + raise RepositoryError(msg) + + return True + + def count(db: Session) -> Optional[int]: """ Counts the number of family events in the repository. diff --git a/app/service/event.py b/app/service/event.py index f5414073..c4922405 100644 --- a/app/service/event.py +++ b/app/service/event.py @@ -9,7 +9,7 @@ import app.repository.event as event_repo import app.service.family as family_service from app.errors import RepositoryError, ValidationError -from app.model.event import EventCreateDTO, EventReadDTO +from app.model.event import EventCreateDTO, EventReadDTO, EventWriteDTO from app.service import id @@ -79,12 +79,12 @@ def validate_import_id(import_id: str) -> None: @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def create(event: EventCreateDTO, db: Session = db_session.get_db()) -> str: """ - Creates a new document with the values passed. + Creates a new event with the values passed. - :param documentDTO document: The values for the new document. + :param eventDTO event: The values for the new event. :raises RepositoryError: raised on a database error :raises ValidationError: raised should the import_id be invalid. - :return Optional[documentDTO]: The new created document or + :return Optional[eventDTO]: The new created event or None if unsuccessful. """ id.validate(event.family_import_id) @@ -98,6 +98,31 @@ def create(event: EventCreateDTO, db: Session = db_session.get_db()) -> str: return event_repo.create(db, event) +@db_session.with_transaction(__name__) +@validate_call(config=ConfigDict(arbitrary_types_allowed=True)) +def update( + import_id: str, event: EventWriteDTO, db: Session = db_session.get_db() +) -> Optional[EventReadDTO]: + """ + Updates a single event with the values passed. + + :param EventWriteDTO event: The DTO with all the values to change (or keep). + :raises RepositoryError: raised on a database error. + :raises ValidationError: raised should the import_id be invalid. + :return Optional[EventReadDTO]: The updated event or None if not updated. + """ + validate_import_id(import_id) + + try: + if event_repo.update(db, import_id, event): + db.commit() + return get(import_id) + + except exc.SQLAlchemyError: + _LOGGER.exception(f"While updating event {import_id}") + raise RepositoryError(f"Error when updating event {import_id}") + + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def count() -> Optional[int]: """ diff --git a/integration_tests/event/test_update.py b/integration_tests/event/test_update.py new file mode 100644 index 00000000..3b15031d --- /dev/null +++ b/integration_tests/event/test_update.py @@ -0,0 +1,152 @@ +from datetime import datetime, timezone +from fastapi.encoders import jsonable_encoder +from fastapi.testclient import TestClient +from fastapi import status +from sqlalchemy.orm import Session +from app.clients.db.models.law_policy.family import EventStatus + +from app.clients.db.models.law_policy import FamilyEvent + +from integration_tests.setup_db import EXPECTED_EVENTS, setup_db +from unit_tests.helpers.event import create_event_write_dto + + +def _get_event_tuple(test_db: Session, import_id: str) -> FamilyEvent: + fe: FamilyEvent = ( + test_db.query(FamilyEvent).filter(FamilyEvent.import_id == import_id).one() + ) + assert fe is not None + return fe + + +def test_update_event(client: TestClient, test_db: Session, user_header_token): + setup_db(test_db) + new_event = create_event_write_dto(title="Updated Title") + response = client.put( + "/api/v1/events/E.0.0.2", + json=jsonable_encoder(new_event), + headers=user_header_token, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + + # Check the properties and values of the EventReadDTO object we return to the + # client. + assert data["event_type_value"] == "Amended" + assert data["event_title"] == "Updated Title" + assert isinstance(data["date"], str) is True + assert data["date"] == "2023-01-01T00:00:00Z" + + # Get the record in the FamilyEvent table we want to update in the database and + # check the types of the values are correct and that the values have been + # successfully updated. + fe = _get_event_tuple(test_db, "E.0.0.2") + assert isinstance(fe.date, datetime) is True + assert isinstance(fe.status, EventStatus) is True + assert ( + isinstance(fe.family_document_import_id, str) is True + or fe.family_document_import_id is None + ) + assert ( + all( + isinstance(x, str) + for x in [ + fe.import_id, + fe.family_import_id, + fe.event_type_name, + fe.title, + ] + ) + is True + ) + assert fe.import_id == "E.0.0.2" + assert fe.event_type_name == "Amended" + assert fe.date == datetime(2023, 1, 1, 0, 0, 0, 0, tzinfo=timezone.utc) + assert fe.title == "Updated Title" + assert fe.family_import_id == "A.0.0.1" + assert fe.family_document_import_id is None + assert fe.status == EventStatus.OK + + +def test_update_event_when_not_authorised(client: TestClient, test_db: Session): + setup_db(test_db) + new_event = create_event_write_dto( + title="Updated Title", + ) + response = client.put("/api/v1/events/E.0.0.2", json=jsonable_encoder(new_event)) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + +def test_update_event_idempotent( + client: TestClient, test_db: Session, user_header_token +): + setup_db(test_db) + event = EXPECTED_EVENTS[1] + response = client.put( + f"/api/v1/events/{event['import_id']}", + json=event, + headers=user_header_token, + ) + assert response.status_code == status.HTTP_200_OK + + data = response.json() + assert data["event_title"] == EXPECTED_EVENTS[1]["event_title"] + + fe = _get_event_tuple(test_db, EXPECTED_EVENTS[1]["import_id"]) + assert fe.title == EXPECTED_EVENTS[1]["event_title"] + + +def test_update_event_rollback( + client: TestClient, test_db: Session, rollback_event_repo, user_header_token +): + setup_db(test_db) + new_event = create_event_write_dto( + title="Updated Title", + ) + response = client.put( + "/api/v1/events/E.0.0.2", + json=jsonable_encoder(new_event), + headers=user_header_token, + ) + assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE + + pd = _get_event_tuple(test_db, "E.0.0.2") + assert pd.title != "Updated Title" + + assert rollback_event_repo.update.call_count == 1 + + +def test_update_event_when_not_found( + client: TestClient, test_db: Session, user_header_token +): + setup_db(test_db) + new_event = create_event_write_dto( + title="Updated Title", + ) + response = client.put( + "/api/v1/events/E.0.0.22", + json=jsonable_encoder(new_event), + headers=user_header_token, + ) + assert response.status_code == status.HTTP_404_NOT_FOUND + data = response.json() + assert data["detail"] == "Event not updated: E.0.0.22" + + +def test_update_event_when_db_error( + client: TestClient, test_db: Session, bad_event_repo, user_header_token +): + setup_db(test_db) + + new_event = create_event_write_dto( + title="Updated Title", + ) + response = client.put( + "/api/v1/events/E.0.0.2", + json=jsonable_encoder(new_event), + headers=user_header_token, + ) + assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE + data = response.json() + assert data["detail"] == "Bad Repo" + assert bad_event_repo.update.call_count == 1 diff --git a/integration_tests/mocks/bad_event_repo.py b/integration_tests/mocks/bad_event_repo.py index c6188250..d7452226 100644 --- a/integration_tests/mocks/bad_event_repo.py +++ b/integration_tests/mocks/bad_event_repo.py @@ -18,6 +18,9 @@ def mock_search(_, q: str) -> list[EventReadDTO]: def mock_create(_, data: EventCreateDTO) -> Optional[EventReadDTO]: raise RepositoryError("Bad Repo") + def mock_update(_, import_id, data: EventReadDTO) -> Optional[EventReadDTO]: + raise RepositoryError("Bad Repo") + def mock_get_count(_) -> Optional[int]: raise RepositoryError("Bad Repo") @@ -33,6 +36,9 @@ def mock_get_count(_) -> Optional[int]: monkeypatch.setattr(repo, "create", mock_create) mocker.spy(repo, "create") + monkeypatch.setattr(repo, "update", mock_update) + mocker.spy(repo, "update") + monkeypatch.setattr(repo, "count", mock_get_count) mocker.spy(repo, "count") diff --git a/integration_tests/mocks/rollback_event_repo.py b/integration_tests/mocks/rollback_event_repo.py index 4041995b..9a14dd01 100644 --- a/integration_tests/mocks/rollback_event_repo.py +++ b/integration_tests/mocks/rollback_event_repo.py @@ -3,15 +3,23 @@ from sqlalchemy.exc import NoResultFound -from app.model.event import EventCreateDTO, EventReadDTO +from app.model.event import EventCreateDTO, EventReadDTO, EventWriteDTO def mock_rollback_event_repo(event_repo, monkeypatch: MonkeyPatch, mocker): actual_create = event_repo.create + actual_update = event_repo.update - def mock_create_document(db, data: EventCreateDTO) -> Optional[EventReadDTO]: + def mock_create_event(db, data: EventCreateDTO) -> Optional[EventReadDTO]: actual_create(db, data) raise NoResultFound() - monkeypatch.setattr(event_repo, "create", mock_create_document) + def mock_update_event(db, import_id: str, data: EventWriteDTO) -> EventReadDTO: + actual_update(db, import_id, data) + raise NoResultFound() + + monkeypatch.setattr(event_repo, "create", mock_create_event) mocker.spy(event_repo, "create") + + monkeypatch.setattr(event_repo, "update", mock_update_event) + mocker.spy(event_repo, "update") diff --git a/unit_tests/helpers/event.py b/unit_tests/helpers/event.py index e0598788..aebe1551 100644 --- a/unit_tests/helpers/event.py +++ b/unit_tests/helpers/event.py @@ -1,5 +1,5 @@ from app.clients.db.models.law_policy.family import EventStatus -from app.model.event import EventCreateDTO, EventReadDTO +from app.model.event import EventCreateDTO, EventReadDTO, EventWriteDTO from datetime import datetime, timezone @@ -29,3 +29,15 @@ def create_event_create_dto( family_document_import_id=None, event_status=EventStatus.OK, ) + + +def create_event_write_dto( + title: str = "title", + event_type_value: str = "Amended", + date: datetime = datetime(2023, 1, 1, 0, 0, 0, 0, tzinfo=timezone.utc), +) -> EventWriteDTO: + return EventWriteDTO( + event_title=title, + date=date, + event_type_value=event_type_value, + ) diff --git a/unit_tests/mocks/repos/event_repo.py b/unit_tests/mocks/repos/event_repo.py index e778ed03..6bfa57bd 100644 --- a/unit_tests/mocks/repos/event_repo.py +++ b/unit_tests/mocks/repos/event_repo.py @@ -2,7 +2,7 @@ from pytest import MonkeyPatch from sqlalchemy import exc -from app.model.event import EventCreateDTO, EventReadDTO +from app.model.event import EventCreateDTO, EventReadDTO, EventWriteDTO from unit_tests.helpers.event import create_event_read_dto @@ -37,6 +37,12 @@ def mock_create(_, data: EventCreateDTO) -> str: raise exc.NoResultFound() return "test.new.event.0" + def mock_update(_, import_id: str, data: EventWriteDTO) -> EventReadDTO: + maybe_throw() + if event_repo.return_empty: + raise exc.NoResultFound() + return create_event_read_dto("a.b.c.d") + def mock_get_count(_) -> Optional[int]: maybe_throw() if not event_repo.return_empty: @@ -55,5 +61,8 @@ def mock_get_count(_) -> Optional[int]: monkeypatch.setattr(event_repo, "create", mock_create) mocker.spy(event_repo, "create") + monkeypatch.setattr(event_repo, "update", mock_update) + mocker.spy(event_repo, "update") + monkeypatch.setattr(event_repo, "count", mock_get_count) mocker.spy(event_repo, "count") diff --git a/unit_tests/mocks/services/event_service.py b/unit_tests/mocks/services/event_service.py index ffaad008..127018fd 100644 --- a/unit_tests/mocks/services/event_service.py +++ b/unit_tests/mocks/services/event_service.py @@ -2,7 +2,7 @@ from pytest import MonkeyPatch from app.errors import RepositoryError, ValidationError -from app.model.event import EventCreateDTO, EventReadDTO +from app.model.event import EventCreateDTO, EventReadDTO, EventWriteDTO from unit_tests.helpers.event import create_event_read_dto @@ -14,46 +14,58 @@ def maybe_throw(): if event_service.throw_repository_error: raise RepositoryError("bad repo") - def mock_get_all_documents() -> list[EventReadDTO]: + def mock_get_all_events() -> list[EventReadDTO]: maybe_throw() return [create_event_read_dto("test")] - def mock_get_document(import_id: str) -> Optional[EventReadDTO]: + def mock_get_event(import_id: str) -> Optional[EventReadDTO]: maybe_throw() if not event_service.missing: return create_event_read_dto(import_id) - def mock_search_documents(q: str) -> list[EventReadDTO]: + def mock_search_events(q: str) -> list[EventReadDTO]: maybe_throw() if event_service.missing: return [] else: return [create_event_read_dto("search1")] - def mock_create_document(data: EventCreateDTO) -> str: + def mock_create_event(data: EventCreateDTO) -> str: maybe_throw() if not event_service.missing: return "new.event.id.0" raise ValidationError(f"Could not find family for {data.family_import_id}") - def mock_count_collection() -> Optional[int]: + def mock_update_event( + import_id: str, data: EventWriteDTO + ) -> Optional[EventReadDTO]: + maybe_throw() + if not event_service.missing: + return create_event_read_dto( + import_id, "family_import_id", data.event_title + ) + + def mock_count_event() -> Optional[int]: maybe_throw() if event_service.missing: return None return 5 - monkeypatch.setattr(event_service, "get", mock_get_document) + monkeypatch.setattr(event_service, "get", mock_get_event) mocker.spy(event_service, "get") - monkeypatch.setattr(event_service, "all", mock_get_all_documents) + monkeypatch.setattr(event_service, "all", mock_get_all_events) mocker.spy(event_service, "all") - monkeypatch.setattr(event_service, "search", mock_search_documents) + monkeypatch.setattr(event_service, "search", mock_search_events) mocker.spy(event_service, "search") - monkeypatch.setattr(event_service, "create", mock_create_document) + monkeypatch.setattr(event_service, "create", mock_create_event) mocker.spy(event_service, "create") - monkeypatch.setattr(event_service, "count", mock_count_collection) + monkeypatch.setattr(event_service, "update", mock_update_event) + mocker.spy(event_service, "update") + + monkeypatch.setattr(event_service, "count", mock_count_event) mocker.spy(event_service, "count") diff --git a/unit_tests/routers/test_event.py b/unit_tests/routers/test_event.py index 4bd664a8..1884a2d4 100644 --- a/unit_tests/routers/test_event.py +++ b/unit_tests/routers/test_event.py @@ -5,7 +5,7 @@ from unit_tests.helpers.event import ( create_event_create_dto, - # create_event_write_dto, + create_event_write_dto, ) @@ -80,3 +80,32 @@ def test_create_when_family_not_found( data = response.json() assert data["detail"] == "Could not find family for this_family" assert event_service_mock.create.call_count == 1 + + +def test_update_when_ok(client: TestClient, event_service_mock, user_header_token): + new_data = create_event_write_dto("event1") + response = client.put( + "/api/v1/events/E.0.0.1", + json=jsonable_encoder(new_data), + headers=user_header_token, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["import_id"] == "E.0.0.1" + assert event_service_mock.update.call_count == 1 + + +def test_update_when_not_found( + client: TestClient, event_service_mock, user_header_token +): + event_service_mock.missing = True + new_data = create_event_write_dto("event1") + response = client.put( + "/api/v1/events/a.b.c.d", + json=jsonable_encoder(new_data), + headers=user_header_token, + ) + assert response.status_code == status.HTTP_404_NOT_FOUND + data = response.json() + assert data["detail"] == "Event not updated: a.b.c.d" + assert event_service_mock.update.call_count == 1 diff --git a/unit_tests/service/test_event_service.py b/unit_tests/service/test_event_service.py index cbde552c..32f1c770 100644 --- a/unit_tests/service/test_event_service.py +++ b/unit_tests/service/test_event_service.py @@ -2,6 +2,16 @@ import app.service.event as event_service from app.errors import RepositoryError, ValidationError from unit_tests.helpers.event import create_event_create_dto +from app.model.event import EventReadDTO, EventWriteDTO + + +def _to_write_dto(dto: EventReadDTO) -> EventWriteDTO: + return EventWriteDTO( + event_title=dto.event_title, + date=dto.date, + event_type_value=dto.event_type_value, + ) + # --- GET @@ -76,6 +86,47 @@ def test_create_raises_when_invalid_family_id(event_repo_mock): assert event_repo_mock.create.call_count == 0 +# --- UPDATE + + +def test_update( + event_repo_mock, +): + event = event_service.get("a.b.c.d") + assert event is not None + + result = event_service.update(event.import_id, _to_write_dto(event)) + assert result is not None + assert event_repo_mock.update.call_count == 1 + + +def test_update_when_missing( + event_repo_mock, +): + event = event_service.get("a.b.c.d") + assert event is not None + event_repo_mock.return_empty = True + + with pytest.raises(RepositoryError) as e: + event_service.update(event.import_id, _to_write_dto(event)) + assert e.value.message == "Error when updating event a.b.c.d" + assert event_repo_mock.update.call_count == 1 + + +def test_update_raises_when_invalid_id( + event_repo_mock, +): + event = event_service.get("a.b.c.d") + assert event is not None # needed to placate pyright + event.import_id = "invalid" + + with pytest.raises(ValidationError) as e: + event_service.update(event.import_id, _to_write_dto(event)) + expected_msg = f"The import id {event.import_id} is invalid!" + assert e.value.message == expected_msg + assert event_repo_mock.update.call_count == 0 + + # --- COUNT From b1ef7cfef7d1f1513e68e9dfb4eea16e309040f7 Mon Sep 17 00:00:00 2001 From: diversemix Date: Fri, 20 Oct 2023 10:44:02 +0100 Subject: [PATCH 4/4] fix mock issue with config --- unit_tests/mocks/repos/config_repo.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/unit_tests/mocks/repos/config_repo.py b/unit_tests/mocks/repos/config_repo.py index d7226764..f59a51ee 100644 --- a/unit_tests/mocks/repos/config_repo.py +++ b/unit_tests/mocks/repos/config_repo.py @@ -3,7 +3,7 @@ from sqlalchemy import exc -from app.model.config import ConfigReadDTO +from app.model.config import ConfigReadDTO, DocumentConfig, EventConfig def mock_config_repo(config_repo, monkeypatch: MonkeyPatch, mocker): @@ -16,7 +16,13 @@ def maybe_throw(): def mock_get(_) -> Optional[ConfigReadDTO]: maybe_throw() - return ConfigReadDTO(geographies=[], taxonomies={}, languages={}) + return ConfigReadDTO( + geographies=[], + taxonomies={}, + languages={}, + document=DocumentConfig(roles=[], types=[], variants=[]), + event=EventConfig(types=[]), + ) monkeypatch.setattr(config_repo, "get", mock_get) mocker.spy(config_repo, "get")