Skip to content

Commit

Permalink
fix/improve more integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
diversemix committed Sep 26, 2023
1 parent f18142d commit 4192e4d
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 248 deletions.
69 changes: 42 additions & 27 deletions app/repository/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ def _get_query(db: Session) -> Query:
return (
db.query(FamilyDocument, PhysicalDocument, Slug)
.filter(FamilyDocument.physical_document_id == PhysicalDocument.id)
.filter(Slug.family_document_import_id == FamilyDocument.import_id)
.join(
Slug,
Slug.family_document_import_id == FamilyDocument.import_id,
isouter=True,
)
)


Expand All @@ -56,7 +60,7 @@ def _document_to_dto(doc_tuple: DocumentTuple) -> DocumentReadDTO:
status=cast(DocumentStatus, fd.document_status),
role=cast(FamilyDocumentRole, fd.document_role),
type=cast(FamilyDocumentType, fd.document_type),
slug=cast(str, slug.name),
slug=cast(str, slug.name) if slug is not None else "",
physical_id=cast(int, pd.id),
title=cast(str, pd.title),
md5_sum=cast(str, pd.md5_sum),
Expand Down Expand Up @@ -146,7 +150,7 @@ def search(db: Session, search_term: str) -> list[DocumentReadDTO]:
search = or_(PhysicalDocument.title.ilike(term))
found = _get_query(db).filter(search).all()

return [_document_to_dto(*d) for d in found]
return [_document_to_dto(d) for d in found]


def update(db: Session, document: DocumentReadDTO) -> bool:
Expand All @@ -160,30 +164,41 @@ def update(db: Session, document: DocumentReadDTO) -> bool:
"""
# TODO: Implement this:

# new_values = document.model_dump()

# original_document = (
# db.query(FamilyDocument)
# .filter(FamilyDocument.import_id == document.import_id)
# .one_or_none()
# )

# if original_document is None: # Not found the document to update
# _LOGGER.error(f"Unable to find document for update {document}")
# return False

# result = db.execute(
# db_update(FamilyDocument)
# .where(FamilyDocument.import_id == document.import_id)
# .values(
# title=new_values["title"],
# description=new_values["description"],
# )
# )
# if result.rowcount == 0: # type: ignore
# msg = f"Could not update document fields: {document}"
# _LOGGER.error(msg)
# raise RepositoryError(msg)
new_values = document.model_dump()

original_fd = (
db.query(FamilyDocument)
.filter(FamilyDocument.import_id == document.import_id)
.one_or_none()
)

if original_fd is None: # Not found the document to update
_LOGGER.error(f"Unable to find document for update {document.import_id}")
return False

original_pd = (
db.query(PhysicalDocument)
.filter(PhysicalDocument.id == original_fd.physical_document_id)
.one_or_none()
)

if original_pd is None: # Not found the document to update
_LOGGER.error(
f"Unable to find document for update {original_fd.physical_document_id}"
)
return False

result = db.execute(
db_update(PhysicalDocument)
.where(PhysicalDocument.id == original_pd.id)
.values(
title=new_values["title"],
)
)
if result.rowcount == 0: # type: ignore
msg = f"Could not update document fields: {document}"
_LOGGER.error(msg)
raise RepositoryError(msg)

return True

Expand Down
8 changes: 4 additions & 4 deletions integration_tests/collection/test_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from fastapi import status
from sqlalchemy.orm import Session
from app.clients.db.models.law_policy.collection import Collection
from integration_tests.setup_db import setup_db
from integration_tests.setup_db import EXPECTED_NUM_COLLECTIONS, setup_db


def test_delete_collection(
Expand All @@ -14,7 +14,7 @@ def test_delete_collection(
)
assert response.status_code == status.HTTP_200_OK
n = test_db.query(Collection).count()
assert n == 1
assert n == EXPECTED_NUM_COLLECTIONS - 1


def test_delete_collection_when_not_authenticated(client: TestClient, test_db: Session):
Expand All @@ -37,7 +37,7 @@ def test_delete_collection_rollback(
)
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
n = test_db.query(Collection).count()
assert n == 2
assert n == EXPECTED_NUM_COLLECTIONS
assert rollback_collection_repo.delete.call_count == 1


Expand All @@ -52,7 +52,7 @@ def test_delete_collection_when_not_found(
data = response.json()
assert data["detail"] == "Collection not deleted: C.0.0.22"
n = test_db.query(Collection).count()
assert n == 2
assert n == EXPECTED_NUM_COLLECTIONS


def test_delete_collection_when_db_error(
Expand Down
46 changes: 22 additions & 24 deletions integration_tests/document/test_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,84 +2,82 @@
from fastapi import status
from sqlalchemy.orm import Session

from integration_tests.setup_db import EXPECTED_COLLECTIONS, setup_db
from integration_tests.setup_db import EXPECTED_DOCUMENTS, setup_db


# --- GET ALL


def test_get_all_collections(client: TestClient, test_db: Session, user_header_token):
def test_get_all_documents(client: TestClient, test_db: Session, user_header_token):
setup_db(test_db)
response = client.get(
"/api/v1/collections",
"/api/v1/documents",
headers=user_header_token,
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert type(data) is list
assert len(data) == 2
ids_found = set([f["import_id"] for f in data])
expected_ids = set(["C.0.0.1", "C.0.0.2"])
expected_ids = set(["D.0.0.1", "D.0.0.2"])

assert ids_found.symmetric_difference(expected_ids) == set([])

sdata = sorted(data, key=lambda d: d["import_id"])
assert sdata[0] == EXPECTED_COLLECTIONS[0]
assert sdata[1] == EXPECTED_COLLECTIONS[1]
assert sdata[0] == EXPECTED_DOCUMENTS[0]
assert sdata[1] == EXPECTED_DOCUMENTS[1]


def test_get_all_collections_when_not_authenticated(
client: TestClient, test_db: Session
):
def test_get_all_documents_when_not_authenticated(client: TestClient, test_db: Session):
setup_db(test_db)
response = client.get(
"/api/v1/collections",
"/api/v1/documents",
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED


# --- GET


def test_get_collection(client: TestClient, test_db: Session, user_header_token):
def test_get_document(client: TestClient, test_db: Session, user_header_token):
setup_db(test_db)
response = client.get(
"/api/v1/collections/C.0.0.1",
"/api/v1/documents/D.0.0.1",
headers=user_header_token,
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["import_id"] == "C.0.0.1"
assert data == EXPECTED_COLLECTIONS[0]
assert data["import_id"] == "D.0.0.1"
assert data == EXPECTED_DOCUMENTS[0]


def test_get_collection_when_not_authenticated(client: TestClient, test_db: Session):
def test_get_document_when_not_authenticated(client: TestClient, test_db: Session):
setup_db(test_db)
response = client.get(
"/api/v1/collections/C.0.0.1",
"/api/v1/documents/D.0.0.1",
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED


def test_get_collection_when_not_found(
def test_get_document_when_not_found(
client: TestClient, test_db: Session, user_header_token
):
setup_db(test_db)
response = client.get(
"/api/v1/collections/C.0.0.8",
"/api/v1/documents/D.0.0.8",
headers=user_header_token,
)
assert response.status_code == status.HTTP_404_NOT_FOUND
data = response.json()
assert data["detail"] == "Collection not found: C.0.0.8"
assert data["detail"] == "Document not found: D.0.0.8"


def test_get_collection_when_id_invalid(
def test_get_document_when_id_invalid(
client: TestClient, test_db: Session, user_header_token
):
setup_db(test_db)
response = client.get(
"/api/v1/collections/A008",
"/api/v1/documents/A008",
headers=user_header_token,
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
Expand All @@ -88,12 +86,12 @@ def test_get_collection_when_id_invalid(
assert data["detail"] == expected_msg


def test_get_collection_when_db_error(
client: TestClient, test_db: Session, bad_collection_repo, user_header_token
def test_get_document_when_db_error(
client: TestClient, test_db: Session, bad_document_repo, user_header_token
):
setup_db(test_db)
response = client.get(
"/api/v1/collections/A.0.0.8",
"/api/v1/documents/A.0.0.8",
headers=user_header_token,
)
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
Expand Down
24 changes: 12 additions & 12 deletions integration_tests/document/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from integration_tests.setup_db import setup_db


def test_search_collection(client: TestClient, test_db: Session, user_header_token):
def test_search_document(client: TestClient, test_db: Session, user_header_token):
setup_db(test_db)
response = client.get(
"/api/v1/collections/?q=big",
"/api/v1/documents/?q=big",
headers=user_header_token,
)
assert response.status_code == status.HTTP_200_OK
Expand All @@ -17,38 +17,38 @@ def test_search_collection(client: TestClient, test_db: Session, user_header_tok
ids_found = set([f["import_id"] for f in data])
assert len(ids_found) == 1

expected_ids = set(["C.0.0.1"])
expected_ids = set(["D.0.0.1"])
assert ids_found.symmetric_difference(expected_ids) == set([])


def test_search_collection_when_not_authorised(client: TestClient, test_db: Session):
def test_search_document_when_not_authorised(client: TestClient, test_db: Session):
setup_db(test_db)
response = client.get(
"/api/v1/collections/?q=orange",
"/api/v1/documents/?q=orange",
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED


def test_search_collection_when_nothing_found(
def test_search_document_when_nothing_found(
client: TestClient, test_db: Session, user_header_token
):
setup_db(test_db)
response = client.get(
"/api/v1/collections/?q=chicken",
"/api/v1/documents/?q=chicken",
headers=user_header_token,
)
assert response.status_code == status.HTTP_404_NOT_FOUND
data = response.json()
assert data["detail"] == "Collections not found for term: chicken"
assert data["detail"] == "Documents not found for term: chicken"


def test_search_collection_when_db_error(
client: TestClient, test_db: Session, bad_collection_repo, user_header_token
def test_search_document_when_db_error(
client: TestClient, test_db: Session, bad_document_repo, user_header_token
):
setup_db(test_db)
response = client.get(
"/api/v1/collections/?q=chicken",
"/api/v1/documents/?q=chicken",
headers=user_header_token,
)
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
assert bad_collection_repo.search.call_count == 1
assert bad_document_repo.search.call_count == 1
Loading

0 comments on commit 4192e4d

Please sign in to comment.