diff --git a/app/model/bulk_import.py b/app/model/bulk_import.py index f14d9497..f1bce8c0 100644 --- a/app/model/bulk_import.py +++ b/app/model/bulk_import.py @@ -1,13 +1,14 @@ from datetime import datetime -from typing import Optional +from typing import Dict, List, Optional, Union -from pydantic import AnyHttpUrl, BaseModel +from pydantic import AnyHttpUrl, BaseModel, RootModel from app.model.collection import CollectionCreateDTO from app.model.document import DocumentCreateDTO from app.model.event import EventCreateDTO from app.model.family import FamilyCreateDTO -from app.model.general import Json + +Metadata = RootModel[Dict[str, Union[str, List[str]]]] class BulkImportCollectionDTO(BaseModel): @@ -38,7 +39,7 @@ class BulkImportFamilyDTO(BaseModel): summary: str geographies: list[str] category: str - metadata: Json + metadata: Metadata collections: list[str] corpus_import_id: str @@ -54,7 +55,7 @@ def to_family_create_dto(self, corpus_import_id: str) -> FamilyCreateDTO: summary=self.summary, geography=self.geographies, category=self.category, - metadata=self.metadata, + metadata=self.metadata.model_dump(), collections=self.collections, corpus_import_id=corpus_import_id, ) @@ -91,7 +92,7 @@ class BulkImportDocumentDTO(BaseModel): import_id: str family_import_id: str variant_name: Optional[str] = None - metadata: Json + metadata: Metadata title: str source_url: Optional[AnyHttpUrl] = None user_language_name: Optional[str] = None @@ -107,7 +108,7 @@ def to_document_create_dto(self) -> DocumentCreateDTO: import_id=self.import_id, family_import_id=self.family_import_id, variant_name=self.variant_name, - metadata=self.metadata, + metadata=self.metadata.model_dump(), title=self.title, source_url=self.source_url, user_language_name=self.user_language_name, diff --git a/pyproject.toml b/pyproject.toml index 0f5ba2ff..cc17c644 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "admin_backend" -version = "2.17.19" +version = "2.17.20" description = "" authors = ["CPR-dev-team "] packages = [{ include = "app" }, { include = "tests" }] diff --git a/tests/unit_tests/service/bulk_import/test_bulk_import_service.py b/tests/unit_tests/service/bulk_import/test_bulk_import_service.py index ebcecbc3..1c427a80 100644 --- a/tests/unit_tests/service/bulk_import/test_bulk_import_service.py +++ b/tests/unit_tests/service/bulk_import/test_bulk_import_service.py @@ -7,6 +7,7 @@ import app.service.bulk_import as bulk_import_service from app.errors import ValidationError +from tests.helpers.bulk_import import default_document, default_family @patch("app.service.bulk_import.uuid4", Mock(return_value="1111-1111")) @@ -104,6 +105,30 @@ def test_slack_notification_sent_on_error(caplog, basic_s3_client, corpus_repo_m ) +@pytest.mark.parametrize( + "test_data", + [ + {"families": [{**default_family, "metadata": {"key": [1]}}]}, + {"families": [{**default_family, "metadata": {"key": None}}]}, + {"families": [{**default_family, "metadata": {"key": 1}}]}, + {"documents": [{**default_document, "metadata": {"key": 1}}]}, + ], +) +@patch.dict(os.environ, {"BULK_IMPORT_BUCKET": "test_bucket"}) +@patch("app.service.bulk_import._exists_in_db", Mock(return_value=False)) +def test_import_data_when_metadata_contains_non_string_values( + test_data, + corpus_repo_mock, + validation_service_mock, + caplog, + basic_s3_client, +): + with caplog.at_level(logging.ERROR): + bulk_import_service.import_data(test_data, "test") + + assert "Input should be a valid string" in caplog.text + + @patch.dict(os.environ, {"BULK_IMPORT_BUCKET": "test_bucket"}) def test_import_data_when_data_invalid(caplog, basic_s3_client): test_data = {