diff --git a/.env.example b/.env.example index 2486258c..ad775221 100644 --- a/.env.example +++ b/.env.example @@ -42,6 +42,7 @@ DOCUMENT_BUCKET=cpr-document-queue # CDN URL CDN_URL=https://cdn.climatepolicyradar.org CDN_DOMAIN=cdn.climatepolicyradar.org +DOCUMENT_CACHE_BUCKET=some_cdn_bucket # Pipeline settings PIPELINE_BUCKET="test_pipeline_bucket" diff --git a/app/api/api_v1/routers/search.py b/app/api/api_v1/routers/search.py index c518772a..34f6e8b1 100644 --- a/app/api/api_v1/routers/search.py +++ b/app/api/api_v1/routers/search.py @@ -8,40 +8,34 @@ import logging from io import BytesIO -from typing import Annotated, Sequence, cast +from typing import Annotated, Optional, Sequence, cast from cpr_sdk.exceptions import QueryError from cpr_sdk.search_adaptors import VespaSearchAdapter from fastapi import APIRouter, Body, Depends, Header, HTTPException, Request, status from fastapi.responses import StreamingResponse -from jwt import PyJWTError -from pydantic_core import Url from sqlalchemy.orm import Session from starlette.responses import RedirectResponse from app.api.api_v1.schemas.search import SearchRequestBody, SearchResponse -from app.core.aws import S3Document, get_s3_client +from app.core.aws import S3Client, S3Document, get_s3_client from app.core.config import ( AWS_REGION, CDN_DOMAIN, - DOC_CACHE_BUCKET, + DOCUMENT_CACHE_BUCKET, INGEST_TRIGGER_ROOT, PIPELINE_BUCKET, PUBLIC_APP_URL, VESPA_SECRETS_LOCATION, VESPA_URL, ) -from app.core.custom_app import decode_config_token +from app.core.custom_app import AppTokenFactory from app.core.download import create_data_download_zip_archive from app.core.search import ( create_vespa_search_params, process_result_into_csv, process_vespa_search_response, ) -from app.db.crud.helpers.validate import ( - validate_corpora_ids, - verify_any_corpora_ids_in_db, -) from app.db.session import get_db _LOGGER = logging.getLogger(__name__) @@ -51,6 +45,7 @@ cert_directory=VESPA_SECRETS_LOCATION, ) + search_router = APIRouter() @@ -140,57 +135,33 @@ def search_documents( the search database. The continuation token can be used to get the next set of results from the search database. See the request schema for more details. """ - origin = request.headers.get("origin") - if origin is not None: - origin = Url(origin).host - _LOGGER.info( "Search request", extra={ "props": { "search_request": search_body.model_dump(), - "origin": origin, "app_token": str(app_token), } }, ) - try: - allowed_corpora_ids = decode_config_token(app_token, origin) - except PyJWTError as e: - _LOGGER.error(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Could not decode configuration token", - headers={"WWW-Authenticate": "Bearer"}, - ) - + # Decode the app token and validate it. + # # First corpora validation is app token against DB. At least one of the app token # corpora IDs must be present in the DB to continue the search request. - if not verify_any_corpora_ids_in_db(db, allowed_corpora_ids): - msg = "Error verifying corpora IDs." - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=msg, - headers={"WWW-Authenticate": "Bearer"}, - ) + token = AppTokenFactory() + token.decode_and_validate(db, request, app_token) # If the search request IDs are null, we want to search using the app token corpora. if search_body.corpus_import_ids is None: - search_body.corpus_import_ids = cast(Sequence, allowed_corpora_ids) + search_body.corpus_import_ids = cast(Sequence, token.allowed_corpora_ids) # For the second validation, search request corpora Ids are validated against the # app token corpora IDs if the search request param 'corpus_import_ids' is not None. # corpus_import_ids must be a subset of app token IDs. - if not validate_corpora_ids( - set(search_body.corpus_import_ids), cast(set, allowed_corpora_ids) - ): - msg = "Error validating corpora IDs." - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=msg, - headers={"WWW-Authenticate": "Bearer"}, - ) + token.validate_subset( + set(search_body.corpus_import_ids), cast(set, token.allowed_corpora_ids) + ) _LOGGER.info( "Starting search...", @@ -203,22 +174,45 @@ def search_documents( def download_search_documents( request: Request, search_body: SearchRequestBody, + app_token: Annotated[str, Header()], db=Depends(get_db), ) -> StreamingResponse: """Download a CSV containing details of documents matching the search criteria.""" + token = AppTokenFactory() _LOGGER.info( "Search download request", extra={ "props": { "search_request": search_body.model_dump(), + "app_token": str(app_token), } }, ) + + # Decode the app token and validate it. + # + # First corpora validation is app token against DB. At least one of the app token + # corpora IDs must be present in the DB to continue the search request. + token = AppTokenFactory() + token.decode_and_validate(db, request, app_token) + + # If the search request IDs are null, we want to search using the app token corpora. + if search_body.corpus_import_ids is None: + search_body.corpus_import_ids = cast(Sequence, token.allowed_corpora_ids) + + # For the second validation, search request corpora Ids are validated against the + # app token corpora IDs if the search request param 'corpus_import_ids' is not None. + # corpus_import_ids must be a subset of app token IDs. + token.validate_subset( + set(search_body.corpus_import_ids), cast(set, token.allowed_corpora_ids) + ) + is_browse = not bool(search_body.query_string) _LOGGER.info( "Starting search...", + extra={"props": {"search_request": search_body.model_dump()}}, ) search_response = _search_request( db=db, @@ -236,18 +230,47 @@ def download_search_documents( ) +def _get_s3_doc_url_from_cdn( + s3_client: S3Client, s3_document: S3Document, data_dump_s3_key: str +) -> Optional[str]: + redirect_url = None + if s3_client.document_exists(s3_document): + _LOGGER.info("Redirecting to CDN data dump location...") + redirect_url = f"https://{CDN_DOMAIN}/{data_dump_s3_key}" + return redirect_url + + @search_router.get("/searches/download-all-data", include_in_schema=False) -def download_all_search_documents(db=Depends(get_db)) -> RedirectResponse: +def download_all_search_documents( + request: Request, app_token: Annotated[str, Header()], db=Depends(get_db) +) -> RedirectResponse: """Download a CSV containing details of all the documents in the corpus.""" - _LOGGER.info("Whole data download request") + token = AppTokenFactory() - if PIPELINE_BUCKET is None or PUBLIC_APP_URL is None or DOC_CACHE_BUCKET is None: + _LOGGER.info( + "Whole data download request", + extra={ + "props": { + "app_token": str(app_token), + } + }, + ) + + # Decode the app token and validate it. + token = AppTokenFactory() + token.decode_and_validate(db, request, app_token) + + if ( + PIPELINE_BUCKET is None + or PUBLIC_APP_URL is None + or DOCUMENT_CACHE_BUCKET is None + ): if PIPELINE_BUCKET is None: _LOGGER.error("{PIPELINE_BUCKET} is not set") if PUBLIC_APP_URL is None: _LOGGER.error("{PUBLIC_APP_URL} is not set") - if DOC_CACHE_BUCKET is None: - _LOGGER.error("{DOC_CACHE_BUCKET} is not set") + if DOCUMENT_CACHE_BUCKET is None: + _LOGGER.error("{DOCUMENT_CACHE_BUCKET} is not set") raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Missing required environment variables", @@ -259,7 +282,9 @@ def download_all_search_documents(db=Depends(get_db)) -> RedirectResponse: ) s3_prefix = "navigator/dumps" - data_dump_s3_key = f"{s3_prefix}/whole_data_dump-{latest_ingest_start}.zip" + data_dump_s3_key = ( + f"{s3_prefix}/{token.sub}-whole_data_dump-{latest_ingest_start}.zip" + ) valid_credentials = s3_client.is_connected() if not valid_credentials: @@ -268,23 +293,25 @@ def download_all_search_documents(db=Depends(get_db)) -> RedirectResponse: status_code=status.HTTP_401_UNAUTHORIZED, detail="Error connecting to AWS" ) - s3_document = S3Document(DOC_CACHE_BUCKET, AWS_REGION, data_dump_s3_key) + s3_document = S3Document(DOCUMENT_CACHE_BUCKET, AWS_REGION, data_dump_s3_key) if valid_credentials is True and (not s3_client.document_exists(s3_document)): aws_env = "production" if "dev" not in PUBLIC_APP_URL else "staging" _LOGGER.info( - f"Generating {aws_env} dump for ingest cycle w/c {latest_ingest_start}..." + f"Generating {token.sub} {aws_env} dump for ingest cycle w/c {latest_ingest_start}..." ) # After writing to a file buffer the position stays at the end whereas when you # upload a buffer, it starts from the position it is currently in. We need to # add the seek(0) to reset the buffer position to the beginning before writing # to S3 to avoid creating an empty file. - zip_buffer = create_data_download_zip_archive(latest_ingest_start, db) + zip_buffer = create_data_download_zip_archive( + latest_ingest_start, token.allowed_corpora_ids, db + ) zip_buffer.seek(0) try: response = s3_client.upload_fileobj( - bucket=DOC_CACHE_BUCKET, + bucket=DOCUMENT_CACHE_BUCKET, key=data_dump_s3_key, content_type="application/zip", fileobj=zip_buffer, @@ -292,18 +319,19 @@ def download_all_search_documents(db=Depends(get_db)) -> RedirectResponse: if response is False: _LOGGER.error("Failed to upload archive to s3: %s", response) else: - _LOGGER.info(f"Finished uploading data archive to {DOC_CACHE_BUCKET}") + _LOGGER.info( + f"Finished uploading data archive to {DOCUMENT_CACHE_BUCKET}" + ) except Exception as e: _LOGGER.error(e) - s3_document = S3Document(DOC_CACHE_BUCKET, AWS_REGION, data_dump_s3_key) - if s3_client.document_exists(s3_document): - _LOGGER.info("Redirecting to CDN data dump location...") - redirect_url = f"https://{CDN_DOMAIN}/{data_dump_s3_key}" + s3_document = S3Document(DOCUMENT_CACHE_BUCKET, AWS_REGION, data_dump_s3_key) + redirect_url = _get_s3_doc_url_from_cdn(s3_client, s3_document, data_dump_s3_key) + if redirect_url is not None: return RedirectResponse(redirect_url, status_code=status.HTTP_303_SEE_OTHER) _LOGGER.info( - f"Can't find data dump for {latest_ingest_start} in {DOC_CACHE_BUCKET}" + f"Can't find data dump for {latest_ingest_start} in {DOCUMENT_CACHE_BUCKET}" ) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) diff --git a/app/core/config.py b/app/core/config.py index 5a4159a1..72e1e532 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -17,7 +17,7 @@ # Whole database dump INGEST_CYCLE_START = os.getenv("INGEST_CYCLE_START") -DOC_CACHE_BUCKET = os.getenv("DOCUMENT_CACHE_BUCKET") +DOCUMENT_CACHE_BUCKET = os.getenv("DOCUMENT_CACHE_BUCKET") PIPELINE_BUCKET: str = os.getenv("PIPELINE_BUCKET", "not_set") INGEST_TRIGGER_ROOT: Final = "input" DEVELOPMENT_MODE: bool = os.getenv("DEVELOPMENT_MODE", "False").lower() == "true" diff --git a/app/core/custom_app.py b/app/core/custom_app.py index 62b20a86..17c03a66 100644 --- a/app/core/custom_app.py +++ b/app/core/custom_app.py @@ -1,118 +1,263 @@ import logging import os from datetime import datetime -from typing import Optional +from typing import Optional, cast import jwt from dateutil.relativedelta import relativedelta +from db_client.models.dfce.family import Corpus +from fastapi import HTTPException, Request, status +from jwt import PyJWTError +from pydantic_core import Url +from sqlalchemy import distinct, select +from sqlalchemy.orm import Session from app.api.api_v1.schemas.custom_app import CustomAppConfigDTO from app.core import security _LOGGER = logging.getLogger(__name__) - TOKEN_SECRET_KEY = os.environ["TOKEN_SECRET_KEY"] -ISSUER = "Climate Policy Radar" - -# TODO: revisit/configure access token expiry -CUSTOM_APP_TOKEN_EXPIRE_YEARS = 10 # token valid for 10 years -EXPECTED_ARGS_LENGTH = 3 - - -def _contains_special_chars(input: str) -> bool: - """Check if string contains any non alpha numeric characters. - - :param str input: A string to check. - :return bool: True if string contains special chars, False otherwise. - """ - if any(not char.isalnum() for char in input): - return True - return False - - -def _parse_and_sort_corpora_ids(corpora_ids_str: str) -> list[str]: - """Parse and sort the comma separated string of corpora IDs. - - :param str corpora_ids_str: A comma separated string containing the - corpus import IDs that the custom app should show. - :return list[str]: A list of corpora IDs sorted alphanumerically. - """ - corpora_ids = corpora_ids_str.split(",") - corpora_ids.sort() - return corpora_ids - - -def create_configuration_token(input: str, years: Optional[int] = None) -> str: - """Create a custom app configuration token. - - :param str input: A semi-colon delimited string containing in this - order: - 1. A comma separated string containing the corpus import IDs - that the custom app should show. - 2. A string containing the name of the theme. - 3. A string containing the hostname of the custom app. - :return str: A JWT token containing the encoded allowed corpora. - """ - expiry_years = years or CUSTOM_APP_TOKEN_EXPIRE_YEARS - issued_at = datetime.utcnow() - expire = issued_at + relativedelta(years=expiry_years) - - parts = input.split(";") - if len(parts) != EXPECTED_ARGS_LENGTH or any(len(part) < 1 for part in parts): - _LOGGER.error("Expected exactly 3 arguments") - raise ValueError - - corpora_ids, subject, audience = parts - - config = CustomAppConfigDTO( - allowed_corpora_ids=_parse_and_sort_corpora_ids(corpora_ids), - subject=subject, - issuer=ISSUER, - audience=audience, - expiry=expire, - issued_at=int( - datetime.timestamp(issued_at.replace(microsecond=0)) - ), # No microseconds - ) - - if _contains_special_chars(config.subject): - _LOGGER.error( - "Subject must not contain any special characters, including spaces" + + +class AppTokenFactory: + def __init__(self) -> None: + # TODO: revisit/configure access token expiry + self.custom_app_token_expire_years: int = 10 # token valid for 10 years + self.expected_args_length: int = 3 + + # These will only be populated after a token has been decoded. + self.allowed_corpora_ids: list[str] = [] + self.exp: Optional[datetime] = None + self.iat: Optional[datetime] = None + self.iss: str = "Climate Policy Radar" + self.sub: Optional[str] = None + self.aud: Optional[str] = None + + @staticmethod + def _contains_special_chars(input: str) -> bool: + """Check if string contains any non alpha numeric characters. + + :param str input: A string to check. + :return bool: True if string contains special chars, False otherwise. + """ + if any(not char.isalnum() for char in input): + return True + return False + + @staticmethod + def _parse_and_sort_corpora_ids(corpora_ids_str: str) -> list[str]: + """Parse and sort the comma separated string of corpora IDs. + + :param str corpora_ids_str: A comma separated string containing the + corpus import IDs that the custom app should show. + :return list[str]: A list of corpora IDs sorted alphanumerically. + """ + corpora_ids = corpora_ids_str.split(",") + corpora_ids.sort() + return corpora_ids + + def create_configuration_token( + self, input: str, years: Optional[int] = None + ) -> str: + """Create a custom app configuration token. + + :param str input: A semi-colon delimited string containing in this + order: + 1. A comma separated string containing the corpus import IDs + that the custom app should show. + 2. A string containing the name of the theme. + 3. A string containing the hostname of the custom app. + :return str: A JWT token containing the encoded allowed corpora. + """ + expiry_years = years or self.custom_app_token_expire_years + issued_at = datetime.utcnow() + expire = issued_at + relativedelta(years=expiry_years) + + parts = input.split(";") + if len(parts) != self.expected_args_length or any( + len(part) < 1 for part in parts + ): + _LOGGER.error(f"Expected exactly {self.expected_args_length} arguments") + raise ValueError + + corpora_ids, subject, audience = parts + + config = CustomAppConfigDTO( + allowed_corpora_ids=self._parse_and_sort_corpora_ids(corpora_ids), + subject=subject, + issuer=self.iss, + audience=audience, + expiry=expire, + issued_at=int( + datetime.timestamp(issued_at.replace(microsecond=0)) + ), # No microseconds ) - raise ValueError - - msg = "Creating custom app configuration token that expires on " - msg += f"{expire.strftime('%a %d %B %Y at %H:%M:%S:%f')} " - msg += f"for the following corpora: {corpora_ids}" - print(msg) - - to_encode = { - "allowed_corpora_ids": config.allowed_corpora_ids, - "exp": config.expiry, - "iat": config.issued_at, - "iss": config.issuer, - "sub": config.subject, - "aud": str(config.audience), - } - return jwt.encode(to_encode, TOKEN_SECRET_KEY, algorithm=security.ALGORITHM) - - -def decode_config_token(token: str, audience: Optional[str]) -> list[str]: - """Decodes a configuration token. - - :param str token : A JWT token that has been encoded with a list of - allowed corpora ids that the custom app should show, an expiry - date and an issued at date. - :return list[str]: A decoded list of valid corpora ids. - """ - decoded_token = jwt.decode( - token, - TOKEN_SECRET_KEY, - algorithms=[security.ALGORITHM], - issuer=ISSUER, - audience=audience, - options={"verify_aud": False}, - ) - corpora_ids: list = decoded_token.get("allowed_corpora_ids") - - return corpora_ids + + if self._contains_special_chars(config.subject): + _LOGGER.error( + "Subject must not contain any special characters, including spaces" + ) + raise ValueError + + msg = "Creating custom app configuration token that expires on " + msg += f"{expire.strftime('%a %d %B %Y at %H:%M:%S:%f')} " + msg += f"for the following corpora: {corpora_ids}" + print(msg) + + to_encode = { + "allowed_corpora_ids": config.allowed_corpora_ids, + "exp": config.expiry, + "iat": config.issued_at, + "iss": config.issuer, + "sub": config.subject, + "aud": str(config.audience), + } + return jwt.encode(to_encode, TOKEN_SECRET_KEY, algorithm=security.ALGORITHM) + + @staticmethod + def get_origin(request: Request) -> Optional[str]: + """Get the origin from the request headers.""" + origin = request.headers.get("origin") + + if origin is not None: + origin = Url(origin).host + return origin + + def verify_corpora_in_db(self, db: Session, any_exist: bool = True) -> bool: + """Validate given corpus IDs against the existing corpora in DB. + + :param Session db: A session to query against. + :param bool any_exist: Whether to check any or all corpora are + valid. True by default. + :return bool: Return whether or not the corpora are valid. + """ + if self.allowed_corpora_ids is None: + return False + + corpora_ids_from_db = cast( + list, db.scalars(select(distinct(Corpus.import_id))).all() + ) + + if any_exist: + validate_success = any( + corpus in corpora_ids_from_db for corpus in self.allowed_corpora_ids + ) + else: + validate_success = all( + corpus in corpora_ids_from_db for corpus in self.allowed_corpora_ids + ) + + if validate_success: + not_in_db = set(self.allowed_corpora_ids).difference(corpora_ids_from_db) + if not_in_db != set(): + _LOGGER.warning(f"Some corpora in app token {not_in_db} not in DB") + + return validate_success + + @staticmethod + def validate_corpora_ids( + corpora_ids: set[str], valid_corpora_ids: set[str] + ) -> bool: + """Validate all given corpus IDs against a list of allowed corpora. + + :param set[str] corpora_ids: The corpus import IDs we want to + validate. + :param set[str] valid_corpora_ids: The corpus import IDs + we want to validate against. + :return bool: Return whether or not all the corpora are valid. + """ + validate_success = corpora_ids.issubset(valid_corpora_ids) + if not validate_success: + invalid_corpora = set(corpora_ids).difference(valid_corpora_ids) + _LOGGER.warning( + f"Some corpora in search request params {invalid_corpora}" + "forbidden to search against." + ) + return validate_success + + def decode(self, token: str, audience: Optional[str]) -> list[str]: + """Decodes a configuration token. + + :param str token : A JWT token that has been encoded with a list of + allowed corpora ids that the custom app should show, an expiry + date and an issued at date. + :param Optional[str] audience: An audience to verify against. + :return list[str]: A decoded list of valid corpora ids. + """ + try: + decoded_token = jwt.decode( + token, + TOKEN_SECRET_KEY, + algorithms=[security.ALGORITHM], + issuer=self.iss, + audience=audience, + options={"verify_aud": False}, + ) + except PyJWTError as e: + _LOGGER.error(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Could not decode configuration token", + headers={"WWW-Authenticate": "Bearer"}, + ) + + self.allowed_corpora_ids = decoded_token.get("allowed_corpora_ids") + self.aud = decoded_token.get("aud") + self.exp = decoded_token.get("exp") + self.iat = decoded_token.get("iat") + self.iss = decoded_token.get("iss") + self.sub = decoded_token.get("sub") + + return decoded_token + + def validate(self, db: Session, any_exist: bool = True) -> None: + """Validate that any or all corpora IDs exist in the database. + + :param Session db: A session to query against. + :param bool any_exist: Whether to check any or all corpora are + valid. True by default. + """ + if not self.verify_corpora_in_db(db, any_exist): + msg = "Error verifying corpora IDs." + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=msg, + headers={"WWW-Authenticate": "Bearer"}, + ) + + def validate_subset( + self, corpora_ids: set[str], valid_corpora_ids: set[str] + ) -> None: + """Validate that a set of corpora IDs is a subset of another. + + :param set[str] corpora_ids: The corpus import IDs we want to + validate. + :param set[str] valid_corpora_ids: The corpus import IDs + we want to validate against. + """ + if not self.validate_corpora_ids(corpora_ids, valid_corpora_ids): + msg = "Error validating corpora IDs." + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=msg, + headers={"WWW-Authenticate": "Bearer"}, + ) + + def decode_and_validate( + self, db: Session, request: Request, token: str, any_exist: bool = True + ): + """Decode a JWT app token and validate 1+ corpora IDs exist. + + :param Session db: A session to query against. + :param Request request: A request object containing headers. + :param bool any_exist: Whether to check any or all corpora are + valid. True by default. + """ + origin = self.get_origin(request) + + # Decode the app token and validate it. + self.decode(token, origin) + + # First corpora validation is app token against DB. At least one of the app token + # corpora IDs must be present in the DB to continue the search request. + self.validate(db, any_exist) diff --git a/app/core/download.py b/app/core/download.py index 0c084985..485f38ee 100644 --- a/app/core/download.py +++ b/app/core/download.py @@ -1,6 +1,7 @@ """Functions to support browsing the RDS document structure""" import zipfile +from functools import lru_cache from io import BytesIO, StringIO from logging import getLogger from typing import Optional @@ -13,22 +14,46 @@ _LOGGER = getLogger(__name__) -def create_query(ingest_cycle_start: str) -> str: - """Browse RDS""" +@lru_cache() +def _get_query_template(): + with open("./app/core/download.sql", "r") as file: + return file.read() - # Read the download.sql file - if create_query.cache is None: # type: ignore - with open("./app/core/download.sql", "r") as file: - create_query.cache = file.read() # type: ignore - return create_query.cache.replace("{ingest_cycle_start}", ingest_cycle_start) # type: ignore +def create_query(ingest_cycle_start: str, allowed_corpora_ids: list[str]) -> str: + """Create download whole database query, replacing variables. - -create_query.cache = None # type: ignore - - -def get_whole_database_dump(ingest_cycle_start: str, db=Depends(get_db)): - query = create_query(ingest_cycle_start) + :param str ingest_cycle_start: The current ingest cycle date. + :param list[str] allowed_corpora_ids: The corpora from which we + should allow the data to be dumped. + :return str: The SQL query to perform on the database session. + """ + template_query = _get_query_template() + + # TODO: Hide MCF from data dump until after COP. + mcf_corpora_ids = [ + "MCF.corpus.AF.n0000", + "MCF.corpus.CIF.n0000", + "MCF.corpus.GCF.n0000", + "MCF.corpus.GEF.n0000", + ] + corpora_ids = [ + corpus_id + for corpus_id in allowed_corpora_ids + if corpus_id not in mcf_corpora_ids + ] + corpora_ids = "'" + "','".join(corpora_ids) + "'" + return template_query.replace( # type: ignore + "{ingest_cycle_start}", ingest_cycle_start + ).replace( + "{allowed_corpora_ids}", corpora_ids + ) # type: ignore + + +def get_whole_database_dump( + ingest_cycle_start: str, allowed_corpora_ids: list[str], db=Depends(get_db) +): + query = create_query(ingest_cycle_start, allowed_corpora_ids) with db.connection() as conn: df = pd.read_sql(query, conn.connection) return df @@ -63,8 +88,10 @@ def convert_dump_to_csv(df: pd.DataFrame): return csv_buffer -def generate_data_dump_as_csv(ingest_cycle_start: str, db=Depends(get_db)): - df = get_whole_database_dump(ingest_cycle_start, db) +def generate_data_dump_as_csv( + ingest_cycle_start: str, allowed_corpora_ids: list[str], db=Depends(get_db) +): + df = get_whole_database_dump(ingest_cycle_start, allowed_corpora_ids, db) csv = convert_dump_to_csv(df) csv.seek(0) return csv @@ -88,9 +115,11 @@ def generate_data_dump_readme(ingest_cycle_start: str): return file_buffer -def create_data_download_zip_archive(ingest_cycle_start: str, db=Depends(get_db)): +def create_data_download_zip_archive( + ingest_cycle_start: str, allowed_corpora_ids: list[str], db=Depends(get_db) +): readme_buffer = generate_data_dump_readme(ingest_cycle_start) - csv_buffer = generate_data_dump_as_csv(ingest_cycle_start, db) + csv_buffer = generate_data_dump_as_csv(ingest_cycle_start, allowed_corpora_ids, db) zip_buffer = BytesIO() with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file: diff --git a/app/core/download.sql b/app/core/download.sql index 601508dd..7218c968 100644 --- a/app/core/download.sql +++ b/app/core/download.sql @@ -229,6 +229,5 @@ LEFT JOIN most_recent_doc_slugs ds on ds.family_document_import_id = d.import_id LEFT JOIN most_recent_family_slugs fs on fs.family_import_id = f.import_id LEFT JOIN event_dates fp on fp.family_import_id = f.import_id -WHERE d.last_modified < '{ingest_cycle_start}' -ORDER BY -d.last_modified desc, d.created desc, d.ctid desc, n1.family_import_id +WHERE d.last_modified < '{ingest_cycle_start}' AND fc.corpus_import_id in ({allowed_corpora_ids}) +ORDER BY d.last_modified desc, d.created desc, d.ctid desc, n1.family_import_id diff --git a/pyproject.toml b/pyproject.toml index 60695d5c..1518679b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "navigator_backend" -version = "1.17.4" +version = "1.17.5" description = "" authors = ["CPR-dev-team "] packages = [{ include = "app" }, { include = "tests" }] diff --git a/tests/conftest.py b/tests/conftest.py index 4844a1a8..c205b18b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,9 +15,9 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy_utils import create_database, database_exists, drop_database -from app.core import security +from app.core import custom_app, security from app.core.aws import S3Client, get_s3_client -from app.core.custom_app import create_configuration_token +from app.core.custom_app import AppTokenFactory from app.db.session import get_db from app.main import app @@ -35,7 +35,7 @@ def mock_aws_creds(): def s3_document_bucket_names() -> dict: return { "queue": os.environ.get("DOCUMENT_BUCKET", "cpr-document-queue"), - "cdn": os.environ.get("DOC_CACHE_BUCKET", "test_cdn_bucket"), + "cdn": os.environ.get("DOCUMENT_CACHE_BUCKET", "test_cdn_bucket"), "pipeline": os.environ.get("PIPELINE_BUCKET", "test_pipeline_bucket"), } @@ -61,7 +61,7 @@ def test_s3_client(s3_document_bucket_names, mock_aws_creds): Body=bytes(1024), ) # Test setup for cdn test bucket - os.environ["DOC_CACHE_BUCKET"] = "test_cdn_bucket" + os.environ["DOCUMENT_CACHE_BUCKET"] = "test_cdn_bucket" # Test setup for Pipeline os.environ["PIPELINE_BUCKET"] = "test_pipeline_bucket" @@ -104,7 +104,7 @@ def get_test_db_url() -> str: @pytest.fixture -def valid_token(): +def valid_token(monkeypatch): """Generate valid config token using TOKEN_SECRET_KEY. Need to generate the config token using the token secret key from @@ -113,11 +113,18 @@ def valid_token(): might be different (e.g., the one for staging). This fixture works around this. """ + + def mock_return(_, __, ___): + return True + corpora_ids = "CCLW.corpus.1.0,CCLW.corpus.2.0" subject = "CCLW" audience = "localhost" input_str = f"{corpora_ids};{subject};{audience}" - return create_configuration_token(input_str) + + af = AppTokenFactory() + monkeypatch.setattr(custom_app.AppTokenFactory, "validate", mock_return) + return af.create_configuration_token(input_str) @pytest.fixture diff --git a/tests/core/test_aws.py b/tests/non_search/core/test_aws.py similarity index 100% rename from tests/core/test_aws.py rename to tests/non_search/core/test_aws.py diff --git a/tests/core/test_browse.py b/tests/non_search/core/test_browse.py similarity index 100% rename from tests/core/test_browse.py rename to tests/non_search/core/test_browse.py diff --git a/tests/core/test_lookups.py b/tests/non_search/core/test_lookups.py similarity index 100% rename from tests/core/test_lookups.py rename to tests/non_search/core/test_lookups.py diff --git a/tests/search/vespa/test_range_and_keyword_filters_search.py b/tests/search/vespa/test_range_and_keyword_filters_search.py index ad090a7e..a50945b6 100644 --- a/tests/search/vespa/test_range_and_keyword_filters_search.py +++ b/tests/search/vespa/test_range_and_keyword_filters_search.py @@ -14,7 +14,10 @@ @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) @pytest.mark.parametrize("label,query", [("search", "the"), ("browse", "")]) def test_keyword_country_filters__geography( mock_corpora_exist_in_db, @@ -57,7 +60,10 @@ def test_keyword_country_filters__geography( @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) @pytest.mark.parametrize("label,query", [("search", "the"), ("browse", "")]) def test_keyword_country_filters__geographies( mock_corpora_exist_in_db, @@ -102,7 +108,10 @@ def test_keyword_country_filters__geographies( @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) @pytest.mark.parametrize("label,query", [("search", "the"), ("browse", "")]) def test_keyword_region_filters( mock_corpora_exist_in_db, @@ -152,7 +161,10 @@ def test_keyword_region_filters( @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) @pytest.mark.parametrize("label,query", [("search", "the"), ("browse", "")]) def test_keyword_region_and_country_filters( mock_corpora_exist_in_db, @@ -185,7 +197,10 @@ def test_keyword_region_and_country_filters( @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) @pytest.mark.parametrize("label,query", [("search", "the"), ("browse", "")]) def test_invalid_keyword_filters( mock_corpora_exist_in_db, @@ -218,7 +233,10 @@ def test_invalid_keyword_filters( @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) @pytest.mark.parametrize( "year_range", [(None, None), (1900, None), (None, 2020), (1900, 2020)] ) @@ -248,7 +266,10 @@ def test_year_range_filtered_in( @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) @pytest.mark.parametrize("year_range", [(None, 2010), (2024, None)]) def test_year_range_filtered_out( mock_corpora_exist_in_db, @@ -276,7 +297,10 @@ def test_year_range_filtered_out( @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) @pytest.mark.parametrize("label, query", [("search", "the"), ("browse", "")]) def test_multiple_filters( mock_corpora_exist_in_db, @@ -306,7 +330,10 @@ def test_multiple_filters( @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) def test_geo_filter_with_exact( mock_corpora_exist_in_db, test_vespa, data_db, monkeypatch, data_client, valid_token ): diff --git a/tests/search/vespa/test_this_vespa_search_download.py b/tests/search/vespa/test_this_vespa_search_download.py index e4447e01..3c0d0c0d 100644 --- a/tests/search/vespa/test_this_vespa_search_download.py +++ b/tests/search/vespa/test_this_vespa_search_download.py @@ -1,8 +1,11 @@ import csv from io import StringIO +from typing import Any, Mapping from unittest.mock import patch +import jwt import pytest +from fastapi import status from app.api.api_v1.routers import search from tests.search.vespa.setup_search_tests import ( @@ -14,8 +17,29 @@ CSV_DOWNLOAD_ENDPOINT = "/api/v1/searches/download-csv" +def _make_download_request( + client, + token, + params: Mapping[str, Any], + expected_status_code: int = status.HTTP_200_OK, +): + headers = {"app-token": token} + + response = client.post( + CSV_DOWNLOAD_ENDPOINT, + json=params, + headers=headers, + ) + assert response is not None + assert response.status_code == expected_status_code, response.text + return response + + @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) @pytest.mark.parametrize("exact_match", [True, False]) @pytest.mark.parametrize("query_string", ["", "local"]) def test_csv_content( @@ -39,14 +63,14 @@ def test_csv_content( families = body["families"] assert len(families) > 0 - csv_response = data_client.post( - CSV_DOWNLOAD_ENDPOINT, - json={ + csv_response = _make_download_request( + data_client, + valid_token, + params={ "exact_match": exact_match, "query_string": query_string, }, ) - assert csv_response.status_code == 200 csv_content = csv.DictReader(StringIO(csv_response.text)) for row, family in zip(csv_content, families): @@ -62,7 +86,10 @@ def test_csv_content( @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) @pytest.mark.parametrize("label, query", [("search", "the"), ("browse", "")]) @pytest.mark.parametrize("limit", [100, 250, 500]) def test_csv_download_search_variable_limit( @@ -75,6 +102,7 @@ def test_csv_download_search_variable_limit( monkeypatch, data_client, mocker, + valid_token, ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -88,11 +116,7 @@ def test_csv_download_search_variable_limit( "offset": 0, } - download_response = data_client.post( - CSV_DOWNLOAD_ENDPOINT, - json=params, - ) - assert download_response.status_code == 200 + _make_download_request(data_client, valid_token, params=params) actual_params = query_spy.call_args.kwargs["parameters"].model_dump() @@ -104,9 +128,12 @@ def test_csv_download_search_variable_limit( @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) def test_csv_download__ignore_extra_fields( - mock_corpora_exist_in_db, test_vespa, data_db, monkeypatch, data_client, mocker + mock_corpora_exist_in_db, test_vespa, data_db, monkeypatch, data_client, valid_token ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -118,10 +145,72 @@ def test_csv_download__ignore_extra_fields( # Ensure extra, unspecified fields don't cause an error fields = [] with patch("app.core.search._CSV_SEARCH_RESPONSE_COLUMNS", fields): - download_response = data_client.post( - CSV_DOWNLOAD_ENDPOINT, - json=params, - ) - assert download_response.status_code == 200 + _make_download_request(data_client, valid_token, params=params) assert mock_corpora_exist_in_db.assert_called + + +@pytest.mark.search +@pytest.mark.parametrize( + "side_effect", + [ + jwt.exceptions.InvalidAudienceError, + jwt.exceptions.ExpiredSignatureError, + jwt.exceptions.InvalidTokenError, + ], +) +def test_csv_download_fails_when_decoding_token_raises_PyJWTError( + side_effect, data_client, data_db, valid_token, monkeypatch, test_vespa +): + """ + GIVEN a request to download the whole database + WHEN the decode() function call raises a PyJWTError + THEN raise a 400 HTTP error + """ + monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) + _populate_db_families(data_db) + + params = { + "query_string": "winter", + } + + with patch("app.core.custom_app.jwt.decode", side_effect=side_effect): + response = _make_download_request( + data_client, + valid_token, + params=params, + expected_status_code=status.HTTP_400_BAD_REQUEST, + ) + assert response.json()["detail"] == "Could not decode configuration token" + + +@pytest.mark.search +def test_csv_download_fails_when_corpus_ids_in_token_not_in_db( + data_client, data_db, monkeypatch, test_vespa +): + """ + GIVEN a list of corpora IDs decoded from an app config token + WHEN one or more of those corpora IDs are not in our database + THEN raise a 400 HTTP error + """ + monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) + _populate_db_families(data_db) + + params = { + "query_string": "winter", + } + + with patch( + "app.api.api_v1.routers.search.AppTokenFactory.decode", + return_value=True, + ), patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=False, + ): + response = _make_download_request( + data_client, + "some_token", + params=params, + expected_status_code=status.HTTP_400_BAD_REQUEST, + ) + assert response.json()["detail"] == "Error verifying corpora IDs." diff --git a/tests/search/vespa/test_vespa_corpus_filtering.py b/tests/search/vespa/test_vespa_corpus_filtering.py index 095b6844..6f40604f 100644 --- a/tests/search/vespa/test_vespa_corpus_filtering.py +++ b/tests/search/vespa/test_vespa_corpus_filtering.py @@ -32,6 +32,7 @@ def test_corpus_filtering( corpus_import_id: str, corpus_type_name: str, expected_hits: int, + valid_token, ): monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) @@ -43,14 +44,12 @@ def test_corpus_filtering( params["corpus_type_names"] = [corpus_type_name] with patch( - "app.api.api_v1.routers.search.decode_config_token", - return_value=["CCLW.corpus.1.0", "CCLW.corpus.2.0"], - ), patch( - "app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, ): response = _make_search_request( data_client, - token="foo", + token=valid_token, params=params, ) @@ -65,7 +64,7 @@ def test_corpus_filtering( @pytest.mark.search def test_search_with_corpus_ids_in_token_not_in_db( - data_client, data_db, valid_token, monkeypatch, test_vespa + data_client, data_db, monkeypatch, test_vespa ): """ GIVEN a list of corpora IDs decoded from an app config token @@ -75,12 +74,12 @@ def test_search_with_corpus_ids_in_token_not_in_db( monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) - with patch( - "app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=False + with patch("app.core.custom_app.AppTokenFactory.decode", return_value=True), patch( + "app.core.custom_app.AppTokenFactory.verify_corpora_in_db", return_value=False ): response = _make_search_request( data_client, - valid_token, + "test_token", params={"query_string": ""}, expected_status_code=status.HTTP_400_BAD_REQUEST, ) @@ -107,15 +106,13 @@ def test_search_decoding_token_raises_PyJWTError( ): """ GIVEN a request to the search endpoint - WHEN the decode_config_token() function call raises a PyJWTError + WHEN the decode() function call raises a PyJWTError THEN raise a 400 HTTP error """ monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) _populate_db_families(data_db) - with patch( - "app.api.api_v1.routers.search.decode_config_token", side_effect=side_effect - ): + with patch("jwt.decode", side_effect=side_effect): response = _make_search_request( data_client, valid_token, @@ -167,9 +164,8 @@ def test_search_with_invalid_corpus_id_in_search_request_params( _populate_db_families(data_db) with patch( - "app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", - return_value=True, - ), patch("app.api.api_v1.routers.search.validate_corpora_ids", return_value=False): + "app.core.custom_app.AppTokenFactory.validate_corpora_ids", return_value=False + ): response = _make_search_request( data_client, valid_token, diff --git a/tests/search/vespa/test_vespa_ids_search.py b/tests/search/vespa/test_vespa_ids_search.py index 1be534e9..5c23f695 100644 --- a/tests/search/vespa/test_vespa_ids_search.py +++ b/tests/search/vespa/test_vespa_ids_search.py @@ -42,7 +42,10 @@ def _fam_ids_from_response(test_db, response) -> list[str]: return family_ids -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) @pytest.mark.parametrize( "family_ids", [ @@ -76,7 +79,10 @@ def test_family_ids_search( assert mock_corpora_exist_in_db.assert_called -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) @pytest.mark.parametrize( "document_ids", [ @@ -113,7 +119,10 @@ def test_document_ids_search( assert mock_corpora_exist_in_db.assert_called -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) @pytest.mark.search def test_document_ids_and_family_ids_search( mock_corpora_exist_in_db, test_vespa, data_db, monkeypatch, data_client, valid_token @@ -135,7 +144,10 @@ def test_document_ids_and_family_ids_search( assert mock_corpora_exist_in_db.assert_called -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) @pytest.mark.search def test_empty_ids_dont_limit_result( mock_corpora_exist_in_db, test_vespa, data_db, monkeypatch, data_client, valid_token diff --git a/tests/search/vespa/test_vespa_metadata_filters_search.py b/tests/search/vespa/test_vespa_metadata_filters_search.py index e8b01901..625974e0 100644 --- a/tests/search/vespa/test_vespa_metadata_filters_search.py +++ b/tests/search/vespa/test_vespa_metadata_filters_search.py @@ -10,7 +10,10 @@ @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) @pytest.mark.parametrize( "label,query,metadata_filters", [ diff --git a/tests/search/vespa/test_vespa_query_ignores.py b/tests/search/vespa/test_vespa_query_ignores.py index a05c8f71..8808d827 100644 --- a/tests/search/vespa/test_vespa_query_ignores.py +++ b/tests/search/vespa/test_vespa_query_ignores.py @@ -11,7 +11,10 @@ @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) def test_case_insensitivity( mock_corpora_exist_in_db, test_vespa, data_db, monkeypatch, data_client, valid_token ): @@ -26,7 +29,10 @@ def test_case_insensitivity( @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) def test_punctuation_ignored( mock_corpora_exist_in_db, test_vespa, data_db, monkeypatch, data_client, valid_token ): @@ -53,7 +59,10 @@ def test_punctuation_ignored( @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) def test_accents_ignored( mock_corpora_exist_in_db, test_vespa, data_db, monkeypatch, data_client, valid_token ): diff --git a/tests/search/vespa/test_vespa_search.py b/tests/search/vespa/test_vespa_search.py index 368c3fe8..40ff67eb 100644 --- a/tests/search/vespa/test_vespa_search.py +++ b/tests/search/vespa/test_vespa_search.py @@ -17,7 +17,10 @@ @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) def test_empty_search_term_performs_browse( mock_corpora_exist_in_db, test_vespa, @@ -45,7 +48,10 @@ def test_empty_search_term_performs_browse( @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) @pytest.mark.parametrize("exact_match", [True, False]) def test_search_body_valid( mock_corpora_exist_in_db, @@ -86,7 +92,10 @@ def test_search_body_valid( @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) def test_no_doc_if_in_postgres_but_not_vespa( mock_corpora_exist_in_db, test_vespa, data_client, data_db, monkeypatch, valid_token ): @@ -143,7 +152,10 @@ def test_no_doc_if_in_postgres_but_not_vespa( @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) @pytest.mark.parametrize("label,query", [("search", "the"), ("browse", "")]) def test_benchmark_families_search( mock_corpora_exist_in_db, @@ -178,7 +190,10 @@ def test_benchmark_families_search( assert mock_corpora_exist_in_db.assert_called -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) @pytest.mark.search def test_specific_doc_returned( mock_corpora_exist_in_db, test_vespa, monkeypatch, data_client, data_db, valid_token @@ -202,7 +217,10 @@ def test_specific_doc_returned( assert mock_corpora_exist_in_db.assert_called -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) @pytest.mark.parametrize( ("extra_params", "invalid_field"), [ @@ -241,7 +259,10 @@ def test_search_params_backend_limits( @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) def test_search_with_deleted_docs( mock_corpora_exist_in_db, test_vespa, monkeypatch, data_client, data_db, valid_token ): diff --git a/tests/search/vespa/test_vespa_search_pagination.py b/tests/search/vespa/test_vespa_search_pagination.py index 708f1c9d..3dd549aa 100644 --- a/tests/search/vespa/test_vespa_search_pagination.py +++ b/tests/search/vespa/test_vespa_search_pagination.py @@ -11,7 +11,10 @@ @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) def test_simple_pagination_families( mock_corpora_exist_in_db, test_vespa, data_client, data_db, monkeypatch, valid_token ): @@ -60,7 +63,10 @@ def test_simple_pagination_families( @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) def test_continuation_token__families( mock_corpora_exist_in_db, test_vespa, data_db, monkeypatch, data_client, valid_token ): @@ -100,7 +106,10 @@ def test_continuation_token__families( @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) def test_continuation_token__passages( mock_corpora_exist_in_db, test_vespa, data_db, monkeypatch, data_client, valid_token ): diff --git a/tests/search/vespa/test_vespa_search_result_order.py b/tests/search/vespa/test_vespa_search_result_order.py index ee738804..79b527a6 100644 --- a/tests/search/vespa/test_vespa_search_result_order.py +++ b/tests/search/vespa/test_vespa_search_result_order.py @@ -11,7 +11,10 @@ @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) @pytest.mark.parametrize("label, query", [("search", "the"), ("browse", "")]) def test_result_order_score( mock_corpora_exist_in_db, @@ -47,7 +50,10 @@ def test_result_order_score( @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) @pytest.mark.parametrize("label, query", [("search", "the"), ("browse", "")]) def test_result_order_title( mock_corpora_exist_in_db, diff --git a/tests/search/vespa/test_whole_database_download.py b/tests/search/vespa/test_whole_database_download.py index df2ca087..024d4d50 100644 --- a/tests/search/vespa/test_whole_database_download.py +++ b/tests/search/vespa/test_whole_database_download.py @@ -1,30 +1,139 @@ from unittest.mock import patch +import jwt import pytest +from fastapi import status +from app.api.api_v1.routers import search from tests.search.vespa.setup_search_tests import _populate_db_families ALL_DATA_DOWNLOAD_ENDPOINT = "/api/v1/searches/download-all-data" @pytest.mark.search -@patch("app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=True) +@pytest.mark.parametrize( + "side_effect", + [ + jwt.exceptions.InvalidAudienceError, + jwt.exceptions.ExpiredSignatureError, + jwt.exceptions.InvalidTokenError, + ], +) +def test_whole_database_download_fails_when_decoding_token_raises_PyJWTError( + side_effect, data_client, data_db, valid_token, monkeypatch, test_vespa +): + """ + GIVEN a request to download the whole database + WHEN the decode() function call raises a PyJWTError + THEN raise a 400 HTTP error + """ + monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) + _populate_db_families(data_db) + + with patch( + "app.core.custom_app.jwt.decode", + side_effect=side_effect, + ), patch( + "app.api.api_v1.routers.search.PIPELINE_BUCKET", "test_pipeline_bucket" + ), patch( + "app.api.api_v1.routers.search.DOCUMENT_CACHE_BUCKET", "test_cdn_bucket" + ), patch( + "app.core.aws.S3Client.is_connected", return_value=True + ): + response = data_client.get( + ALL_DATA_DOWNLOAD_ENDPOINT, + headers={"app-token": valid_token}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.json()["detail"] == "Could not decode configuration token" + + +@pytest.mark.search +def test_whole_database_download_fails_when_corpus_ids_in_token_not_in_db( + data_client, data_db, monkeypatch, test_vespa +): + """ + GIVEN a list of corpora IDs decoded from an app config token + WHEN one or more of those corpora IDs are not in our database + THEN raise a 400 HTTP error + """ + monkeypatch.setattr(search, "_VESPA_CONNECTION", test_vespa) + _populate_db_families(data_db) + + with patch( + "app.api.api_v1.routers.search.AppTokenFactory.decode", + return_value=True, + ), patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=False, + ), patch( + "app.core.aws.S3Client.is_connected", return_value=True + ), patch( + "app.api.api_v1.routers.search.PIPELINE_BUCKET", "test_pipeline_bucket" + ), patch( + "app.api.api_v1.routers.search.DOCUMENT_CACHE_BUCKET", "test_cdn_bucket" + ): + response = data_client.get( + ALL_DATA_DOWNLOAD_ENDPOINT, + headers={"app-token": "some_token"}, + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.json()["detail"] == "Error verifying corpora IDs." + + +@pytest.mark.search +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) def test_all_data_download(mock_corpora_exist_in_db, data_db, data_client, valid_token): _populate_db_families(data_db) + headers = {"app-token": valid_token} + with ( patch("app.api.api_v1.routers.search.PIPELINE_BUCKET", "test_pipeline_bucket"), - patch("app.api.api_v1.routers.search.DOC_CACHE_BUCKET", "test_cdn_bucket"), + patch("app.api.api_v1.routers.search.DOCUMENT_CACHE_BUCKET", "test_cdn_bucket"), patch("app.core.aws.S3Client.is_connected", return_value=True), ): data_client.follow_redirects = False - download_response = data_client.get(ALL_DATA_DOWNLOAD_ENDPOINT) + download_response = data_client.get(ALL_DATA_DOWNLOAD_ENDPOINT, headers=headers) # Redirects to cdn assert download_response.status_code == 303 assert download_response.headers["location"] == ( "https://cdn.climatepolicyradar.org/" - "navigator/dumps/whole_data_dump-2024-03-22.zip" + "navigator/dumps/CCLW-whole_data_dump-2024-03-22.zip" ) assert mock_corpora_exist_in_db.assert_called + + +@pytest.mark.search +@patch( + "app.api.api_v1.routers.search.AppTokenFactory.verify_corpora_in_db", + return_value=True, +) +def test_all_data_download_fails_when_s3_upload_failed( + mock_corpora_exist_in_db, data_db, data_client, valid_token +): + _populate_db_families(data_db) + + headers = {"app-token": valid_token} + + with ( + patch("app.api.api_v1.routers.search.PIPELINE_BUCKET", "test_pipeline_bucket"), + patch("app.api.api_v1.routers.search.DOCUMENT_CACHE_BUCKET", "test_cdn_bucket"), + patch("app.core.aws.S3Client.is_connected", return_value=True), + patch( + "app.api.api_v1.routers.search._get_s3_doc_url_from_cdn", return_value=None + ), + ): + data_client.follow_redirects = False + download_response = data_client.get(ALL_DATA_DOWNLOAD_ENDPOINT, headers=headers) + + # Redirects to cdn + assert download_response.status_code == status.HTTP_404_NOT_FOUND + assert mock_corpora_exist_in_db.assert_called diff --git a/tests/unit/app/core/custom_app/conftest.py b/tests/unit/app/core/custom_app/conftest.py index 97ae1d37..e481be7c 100644 --- a/tests/unit/app/core/custom_app/conftest.py +++ b/tests/unit/app/core/custom_app/conftest.py @@ -4,7 +4,7 @@ import jwt import pytest -from app.core.custom_app import create_configuration_token +from app.core.custom_app import AppTokenFactory TOKEN_SECRET_KEY = os.environ["TOKEN_SECRET_KEY"] ALGORITHM = "HS256" @@ -42,4 +42,5 @@ def token_with_invalid_aud(request) -> str: @pytest.fixture def valid_token() -> str: - return create_configuration_token(f"mango,apple;subject;{VALID_AUDIENCE}") + af = AppTokenFactory() + return af.create_configuration_token(f"mango,apple;subject;{VALID_AUDIENCE}") diff --git a/tests/unit/app/core/custom_app/test_create_configuration_token.py b/tests/unit/app/core/custom_app/test_create_token.py similarity index 91% rename from tests/unit/app/core/custom_app/test_create_configuration_token.py rename to tests/unit/app/core/custom_app/test_create_token.py index 5817196c..558a8753 100644 --- a/tests/unit/app/core/custom_app/test_create_configuration_token.py +++ b/tests/unit/app/core/custom_app/test_create_token.py @@ -4,7 +4,7 @@ import pytest from dateutil.relativedelta import relativedelta -from app.core.custom_app import create_configuration_token +from app.core.custom_app import AppTokenFactory from tests.unit.app.core.custom_app.conftest import ALGORITHM, TOKEN_SECRET_KEY EXPIRE_AFTER_1_YEAR = 1 @@ -43,8 +43,9 @@ def has_expected_keys(keys: list[str]) -> bool: ], ) def test_create_configuration_token_incorrect_num_args_in_input(input_str: str): + af = AppTokenFactory() with pytest.raises(ValueError): - token = create_configuration_token(input_str) + token = af.create_configuration_token(input_str) assert token is None data = jwt.decode(token, TOKEN_SECRET_KEY, algorithms=[ALGORITHM]) @@ -59,8 +60,9 @@ def test_create_configuration_token_incorrect_num_args_in_input(input_str: str): ], ) def test_create_configuration_token_subject_contains_special_chars(input_str: str): + af = AppTokenFactory() with pytest.raises(ValueError): - token = create_configuration_token(input_str) + token = af.create_configuration_token(input_str) assert token is None data = jwt.decode(token, TOKEN_SECRET_KEY, algorithms=[ALGORITHM]) @@ -85,7 +87,8 @@ def test_create_configuration_token_default_expiry( expected_subject: str, expected_audience: str, ): - token = create_configuration_token(input_str) + af = AppTokenFactory() + token = af.create_configuration_token(input_str) assert token is not None assert isinstance(token, str) @@ -136,7 +139,8 @@ def test_create_configuration_token_specific_expiry( expected_subject: str, expected_audience: str, ): - token = create_configuration_token(input_str, expiry_years) + af = AppTokenFactory() + token = af.create_configuration_token(input_str, expiry_years) assert token is not None assert isinstance(token, str) diff --git a/tests/unit/app/core/custom_app/test_decode_configuration_token.py b/tests/unit/app/core/custom_app/test_decode_configuration_token.py deleted file mode 100644 index d5f0af71..00000000 --- a/tests/unit/app/core/custom_app/test_decode_configuration_token.py +++ /dev/null @@ -1,68 +0,0 @@ -from typing import Optional - -import jwt -import pytest - -from app.core.custom_app import create_configuration_token, decode_config_token -from tests.unit.app.core.custom_app.conftest import VALID_AUDIENCE - - -def test_decoding_expired_token_raise_expired_signature_token_error(expired_token): - with pytest.raises(jwt.ExpiredSignatureError) as error: - decode_config_token(expired_token, VALID_AUDIENCE) - - assert str(error.value) == "Signature has expired" - - -@pytest.mark.skip("Re-implement this as part of PDCT-1509") -@pytest.mark.parametrize( - "input_str, aud, error_msg", - [ - ("mango,apple;subject;https://audience.com", None, "Invalid audience"), - ( - "mango,apple;subject;https://audience.com", - "https://audience.org", - "Audience doesn't match", - ), - ( - "mango,apple;subject;https://AUDIENCE.OrG", - "https://AUDIENCE.Com", - "Audience doesn't match", - ), - ], -) -def test_decoding_token_with_invalid_aud_raises_invalid_token_error( - input_str: str, aud: Optional[str], error_msg: str -): - token = create_configuration_token(input_str) - with pytest.raises(jwt.InvalidTokenError) as error: - decode_config_token(token, aud) - - assert str(error.value) == error_msg - - -@pytest.mark.parametrize( - "input_str, aud", - [ - ("mango,apple;subject;https://audience.com", None), - ("mango,apple;subject;https://audience.com", "https://audience.org"), - ("mango,apple;subject;https://AUDIENCE.OrG", "https://AUDIENCE.Com"), - ], -) -def test_decoding_token_with_invalid_aud_success_in_dev_mode( - input_str: str, aud: Optional[str] -): - token = create_configuration_token(input_str) - decoded_corpora_ids = decode_config_token(token, aud) - assert len(decoded_corpora_ids) > 0 - - expected_num_corpora = 2 - assert len(decoded_corpora_ids) == expected_num_corpora - - -def test_decode_configuration_token_success(valid_token): - decoded_corpora_ids = decode_config_token(valid_token, VALID_AUDIENCE) - assert len(decoded_corpora_ids) > 0 - - expected_num_corpora = 2 - assert len(decoded_corpora_ids) == expected_num_corpora diff --git a/tests/unit/app/core/custom_app/test_decode_token.py b/tests/unit/app/core/custom_app/test_decode_token.py new file mode 100644 index 00000000..e8afc4a7 --- /dev/null +++ b/tests/unit/app/core/custom_app/test_decode_token.py @@ -0,0 +1,82 @@ +import logging +from typing import Optional, cast +from unittest.mock import patch + +import jwt +import pytest +from fastapi import HTTPException, status + +from app.core.custom_app import AppTokenFactory +from tests.unit.app.core.custom_app.conftest import VALID_AUDIENCE + + +def test_decoding_expired_token_raise_expired_signature_token_error( + expired_token, caplog +): + af = AppTokenFactory() + with patch("jwt.decode", side_effect=jwt.ExpiredSignatureError), caplog.at_level( + logging.DEBUG + ), pytest.raises(HTTPException): + response = cast(HTTPException, af.decode(expired_token, VALID_AUDIENCE)) + assert response.status_code == status.HTTP_400_BAD_REQUEST + data = response.json() # pyright: ignore + assert str(data["detail"]) == "Could not decode configuration token" + + +@pytest.mark.skip("Re-implement this as part of PDCT-1509") +@pytest.mark.parametrize( + "input_str, aud, error_msg", + [ + ("mango,apple;subject;https://audience.com", None, "Invalid audience"), + ( + "mango,apple;subject;https://audience.com", + "https://audience.org", + "Audience doesn't match", + ), + ( + "mango,apple;subject;https://AUDIENCE.OrG", + "https://AUDIENCE.Com", + "Audience doesn't match", + ), + ], +) +def test_decoding_token_with_invalid_aud_raises_invalid_token_error( + input_str: str, aud: Optional[str], error_msg: str, caplog +): + af = AppTokenFactory() + token = af.create_configuration_token(input_str) + with patch("jwt.decode", side_effect=jwt.InvalidTokenError), pytest.raises( + HTTPException + ), caplog.at_level(logging.ERROR) as error: + af.decode(token, aud) + + assert str(error.value) == error_msg + + +@pytest.mark.parametrize( + "input_str, aud", + [ + ("mango,apple;subject;https://audience.com", None), + ("mango,apple;subject;https://audience.com", "https://audience.org"), + ("mango,apple;subject;https://AUDIENCE.OrG", "https://AUDIENCE.Com"), + ], +) +def test_decoding_token_with_invalid_aud_success_in_dev_mode( + input_str: str, aud: Optional[str] +): + af = AppTokenFactory() + token = af.create_configuration_token(input_str) + token_content = af.decode(token, aud) + assert len(token_content) > 0 + + expected_num_keys = 6 + assert len(token_content) == expected_num_keys + + +def test_decode_configuration_token_success(valid_token): + af = AppTokenFactory() + token_content = af.decode(valid_token, VALID_AUDIENCE) + assert len(token_content) > 0 + + expected_num_keys = 6 + assert len(token_content) == expected_num_keys