Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PDCT 419 Fully implement Document writes #14

Merged
merged 2 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion app/model/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,18 @@ class DocumentReadDTO(BaseModel):
# languages: list[]


DocumentWriteDTO = DocumentReadDTO
class DocumentWriteDTO(BaseModel):
"""Representation of a Document."""

# From FamilyDocument
import_id: str
variant_name: str
role: str
type: str
title: str
source_url: str
# TODO: Languages for a document
# languages: list[]


class DocumentUploadRequest(BaseModel):
Expand Down
31 changes: 26 additions & 5 deletions app/repository/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
Variant,
)
from app.errors import RepositoryError
from app.model.document import DocumentReadDTO
from app.model.document import DocumentReadDTO, DocumentWriteDTO
from app.clients.db.models.document.physical_document import (
Language,
LanguageSource,
Expand Down Expand Up @@ -153,7 +153,7 @@ def search(db: Session, search_term: str) -> list[DocumentReadDTO]:
return [_document_to_dto(d) for d in found]


def update(db: Session, document: DocumentReadDTO) -> bool:
def update(db: Session, document: DocumentWriteDTO) -> bool:
"""
Updates a single entry with the new values passed.

Expand Down Expand Up @@ -188,18 +188,39 @@ def update(db: Session, document: DocumentReadDTO) -> bool:
)
return False

result = db.execute(
update_slug = original_pd.title != new_values["title"]

commands = [
db_update(PhysicalDocument)
.where(PhysicalDocument.id == original_pd.id)
.values(
title=new_values["title"],
)
)
source_url=new_values["source_url"],
),
db_update(FamilyDocument)
.where(FamilyDocument.import_id == original_fd.import_id)
.values(
variant_name=new_values["variant_name"],
document_role=new_values["role"],
document_type=new_values["type"],
),
]

for c in commands:
result = db.execute(c)

if result.rowcount == 0: # type: ignore
msg = f"Could not update document fields: {document}"
_LOGGER.error(msg)
raise RepositoryError(msg)

if update_slug:
db.add(
Slug(
family_document_import_id=original_fd.import_id,
name=generate_slug(db, new_values["title"]),
)
)
return True


Expand Down
2 changes: 1 addition & 1 deletion app/service/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def update(
@db_session.with_transaction(__name__)
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def create(
document: DocumentWriteDTO, db: Session = db_session.get_db()
document: DocumentReadDTO, db: Session = db_session.get_db()
) -> Optional[DocumentReadDTO]:
"""
Creates a new document with the values passed.
Expand Down
27 changes: 24 additions & 3 deletions integration_tests/document/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from fastapi import status
from sqlalchemy.orm import Session
from app.clients.db.models.document.physical_document import PhysicalDocument
from app.clients.db.models.law_policy.family import FamilyDocument
from app.clients.db.models.law_policy.family import FamilyDocument, Slug
from app.model.document import DocumentWriteDTO

from integration_tests.setup_db import EXPECTED_DOCUMENTS, setup_db
from unit_tests.helpers.document import create_document_dto
Expand Down Expand Up @@ -31,10 +32,13 @@ def _get_doc_tuple(

def test_update_document(client: TestClient, test_db: Session, user_header_token):
setup_db(test_db)
new_document = create_document_dto(
new_document = DocumentWriteDTO(
import_id="D.0.0.2",
family_import_id="A.0.0.3",
variant_name="Translation",
role="SUMMARY",
type="Annex",
title="Updated Title",
source_url="Updated Source",
)
response = client.put(
"/api/v1/documents",
Expand All @@ -43,10 +47,27 @@ def test_update_document(client: TestClient, test_db: Session, user_header_token
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["import_id"] == "D.0.0.2"
assert data["variant_name"] == "Translation"
assert data["role"] == "SUMMARY"
assert data["type"] == "Annex"
assert data["title"] == "Updated Title"
assert data["source_url"] == "Updated Source"

fd, pd = _get_doc_tuple(test_db, "D.0.0.2")
assert fd.import_id == "D.0.0.2"
assert fd.variant_name == "Translation"
assert fd.document_role == "SUMMARY"
assert fd.document_type == "Annex"
assert pd.title == "Updated Title"
assert pd.source_url == "Updated Source"

# Check slug is updated too
slugs = (
test_db.query(Slug).filter(Slug.family_document_import_id == "D.0.0.2").all()
)
last_slug = slugs[-1].name
assert last_slug.startswith("updated-title")


def test_update_document_when_not_authorised(client: TestClient, test_db: Session):
Expand Down
8 changes: 3 additions & 5 deletions unit_tests/mocks/services/document_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,16 @@ def mock_search_documents(q: str) -> list[DocumentReadDTO]:
def mock_update_document(data: DocumentWriteDTO) -> Optional[DocumentReadDTO]:
maybe_throw()
if not document_service.missing:
return create_document_dto(
data.import_id, data.family_import_id, data.title
)
return create_document_dto(data.import_id, "family_import_id", data.title)

def mock_create_document(data: DocumentWriteDTO) -> Optional[DocumentReadDTO]:
def mock_create_document(data: DocumentReadDTO) -> Optional[DocumentReadDTO]:
maybe_throw()
if not document_service.missing:
return create_document_dto(
data.import_id, data.family_import_id, data.title
)

def mock_delete_document(import_id: str) -> bool:
def mock_delete_document(_) -> bool:
maybe_throw()
return not document_service.missing

Expand Down
18 changes: 12 additions & 6 deletions unit_tests/service/test_document_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import cast
import pytest
from app.model.document import DocumentReadDTO, DocumentWriteDTO
import app.service.document as doc_service
Expand All @@ -7,7 +6,14 @@


def _to_write_dto(dto: DocumentReadDTO) -> DocumentWriteDTO:
return cast(DocumentWriteDTO, dto)
return DocumentWriteDTO(
import_id=dto.import_id,
variant_name=dto.variant_name,
role=dto.role,
type=dto.type,
title=dto.title,
source_url=dto.source_url,
)


def test_document_service_get_upload_details(test_s3_client):
Expand Down Expand Up @@ -140,7 +146,7 @@ def test_update_raises_when_invalid_id(

def test_create(document_repo_mock, family_repo_mock):
new_document = create_dto(import_id="A.0.0.5")
document = doc_service.create(_to_write_dto(new_document))
document = doc_service.create(new_document)
assert document is not None
assert document_repo_mock.create.call_count == 1
assert family_repo_mock.get.call_count == 1
Expand All @@ -149,7 +155,7 @@ def test_create(document_repo_mock, family_repo_mock):
def test_create_when_db_fails(document_repo_mock, family_repo_mock):
new_document = create_dto(import_id="a.b.c.d")
document_repo_mock.return_empty = True
document = doc_service.create(_to_write_dto(new_document))
document = doc_service.create(new_document)
assert document is None
assert document_repo_mock.create.call_count == 1
assert family_repo_mock.get.call_count == 1
Expand All @@ -158,7 +164,7 @@ def test_create_when_db_fails(document_repo_mock, family_repo_mock):
def test_create_raises_when_invalid_id(document_repo_mock):
new_document = create_dto(import_id="invalid")
with pytest.raises(ValidationError) as e:
doc_service.create(_to_write_dto(new_document))
doc_service.create(new_document)
expected_msg = f"The import id {new_document.import_id} is invalid!"
assert e.value.message == expected_msg
assert document_repo_mock.create.call_count == 0
Expand All @@ -167,7 +173,7 @@ def test_create_raises_when_invalid_id(document_repo_mock):
def test_create_raises_when_invalid_family_id(document_repo_mock):
new_document = create_dto(import_id="a.b.c.d", family_import_id="invalid family")
with pytest.raises(ValidationError) as e:
doc_service.create(_to_write_dto(new_document))
doc_service.create(new_document)
expected_msg = f"The import id {new_document.family_import_id} is invalid!"
assert e.value.message == expected_msg
assert document_repo_mock.create.call_count == 0