diff --git a/app/api/api_v1/query_params.py b/app/api/api_v1/query_params.py new file mode 100644 index 00000000..657288a8 --- /dev/null +++ b/app/api/api_v1/query_params.py @@ -0,0 +1,46 @@ +from typing import Union, cast + +from fastapi import HTTPException, status + + +def get_query_params_as_dict(query_params) -> dict[str, Union[str, int]]: + print(query_params) + return {k: query_params[k] for k in query_params.keys()} + + +def set_default_query_params( + query_params, + default_query_term: str = "", + default_max_results: int = 500, +) -> dict[str, Union[str, int]]: + query_fields = query_params.keys() + + if len(query_fields) < 1: + return {"q": default_query_term, "max_results": default_max_results} + + if "max_results" not in query_fields: + query_params["max_results"] = default_max_results + + return query_params + + +def validate_query_params( + query_params, valid_params: list[str] = ["q", "max_results"] +) -> bool: + query_fields = query_params.keys() + invalid_params = [x for x in query_fields if x not in valid_params] + if any(invalid_params): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Search parameters are invalid: {invalid_params}", + ) + + if not isinstance(query_params["max_results"], int): + try: + query_params.update({"max_results": cast(int, query_params["max_results"])}) + except Exception: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Maximum results must be an integer value", + ) + return True diff --git a/app/api/api_v1/routers/collection.py b/app/api/api_v1/routers/collection.py index b76ab7a8..9f8bbea1 100644 --- a/app/api/api_v1/routers/collection.py +++ b/app/api/api_v1/routers/collection.py @@ -1,14 +1,20 @@ """Endpoints for managing the Collection entity.""" import logging + from fastapi import APIRouter, HTTPException, Request, status -from app.errors import RepositoryError, ValidationError +import app.service.collection as collection_service +from app.api.api_v1.query_params import ( + get_query_params_as_dict, + set_default_query_params, + validate_query_params, +) +from app.errors import RepositoryError, ValidationError from app.model.collection import ( + CollectionCreateDTO, CollectionReadDTO, CollectionWriteDTO, - CollectionCreateDTO, ) -import app.service.collection as collection_service collections_router = r = APIRouter() @@ -69,27 +75,49 @@ async def get_all_collections() -> list[CollectionReadDTO]: "/collections/", response_model=list[CollectionReadDTO], ) -async def search_collection(q: str = "") -> list[CollectionReadDTO]: +async def search_collection(request: Request) -> list[CollectionReadDTO]: """ - Searches for collections matching the "q" URL parameter. - - :param str q: The string to match, defaults to "" - :raises HTTPException: If nothing found a 404 is returned. - :return list[CollectionDTO]: A list of matching collections. + Searches for collections matching URL parameters ("q" by default). + + :param Request request: The fields to match against and the values + to search for. Defaults to searching for "" in collection titles + and summaries. + :raises HTTPException: If invalid fields passed a 400 is returned. + :raises HTTPException: If a DB error occurs a 503 is returned. + :raises HTTPException: If the search request times out a 408 is + returned. + :return list[CollectionReadDTO]: A list of matching collections + (which can be empty). """ + + query_params = get_query_params_as_dict(request.query_params) + + query_params = set_default_query_params(query_params) + + VALID_PARAMS = ["q", "max_results"] + validate_query_params(query_params, VALID_PARAMS) + try: - collections = collection_service.search(q) + collections = collection_service.search(query_params) + except ValidationError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.message) except RepositoryError as e: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=e.message ) - - if len(collections) == 0: + except TimeoutError: + msg = ( + "Request timed out fetching matching collections. Try adjusting your query." + ) + _LOGGER.error(msg) raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Collections not found for term: {q}", + status_code=status.HTTP_408_REQUEST_TIMEOUT, + detail=msg, ) + if len(collections) == 0: + _LOGGER.info(f"Collections not found for terms: {query_params}") + return collections diff --git a/app/api/api_v1/routers/document.py b/app/api/api_v1/routers/document.py index a3aa5ff0..25c06a9d 100644 --- a/app/api/api_v1/routers/document.py +++ b/app/api/api_v1/routers/document.py @@ -1,6 +1,14 @@ """Endpoints for managing the Document entity.""" import logging -from fastapi import APIRouter, HTTPException, status + +from fastapi import APIRouter, HTTPException, Request, status + +import app.service.document as document_service +from app.api.api_v1.query_params import ( + get_query_params_as_dict, + set_default_query_params, + validate_query_params, +) from app.errors import RepositoryError, ValidationError from app.model.document import ( DocumentCreateDTO, @@ -8,8 +16,6 @@ DocumentWriteDTO, ) -import app.service.document as document_service - document_router = r = APIRouter() _LOGGER = logging.getLogger(__name__) @@ -69,27 +75,45 @@ async def get_all_documents() -> list[DocumentReadDTO]: "/documents/", response_model=list[DocumentReadDTO], ) -async def search_document(q: str = "") -> list[DocumentReadDTO]: +async def search_document(request: Request) -> list[DocumentReadDTO]: """ - Searches for documents matching the "q" URL parameter. - - :param str q: The string to match, defaults to "" - :raises HTTPException: If nothing found a 404 is returned. - :return list[DocumentDTO]: A list of matching documents. + Searches for documents matching URL parameters ("q" by default). + + :param Request request: The fields to match against and the values + to search for. Defaults to searching for "" in document titles. + :raises HTTPException: If invalid fields passed a 400 is returned. + :raises HTTPException: If a DB error occurs a 503 is returned. + :raises HTTPException: If the search request times out a 408 is + returned. + :return list[DocumentReadDTO]: A list of matching documents (which + can be empty). """ + query_params = get_query_params_as_dict(request.query_params) + + query_params = set_default_query_params(query_params) + + VALID_PARAMS = ["q", "max_results"] + validate_query_params(query_params, VALID_PARAMS) + try: - documents = document_service.search(q) + documents = document_service.search(query_params) + except ValidationError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.message) except RepositoryError as e: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=e.message ) - - if len(documents) == 0: + except TimeoutError: + msg = "Request timed out fetching matching documents. Try adjusting your query." + _LOGGER.error(msg) raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Documents not found for term: {q}", + status_code=status.HTTP_408_REQUEST_TIMEOUT, + detail=msg, ) + if len(documents) == 0: + _LOGGER.info(f"Documents not found for terms: {query_params}") + return documents diff --git a/app/api/api_v1/routers/event.py b/app/api/api_v1/routers/event.py index b1f2699e..d1aefebc 100644 --- a/app/api/api_v1/routers/event.py +++ b/app/api/api_v1/routers/event.py @@ -1,9 +1,14 @@ """Endpoints for managing Family Event entities.""" import logging -from fastapi import APIRouter, HTTPException, status +from fastapi import APIRouter, HTTPException, Request, status import app.service.event as event_service +from app.api.api_v1.query_params import ( + get_query_params_as_dict, + set_default_query_params, + validate_query_params, +) from app.errors import RepositoryError, ValidationError from app.model.event import EventCreateDTO, EventReadDTO, EventWriteDTO @@ -37,27 +42,46 @@ async def get_all_events() -> list[EventReadDTO]: "/events/", response_model=list[EventReadDTO], ) -async def search_event(q: str = "") -> list[EventReadDTO]: +async def search_event(request: Request) -> list[EventReadDTO]: """ - Searches for family events matching the "q" URL parameter. - - :param str q: The string to match, defaults to "" - :raises HTTPException: If nothing found a 404 is returned. - :return list[EventDTO]: A list of matching events. + Searches for family events matching URL parameters ("q" by default). + + :param Request request: The fields to match against and the values + to search for. Defaults to searching for "" in event titles and + type names. + :raises HTTPException: If invalid fields passed a 400 is returned. + :raises HTTPException: If a DB error occurs a 503 is returned. + :raises HTTPException: If the search request times out a 408 is + returned. + :return list[EventReadDTO]: A list of matching events (which can be + empty). """ + query_params = get_query_params_as_dict(request.query_params) + + query_params = set_default_query_params(query_params) + + VALID_PARAMS = ["q", "max_results"] + validate_query_params(query_params, VALID_PARAMS) + try: - events_found = event_service.search(q) + events_found = event_service.search(query_params) + except ValidationError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.message) except RepositoryError as e: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=e.message ) - - if not events_found: + except TimeoutError: + msg = "Request timed out fetching matching events. Try adjusting your query." + _LOGGER.error(msg) raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Events not found for term: {q}", + status_code=status.HTTP_408_REQUEST_TIMEOUT, + detail=msg, ) + if len(events_found) == 0: + _LOGGER.info(f"Events not found for terms: {query_params}") + return events_found diff --git a/app/api/api_v1/routers/family.py b/app/api/api_v1/routers/family.py index 8d23bbfa..ff20a045 100644 --- a/app/api/api_v1/routers/family.py +++ b/app/api/api_v1/routers/family.py @@ -7,11 +7,15 @@ implemented directly accesses the "repository" layer. """ import logging -from typing import Union, cast from fastapi import APIRouter, HTTPException, Request, status import app.service.family as family_service +from app.api.api_v1.query_params import ( + get_query_params_as_dict, + set_default_query_params, + validate_query_params, +) from app.errors import RepositoryError, ValidationError from app.model.family import FamilyCreateDTO, FamilyReadDTO, FamilyWriteDTO @@ -77,46 +81,18 @@ async def search_family(request: Request) -> list[FamilyReadDTO]: to search for. Defaults to searching for "" in family titles and summaries. :raises HTTPException: If invalid fields passed a 400 is returned. - :raises HTTPException: If nothing found a 404 is returned. - :return list[FamilyDTO]: A list of matching families. + :raises HTTPException: If a DB error occurs a 503 is returned. + :raises HTTPException: If the search request times out a 408 is + returned. + :return list[FamilyDTO]: A list of matching families (which can be + empty). """ - query_params: dict[str, Union[str, int]] = { - k: request.query_params[k] for k in request.query_params.keys() - } - - query_fields = query_params.keys() - if len(query_fields) < 1: - query_params = {"q": ""} - - VALID_PARAMS = ["q", "title", "description", "geography", "status", "max_results"] - invalid_params = [x for x in query_fields if x not in VALID_PARAMS] - if any(invalid_params): - msg = f"Search parameters are invalid: {invalid_params}" - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=msg, - ) + query_params = get_query_params_as_dict(request.query_params) + + query_params = set_default_query_params(query_params) - if "q" in query_fields: - if "title" in query_fields: - query_params.pop("title") - if "description" in query_fields: - query_params.pop("description") - - DEFAULT_MAX_RESULTS = 500 - if "max_results" not in query_fields: - query_params["max_results"] = DEFAULT_MAX_RESULTS - else: - if not isinstance(query_params["max_results"], int): - try: - query_params.update( - {"max_results": cast(int, query_params["max_results"])} - ) - except Exception: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Maximum results must be an integer value", - ) + VALID_PARAMS = ["q", "title", "summary", "geography", "status", "max_results"] + validate_query_params(query_params, VALID_PARAMS) try: families = family_service.search(query_params) diff --git a/app/repository/collection.py b/app/repository/collection.py index 34b38565..604e14f2 100644 --- a/app/repository/collection.py +++ b/app/repository/collection.py @@ -2,12 +2,12 @@ import logging from datetime import datetime -from typing import Optional, Tuple, cast +from typing import Optional, Tuple, Union, cast -from sqlalchemy import Column, or_ +from sqlalchemy import Column, and_, or_ from sqlalchemy import delete as db_delete from sqlalchemy import update as db_update -from sqlalchemy.exc import NoResultFound +from sqlalchemy.exc import NoResultFound, OperationalError from sqlalchemy.orm import Query, Session from sqlalchemy_utils import escape_like @@ -121,17 +121,36 @@ def get(db: Session, import_id: str) -> Optional[CollectionReadDTO]: return _collection_to_dto(db, collection_org) -def search(db: Session, search_term: str) -> list[CollectionReadDTO]: +def search( + db: Session, query_params: dict[str, Union[str, int]] +) -> list[CollectionReadDTO]: """ - Gets a list of collections from the repository searching title and summary. + Gets a list of collections from the repo searching given fields. :param db Session: the database connection - :param str search_term: Any search term to filter on title or summary - :return Optional[list[CollectionResponse]]: A list of matches + :param dict query_params: Any search terms to filter on specified + fields (title & summary by default if 'q' specified). + :raises HTTPException: If a DB error occurs a 503 is returned. + :raises HTTPException: If the search request times out a 408 is + returned. + :return list[CollectionResponse]: A list of matching collections. """ - term = f"%{escape_like(search_term)}%" - search = or_(Collection.title.ilike(term), Collection.description.ilike(term)) - found = _get_query(db).filter(search).all() + search = [] + if "q" in query_params.keys(): + term = f"%{escape_like(query_params['q'])}%" + search.append( + or_(Collection.title.ilike(term), Collection.description.ilike(term)) + ) + + condition = and_(*search) if len(search) > 1 else search[0] + try: + found = ( + _get_query(db).filter(condition).limit(query_params["max_results"]).all() + ) + except OperationalError as e: + if "canceling statement due to statement timeout" in str(e): + raise TimeoutError + raise RepositoryError(e) return [_collection_to_dto(db, f) for f in found] diff --git a/app/repository/document.py b/app/repository/document.py index ba95d7b3..928b6e1c 100644 --- a/app/repository/document.py +++ b/app/repository/document.py @@ -1,10 +1,10 @@ import logging -from typing import Optional, Tuple, cast +from typing import Optional, Tuple, Union, cast -from sqlalchemy import Column, and_, func, or_ +from sqlalchemy import Column, and_, func from sqlalchemy import insert as db_insert from sqlalchemy import update as db_update -from sqlalchemy.exc import NoResultFound +from sqlalchemy.exc import NoResultFound, OperationalError from sqlalchemy.orm import Query, Session, aliased from sqlalchemy.sql.functions import concat from sqlalchemy_utils import escape_like @@ -161,17 +161,34 @@ def get(db: Session, import_id: str) -> Optional[DocumentReadDTO]: return DocumentReadDTO(**dict(result)) -def search(db: Session, search_term: str) -> list[DocumentReadDTO]: +def search( + db: Session, query_params: dict[str, Union[str, int]] +) -> list[DocumentReadDTO]: """ - Gets a list of documents from the repository searching title and summary. + Gets a list of documents from the repository searching the title. :param db Session: the database connection - :param str search_term: Any search term to filter on title or summary - :return Optional[list[DocumentResponse]]: A list of matches + :param dict query_params: Any search terms to filter on specified + fields (title by default if 'q' specified). + :raises HTTPException: If a DB error occurs a 503 is returned. + :raises HTTPException: If the search request times out a 408 is + returned. + :return list[DocumentResponse]: A list of matching documents. """ - term = f"%{escape_like(search_term)}%" - search = or_(PhysicalDocument.title.ilike(term)) - result = _get_query(db).filter(search).all() + search = [] + if "q" in query_params.keys(): + term = f"%{escape_like(query_params['q'])}%" + search.append(PhysicalDocument.title.ilike(term)) + + condition = and_(*search) if len(search) > 1 else search[0] + try: + result = ( + _get_query(db).filter(condition).limit(query_params["max_results"]).all() + ) + except OperationalError as e: + if "canceling statement due to statement timeout" in str(e): + raise TimeoutError + raise RepositoryError(e) return [DocumentReadDTO(**dict(r)) for r in result] diff --git a/app/repository/event.py b/app/repository/event.py index 47ec0082..7e1d67b9 100644 --- a/app/repository/event.py +++ b/app/repository/event.py @@ -2,12 +2,12 @@ import logging from datetime import datetime -from typing import Optional, Tuple, cast +from typing import Optional, Tuple, Union, cast -from sqlalchemy import Column, or_ +from sqlalchemy import Column, and_, or_ from sqlalchemy import delete as db_delete from sqlalchemy import update as db_update -from sqlalchemy.exc import NoResultFound +from sqlalchemy.exc import NoResultFound, OperationalError from sqlalchemy.orm import Query, Session from sqlalchemy_utils import escape_like @@ -114,27 +114,34 @@ def get(db: Session, import_id: str) -> Optional[EventReadDTO]: return _event_to_dto(family_event_meta) -def search(db: Session, search_term: str) -> Optional[list[EventReadDTO]]: +def search(db: Session, query_params: dict[str, Union[str, int]]) -> list[EventReadDTO]: """ Get family events matching a search term on the event title or type. :param db Session: The database connection. - :param str search_term: Any search term to filter on the event title - or event type name. - :return Optional[list[EventReadDTO]]: A list of matching family - events or none. + :param dict query_params: Any search terms to filter on specified + fields (title & event type name by default if 'q' specified). + :raises HTTPException: If a DB error occurs a 503 is returned. + :raises HTTPException: If the search request times out a 408 is + returned. + :return list[EventReadDTO]: A list of matching family events. """ - term = f"%{escape_like(search_term)}%" - search = or_(FamilyEvent.title.ilike(term), FamilyEvent.event_type_name.ilike(term)) + search = [] + if "q" in query_params.keys(): + term = f"%{escape_like(query_params['q'])}%" + search.append( + or_(FamilyEvent.title.ilike(term), FamilyEvent.event_type_name.ilike(term)) + ) + condition = and_(*search) if len(search) > 1 else search[0] try: - found = _get_query(db).filter(search).all() - except NoResultFound as e: - _LOGGER.error(e) - return - - if not found: - return [] + found = ( + _get_query(db).filter(condition).limit(query_params["max_results"]).all() + ) + except OperationalError as e: + if "canceling statement due to statement timeout" in str(e): + raise TimeoutError + raise RepositoryError(e) return [_event_to_dto(f) for f in found] diff --git a/app/repository/family.py b/app/repository/family.py index 8bc92139..27e8b1e5 100644 --- a/app/repository/family.py +++ b/app/repository/family.py @@ -167,22 +167,26 @@ def search( Gets a list of families from the repository searching given fields. :param db Session: the database connection - :param dict[str, str] query_params: Any search terms to filter on - specified fields (title & summary by default if 'q' specified). - :return list[FamilyResponse]: A list of matches + :param dict query_params: Any search terms to filter on specified + fields (title & summary by default if 'q' specified). + :raises HTTPException: If a DB error occurs a 503 is returned. + :raises HTTPException: If the search request times out a 408 is + returned. + :return list[FamilyReadDTO]: A list of families matching the search + terms. """ search = [] if "q" in query_params.keys(): term = f"%{escape_like(query_params['q'])}%" search.append(or_(Family.title.ilike(term), Family.description.ilike(term))) - - if "title" in query_params.keys(): - term = f"%{escape_like(query_params['title'])}%" - search.append(Family.title.ilike(term)) - - if "description" in query_params.keys(): - term = f"%{escape_like(query_params['description'])}%" - search.append(Family.description.ilike(term)) + else: + if "title" in query_params.keys(): + term = f"%{escape_like(query_params['title'])}%" + search.append(Family.title.ilike(term)) + + if "summary" in query_params.keys(): + term = f"%{escape_like(query_params['summary'])}%" + search.append(Family.description.ilike(term)) if "geography" in query_params.keys(): term = cast(str, query_params["geography"]) diff --git a/app/service/collection.py b/app/service/collection.py index dc3f46e9..590d6dd3 100644 --- a/app/service/collection.py +++ b/app/service/collection.py @@ -5,9 +5,13 @@ services for validation etc. """ import logging -from typing import Optional +from typing import Optional, Union from pydantic import ConfigDict, validate_call +from sqlalchemy import exc +from sqlalchemy.orm import Session + +import app.clients.db.session as db_session from app.errors import RepositoryError from app.model.collection import ( CollectionCreateDTO, @@ -15,13 +19,7 @@ CollectionWriteDTO, ) from app.repository import collection_repo -import app.clients.db.session as db_session -from sqlalchemy import exc -from sqlalchemy.orm import Session - -from app.service import id -from app.service import app_user - +from app.service import app_user, id _LOGGER = logging.getLogger(__name__) @@ -61,19 +59,21 @@ def all() -> list[CollectionReadDTO]: @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) -def search(search_term: str) -> list[CollectionReadDTO]: +def search(query_params: dict[str, Union[str, int]]) -> list[CollectionReadDTO]: """ - Searches the title and descriptions of all the collections for the search term. + Searches for the search term against collections on specified fields. + + Where 'q' is used instead of an explicit field name, the titles and + descriptions of all the collections are searched for the given term + only. - :param str search_term: Search pattern to match. - :return list[CollectionDTO]: The list of collections matching the search term. + :param dict query_params: Search patterns to match against specified + fields, given as key value pairs in a dictionary. + :return list[CollectionReadDTO]: The list of collections matching + the given search terms. """ - try: - with db_session.get_db() as db: - return collection_repo.search(db, search_term) - except exc.SQLAlchemyError: - _LOGGER.exception(f"When searching for collections with '{search_term}'") - raise RepositoryError(f"When searching for collections with '{search_term}'") + with db_session.get_db() as db: + return collection_repo.search(db, query_params) @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) diff --git a/app/service/document.py b/app/service/document.py index dfc5522e..c43c2ebf 100644 --- a/app/service/document.py +++ b/app/service/document.py @@ -1,18 +1,18 @@ import logging -from typing import Optional, Tuple +from typing import Optional, Tuple, Union from pydantic import ConfigDict, validate_call -from app.clients.aws.client import get_s3_client -from app.errors import RepositoryError, ValidationError -from app.model.document import DocumentCreateDTO, DocumentReadDTO, DocumentWriteDTO +from sqlalchemy import exc +from sqlalchemy.orm import Session + +import app.clients.db.session as db_session import app.repository.document as document_repo import app.repository.document_file as file_repo -import app.clients.db.session as db_session import app.service.family as family_service +from app.clients.aws.client import get_s3_client +from app.errors import RepositoryError, ValidationError +from app.model.document import DocumentCreateDTO, DocumentReadDTO, DocumentWriteDTO from app.service import id -from sqlalchemy import exc -from sqlalchemy.orm import Session - _LOGGER = logging.getLogger(__name__) @@ -55,15 +55,20 @@ def all() -> list[DocumentReadDTO]: @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) -def search(search_term: str) -> list[DocumentReadDTO]: +def search(query_params: dict[str, Union[str, int]]) -> list[DocumentReadDTO]: """ - Searches the title and descriptions of all the documents for the search term. + Searches for the search term against documents on specified fields. + + Where 'q' is used instead of an explicit field name, only the titles + of all the documents are searched for the given term. - :param str search_term: Search pattern to match. - :return list[documentDTO]: The list of documents matching the search term. + :param dict query_params: Search patterns to match against specified + fields, given as key value pairs in a dictionary. + :return list[DocumentReadDTO]: The list of documents matching the + given search terms. """ with db_session.get_db() as db: - return document_repo.search(db, search_term) + return document_repo.search(db, query_params) @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) diff --git a/app/service/event.py b/app/service/event.py index 1dd8b0fb..30bf9d7b 100644 --- a/app/service/event.py +++ b/app/service/event.py @@ -1,5 +1,5 @@ import logging -from typing import Optional +from typing import Optional, Union from pydantic import ConfigDict, validate_call from sqlalchemy import exc @@ -12,7 +12,6 @@ from app.model.event import EventCreateDTO, EventReadDTO, EventWriteDTO from app.service import id - _LOGGER = logging.getLogger(__name__) @@ -47,19 +46,21 @@ def all() -> list[EventReadDTO]: @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) -def search(search_term: str) -> Optional[list[EventReadDTO]]: +def search(query_params: dict[str, Union[str, int]]) -> list[EventReadDTO]: """ - Search for all family events that match a search term. + Searches for the search term against events on specified fields. - Specifically searches the event title and event type name for the - search term. + Where 'q' is used instead of an explicit field name, the titles and + event type names of all the events are searched for the given term + only. - :param str search_term: Search pattern to match. - :return Optional[list[EventReadDTO]] The list of events that match - the search term or none. + :param dict query_params: Search patterns to match against specified + fields, given as key value pairs in a dictionary. + :return list[EventReadDTO]: The list of events matching the given + search terms. """ with db_session.get_db() as db: - return event_repo.search(db, search_term) + return event_repo.search(db, query_params) @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) diff --git a/app/service/family.py b/app/service/family.py index b724a06f..e8c70a55 100644 --- a/app/service/family.py +++ b/app/service/family.py @@ -66,10 +66,10 @@ def search(query_params: dict[str, Union[str, int]]) -> list[FamilyReadDTO]: descriptions of all the Families are searched for the given term only. - :param str query_params: Search patterns to match against specified - fields. - :return list[FamilyDTO]: The list of families matching the search - term. + :param dict query_params: Search patterns to match against specified + fields, given as key value pairs in a dictionary. + :return list[FamilyDTO]: The list of families matching the given + search terms. """ with db_session.get_db() as db: return family_repo.search(db, query_params) diff --git a/integration_tests/collection/test_search.py b/integration_tests/collection/test_search.py index c1130655..f14003d8 100644 --- a/integration_tests/collection/test_search.py +++ b/integration_tests/collection/test_search.py @@ -1,13 +1,16 @@ -from fastapi.testclient import TestClient +import logging + from fastapi import status +from fastapi.testclient import TestClient from sqlalchemy.orm import Session + from integration_tests.setup_db import setup_db def test_search_collection(client: TestClient, test_db: Session, user_header_token): setup_db(test_db) response = client.get( - "/api/v1/collections/?q=big", + "/api/v1/collections/?q=description", headers=user_header_token, ) assert response.status_code == status.HTTP_200_OK @@ -15,9 +18,9 @@ def test_search_collection(client: TestClient, test_db: Session, user_header_tok assert type(data) is list ids_found = set([f["import_id"] for f in data]) - assert len(ids_found) == 1 + assert len(ids_found) == 3 - expected_ids = set(["C.0.0.1"]) + expected_ids = set(["C.0.0.1", "C.0.0.2", "C.0.0.3"]) assert ids_found.symmetric_difference(expected_ids) == set([]) @@ -30,16 +33,19 @@ def test_search_collection_when_not_authorised(client: TestClient, test_db: Sess def test_search_collection_when_nothing_found( - client: TestClient, test_db: Session, user_header_token + client: TestClient, test_db: Session, user_header_token, caplog ): setup_db(test_db) - response = client.get( - "/api/v1/collections/?q=chicken", - headers=user_header_token, + with caplog.at_level(logging.INFO): + response = client.get( + "/api/v1/collections/?q=chicken", + headers=user_header_token, + ) + assert response.status_code == status.HTTP_200_OK + assert ( + "Collections not found for terms: {'q': 'chicken', 'max_results': 500}" + in caplog.text ) - assert response.status_code == status.HTTP_404_NOT_FOUND - data = response.json() - assert data["detail"] == "Collections not found for term: chicken" def test_search_collection_when_db_error( @@ -51,4 +57,38 @@ def test_search_collection_when_db_error( headers=user_header_token, ) assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE + data = response.json() + assert data["detail"] == "Bad Repo" assert bad_collection_repo.search.call_count == 1 + + +def test_search_collections_with_max_results( + client: TestClient, test_db: Session, user_header_token +): + setup_db(test_db) + response = client.get( + "/api/v1/collections/?q=description&max_results=1", + headers=user_header_token, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert type(data) is list + + ids_found = set([f["import_id"] for f in data]) + assert len(ids_found) == 1 + + expected_ids = set(["C.0.0.1"]) + assert ids_found.symmetric_difference(expected_ids) == set([]) + + +def test_search_collections_when_invalid_params( + client: TestClient, test_db: Session, user_header_token +): + setup_db(test_db) + response = client.get( + "/api/v1/collections/?wrong=param", + headers=user_header_token, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = response.json() + assert data["detail"] == "Search parameters are invalid: ['wrong']" diff --git a/integration_tests/document/test_search.py b/integration_tests/document/test_search.py index 01cc2cd4..a482a3f1 100644 --- a/integration_tests/document/test_search.py +++ b/integration_tests/document/test_search.py @@ -1,13 +1,16 @@ -from fastapi.testclient import TestClient +import logging + from fastapi import status +from fastapi.testclient import TestClient from sqlalchemy.orm import Session + from integration_tests.setup_db import setup_db def test_search_document(client: TestClient, test_db: Session, user_header_token): setup_db(test_db) response = client.get( - "/api/v1/documents/?q=big", + "/api/v1/documents/?q=title", headers=user_header_token, ) assert response.status_code == status.HTTP_200_OK @@ -15,9 +18,9 @@ def test_search_document(client: TestClient, test_db: Session, user_header_token assert type(data) is list ids_found = set([f["import_id"] for f in data]) - assert len(ids_found) == 1 + assert len(ids_found) == 2 - expected_ids = set(["D.0.0.1"]) + expected_ids = set(["D.0.0.1", "D.0.0.2"]) assert ids_found.symmetric_difference(expected_ids) == set([]) @@ -30,16 +33,19 @@ def test_search_document_when_not_authorised(client: TestClient, test_db: Sessio def test_search_document_when_nothing_found( - client: TestClient, test_db: Session, user_header_token + client: TestClient, test_db: Session, user_header_token, caplog ): setup_db(test_db) - response = client.get( - "/api/v1/documents/?q=chicken", - headers=user_header_token, + with caplog.at_level(logging.INFO): + response = client.get( + "/api/v1/documents/?q=chicken", + headers=user_header_token, + ) + assert response.status_code == status.HTTP_200_OK + assert ( + "Documents not found for terms: {'q': 'chicken', 'max_results': 500}" + in caplog.text ) - assert response.status_code == status.HTTP_404_NOT_FOUND - data = response.json() - assert data["detail"] == "Documents not found for term: chicken" def test_search_document_when_db_error( @@ -51,4 +57,38 @@ def test_search_document_when_db_error( headers=user_header_token, ) assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE + data = response.json() + assert data["detail"] == "Bad Repo" assert bad_document_repo.search.call_count == 1 + + +def test_search_document_with_max_results( + client: TestClient, test_db: Session, user_header_token +): + setup_db(test_db) + response = client.get( + "/api/v1/documents/?q=title&max_results=1", + headers=user_header_token, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert type(data) is list + + ids_found = set([f["import_id"] for f in data]) + assert len(ids_found) == 1 + + expected_ids = set(["D.0.0.1"]) + assert ids_found.symmetric_difference(expected_ids) == set([]) + + +def test_search_document_when_invalid_params( + client: TestClient, test_db: Session, user_header_token +): + setup_db(test_db) + response = client.get( + "/api/v1/documents/?wrong=param", + headers=user_header_token, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = response.json() + assert data["detail"] == "Search parameters are invalid: ['wrong']" diff --git a/integration_tests/event/test_search.py b/integration_tests/event/test_search.py index 6e9e639d..36d32a88 100644 --- a/integration_tests/event/test_search.py +++ b/integration_tests/event/test_search.py @@ -1,6 +1,9 @@ -from fastapi.testclient import TestClient +import logging + from fastapi import status +from fastapi.testclient import TestClient from sqlalchemy.orm import Session + from integration_tests.setup_db import setup_db @@ -30,19 +33,21 @@ def test_search_event_when_not_authorised(client: TestClient, test_db: Session): def test_search_event_when_nothing_found( - client: TestClient, test_db: Session, user_header_token + client: TestClient, test_db: Session, user_header_token, caplog ): setup_db(test_db) - response = client.get( - "/api/v1/events/?q=lemon", - headers=user_header_token, + with caplog.at_level(logging.INFO): + response = client.get( + "/api/v1/events/?q=lemon", + headers=user_header_token, + ) + assert response.status_code == status.HTTP_200_OK + assert ( + "Events not found for terms: {'q': 'lemon', 'max_results': 500}" in caplog.text ) - assert response.status_code == status.HTTP_404_NOT_FOUND - data = response.json() - assert data["detail"] == "Events not found for term: lemon" -def test_search_event_when_db_error( +def test_search_document_when_db_error( client: TestClient, test_db: Session, bad_event_repo, user_header_token ): setup_db(test_db) @@ -51,4 +56,38 @@ def test_search_event_when_db_error( headers=user_header_token, ) assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE + data = response.json() + assert data["detail"] == "Bad Repo" assert bad_event_repo.search.call_count == 1 + + +def test_search_document_with_max_results( + client: TestClient, test_db: Session, user_header_token +): + setup_db(test_db) + response = client.get( + "/api/v1/events/?q=Amended&max_results=1", + headers=user_header_token, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert type(data) is list + + ids_found = set([f["import_id"] for f in data]) + assert len(ids_found) == 1 + + expected_ids = set(["E.0.0.2"]) + assert ids_found.symmetric_difference(expected_ids) == set([]) + + +def test_search_document_when_invalid_params( + client: TestClient, test_db: Session, user_header_token +): + setup_db(test_db) + response = client.get( + "/api/v1/events/?wrong=param", + headers=user_header_token, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = response.json() + assert data["detail"] == "Search parameters are invalid: ['wrong']" diff --git a/integration_tests/family/test_search.py b/integration_tests/family/test_search.py index ff3371f9..f5eca59c 100644 --- a/integration_tests/family/test_search.py +++ b/integration_tests/family/test_search.py @@ -7,7 +7,7 @@ from integration_tests.setup_db import setup_db -def test_search_family(client: TestClient, test_db: Session, user_header_token): +def test_search_family_using_q(client: TestClient, test_db: Session, user_header_token): setup_db(test_db) response = client.get( "/api/v1/families/?q=orange", @@ -24,6 +24,25 @@ def test_search_family(client: TestClient, test_db: Session, user_header_token): assert ids_found.symmetric_difference(expected_ids) == set([]) +def test_search_family_with_specific_param( + client: TestClient, test_db: Session, user_header_token +): + setup_db(test_db) + response = client.get( + "/api/v1/families/?summary=apple", + headers=user_header_token, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert type(data) is list + + ids_found = set([f["import_id"] for f in data]) + assert len(ids_found) == 1 + + expected_ids = set(["A.0.0.2"]) + assert ids_found.symmetric_difference(expected_ids) == set([]) + + def test_search_family_with_max_results( client: TestClient, test_db: Session, user_header_token ): diff --git a/integration_tests/family/test_update.py b/integration_tests/family/test_update.py index 1e3fd05f..3ebc907c 100644 --- a/integration_tests/family/test_update.py +++ b/integration_tests/family/test_update.py @@ -218,7 +218,7 @@ def test_update_family_when_user_org_different_to_family_org( test_db.query(Family).filter(Family.import_id == "A.0.0.2").one() ) assert db_family.title == "apple orange banana" - assert db_family.description == "" + assert db_family.description == "apple" assert db_family.geography_id == 1 assert db_family.family_category == "UNFCCC" diff --git a/integration_tests/setup_db.py b/integration_tests/setup_db.py index fc4490c8..81e0c783 100644 --- a/integration_tests/setup_db.py +++ b/integration_tests/setup_db.py @@ -49,7 +49,7 @@ { "import_id": "A.0.0.2", "title": "apple orange banana", - "summary": "", + "summary": "apple", "geography": "South Asia", "category": "UNFCCC", "status": "Created", diff --git a/unit_tests/mocks/repos/collection_repo.py b/unit_tests/mocks/repos/collection_repo.py index a90a5f37..d4e81254 100644 --- a/unit_tests/mocks/repos/collection_repo.py +++ b/unit_tests/mocks/repos/collection_repo.py @@ -1,7 +1,9 @@ from typing import Optional -from pytest import MonkeyPatch +from pytest import MonkeyPatch from sqlalchemy import exc + +from app.errors import RepositoryError from app.model.collection import CollectionReadDTO from unit_tests.helpers.collection import create_collection_read_dto @@ -14,11 +16,16 @@ def mock_collection_repo(collection_repo, monkeypatch: MonkeyPatch, mocker): collection_repo.invalid_org = False collection_repo.missing = False collection_repo.throw_repository_error = False + collection_repo.throw_timeout_error = False collection_repo.alternative_org = False def maybe_throw(): if collection_repo.throw_repository_error: - raise exc.SQLAlchemyError("bad repo") + raise RepositoryError("bad repo") + + def maybe_timeout(): + if collection_repo.throw_timeout_error: + raise TimeoutError def mock_get_all(_) -> list[CollectionReadDTO]: return [ @@ -32,6 +39,7 @@ def mock_get(_, import_id: str) -> Optional[CollectionReadDTO]: def mock_search(_, q: str) -> list[CollectionReadDTO]: maybe_throw() + maybe_timeout() if not collection_repo.return_empty: return [create_collection_read_dto("search1")] return [] diff --git a/unit_tests/mocks/repos/document_repo.py b/unit_tests/mocks/repos/document_repo.py index 18700254..6a2209c3 100644 --- a/unit_tests/mocks/repos/document_repo.py +++ b/unit_tests/mocks/repos/document_repo.py @@ -1,7 +1,9 @@ from typing import Optional -from pytest import MonkeyPatch +from pytest import MonkeyPatch from sqlalchemy import exc + +from app.errors import RepositoryError from app.model.document import DocumentCreateDTO, DocumentReadDTO from unit_tests.helpers.document import create_document_read_dto @@ -9,10 +11,15 @@ def mock_document_repo(document_repo, monkeypatch: MonkeyPatch, mocker): document_repo.return_empty = False document_repo.throw_repository_error = False + document_repo.throw_timeout_error = False def maybe_throw(): if document_repo.throw_repository_error: - raise exc.SQLAlchemyError("bad repo") + raise RepositoryError("bad repo") + + def maybe_timeout(): + if document_repo.throw_timeout_error: + raise TimeoutError def mock_get_all(_) -> list[DocumentReadDTO]: values = [] @@ -27,6 +34,7 @@ def mock_get(_, import_id: str) -> Optional[DocumentReadDTO]: def mock_search(_, q: str) -> list[DocumentReadDTO]: maybe_throw() + maybe_timeout() if not document_repo.return_empty: return [create_document_read_dto("search1")] return [] diff --git a/unit_tests/mocks/repos/event_repo.py b/unit_tests/mocks/repos/event_repo.py index ffc2cabc..624da206 100644 --- a/unit_tests/mocks/repos/event_repo.py +++ b/unit_tests/mocks/repos/event_repo.py @@ -1,7 +1,9 @@ from typing import Optional -from pytest import MonkeyPatch +from pytest import MonkeyPatch from sqlalchemy import exc + +from app.errors import RepositoryError from app.model.event import EventCreateDTO, EventReadDTO, EventWriteDTO from unit_tests.helpers.event import create_event_read_dto @@ -9,10 +11,15 @@ def mock_event_repo(event_repo, monkeypatch: MonkeyPatch, mocker): event_repo.return_empty = False event_repo.throw_repository_error = False + event_repo.throw_timeout_error = False def maybe_throw(): if event_repo.throw_repository_error: - raise exc.SQLAlchemyError("bad repo") + raise RepositoryError("bad repo") + + def maybe_timeout(): + if event_repo.throw_timeout_error: + raise TimeoutError def mock_get_all(_) -> list[EventReadDTO]: values = [] @@ -25,8 +32,9 @@ def mock_get(_, import_id: str) -> Optional[EventReadDTO]: dto = create_event_read_dto(import_id) return dto - def mock_search(_, q: str) -> list[EventReadDTO]: + def mock_search(_, q: dict) -> list[EventReadDTO]: maybe_throw() + maybe_timeout() if not event_repo.return_empty: return [create_event_read_dto("search1")] return [] diff --git a/unit_tests/mocks/repos/family_repo.py b/unit_tests/mocks/repos/family_repo.py index 943b3746..ee140048 100644 --- a/unit_tests/mocks/repos/family_repo.py +++ b/unit_tests/mocks/repos/family_repo.py @@ -35,7 +35,9 @@ def search( _maybe_timeout() if getattr(family_repo, "return_empty"): return [] - return [create_family_dto("search1")] + if "title" in query_params.keys(): + return [create_family_dto("search1")] + return [create_family_dto("search1"), create_family_dto("search2")] def update(db: Session, import_id: str, family: FamilyWriteDTO, geo_id: int) -> bool: diff --git a/unit_tests/mocks/services/collection_service.py b/unit_tests/mocks/services/collection_service.py index f37f7f7f..34ccf880 100644 --- a/unit_tests/mocks/services/collection_service.py +++ b/unit_tests/mocks/services/collection_service.py @@ -1,7 +1,8 @@ from typing import Optional + from pytest import MonkeyPatch -from app.errors import RepositoryError +from app.errors import RepositoryError from app.model.collection import CollectionReadDTO, CollectionWriteDTO from unit_tests.helpers.collection import create_collection_read_dto @@ -12,12 +13,17 @@ def mock_collection_service(collection_service, monkeypatch: MonkeyPatch, mocker): collection_service.missing = False collection_service.throw_repository_error = False + collection_service.throw_timeout_error = False collection_service.invalid_org = False def maybe_throw(): if collection_service.throw_repository_error: raise RepositoryError("bad repo") + def maybe_timeout(): + if collection_service.throw_timeout_error: + raise TimeoutError + def mock_get_all_collections(): maybe_throw() return [create_collection_read_dto("test")] @@ -27,12 +33,12 @@ def mock_get_collection(import_id: str) -> Optional[CollectionReadDTO]: if not collection_service.missing: return create_collection_read_dto(import_id) - def mock_search_collections(q: str) -> list[CollectionReadDTO]: + def mock_search_collections(q_params: dict) -> list[CollectionReadDTO]: maybe_throw() + maybe_timeout() if collection_service.missing: return [] - else: - return [create_collection_read_dto("search1")] + return [create_collection_read_dto("search1")] def mock_update_collection( import_id: str, data: CollectionWriteDTO diff --git a/unit_tests/mocks/services/document_service.py b/unit_tests/mocks/services/document_service.py index 7e7e72ed..c69d4b93 100644 --- a/unit_tests/mocks/services/document_service.py +++ b/unit_tests/mocks/services/document_service.py @@ -1,7 +1,8 @@ from typing import Optional + from pytest import MonkeyPatch -from app.errors import RepositoryError, ValidationError +from app.errors import RepositoryError, ValidationError from app.model.document import DocumentCreateDTO, DocumentReadDTO, DocumentWriteDTO from unit_tests.helpers.document import create_document_read_dto @@ -10,11 +11,16 @@ def mock_document_service(document_service, monkeypatch: MonkeyPatch, mocker): document_service.missing = False document_service.throw_repository_error = False document_service.throw_validation_error = False + document_service.throw_timeout_error = False def maybe_throw(): if document_service.throw_repository_error: raise RepositoryError("bad repo") + def maybe_timeout(): + if document_service.throw_timeout_error: + raise TimeoutError + def mock_get_all_documents() -> list[DocumentReadDTO]: maybe_throw() return [create_document_read_dto("test")] @@ -24,12 +30,13 @@ def mock_get_document(import_id: str) -> Optional[DocumentReadDTO]: if not document_service.missing: return create_document_read_dto(import_id) - def mock_search_documents(q: str) -> list[DocumentReadDTO]: - maybe_throw() + def mock_search_documents(q_params: dict) -> list[DocumentReadDTO]: if document_service.missing: return [] - else: - return [create_document_read_dto("search1")] + + maybe_throw() + maybe_timeout() + return [create_document_read_dto("search1")] def mock_update_document( import_id: str, data: DocumentWriteDTO diff --git a/unit_tests/mocks/services/event_service.py b/unit_tests/mocks/services/event_service.py index cd72fac9..db634b2e 100644 --- a/unit_tests/mocks/services/event_service.py +++ b/unit_tests/mocks/services/event_service.py @@ -1,7 +1,8 @@ from typing import Optional + from pytest import MonkeyPatch -from app.errors import RepositoryError, ValidationError +from app.errors import RepositoryError, ValidationError from app.model.event import EventCreateDTO, EventReadDTO, EventWriteDTO from unit_tests.helpers.event import create_event_read_dto @@ -9,11 +10,16 @@ def mock_event_service(event_service, monkeypatch: MonkeyPatch, mocker): event_service.missing = False event_service.throw_repository_error = False + event_service.throw_timeout_error = False def maybe_throw(): if event_service.throw_repository_error: raise RepositoryError("bad repo") + def maybe_timeout(): + if event_service.throw_timeout_error: + raise TimeoutError + def mock_get_all_events() -> list[EventReadDTO]: maybe_throw() return [create_event_read_dto("test")] @@ -23,8 +29,9 @@ def mock_get_event(import_id: str) -> Optional[EventReadDTO]: if not event_service.missing: return create_event_read_dto(import_id) - def mock_search_events(q: str) -> list[EventReadDTO]: + def mock_search_events(q: dict) -> list[EventReadDTO]: maybe_throw() + maybe_timeout() if event_service.missing: return [] else: diff --git a/unit_tests/routers/test_collection.py b/unit_tests/routers/test_collection.py index 0d3c7743..c6e8e044 100644 --- a/unit_tests/routers/test_collection.py +++ b/unit_tests/routers/test_collection.py @@ -3,6 +3,8 @@ This uses a service mock and ensures each endpoint calls into the service. """ +import logging + import pytest from fastapi import status from fastapi.testclient import TestClient @@ -53,15 +55,59 @@ def test_search_when_ok(client: TestClient, collection_service_mock, user_header assert collection_service_mock.search.call_count == 1 -def test_search_when_not_found( +def test_search_when_invalid_params( client: TestClient, collection_service_mock, user_header_token ): - collection_service_mock.missing = True - response = client.get("/api/v1/collections/?q=stuff", headers=user_header_token) - assert response.status_code == status.HTTP_404_NOT_FOUND + response = client.get("/api/v1/collections/?wrong=yes", headers=user_header_token) + assert response.status_code == status.HTTP_400_BAD_REQUEST data = response.json() - assert data["detail"] == "Collections not found for term: stuff" + assert data["detail"] == "Search parameters are invalid: ['wrong']" + assert collection_service_mock.search.call_count == 0 + + +def test_search_when_db_error( + client: TestClient, collection_service_mock, user_header_token +): + collection_service_mock.throw_repository_error = True + response = client.get("/api/v1/collections/?q=error", headers=user_header_token) + assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE + data = response.json() + assert data["detail"] == "bad repo" + assert collection_service_mock.search.call_count == 1 + + +def test_search_when_request_timeout( + client: TestClient, + collection_service_mock, + user_header_token, + caplog, +): + collection_service_mock.throw_timeout_error = True + with caplog.at_level(logging.INFO): + response = client.get( + "/api/v1/collections/?q=timeout", headers=user_header_token + ) + assert response.status_code == status.HTTP_408_REQUEST_TIMEOUT assert collection_service_mock.search.call_count == 1 + assert ( + "Request timed out fetching matching collections. Try adjusting your query." + in caplog.text + ) + + +def test_search_when_not_found( + client: TestClient, collection_service_mock, user_header_token, caplog +): + collection_service_mock.missing = True + with caplog.at_level(logging.INFO): + response = client.get("/api/v1/collections/?q=stuff", headers=user_header_token) + assert response.status_code == status.HTTP_200_OK + response.json() + assert collection_service_mock.search.call_count == 1 + assert ( + "Collections not found for terms: {'q': 'stuff', 'max_results': 500}" + in caplog.text + ) def test_update_when_ok(client: TestClient, collection_service_mock, user_header_token): diff --git a/unit_tests/routers/test_document.py b/unit_tests/routers/test_document.py index 2625ff5b..6cda7049 100644 --- a/unit_tests/routers/test_document.py +++ b/unit_tests/routers/test_document.py @@ -1,3 +1,5 @@ +import logging + import pytest from fastapi import status from fastapi.testclient import TestClient @@ -48,17 +50,56 @@ def test_search_when_ok(client: TestClient, document_service_mock, user_header_t assert document_service_mock.search.call_count == 1 -def test_search_when_not_found( +def test_search_when_invalid_params( client: TestClient, document_service_mock, user_header_token ): - document_service_mock.missing = True - response = client.get("/api/v1/documents/?q=stuff", headers=user_header_token) - assert response.status_code == status.HTTP_404_NOT_FOUND + response = client.get("/api/v1/documents/?wrong=yes", headers=user_header_token) + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = response.json() + assert data["detail"] == "Search parameters are invalid: ['wrong']" + assert document_service_mock.search.call_count == 0 + + +def test_search_when_db_error( + client: TestClient, document_service_mock, user_header_token +): + document_service_mock.throw_repository_error = True + response = client.get("/api/v1/documents/?q=error", headers=user_header_token) + assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE data = response.json() - assert data["detail"] == "Documents not found for term: stuff" + assert data["detail"] == "bad repo" assert document_service_mock.search.call_count == 1 +def test_search_when_request_timeout( + client: TestClient, document_service_mock, user_header_token, caplog +): + document_service_mock.throw_timeout_error = True + with caplog.at_level(logging.INFO): + response = client.get("/api/v1/documents/?q=timeout", headers=user_header_token) + assert response.status_code == status.HTTP_408_REQUEST_TIMEOUT + assert document_service_mock.search.call_count == 1 + assert ( + "Request timed out fetching matching documents. Try adjusting your query." + in caplog.text + ) + + +def test_search_when_not_found( + client: TestClient, document_service_mock, user_header_token, caplog +): + document_service_mock.missing = True + with caplog.at_level(logging.INFO): + response = client.get("/api/v1/documents/?q=stuff", headers=user_header_token) + assert response.status_code == status.HTTP_200_OK + response.json() + assert document_service_mock.search.call_count == 1 + assert ( + "Documents not found for terms: {'q': 'stuff', 'max_results': 500}" + in caplog.text + ) + + def test_update_when_ok(client: TestClient, document_service_mock, user_header_token): new_data = create_document_write_dto("doc1").model_dump() response = client.put( diff --git a/unit_tests/routers/test_event.py b/unit_tests/routers/test_event.py index 7f3ecd14..6aff1cd2 100644 --- a/unit_tests/routers/test_event.py +++ b/unit_tests/routers/test_event.py @@ -1,3 +1,5 @@ +import logging + import pytest from fastapi import status from fastapi.encoders import jsonable_encoder @@ -47,17 +49,58 @@ def test_search_when_ok(client: TestClient, event_service_mock, user_header_toke assert event_service_mock.search.call_count == 1 -def test_search_when_not_found( +def test_search_when_invalid_params( client: TestClient, event_service_mock, user_header_token ): - event_service_mock.missing = True - response = client.get("/api/v1/events/?q=stuff", headers=user_header_token) - assert response.status_code == status.HTTP_404_NOT_FOUND + response = client.get("/api/v1/events/?wrong=yes", headers=user_header_token) + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = response.json() + assert data["detail"] == "Search parameters are invalid: ['wrong']" + assert event_service_mock.search.call_count == 0 + + +def test_search_when_db_error( + client: TestClient, event_service_mock, user_header_token +): + event_service_mock.throw_repository_error = True + response = client.get("/api/v1/events/?q=error", headers=user_header_token) + assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE data = response.json() - assert data["detail"] == "Events not found for term: stuff" + assert data["detail"] == "bad repo" assert event_service_mock.search.call_count == 1 +def test_search_when_request_timeout( + client: TestClient, + event_service_mock, + user_header_token, + caplog, +): + event_service_mock.throw_timeout_error = True + with caplog.at_level(logging.INFO): + response = client.get("/api/v1/events/?q=timeout", headers=user_header_token) + assert response.status_code == status.HTTP_408_REQUEST_TIMEOUT + assert event_service_mock.search.call_count == 1 + assert ( + "Request timed out fetching matching events. Try adjusting your query." + in caplog.text + ) + + +def test_search_when_not_found( + client: TestClient, event_service_mock, user_header_token, caplog +): + event_service_mock.missing = True + with caplog.at_level(logging.INFO): + response = client.get("/api/v1/events/?q=stuff", headers=user_header_token) + assert response.status_code == status.HTTP_200_OK + response.json() + assert event_service_mock.search.call_count == 1 + assert ( + "Events not found for terms: {'q': 'stuff', 'max_results': 500}" in caplog.text + ) + + def test_create_when_ok(client: TestClient, event_service_mock, user_header_token): new_data = create_event_create_dto("event1").model_dump() response = client.post( diff --git a/unit_tests/service/test_collection_service.py b/unit_tests/service/test_collection_service.py index 3484c49f..c9c8fb39 100644 --- a/unit_tests/service/test_collection_service.py +++ b/unit_tests/service/test_collection_service.py @@ -1,12 +1,12 @@ import pytest +import app.service.collection as collection_service from app.errors import RepositoryError, ValidationError from app.model.collection import ( CollectionCreateDTO, CollectionReadDTO, CollectionWriteDTO, ) -import app.service.collection as collection_service from unit_tests.helpers.collection import create_collection_write_dto from unit_tests.mocks.repos.collection_repo import ( create_collection_read_dto as create_dto, @@ -57,15 +57,29 @@ def test_get_raises_if_invalid_id(collection_repo_mock): def test_search(collection_repo_mock): - result = collection_service.search("two") + result = collection_service.search({"q": "two"}) assert result is not None assert len(result) == 1 assert collection_repo_mock.search.call_count == 1 -def test_search_when_missing(collection_repo_mock): +def test_search_db_error(collection_repo_mock): + collection_repo_mock.throw_repository_error = True + with pytest.raises(RepositoryError): + collection_service.search({"q": "error"}) + assert collection_repo_mock.search.call_count == 1 + + +def test_search_request_timeout(collection_repo_mock): + collection_repo_mock.throw_timeout_error = True + with pytest.raises(TimeoutError): + collection_service.search({"q": "timeout"}) + assert collection_repo_mock.search.call_count == 1 + + +def test_search_missing(collection_repo_mock): collection_repo_mock.return_empty = True - result = collection_service.search("empty") + result = collection_service.search({"q": "empty"}) assert result is not None assert len(result) == 0 assert collection_repo_mock.search.call_count == 1 diff --git a/unit_tests/service/test_document_service.py b/unit_tests/service/test_document_service.py index 30362566..54a46591 100644 --- a/unit_tests/service/test_document_service.py +++ b/unit_tests/service/test_document_service.py @@ -1,7 +1,8 @@ import pytest -from app.model.document import DocumentReadDTO, DocumentWriteDTO + import app.service.document as doc_service from app.errors import RepositoryError, ValidationError +from app.model.document import DocumentReadDTO, DocumentWriteDTO from unit_tests.helpers.document import create_document_create_dto @@ -62,15 +63,29 @@ def test_get_raises_when_invalid_id(document_repo_mock): def test_search(document_repo_mock): - result = doc_service.search("two") + result = doc_service.search({"q": "two"}) assert result is not None assert len(result) == 1 assert document_repo_mock.search.call_count == 1 -def test_search_when_missing(document_repo_mock): +def test_search_db_error(document_repo_mock): + document_repo_mock.throw_repository_error = True + with pytest.raises(RepositoryError): + doc_service.search({"q": "error"}) + assert document_repo_mock.search.call_count == 1 + + +def test_search_request_timeout(document_repo_mock): + document_repo_mock.throw_timeout_error = True + with pytest.raises(TimeoutError): + doc_service.search({"q": "timeout"}) + assert document_repo_mock.search.call_count == 1 + + +def test_search_missing(document_repo_mock): document_repo_mock.return_empty = True - result = doc_service.search("empty") + result = doc_service.search({"q": "empty"}) assert result is not None assert len(result) == 0 assert document_repo_mock.search.call_count == 1 diff --git a/unit_tests/service/test_event_service.py b/unit_tests/service/test_event_service.py index 7f8a9e32..24ce3402 100644 --- a/unit_tests/service/test_event_service.py +++ b/unit_tests/service/test_event_service.py @@ -1,8 +1,9 @@ import pytest + import app.service.event as event_service from app.errors import RepositoryError, ValidationError -from unit_tests.helpers.event import create_event_create_dto from app.model.event import EventReadDTO, EventWriteDTO +from unit_tests.helpers.event import create_event_create_dto def _to_write_dto(dto: EventReadDTO) -> EventWriteDTO: @@ -42,15 +43,29 @@ def test_get_raises_when_invalid_id(event_repo_mock): def test_search(event_repo_mock): - result = event_service.search("two") + result = event_service.search({"q": "two"}) assert result is not None assert len(result) == 1 assert event_repo_mock.search.call_count == 1 -def test_search_when_missing(event_repo_mock): +def test_search_db_error(event_repo_mock): + event_repo_mock.throw_repository_error = True + with pytest.raises(RepositoryError): + event_service.search({"q": "error"}) + assert event_repo_mock.search.call_count == 1 + + +def test_search_request_timeout(event_repo_mock): + event_repo_mock.throw_timeout_error = True + with pytest.raises(TimeoutError): + event_service.search({"q": "timeout"}) + assert event_repo_mock.search.call_count == 1 + + +def test_search_missing(event_repo_mock): event_repo_mock.return_empty = True - result = event_service.search("empty") + result = event_service.search({"q": "empty"}) assert result is not None assert len(result) == 0 assert event_repo_mock.search.call_count == 1 diff --git a/unit_tests/service/test_family_service.py b/unit_tests/service/test_family_service.py index b18a6733..f2760654 100644 --- a/unit_tests/service/test_family_service.py +++ b/unit_tests/service/test_family_service.py @@ -72,6 +72,13 @@ def test_get_raises_when_invalid_id(family_repo_mock): def test_search(family_repo_mock): result = family_service.search({"q": "two"}) assert result is not None + assert len(result) == 2 + assert family_repo_mock.search.call_count == 1 + + +def test_search_on_specific_field(family_repo_mock): + result = family_service.search({"title": "one"}) + assert result is not None assert len(result) == 1 assert family_repo_mock.search.call_count == 1