Skip to content

Commit

Permalink
PDCT-311 Added endpoint and tests for updating family events.
Browse files Browse the repository at this point in the history
  • Loading branch information
katybaulch committed Oct 19, 2023
1 parent 740548d commit 22eb526
Show file tree
Hide file tree
Showing 12 changed files with 426 additions and 30 deletions.
36 changes: 34 additions & 2 deletions app/api/api_v1/routers/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import app.service.event as event_service
from app.errors import RepositoryError, ValidationError
from app.model.event import EventCreateDTO, EventReadDTO
from app.model.event import EventCreateDTO, EventReadDTO, EventWriteDTO

event_router = r = APIRouter()

Expand Down Expand Up @@ -92,7 +92,7 @@ async def get_event(import_id: str) -> EventReadDTO:


@r.post("/events", response_model=str, status_code=status.HTTP_201_CREATED)
async def create_document(
async def create_event(
new_event: EventCreateDTO,
) -> str:
"""
Expand All @@ -116,3 +116,35 @@ async def create_document(
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=e.message
)


@r.put(
"/events/{import_id}",
response_model=EventReadDTO,
)
async def update_event(
import_id: str,
new_event: EventWriteDTO,
) -> EventReadDTO:
"""
Updates a specific event given the import id.
:param str import_id: Specified import_id.
:raises HTTPException: If the event is not found a 404 is returned.
:return EventDTO: returns a EventDTO of the event updated.
"""
_LOGGER.warning("hit update event endpoint")
try:
event = event_service.update(import_id, new_event)
except ValidationError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.message)
except RepositoryError as e:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=e.message
)

if event is None:
detail = f"Event not updated: {import_id}"
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=detail)

return event
20 changes: 17 additions & 3 deletions app/model/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
from typing import Optional
from pydantic import BaseModel

from app.clients.db.models.law_policy.family import (
EventStatus,
)
from app.clients.db.models.law_policy.family import EventStatus


class EventReadDTO(BaseModel):
Expand Down Expand Up @@ -44,3 +42,19 @@ class EventCreateDTO(BaseModel):
# From FamilyDocument
family_import_id: str
family_document_import_id: Optional[str] = None


class EventWriteDTO(BaseModel):
"""
JSON Representation of the DTO for writing an event.
The following fields are immutable:
- family_import_id
- import_id
- family_document_import_id
- event_status
"""

event_title: str
date: datetime
event_type_value: str
50 changes: 46 additions & 4 deletions app/repository/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from datetime import datetime
from typing import Optional, Tuple, cast

from sqlalchemy import or_, Column
from sqlalchemy import or_, Column, update as db_update
from sqlalchemy.orm import Query, Session
from sqlalchemy.exc import NoResultFound
from sqlalchemy_utils import escape_like
Expand All @@ -16,8 +16,8 @@
Family,
FamilyDocument,
)
from app.errors import ValidationError
from app.model.event import EventCreateDTO, EventReadDTO
from app.errors import ValidationError, RepositoryError
from app.model.event import EventCreateDTO, EventReadDTO, EventWriteDTO
from app.repository import family as family_repo
from app.repository.helpers import generate_import_id

Expand Down Expand Up @@ -108,7 +108,7 @@ def get(db: Session, import_id: str) -> Optional[EventReadDTO]:
except NoResultFound as e:
_LOGGER.error(e)
return

_LOGGER.warning(_event_to_dto(family_event_meta))
return _event_to_dto(family_event_meta)


Expand Down Expand Up @@ -172,6 +172,48 @@ def create(db: Session, event: EventCreateDTO) -> str:
return cast(str, new_family_event.import_id)


def update(db: Session, import_id: str, event: EventWriteDTO) -> bool:
"""
Updates a single entry with the new values passed.
:param db Session: the database connection
:param str import_id: The event import id to change.
:param DocumentDTO event: The new values
:return bool: True if new values were set otherwise false.
"""
_LOGGER.warning("hit update event repo")
new_values = event.model_dump()
_LOGGER.warning(f"new_values: {new_values['date']}")

original_fe = (
db.query(FamilyEvent).filter(FamilyEvent.import_id == import_id).one_or_none()
)

if original_fe is None: # Not found the event to update
_LOGGER.error(f"Unable to find event for update {import_id}")
return False

_LOGGER.warning(f"date: {original_fe.date}")
# update_slug = original_pd.title != new_values["title"]

result = db.execute(
db_update(FamilyEvent)
.where(FamilyEvent.import_id == original_fe.import_id)
.values(
title=new_values["event_title"],
event_type_name=new_values["event_type_value"],
date=new_values["date"],
)
)

if result.rowcount == 0: # type: ignore
msg = f"Could not update event fields: {event}"
_LOGGER.error(msg)
raise RepositoryError(msg)

return True


def count(db: Session) -> Optional[int]:
"""
Counts the number of family events in the repository.
Expand Down
37 changes: 33 additions & 4 deletions app/service/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import app.repository.event as event_repo
import app.service.family as family_service
from app.errors import RepositoryError, ValidationError
from app.model.event import EventCreateDTO, EventReadDTO
from app.model.event import EventCreateDTO, EventReadDTO, EventWriteDTO
from app.service import id


Expand Down Expand Up @@ -79,12 +79,12 @@ def validate_import_id(import_id: str) -> None:
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def create(event: EventCreateDTO, db: Session = db_session.get_db()) -> str:
"""
Creates a new document with the values passed.
Creates a new event with the values passed.
:param documentDTO document: The values for the new document.
:param eventDTO event: The values for the new event.
:raises RepositoryError: raised on a database error
:raises ValidationError: raised should the import_id be invalid.
:return Optional[documentDTO]: The new created document or
:return Optional[eventDTO]: The new created event or
None if unsuccessful.
"""
id.validate(event.family_import_id)
Expand All @@ -98,6 +98,35 @@ def create(event: EventCreateDTO, db: Session = db_session.get_db()) -> str:
return event_repo.create(db, event)


@db_session.with_transaction(__name__)
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def update(
import_id: str, event: EventWriteDTO, db: Session = db_session.get_db()
) -> Optional[EventReadDTO]:
"""
Updates a single event with the values passed.
:param EventWriteDTO event: The DTO with all the values to change (or keep).
:raises RepositoryError: raised on a database error.
:raises ValidationError: raised should the import_id be invalid.
:return Optional[EventReadDTO]: The updated event or None if not updated.
"""
_LOGGER.warning("hit update event service")
validate_import_id(import_id)

# TODO: implement changing of a event's organisation
# org_id = organisation.get_id(db, event.organisation)

try:
if event_repo.update(db, import_id, event):
db.commit()
return get(import_id)

except exc.SQLAlchemyError:
_LOGGER.exception(f"While updating event {import_id}")
raise RepositoryError(f"Error when updating event {import_id}")


@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def count() -> Optional[int]:
"""
Expand Down
152 changes: 152 additions & 0 deletions integration_tests/event/test_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from datetime import datetime, timezone
from fastapi.encoders import jsonable_encoder
from fastapi.testclient import TestClient
from fastapi import status
from sqlalchemy.orm import Session
from app.clients.db.models.law_policy.family import EventStatus

from app.clients.db.models.law_policy import FamilyEvent

from integration_tests.setup_db import EXPECTED_EVENTS, setup_db
from unit_tests.helpers.event import create_event_write_dto


def _get_event_tuple(test_db: Session, import_id: str) -> FamilyEvent:
fe: FamilyEvent = (
test_db.query(FamilyEvent).filter(FamilyEvent.import_id == import_id).one()
)
assert fe is not None
return fe


def test_update_event(client: TestClient, test_db: Session, user_header_token):
setup_db(test_db)
new_event = create_event_write_dto(title="Updated Title")
response = client.put(
"/api/v1/events/E.0.0.2",
json=jsonable_encoder(new_event),
headers=user_header_token,
)
assert response.status_code == status.HTTP_200_OK
data = response.json()

# Check the properties and values of the EventReadDTO object we return to the
# client.
assert data["event_type_value"] == "Amended"
assert data["event_title"] == "Updated Title"
assert isinstance(data["date"], str) is True
assert data["date"] == "2023-01-01T00:00:00Z"

# Get the record in the FamilyEvent table we want to update in the database and
# check the types of the values are correct and that the values have been
# successfully updated.
fe = _get_event_tuple(test_db, "E.0.0.2")
assert isinstance(fe.date, datetime) is True
assert isinstance(fe.status, EventStatus) is True
assert (
isinstance(fe.family_document_import_id, str) is True
or fe.family_document_import_id is None
)
assert (
all(
isinstance(x, str)
for x in [
fe.import_id,
fe.family_import_id,
fe.event_type_name,
fe.title,
]
)
is True
)
assert fe.import_id == "E.0.0.2"
assert fe.event_type_name == "Amended"
assert fe.date == datetime(2023, 1, 1, 0, 0, 0, 0, tzinfo=timezone.utc)
assert fe.title == "Updated Title"
assert fe.family_import_id == "A.0.0.1"
assert fe.family_document_import_id is None
assert fe.status == EventStatus.OK


def test_update_event_when_not_authorised(client: TestClient, test_db: Session):
setup_db(test_db)
new_event = create_event_write_dto(
title="Updated Title",
)
response = client.put("/api/v1/events/E.0.0.2", json=jsonable_encoder(new_event))
assert response.status_code == status.HTTP_401_UNAUTHORIZED


def test_update_event_idempotent(
client: TestClient, test_db: Session, user_header_token
):
setup_db(test_db)
event = EXPECTED_EVENTS[1]
response = client.put(
f"/api/v1/events/{event['import_id']}",
json=event,
headers=user_header_token,
)
assert response.status_code == status.HTTP_200_OK

data = response.json()
assert data["event_title"] == EXPECTED_EVENTS[1]["event_title"]

fe = _get_event_tuple(test_db, EXPECTED_EVENTS[1]["import_id"])
assert fe.title == EXPECTED_EVENTS[1]["event_title"]


def test_update_event_rollback(
client: TestClient, test_db: Session, rollback_event_repo, user_header_token
):
setup_db(test_db)
new_event = create_event_write_dto(
title="Updated Title",
)
response = client.put(
"/api/v1/events/E.0.0.2",
json=jsonable_encoder(new_event),
headers=user_header_token,
)
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE

pd = _get_event_tuple(test_db, "E.0.0.2")
assert pd.title != "Updated Title"

assert rollback_event_repo.update.call_count == 1


def test_update_event_when_not_found(
client: TestClient, test_db: Session, user_header_token
):
setup_db(test_db)
new_event = create_event_write_dto(
title="Updated Title",
)
response = client.put(
"/api/v1/events/E.0.0.22",
json=jsonable_encoder(new_event),
headers=user_header_token,
)
assert response.status_code == status.HTTP_404_NOT_FOUND
data = response.json()
assert data["detail"] == "Event not updated: E.0.0.22"


def test_update_event_when_db_error(
client: TestClient, test_db: Session, bad_event_repo, user_header_token
):
setup_db(test_db)

new_event = create_event_write_dto(
title="Updated Title",
)
response = client.put(
"/api/v1/events/E.0.0.2",
json=jsonable_encoder(new_event),
headers=user_header_token,
)
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
data = response.json()
assert data["detail"] == "Bad Repo"
assert bad_event_repo.update.call_count == 1
6 changes: 6 additions & 0 deletions integration_tests/mocks/bad_event_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def mock_search(_, q: str) -> list[EventReadDTO]:
def mock_create(_, data: EventCreateDTO) -> Optional[EventReadDTO]:
raise RepositoryError("Bad Repo")

def mock_update(_, import_id, data: EventReadDTO) -> Optional[EventReadDTO]:
raise RepositoryError("Bad Repo")

def mock_get_count(_) -> Optional[int]:
raise RepositoryError("Bad Repo")

Expand All @@ -33,6 +36,9 @@ def mock_get_count(_) -> Optional[int]:
monkeypatch.setattr(repo, "create", mock_create)
mocker.spy(repo, "create")

monkeypatch.setattr(repo, "update", mock_update)
mocker.spy(repo, "update")

monkeypatch.setattr(repo, "count", mock_get_count)
mocker.spy(repo, "count")

Expand Down
Loading

0 comments on commit 22eb526

Please sign in to comment.