Skip to content

Commit

Permalink
feat: make integration test data more dynamic
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesgorrie committed Dec 9, 2024
1 parent 31dc0dd commit 8992703
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 47 deletions.
2 changes: 2 additions & 0 deletions app/api/api_v1/routers/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ async def get_all_families(request: Request) -> list[FamilyReadDTO]:
@r.get("/families/", response_model=list[FamilyReadDTO])
async def search_family(
request: Request,
# We have used the built in parsers here for geography specifically
# so that we do not have to build our own
geography: Annotated[list[str] | None, Query()] = None,
) -> list[FamilyReadDTO]:
"""
Expand Down
45 changes: 44 additions & 1 deletion tests/integration_tests/family/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,56 @@
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session

from tests.integration_tests.setup_db import setup_db
from tests.integration_tests.setup_db import add_data, setup_db


def test_search_geographies(
client: TestClient, data_db: Session, superuser_header_token
):
setup_db(data_db)
add_data(
data_db,
[
{
"import_id": "A.0.0.4",
"title": "title",
"summary": "gregarious magazine rub",
"geography": "ALB",
"category": "UNFCCC",
"status": "Created",
"metadata": {"author": "CPR", "author_type": "Party"},
"organisation": "UNFCCC",
"corpus_import_id": "UNFCCC.corpus.i00000001.n0000",
"corpus_title": "UNFCCC Submissions",
"corpus_type": "Intl. agreements",
"slug": "Slug4",
"events": ["E.0.0.3"],
"published_date": "2018-12-24T04:59:33Z",
"last_updated_date": "2018-12-24T04:59:33Z",
"documents": ["D.0.0.1", "D.0.0.2"],
"collections": ["C.0.0.4"],
},
{
"import_id": "A.0.0.5",
"title": "title",
"summary": "flour umbrella established",
"geography": "ZMB",
"category": "UNFCCC",
"status": "Created",
"metadata": {"author": "CPR", "author_type": "Party"},
"organisation": "UNFCCC",
"corpus_import_id": "UNFCCC.corpus.i00000001.n0000",
"corpus_title": "UNFCCC Submissions",
"corpus_type": "Intl. agreements",
"slug": "Slug5",
"events": ["E.0.0.3"],
"published_date": "2018-12-24T04:59:33Z",
"last_updated_date": "2018-12-24T04:59:33Z",
"documents": ["D.0.0.1", "D.0.0.2"],
"collections": ["C.0.0.4"],
},
],
)

tests_cases = [
(["afghanistan"], ["A.0.0.1", "A.0.0.3"]),
Expand Down
84 changes: 39 additions & 45 deletions tests/integration_tests/setup_db.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import cast
from typing import TypedDict, cast

from db_client.models.dfce.collection import (
Collection,
Expand Down Expand Up @@ -26,7 +26,28 @@
from sqlalchemy import update
from sqlalchemy.orm import Session

EXPECTED_FAMILIES = [

class DBEntry(TypedDict):
import_id: str
title: str
summary: str
geography: str
category: str
status: str
metadata: dict
organisation: str
corpus_import_id: str
corpus_title: str
corpus_type: str
slug: str
events: list[str]
published_date: str | None
last_updated_date: str | None
documents: list[str]
collections: list[str]


EXPECTED_FAMILIES: list[DBEntry] = [
{
"import_id": "A.0.0.1",
"title": "apple",
Expand Down Expand Up @@ -98,44 +119,6 @@
"documents": ["D.0.0.1", "D.0.0.2"],
"collections": ["C.0.0.4"],
},
{
"import_id": "A.0.0.4",
"title": "title",
"summary": "gregarious magazine rub",
"geography": "ALB",
"category": "UNFCCC",
"status": "Created",
"metadata": {"author": "CPR", "author_type": "Party"},
"organisation": "UNFCCC",
"corpus_import_id": "UNFCCC.corpus.i00000001.n0000",
"corpus_title": "UNFCCC Submissions",
"corpus_type": "Intl. agreements",
"slug": "Slug3",
"events": ["E.0.0.3"],
"published_date": "2018-12-24T04:59:33Z",
"last_updated_date": "2018-12-24T04:59:33Z",
"documents": ["D.0.0.1", "D.0.0.2"],
"collections": ["C.0.0.4"],
},
{
"import_id": "A.0.0.5",
"title": "title",
"summary": "flour umbrella established",
"geography": "ZMB",
"category": "UNFCCC",
"status": "Created",
"metadata": {"author": "CPR", "author_type": "Party"},
"organisation": "UNFCCC",
"corpus_import_id": "UNFCCC.corpus.i00000001.n0000",
"corpus_title": "UNFCCC Submissions",
"corpus_type": "Intl. agreements",
"slug": "Slug3",
"events": ["E.0.0.3"],
"published_date": "2018-12-24T04:59:33Z",
"last_updated_date": "2018-12-24T04:59:33Z",
"documents": ["D.0.0.1", "D.0.0.2"],
"collections": ["C.0.0.4"],
},
]
EXPECTED_NUM_FAMILIES = len(EXPECTED_FAMILIES)

Expand Down Expand Up @@ -318,6 +301,15 @@
]


def add_data(test_db: Session, data: list[DBEntry]):
org_id = test_db.query(Organisation).filter(Organisation.name == "CCLW").one().id
other_org_id = (
test_db.query(Organisation).filter(Organisation.name == "UNFCCC").one().id
)
_setup_family_data(test_db, org_id, other_org_id, data)
test_db.commit()


def setup_db(test_db: Session, configure_empty: bool = False):
setup_test_data(test_db, configure_empty)

Expand Down Expand Up @@ -518,13 +510,15 @@ def _setup_family_data(
test_db: Session,
default_org_id: int,
other_org_id: int,
initial_data: list[DBEntry] = EXPECTED_FAMILIES,
configure_empty: bool = False,
) -> None:
if configure_empty is True:
return None

for index in range(EXPECTED_NUM_FAMILIES):
data = EXPECTED_FAMILIES[index]
num_families = len(initial_data)
for index in range(num_families):
data = initial_data[index]

geo_id = (
test_db.query(Geography.id)
Expand Down Expand Up @@ -562,8 +556,8 @@ def _setup_family_data(
)

# Now add the metadata onto the families
for index in range(EXPECTED_NUM_FAMILIES):
data = EXPECTED_FAMILIES[index]
for index in range(num_families):
data = initial_data[index]
test_db.add(
FamilyMetadata(
family_import_id=data["import_id"],
Expand All @@ -572,7 +566,7 @@ def _setup_family_data(
)
test_db.add(
Slug(
name=f"Slug{index+1}",
name=data["slug"],
family_import_id=data["import_id"],
)
)
Expand Down
7 changes: 6 additions & 1 deletion tests/mocks/repos/bad_family_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ def mock_get_all(_):
def mock_get(_, import_id: str) -> Optional[FamilyReadDTO]:
raise RepositoryError("Bad Repo")

def mock_search(_, q: str, org_id: Optional[int]) -> list[FamilyReadDTO]:
def mock_search(
_,
q: str,
org_id: Optional[int],
geography: list[str] | None,
) -> list[FamilyReadDTO]:
raise RepositoryError("Bad Repo")

def mock_update(
Expand Down

0 comments on commit 8992703

Please sign in to comment.