Skip to content

Commit

Permalink
Merge branch 'main' into docker-orchestration
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesgorrie authored Dec 9, 2024
2 parents 5333359 + 20265e7 commit 0555c16
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 17 deletions.
12 changes: 10 additions & 2 deletions app/api/api_v1/routers/bulk_import.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
from typing import Optional

from fastapi import (
APIRouter,
Expand Down Expand Up @@ -63,11 +64,16 @@ async def bulk_import(
data: UploadFile,
corpus_import_id: str,
background_tasks: BackgroundTasks,
document_limit: Optional[int] = None,
) -> Json:
"""
Bulk import endpoint.
:param UploadFile new_data: file containing json representation of data to import.
:param Request request: The request object containing all request data.
:param UploadFile data: File containing json representation of data to import.
:param str corpus_import_id: The ID of the corpus to import.
:param BackgroundTasks background_tasks: Background tasks to be performed after the request is completed.
:param Optional[int] document_limit: The max number of documents to be saved in this session or None.
:return Json: json representation of the data to import.
"""
_LOGGER.info(
Expand All @@ -79,7 +85,9 @@ async def bulk_import(
data_dict = json.loads(content)
validate_bulk_import_data(data_dict)

background_tasks.add_task(import_data, data_dict, corpus_import_id)
background_tasks.add_task(
import_data, data_dict, corpus_import_id, document_limit
)

return {
"message": "Bulk import request accepted. Check Cloudwatch logs for result."
Expand Down
22 changes: 18 additions & 4 deletions app/service/bulk_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
get_datetime_event_name_for_corpus,
)

DOCUMENT_BULK_IMPORT_LIMIT = 1000
# Any increase to this number should first be discussed with the Platform Team
DEFAULT_DOCUMENT_LIMIT = 1000

_LOGGER = logging.getLogger(__name__)
_LOGGER.setLevel(logging.DEBUG)

Expand Down Expand Up @@ -234,13 +236,15 @@ def save_families(
def save_documents(
document_data: list[dict[str, Any]],
corpus_import_id: str,
document_limit: int,
db: Optional[Session] = None,
) -> list[str]:
"""
Creates new documents with the values passed.
:param list[dict[str, Any]] document_data: The data to use for creating documents.
:param str corpus_import_id: The import_id of the corpus the documents belong to.
:param int document_limit: The max number of documents to be saved in this session.
:param Optional[Session] db: The database session to use for saving documents or None.
:return list[str]: The new import_ids for the saved documents.
"""
Expand All @@ -256,7 +260,7 @@ def save_documents(
for doc in document_data:
if (
not _exists_in_db(FamilyDocument, doc["import_id"], db)
and total_documents_saved < DOCUMENT_BULK_IMPORT_LIMIT
and total_documents_saved < document_limit
):
_LOGGER.info(f"Importing document {doc['import_id']}")
dto = BulkImportDocumentDTO(**doc).to_document_create_dto()
Expand Down Expand Up @@ -309,12 +313,17 @@ def save_events(


@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def import_data(data: dict[str, Any], corpus_import_id: str) -> None:
def import_data(
data: dict[str, Any],
corpus_import_id: str,
document_limit: Optional[int] = None,
) -> None:
"""
Imports data for a given corpus_import_id.
:param dict[str, Any] data: The data to be imported.
:param str corpus_import_id: The import_id of the corpus the data should be imported into.
:param Optional[int] document_limit: The max number of documents to be saved in this session or None.
:raises RepositoryError: raised on a database error.
:raises ValidationError: raised should the data be invalid.
"""
Expand Down Expand Up @@ -348,7 +357,12 @@ def import_data(data: dict[str, Any], corpus_import_id: str) -> None:
result["families"] = save_families(family_data, corpus_import_id, db)
if document_data:
_LOGGER.info("Saving documents")
result["documents"] = save_documents(document_data, corpus_import_id, db)
result["documents"] = save_documents(
document_data,
corpus_import_id,
document_limit or DEFAULT_DOCUMENT_LIMIT,
db,
)
if event_data:
_LOGGER.info("Saving events")
result["events"] = save_events(event_data, corpus_import_id, db)
Expand Down
64 changes: 57 additions & 7 deletions tests/integration_tests/bulk_import/test_bulk_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sqlalchemy import update
from sqlalchemy.orm import Session

from app.service.bulk_import import DEFAULT_DOCUMENT_LIMIT
from tests.helpers.bulk_import import (
build_json_file,
default_collection,
Expand Down Expand Up @@ -149,6 +150,53 @@ def test_import_data_rollback(
assert actual_collection is None


@pytest.mark.s3
def test_bulk_import_saves_default_number_of_documents_if_no_limit_provided_in_request(
data_db: Session, client: TestClient, superuser_header_token
):
input_json = build_json_file(
{
"collections": [default_collection],
"families": [default_family],
"documents": [
{**default_document, "import_id": f"test.new.document.{i}"}
for i in range(DEFAULT_DOCUMENT_LIMIT + 1)
],
"events": [default_event],
}
)

response = client.post(
"/api/v1/bulk-import/UNFCCC.corpus.i00000001.n0000",
files={"data": input_json},
headers=superuser_header_token,
)

assert response.status_code == status.HTTP_202_ACCEPTED
assert response.json() == {
"message": "Bulk import request accepted. Check Cloudwatch logs for result."
}

assert (
"Created"
== data_db.query(FamilyDocument)
.filter(
FamilyDocument.import_id
== f"test.new.document.{DEFAULT_DOCUMENT_LIMIT - 1}"
)
.one_or_none()
.document_status
)

assert (
not data_db.query(FamilyDocument)
.filter(
FamilyDocument.import_id == f"test.new.document.{DEFAULT_DOCUMENT_LIMIT}"
)
.one_or_none()
)


@pytest.mark.s3
def test_bulk_import_idempotency(
caplog,
Expand All @@ -161,8 +209,8 @@ def test_bulk_import_idempotency(
"collections": [default_collection],
"families": [default_family],
"documents": [
{**default_document, "import_id": f"test.new.document.{i}"}
for i in range(1001)
default_document,
{**default_document, "import_id": "test.new.document.1"},
],
"events": [default_event],
}
Expand All @@ -172,6 +220,7 @@ def test_bulk_import_idempotency(
first_response = client.post(
"/api/v1/bulk-import/UNFCCC.corpus.i00000001.n0000",
files={"data": input_json},
params={"document_limit": 1},
headers=superuser_header_token,
)

Expand All @@ -183,28 +232,29 @@ def test_bulk_import_idempotency(
assert (
"Created"
== data_db.query(FamilyDocument)
.filter(FamilyDocument.import_id == "test.new.document.999")
.filter(FamilyDocument.import_id == "test.new.document.0")
.one_or_none()
.document_status
)

assert (
not data_db.query(FamilyDocument)
.filter(FamilyDocument.import_id == "test.new.document.1000")
.filter(FamilyDocument.import_id == "test.new.document.1")
.one_or_none()
)

# simulating pipeline ingest
data_db.execute(
update(FamilyDocument)
.where(FamilyDocument.import_id == "test.new.document.999")
.where(FamilyDocument.import_id == "test.new.document.0")
.values(document_status="Published")
)

with caplog.at_level(logging.ERROR):
second_response = client.post(
"/api/v1/bulk-import/UNFCCC.corpus.i00000001.n0000",
files={"data": input_json},
params={"document_limit": 1},
headers=superuser_header_token,
)

Expand All @@ -217,15 +267,15 @@ def test_bulk_import_idempotency(
assert (
"Published"
== data_db.query(FamilyDocument)
.filter(FamilyDocument.import_id == "test.new.document.999")
.filter(FamilyDocument.import_id == "test.new.document.0")
.one_or_none()
.document_status
)

assert (
"Created"
== data_db.query(FamilyDocument)
.filter(FamilyDocument.import_id == "test.new.document.1000")
.filter(FamilyDocument.import_id == "test.new.document.1")
.one_or_none()
.document_status
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def test_save_documents_when_data_invalid(validation_service_mock):
test_data = [{"import_id": "invalid"}]

with pytest.raises(ValidationError) as e:
bulk_import_service.save_documents(test_data, "test")
bulk_import_service.save_documents(test_data, "test", 1)
assert "Error" == e.value.message


Expand All @@ -181,8 +181,6 @@ def test_save_documents_when_data_invalid(validation_service_mock):
def test_do_not_save_documents_over_bulk_import_limit(
validation_service_mock, document_repo_mock, monkeypatch
):
monkeypatch.setattr(bulk_import_service, "DOCUMENT_BULK_IMPORT_LIMIT", 1)

test_data = [
{
"import_id": "test.new.document.0",
Expand All @@ -204,7 +202,7 @@ def test_do_not_save_documents_over_bulk_import_limit(
},
]

saved_documents = bulk_import_service.save_documents(test_data, "test")
saved_documents = bulk_import_service.save_documents(test_data, "test", 1)
assert ["test.new.document.0"] == saved_documents


Expand Down

0 comments on commit 0555c16

Please sign in to comment.