Skip to content

Commit

Permalink
PDCT 419 Fully implement Document writes (#14)
Browse files Browse the repository at this point in the history
* Create DocumentWriteDTO to be used for update

* Fully implement document updates
  • Loading branch information
diversemix authored Sep 27, 2023
1 parent 4192e4d commit 1650d28
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 21 deletions.
13 changes: 12 additions & 1 deletion app/model/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,18 @@ class DocumentReadDTO(BaseModel):
# languages: list[]


DocumentWriteDTO = DocumentReadDTO
class DocumentWriteDTO(BaseModel):
"""Representation of a Document."""

# From FamilyDocument
import_id: str
variant_name: str
role: str
type: str
title: str
source_url: str
# TODO: Languages for a document
# languages: list[]


class DocumentUploadRequest(BaseModel):
Expand Down
31 changes: 26 additions & 5 deletions app/repository/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
Variant,
)
from app.errors import RepositoryError
from app.model.document import DocumentReadDTO
from app.model.document import DocumentReadDTO, DocumentWriteDTO
from app.clients.db.models.document.physical_document import (
Language,
LanguageSource,
Expand Down Expand Up @@ -153,7 +153,7 @@ def search(db: Session, search_term: str) -> list[DocumentReadDTO]:
return [_document_to_dto(d) for d in found]


def update(db: Session, document: DocumentReadDTO) -> bool:
def update(db: Session, document: DocumentWriteDTO) -> bool:
"""
Updates a single entry with the new values passed.
Expand Down Expand Up @@ -188,18 +188,39 @@ def update(db: Session, document: DocumentReadDTO) -> bool:
)
return False

result = db.execute(
update_slug = original_pd.title != new_values["title"]

commands = [
db_update(PhysicalDocument)
.where(PhysicalDocument.id == original_pd.id)
.values(
title=new_values["title"],
)
)
source_url=new_values["source_url"],
),
db_update(FamilyDocument)
.where(FamilyDocument.import_id == original_fd.import_id)
.values(
variant_name=new_values["variant_name"],
document_role=new_values["role"],
document_type=new_values["type"],
),
]

for c in commands:
result = db.execute(c)

if result.rowcount == 0: # type: ignore
msg = f"Could not update document fields: {document}"
_LOGGER.error(msg)
raise RepositoryError(msg)

if update_slug:
db.add(
Slug(
family_document_import_id=original_fd.import_id,
name=generate_slug(db, new_values["title"]),
)
)
return True


Expand Down
2 changes: 1 addition & 1 deletion app/service/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def update(
@db_session.with_transaction(__name__)
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def create(
document: DocumentWriteDTO, db: Session = db_session.get_db()
document: DocumentReadDTO, db: Session = db_session.get_db()
) -> Optional[DocumentReadDTO]:
"""
Creates a new document with the values passed.
Expand Down
27 changes: 24 additions & 3 deletions integration_tests/document/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from fastapi import status
from sqlalchemy.orm import Session
from app.clients.db.models.document.physical_document import PhysicalDocument
from app.clients.db.models.law_policy.family import FamilyDocument
from app.clients.db.models.law_policy.family import FamilyDocument, Slug
from app.model.document import DocumentWriteDTO

from integration_tests.setup_db import EXPECTED_DOCUMENTS, setup_db
from unit_tests.helpers.document import create_document_dto
Expand Down Expand Up @@ -31,10 +32,13 @@ def _get_doc_tuple(

def test_update_document(client: TestClient, test_db: Session, user_header_token):
setup_db(test_db)
new_document = create_document_dto(
new_document = DocumentWriteDTO(
import_id="D.0.0.2",
family_import_id="A.0.0.3",
variant_name="Translation",
role="SUMMARY",
type="Annex",
title="Updated Title",
source_url="Updated Source",
)
response = client.put(
"/api/v1/documents",
Expand All @@ -43,10 +47,27 @@ def test_update_document(client: TestClient, test_db: Session, user_header_token
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["import_id"] == "D.0.0.2"
assert data["variant_name"] == "Translation"
assert data["role"] == "SUMMARY"
assert data["type"] == "Annex"
assert data["title"] == "Updated Title"
assert data["source_url"] == "Updated Source"

fd, pd = _get_doc_tuple(test_db, "D.0.0.2")
assert fd.import_id == "D.0.0.2"
assert fd.variant_name == "Translation"
assert fd.document_role == "SUMMARY"
assert fd.document_type == "Annex"
assert pd.title == "Updated Title"
assert pd.source_url == "Updated Source"

# Check slug is updated too
slugs = (
test_db.query(Slug).filter(Slug.family_document_import_id == "D.0.0.2").all()
)
last_slug = slugs[-1].name
assert last_slug.startswith("updated-title")


def test_update_document_when_not_authorised(client: TestClient, test_db: Session):
Expand Down
8 changes: 3 additions & 5 deletions unit_tests/mocks/services/document_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,16 @@ def mock_search_documents(q: str) -> list[DocumentReadDTO]:
def mock_update_document(data: DocumentWriteDTO) -> Optional[DocumentReadDTO]:
maybe_throw()
if not document_service.missing:
return create_document_dto(
data.import_id, data.family_import_id, data.title
)
return create_document_dto(data.import_id, "family_import_id", data.title)

def mock_create_document(data: DocumentWriteDTO) -> Optional[DocumentReadDTO]:
def mock_create_document(data: DocumentReadDTO) -> Optional[DocumentReadDTO]:
maybe_throw()
if not document_service.missing:
return create_document_dto(
data.import_id, data.family_import_id, data.title
)

def mock_delete_document(import_id: str) -> bool:
def mock_delete_document(_) -> bool:
maybe_throw()
return not document_service.missing

Expand Down
18 changes: 12 additions & 6 deletions unit_tests/service/test_document_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import cast
import pytest
from app.model.document import DocumentReadDTO, DocumentWriteDTO
import app.service.document as doc_service
Expand All @@ -7,7 +6,14 @@


def _to_write_dto(dto: DocumentReadDTO) -> DocumentWriteDTO:
return cast(DocumentWriteDTO, dto)
return DocumentWriteDTO(
import_id=dto.import_id,
variant_name=dto.variant_name,
role=dto.role,
type=dto.type,
title=dto.title,
source_url=dto.source_url,
)


def test_document_service_get_upload_details(test_s3_client):
Expand Down Expand Up @@ -140,7 +146,7 @@ def test_update_raises_when_invalid_id(

def test_create(document_repo_mock, family_repo_mock):
new_document = create_dto(import_id="A.0.0.5")
document = doc_service.create(_to_write_dto(new_document))
document = doc_service.create(new_document)
assert document is not None
assert document_repo_mock.create.call_count == 1
assert family_repo_mock.get.call_count == 1
Expand All @@ -149,7 +155,7 @@ def test_create(document_repo_mock, family_repo_mock):
def test_create_when_db_fails(document_repo_mock, family_repo_mock):
new_document = create_dto(import_id="a.b.c.d")
document_repo_mock.return_empty = True
document = doc_service.create(_to_write_dto(new_document))
document = doc_service.create(new_document)
assert document is None
assert document_repo_mock.create.call_count == 1
assert family_repo_mock.get.call_count == 1
Expand All @@ -158,7 +164,7 @@ def test_create_when_db_fails(document_repo_mock, family_repo_mock):
def test_create_raises_when_invalid_id(document_repo_mock):
new_document = create_dto(import_id="invalid")
with pytest.raises(ValidationError) as e:
doc_service.create(_to_write_dto(new_document))
doc_service.create(new_document)
expected_msg = f"The import id {new_document.import_id} is invalid!"
assert e.value.message == expected_msg
assert document_repo_mock.create.call_count == 0
Expand All @@ -167,7 +173,7 @@ def test_create_raises_when_invalid_id(document_repo_mock):
def test_create_raises_when_invalid_family_id(document_repo_mock):
new_document = create_dto(import_id="a.b.c.d", family_import_id="invalid family")
with pytest.raises(ValidationError) as e:
doc_service.create(_to_write_dto(new_document))
doc_service.create(new_document)
expected_msg = f"The import id {new_document.family_import_id} is invalid!"
assert e.value.message == expected_msg
assert document_repo_mock.create.call_count == 0

0 comments on commit 1650d28

Please sign in to comment.