Skip to content

Commit

Permalink
Feature/pdct 1559 download all data should only show data from allowe…
Browse files Browse the repository at this point in the history
…d corpora (#371)

* Move app token factory into class

* Update filename based on theme from app token

* Make TODO a FIXME so we remember to remove it

* Update create token tests to use class

* Update decode token tests to use class

* Remove unused db argument

* Update conftests to use app token factory class

* Rename files & raise HTTPException from PyJwtError

* Consolidate core tests under non_search folder

* Remove self.db attribute

* Mock token validate so it always passes in valid_token

* Mock verify_corpora_in_db as True

* Update existing search tests

* Fix mocking paths based on new AppFactory

* Pull verification of dump upload into separate function

* Add test coverage for whole database download

* Add test coverage for this search download

* Add docstrings

* Remove fixme

* Make default corpora an empty list

* Only dump data from list of allowed corpora

* Update docstring

* Rename DOC_CACHE_BUCKET to DOCUMENT_CACHE_BUCKET

* Remove duplicate property

* Update cache to use functools lru_cache

* Bump to 1.17.5

* Remove fixme
  • Loading branch information
katybaulch authored Oct 9, 2024
1 parent 4f7b4ad commit 299cfab
Show file tree
Hide file tree
Showing 25 changed files with 842 additions and 333 deletions.
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ DOCUMENT_BUCKET=cpr-document-queue
# CDN URL
CDN_URL=https://cdn.climatepolicyradar.org
CDN_DOMAIN=cdn.climatepolicyradar.org
DOCUMENT_CACHE_BUCKET=some_cdn_bucket

# Pipeline settings
PIPELINE_BUCKET="test_pipeline_bucket"
Expand Down
144 changes: 86 additions & 58 deletions app/api/api_v1/routers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,40 +8,34 @@

import logging
from io import BytesIO
from typing import Annotated, Sequence, cast
from typing import Annotated, Optional, Sequence, cast

from cpr_sdk.exceptions import QueryError
from cpr_sdk.search_adaptors import VespaSearchAdapter
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Request, status
from fastapi.responses import StreamingResponse
from jwt import PyJWTError
from pydantic_core import Url
from sqlalchemy.orm import Session
from starlette.responses import RedirectResponse

from app.api.api_v1.schemas.search import SearchRequestBody, SearchResponse
from app.core.aws import S3Document, get_s3_client
from app.core.aws import S3Client, S3Document, get_s3_client
from app.core.config import (
AWS_REGION,
CDN_DOMAIN,
DOC_CACHE_BUCKET,
DOCUMENT_CACHE_BUCKET,
INGEST_TRIGGER_ROOT,
PIPELINE_BUCKET,
PUBLIC_APP_URL,
VESPA_SECRETS_LOCATION,
VESPA_URL,
)
from app.core.custom_app import decode_config_token
from app.core.custom_app import AppTokenFactory
from app.core.download import create_data_download_zip_archive
from app.core.search import (
create_vespa_search_params,
process_result_into_csv,
process_vespa_search_response,
)
from app.db.crud.helpers.validate import (
validate_corpora_ids,
verify_any_corpora_ids_in_db,
)
from app.db.session import get_db

_LOGGER = logging.getLogger(__name__)
Expand All @@ -51,6 +45,7 @@
cert_directory=VESPA_SECRETS_LOCATION,
)


search_router = APIRouter()


Expand Down Expand Up @@ -140,57 +135,33 @@ def search_documents(
the search database. The continuation token can be used to get the next set of
results from the search database. See the request schema for more details.
"""
origin = request.headers.get("origin")
if origin is not None:
origin = Url(origin).host

_LOGGER.info(
"Search request",
extra={
"props": {
"search_request": search_body.model_dump(),
"origin": origin,
"app_token": str(app_token),
}
},
)

try:
allowed_corpora_ids = decode_config_token(app_token, origin)
except PyJWTError as e:
_LOGGER.error(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Could not decode configuration token",
headers={"WWW-Authenticate": "Bearer"},
)

# Decode the app token and validate it.
#
# First corpora validation is app token against DB. At least one of the app token
# corpora IDs must be present in the DB to continue the search request.
if not verify_any_corpora_ids_in_db(db, allowed_corpora_ids):
msg = "Error verifying corpora IDs."
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=msg,
headers={"WWW-Authenticate": "Bearer"},
)
token = AppTokenFactory()
token.decode_and_validate(db, request, app_token)

# If the search request IDs are null, we want to search using the app token corpora.
if search_body.corpus_import_ids is None:
search_body.corpus_import_ids = cast(Sequence, allowed_corpora_ids)
search_body.corpus_import_ids = cast(Sequence, token.allowed_corpora_ids)

# For the second validation, search request corpora Ids are validated against the
# app token corpora IDs if the search request param 'corpus_import_ids' is not None.
# corpus_import_ids must be a subset of app token IDs.
if not validate_corpora_ids(
set(search_body.corpus_import_ids), cast(set, allowed_corpora_ids)
):
msg = "Error validating corpora IDs."
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=msg,
headers={"WWW-Authenticate": "Bearer"},
)
token.validate_subset(
set(search_body.corpus_import_ids), cast(set, token.allowed_corpora_ids)
)

_LOGGER.info(
"Starting search...",
Expand All @@ -203,22 +174,45 @@ def search_documents(
def download_search_documents(
request: Request,
search_body: SearchRequestBody,
app_token: Annotated[str, Header()],
db=Depends(get_db),
) -> StreamingResponse:
"""Download a CSV containing details of documents matching the search criteria."""
token = AppTokenFactory()

_LOGGER.info(
"Search download request",
extra={
"props": {
"search_request": search_body.model_dump(),
"app_token": str(app_token),
}
},
)

# Decode the app token and validate it.
#
# First corpora validation is app token against DB. At least one of the app token
# corpora IDs must be present in the DB to continue the search request.
token = AppTokenFactory()
token.decode_and_validate(db, request, app_token)

# If the search request IDs are null, we want to search using the app token corpora.
if search_body.corpus_import_ids is None:
search_body.corpus_import_ids = cast(Sequence, token.allowed_corpora_ids)

# For the second validation, search request corpora Ids are validated against the
# app token corpora IDs if the search request param 'corpus_import_ids' is not None.
# corpus_import_ids must be a subset of app token IDs.
token.validate_subset(
set(search_body.corpus_import_ids), cast(set, token.allowed_corpora_ids)
)

is_browse = not bool(search_body.query_string)

_LOGGER.info(
"Starting search...",
extra={"props": {"search_request": search_body.model_dump()}},
)
search_response = _search_request(
db=db,
Expand All @@ -236,18 +230,47 @@ def download_search_documents(
)


def _get_s3_doc_url_from_cdn(
s3_client: S3Client, s3_document: S3Document, data_dump_s3_key: str
) -> Optional[str]:
redirect_url = None
if s3_client.document_exists(s3_document):
_LOGGER.info("Redirecting to CDN data dump location...")
redirect_url = f"https://{CDN_DOMAIN}/{data_dump_s3_key}"
return redirect_url


@search_router.get("/searches/download-all-data", include_in_schema=False)
def download_all_search_documents(db=Depends(get_db)) -> RedirectResponse:
def download_all_search_documents(
request: Request, app_token: Annotated[str, Header()], db=Depends(get_db)
) -> RedirectResponse:
"""Download a CSV containing details of all the documents in the corpus."""
_LOGGER.info("Whole data download request")
token = AppTokenFactory()

if PIPELINE_BUCKET is None or PUBLIC_APP_URL is None or DOC_CACHE_BUCKET is None:
_LOGGER.info(
"Whole data download request",
extra={
"props": {
"app_token": str(app_token),
}
},
)

# Decode the app token and validate it.
token = AppTokenFactory()
token.decode_and_validate(db, request, app_token)

if (
PIPELINE_BUCKET is None
or PUBLIC_APP_URL is None
or DOCUMENT_CACHE_BUCKET is None
):
if PIPELINE_BUCKET is None:
_LOGGER.error("{PIPELINE_BUCKET} is not set")
if PUBLIC_APP_URL is None:
_LOGGER.error("{PUBLIC_APP_URL} is not set")
if DOC_CACHE_BUCKET is None:
_LOGGER.error("{DOC_CACHE_BUCKET} is not set")
if DOCUMENT_CACHE_BUCKET is None:
_LOGGER.error("{DOCUMENT_CACHE_BUCKET} is not set")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Missing required environment variables",
Expand All @@ -259,7 +282,9 @@ def download_all_search_documents(db=Depends(get_db)) -> RedirectResponse:
)

s3_prefix = "navigator/dumps"
data_dump_s3_key = f"{s3_prefix}/whole_data_dump-{latest_ingest_start}.zip"
data_dump_s3_key = (
f"{s3_prefix}/{token.sub}-whole_data_dump-{latest_ingest_start}.zip"
)

valid_credentials = s3_client.is_connected()
if not valid_credentials:
Expand All @@ -268,42 +293,45 @@ def download_all_search_documents(db=Depends(get_db)) -> RedirectResponse:
status_code=status.HTTP_401_UNAUTHORIZED, detail="Error connecting to AWS"
)

s3_document = S3Document(DOC_CACHE_BUCKET, AWS_REGION, data_dump_s3_key)
s3_document = S3Document(DOCUMENT_CACHE_BUCKET, AWS_REGION, data_dump_s3_key)
if valid_credentials is True and (not s3_client.document_exists(s3_document)):
aws_env = "production" if "dev" not in PUBLIC_APP_URL else "staging"
_LOGGER.info(
f"Generating {aws_env} dump for ingest cycle w/c {latest_ingest_start}..."
f"Generating {token.sub} {aws_env} dump for ingest cycle w/c {latest_ingest_start}..."
)

# After writing to a file buffer the position stays at the end whereas when you
# upload a buffer, it starts from the position it is currently in. We need to
# add the seek(0) to reset the buffer position to the beginning before writing
# to S3 to avoid creating an empty file.
zip_buffer = create_data_download_zip_archive(latest_ingest_start, db)
zip_buffer = create_data_download_zip_archive(
latest_ingest_start, token.allowed_corpora_ids, db
)
zip_buffer.seek(0)

try:
response = s3_client.upload_fileobj(
bucket=DOC_CACHE_BUCKET,
bucket=DOCUMENT_CACHE_BUCKET,
key=data_dump_s3_key,
content_type="application/zip",
fileobj=zip_buffer,
)
if response is False:
_LOGGER.error("Failed to upload archive to s3: %s", response)
else:
_LOGGER.info(f"Finished uploading data archive to {DOC_CACHE_BUCKET}")
_LOGGER.info(
f"Finished uploading data archive to {DOCUMENT_CACHE_BUCKET}"
)

except Exception as e:
_LOGGER.error(e)

s3_document = S3Document(DOC_CACHE_BUCKET, AWS_REGION, data_dump_s3_key)
if s3_client.document_exists(s3_document):
_LOGGER.info("Redirecting to CDN data dump location...")
redirect_url = f"https://{CDN_DOMAIN}/{data_dump_s3_key}"
s3_document = S3Document(DOCUMENT_CACHE_BUCKET, AWS_REGION, data_dump_s3_key)
redirect_url = _get_s3_doc_url_from_cdn(s3_client, s3_document, data_dump_s3_key)
if redirect_url is not None:
return RedirectResponse(redirect_url, status_code=status.HTTP_303_SEE_OTHER)

_LOGGER.info(
f"Can't find data dump for {latest_ingest_start} in {DOC_CACHE_BUCKET}"
f"Can't find data dump for {latest_ingest_start} in {DOCUMENT_CACHE_BUCKET}"
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
2 changes: 1 addition & 1 deletion app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

# Whole database dump
INGEST_CYCLE_START = os.getenv("INGEST_CYCLE_START")
DOC_CACHE_BUCKET = os.getenv("DOCUMENT_CACHE_BUCKET")
DOCUMENT_CACHE_BUCKET = os.getenv("DOCUMENT_CACHE_BUCKET")
PIPELINE_BUCKET: str = os.getenv("PIPELINE_BUCKET", "not_set")
INGEST_TRIGGER_ROOT: Final = "input"
DEVELOPMENT_MODE: bool = os.getenv("DEVELOPMENT_MODE", "False").lower() == "true"
Expand Down
Loading

0 comments on commit 299cfab

Please sign in to comment.