Skip to content

Commit

Permalink
Add multi geographies to search API (#266)
Browse files Browse the repository at this point in the history
* feat: add multi geographies test

* Update pyproject.toml

---------

Co-authored-by: Katy Baulch <[email protected]>
  • Loading branch information
jamesgorrie and katybaulch authored Dec 16, 2024
1 parent 812ff4a commit c8df79c
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 30 deletions.
12 changes: 9 additions & 3 deletions app/api/api_v1/routers/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
"""

import logging
from typing import Annotated

from fastapi import APIRouter, HTTPException, Request, status
from fastapi import APIRouter, HTTPException, Query, Request, status

import app.service.family as family_service
from app.api.api_v1.query_params import (
Expand Down Expand Up @@ -74,7 +75,12 @@ async def get_all_families(request: Request) -> list[FamilyReadDTO]:


@r.get("/families/", response_model=list[FamilyReadDTO])
async def search_family(request: Request) -> list[FamilyReadDTO]:
async def search_family(
request: Request,
# We have used the built in parsers here for geography specifically
# so that we do not have to build our own
geography: Annotated[list[str] | None, Query()] = None,
) -> list[FamilyReadDTO]:
"""
Searches for families matching URL parameters ("q" by default).
Expand All @@ -96,7 +102,7 @@ async def search_family(request: Request) -> list[FamilyReadDTO]:
validate_query_params(query_params, VALID_PARAMS)

try:
families = family_service.search(query_params, request.state.user)
families = family_service.search(query_params, request.state.user, geography)
except ValidationError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.message)
except RepositoryError as e:
Expand Down
15 changes: 8 additions & 7 deletions app/repository/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,10 @@ def get(db: Session, import_id: str) -> Optional[FamilyReadDTO]:


def search(
db: Session, search_params: dict[str, Union[str, int]], org_id: Optional[int]
db: Session,
search_params: dict[str, Union[str, int]],
org_id: Optional[int],
geography: Optional[list[str]],
) -> list[FamilyReadDTO]:
"""
Gets a list of families from the repository searching given fields.
Expand Down Expand Up @@ -195,13 +198,11 @@ def search(
term = f"%{escape_like(search_params['summary'])}%"
search.append(Family.description.ilike(term))

if "geography" in search_params.keys():
term = cast(str, search_params["geography"])
search.append(
or_(
Geography.display_value == term.title(), Geography.value == term.upper()
)
if geography is not None:
geography_filter = or_(
*[(Geography.display_value == g.title()) for g in geography]
)
search.append(geography_filter)

if "status" in search_params.keys():
term = cast(str, search_params["status"])
Expand Down
5 changes: 4 additions & 1 deletion app/repository/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ def get(db: Session, import_id: str) -> Optional[FamilyReadDTO]:

@staticmethod
def search(
db: Session, search_params: dict[str, Union[str, int]], org_id: Optional[int]
db: Session,
search_params: dict[str, Union[str, int]],
org_id: Optional[int],
geography: Optional[list[str]],
) -> list[FamilyReadDTO]:
"""Searches the families"""
...
Expand Down
7 changes: 5 additions & 2 deletions app/service/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def all(user: UserContext) -> list[FamilyReadDTO]:

@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def search(
search_params: dict[str, Union[str, int]], user: UserContext
search_params: dict[str, Union[str, int]],
user: UserContext,
geography: Optional[list[str]] = None,
) -> list[FamilyReadDTO]:
"""
Searches for the search term against families on specified fields.
Expand All @@ -76,12 +78,13 @@ def search(
:param dict search_params: Search patterns to match against specified
fields, given as key value pairs in a dictionary.
:param UserContext user: The current user context.
:param Optional[list[str]] geography: geographies to filter on.
:return list[FamilyDTO]: The list of families matching the given
search terms.
"""
with db_session.get_db() as db:
org_id = app_user.restrict_entities_to_user_org(user)
return family_repo.search(db, search_params, org_id)
return family_repo.search(db, search_params, org_id, geography)


@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "admin_backend"
version = "2.17.22"
version = "2.17.23"
description = ""
authors = ["CPR-dev-team <[email protected]>"]
packages = [{ include = "app" }, { include = "tests" }]
Expand Down
58 changes: 52 additions & 6 deletions tests/integration_tests/family/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,74 @@
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session

from tests.integration_tests.setup_db import setup_db
from tests.integration_tests.setup_db import add_data, setup_db


def test_search_geographies(
client: TestClient, data_db: Session, superuser_header_token
):
setup_db(data_db)
add_data(
data_db,
[
{
"import_id": "A.0.0.4",
"title": "title",
"summary": "gregarious magazine rub",
"geography": "ALB",
"category": "UNFCCC",
"status": "Created",
"metadata": {"author": "CPR", "author_type": "Party"},
"organisation": "UNFCCC",
"corpus_import_id": "UNFCCC.corpus.i00000001.n0000",
"corpus_title": "UNFCCC Submissions",
"corpus_type": "Intl. agreements",
"slug": "Slug4",
"events": ["E.0.0.3"],
"published_date": "2018-12-24T04:59:33Z",
"last_updated_date": "2018-12-24T04:59:33Z",
"documents": ["D.0.0.1", "D.0.0.2"],
"collections": ["C.0.0.4"],
},
{
"import_id": "A.0.0.5",
"title": "title",
"summary": "flour umbrella established",
"geography": "ZMB",
"category": "UNFCCC",
"status": "Created",
"metadata": {"author": "CPR", "author_type": "Party"},
"organisation": "UNFCCC",
"corpus_import_id": "UNFCCC.corpus.i00000001.n0000",
"corpus_title": "UNFCCC Submissions",
"corpus_type": "Intl. agreements",
"slug": "Slug5",
"events": ["E.0.0.3"],
"published_date": "2018-12-24T04:59:33Z",
"last_updated_date": "2018-12-24T04:59:33Z",
"documents": ["D.0.0.1", "D.0.0.2"],
"collections": ["C.0.0.4"],
},
],
)

tests_cases = [
("afghanistan", 2),
("zimbabwe", 1),
(["afghanistan"], ["A.0.0.1", "A.0.0.3"]),
(["zimbabwe"], ["A.0.0.2"]),
(["albania", "zambia"], ["A.0.0.4", "A.0.0.5"]),
]

for country, expected_count in tests_cases:
for countries, expected_ids in tests_cases:
geographies_query = "&".join([f"geography={country}" for country in countries])
response = client.get(
f"/api/v1/families/?geography={country}",
f"/api/v1/families/?{geographies_query}",
headers=superuser_header_token,
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
ids = [result["import_id"] for result in data]
assert isinstance(data, list)
assert len(data) == expected_count
assert ids == expected_ids


def test_search_family_super(
Expand Down
46 changes: 39 additions & 7 deletions tests/integration_tests/setup_db.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import cast
from typing import TypedDict, cast

from db_client.models.dfce.collection import (
Collection,
Expand Down Expand Up @@ -26,7 +26,28 @@
from sqlalchemy import update
from sqlalchemy.orm import Session

EXPECTED_FAMILIES = [

class DBEntry(TypedDict):
import_id: str
title: str
summary: str
geography: str
category: str
status: str
metadata: dict
organisation: str
corpus_import_id: str
corpus_title: str
corpus_type: str
slug: str
events: list[str]
published_date: str | None
last_updated_date: str | None
documents: list[str]
collections: list[str]


EXPECTED_FAMILIES: list[DBEntry] = [
{
"import_id": "A.0.0.1",
"title": "apple",
Expand Down Expand Up @@ -280,6 +301,15 @@
]


def add_data(test_db: Session, data: list[DBEntry]):
org_id = test_db.query(Organisation).filter(Organisation.name == "CCLW").one().id
other_org_id = (
test_db.query(Organisation).filter(Organisation.name == "UNFCCC").one().id
)
_setup_family_data(test_db, org_id, other_org_id, data)
test_db.commit()


def setup_db(test_db: Session, configure_empty: bool = False):
setup_test_data(test_db, configure_empty)

Expand Down Expand Up @@ -480,13 +510,15 @@ def _setup_family_data(
test_db: Session,
default_org_id: int,
other_org_id: int,
initial_data: list[DBEntry] = EXPECTED_FAMILIES,
configure_empty: bool = False,
) -> None:
if configure_empty is True:
return None

for index in range(EXPECTED_NUM_FAMILIES):
data = EXPECTED_FAMILIES[index]
num_families = len(initial_data)
for index in range(num_families):
data = initial_data[index]

geo_id = (
test_db.query(Geography.id)
Expand Down Expand Up @@ -524,8 +556,8 @@ def _setup_family_data(
)

# Now add the metadata onto the families
for index in range(EXPECTED_NUM_FAMILIES):
data = EXPECTED_FAMILIES[index]
for index in range(num_families):
data = initial_data[index]
test_db.add(
FamilyMetadata(
family_import_id=data["import_id"],
Expand All @@ -534,7 +566,7 @@ def _setup_family_data(
)
test_db.add(
Slug(
name=f"Slug{index+1}",
name=data["slug"],
family_import_id=data["import_id"],
)
)
Expand Down
7 changes: 6 additions & 1 deletion tests/mocks/repos/bad_family_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@ def mock_get_all(_):
def mock_get(_, import_id: str) -> Optional[FamilyReadDTO]:
raise RepositoryError("Bad Repo")

def mock_search(_, q: str, org_id: Optional[int]) -> list[FamilyReadDTO]:
def mock_search(
_,
q: str,
org_id: Optional[int],
geography: Optional[list[str]],
) -> list[FamilyReadDTO]:
raise RepositoryError("Bad Repo")

def mock_update(
Expand Down
5 changes: 4 additions & 1 deletion tests/mocks/repos/family_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def get(db: Session, import_id: str) -> Optional[FamilyReadDTO]:


def search(
db: Session, search_params: dict[str, Union[str, int]], org_id: Optional[int]
db: Session,
search_params: dict[str, Union[str, int]],
org_id: Optional[int],
geography: Optional[list[str]],
) -> list[FamilyReadDTO]:
_maybe_throw()
_maybe_timeout()
Expand Down
4 changes: 3 additions & 1 deletion tests/mocks/services/family_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def mock_get_family(import_id: str) -> Optional[FamilyReadDTO]:
if not family_service.missing:
return create_family_read_dto(import_id, collections=["x.y.z.1", "x.y.z.2"])

def mock_search_families(q_params: dict, user_email: str) -> list[FamilyReadDTO]:
def mock_search_families(
q_params: dict, user_email: str, geography: Optional[list[str]]
) -> list[FamilyReadDTO]:
if q_params["q"] == "empty":
return []

Expand Down

0 comments on commit c8df79c

Please sign in to comment.