diff --git a/app/model/document.py b/app/model/document.py index cf0ffcde..87ca3ade 100644 --- a/app/model/document.py +++ b/app/model/document.py @@ -29,7 +29,18 @@ class DocumentReadDTO(BaseModel): # languages: list[] -DocumentWriteDTO = DocumentReadDTO +class DocumentWriteDTO(BaseModel): + """Representation of a Document.""" + + # From FamilyDocument + import_id: str + variant_name: str + role: str + type: str + title: str + source_url: str + # TODO: Languages for a document + # languages: list[] class DocumentUploadRequest(BaseModel): diff --git a/app/repository/document.py b/app/repository/document.py index a4b390ed..437d15dd 100644 --- a/app/repository/document.py +++ b/app/repository/document.py @@ -10,7 +10,7 @@ Variant, ) from app.errors import RepositoryError -from app.model.document import DocumentReadDTO +from app.model.document import DocumentReadDTO, DocumentWriteDTO from app.clients.db.models.document.physical_document import ( Language, LanguageSource, @@ -153,7 +153,7 @@ def search(db: Session, search_term: str) -> list[DocumentReadDTO]: return [_document_to_dto(d) for d in found] -def update(db: Session, document: DocumentReadDTO) -> bool: +def update(db: Session, document: DocumentWriteDTO) -> bool: """ Updates a single entry with the new values passed. @@ -188,18 +188,39 @@ def update(db: Session, document: DocumentReadDTO) -> bool: ) return False - result = db.execute( + update_slug = original_pd.title != new_values["title"] + + commands = [ db_update(PhysicalDocument) .where(PhysicalDocument.id == original_pd.id) .values( title=new_values["title"], - ) - ) + source_url=new_values["source_url"], + ), + db_update(FamilyDocument) + .where(FamilyDocument.import_id == original_fd.import_id) + .values( + variant_name=new_values["variant_name"], + document_role=new_values["role"], + document_type=new_values["type"], + ), + ] + + for c in commands: + result = db.execute(c) + if result.rowcount == 0: # type: ignore msg = f"Could not update document fields: {document}" _LOGGER.error(msg) raise RepositoryError(msg) + if update_slug: + db.add( + Slug( + family_document_import_id=original_fd.import_id, + name=generate_slug(db, new_values["title"]), + ) + ) return True diff --git a/app/service/document.py b/app/service/document.py index 8595502b..d6737c85 100644 --- a/app/service/document.py +++ b/app/service/document.py @@ -105,7 +105,7 @@ def update( @db_session.with_transaction(__name__) @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def create( - document: DocumentWriteDTO, db: Session = db_session.get_db() + document: DocumentReadDTO, db: Session = db_session.get_db() ) -> Optional[DocumentReadDTO]: """ Creates a new document with the values passed. diff --git a/integration_tests/document/test_update.py b/integration_tests/document/test_update.py index de4647aa..9b8253c9 100644 --- a/integration_tests/document/test_update.py +++ b/integration_tests/document/test_update.py @@ -3,7 +3,8 @@ from fastapi import status from sqlalchemy.orm import Session from app.clients.db.models.document.physical_document import PhysicalDocument -from app.clients.db.models.law_policy.family import FamilyDocument +from app.clients.db.models.law_policy.family import FamilyDocument, Slug +from app.model.document import DocumentWriteDTO from integration_tests.setup_db import EXPECTED_DOCUMENTS, setup_db from unit_tests.helpers.document import create_document_dto @@ -31,10 +32,13 @@ def _get_doc_tuple( def test_update_document(client: TestClient, test_db: Session, user_header_token): setup_db(test_db) - new_document = create_document_dto( + new_document = DocumentWriteDTO( import_id="D.0.0.2", - family_import_id="A.0.0.3", + variant_name="Translation", + role="SUMMARY", + type="Annex", title="Updated Title", + source_url="Updated Source", ) response = client.put( "/api/v1/documents", @@ -43,10 +47,27 @@ def test_update_document(client: TestClient, test_db: Session, user_header_token ) assert response.status_code == status.HTTP_200_OK data = response.json() + assert data["import_id"] == "D.0.0.2" + assert data["variant_name"] == "Translation" + assert data["role"] == "SUMMARY" + assert data["type"] == "Annex" assert data["title"] == "Updated Title" + assert data["source_url"] == "Updated Source" fd, pd = _get_doc_tuple(test_db, "D.0.0.2") + assert fd.import_id == "D.0.0.2" + assert fd.variant_name == "Translation" + assert fd.document_role == "SUMMARY" + assert fd.document_type == "Annex" assert pd.title == "Updated Title" + assert pd.source_url == "Updated Source" + + # Check slug is updated too + slugs = ( + test_db.query(Slug).filter(Slug.family_document_import_id == "D.0.0.2").all() + ) + last_slug = slugs[-1].name + assert last_slug.startswith("updated-title") def test_update_document_when_not_authorised(client: TestClient, test_db: Session): diff --git a/unit_tests/mocks/services/document_service.py b/unit_tests/mocks/services/document_service.py index 943a94df..da9d916c 100644 --- a/unit_tests/mocks/services/document_service.py +++ b/unit_tests/mocks/services/document_service.py @@ -33,18 +33,16 @@ def mock_search_documents(q: str) -> list[DocumentReadDTO]: def mock_update_document(data: DocumentWriteDTO) -> Optional[DocumentReadDTO]: maybe_throw() if not document_service.missing: - return create_document_dto( - data.import_id, data.family_import_id, data.title - ) + return create_document_dto(data.import_id, "family_import_id", data.title) - def mock_create_document(data: DocumentWriteDTO) -> Optional[DocumentReadDTO]: + def mock_create_document(data: DocumentReadDTO) -> Optional[DocumentReadDTO]: maybe_throw() if not document_service.missing: return create_document_dto( data.import_id, data.family_import_id, data.title ) - def mock_delete_document(import_id: str) -> bool: + def mock_delete_document(_) -> bool: maybe_throw() return not document_service.missing diff --git a/unit_tests/service/test_document_service.py b/unit_tests/service/test_document_service.py index 5cf05ad5..7825f47d 100644 --- a/unit_tests/service/test_document_service.py +++ b/unit_tests/service/test_document_service.py @@ -1,4 +1,3 @@ -from typing import cast import pytest from app.model.document import DocumentReadDTO, DocumentWriteDTO import app.service.document as doc_service @@ -7,7 +6,14 @@ def _to_write_dto(dto: DocumentReadDTO) -> DocumentWriteDTO: - return cast(DocumentWriteDTO, dto) + return DocumentWriteDTO( + import_id=dto.import_id, + variant_name=dto.variant_name, + role=dto.role, + type=dto.type, + title=dto.title, + source_url=dto.source_url, + ) def test_document_service_get_upload_details(test_s3_client): @@ -140,7 +146,7 @@ def test_update_raises_when_invalid_id( def test_create(document_repo_mock, family_repo_mock): new_document = create_dto(import_id="A.0.0.5") - document = doc_service.create(_to_write_dto(new_document)) + document = doc_service.create(new_document) assert document is not None assert document_repo_mock.create.call_count == 1 assert family_repo_mock.get.call_count == 1 @@ -149,7 +155,7 @@ def test_create(document_repo_mock, family_repo_mock): def test_create_when_db_fails(document_repo_mock, family_repo_mock): new_document = create_dto(import_id="a.b.c.d") document_repo_mock.return_empty = True - document = doc_service.create(_to_write_dto(new_document)) + document = doc_service.create(new_document) assert document is None assert document_repo_mock.create.call_count == 1 assert family_repo_mock.get.call_count == 1 @@ -158,7 +164,7 @@ def test_create_when_db_fails(document_repo_mock, family_repo_mock): def test_create_raises_when_invalid_id(document_repo_mock): new_document = create_dto(import_id="invalid") with pytest.raises(ValidationError) as e: - doc_service.create(_to_write_dto(new_document)) + doc_service.create(new_document) expected_msg = f"The import id {new_document.import_id} is invalid!" assert e.value.message == expected_msg assert document_repo_mock.create.call_count == 0 @@ -167,7 +173,7 @@ def test_create_raises_when_invalid_id(document_repo_mock): def test_create_raises_when_invalid_family_id(document_repo_mock): new_document = create_dto(import_id="a.b.c.d", family_import_id="invalid family") with pytest.raises(ValidationError) as e: - doc_service.create(_to_write_dto(new_document)) + doc_service.create(new_document) expected_msg = f"The import id {new_document.family_import_id} is invalid!" assert e.value.message == expected_msg assert document_repo_mock.create.call_count == 0