Skip to content

Commit

Permalink
Feature/pdct 1533 documents should only show documents in allowed cor…
Browse files Browse the repository at this point in the history
…pora (#415)

* Make documents router dependent on app token

* Driveby: Add CORS tests for MCF

* Update slug lookup query to respect allowed corpora

* Include actual CCLW corpus ID in test token

* Bump to 1.19.11

* Refactor _get_query_template

* Refactor doc and fam lookup tests

* Add integration tests for doc/fam lookup when corpora mismatch

* Add alternative corpora token

* Refactor download code
  • Loading branch information
katybaulch authored Nov 5, 2024
1 parent 9fc8060 commit e4e9b9d
Show file tree
Hide file tree
Showing 15 changed files with 435 additions and 279 deletions.
20 changes: 12 additions & 8 deletions app/api/api_v1/routers/documents.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
from http.client import NOT_FOUND
from typing import Union
from typing import Annotated, Union

from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, Header, HTTPException, Request

from app.clients.db.session import get_db
from app.models.document import (
Expand All @@ -14,6 +14,7 @@
get_family_document_and_context,
get_slugged_objects,
)
from app.service.custom_app import AppTokenFactory

_LOGGER = logging.getLogger(__file__)

Expand All @@ -28,20 +29,23 @@
],
)
async def family_or_document_detail(
slug: str,
db=Depends(get_db),
slug: str, request: Request, app_token: Annotated[str, Header()], db=Depends(get_db)
):
"""Get details of the family or document associated with the slug."""
_LOGGER.info(
f"Getting detailed information for family or document '{slug}'",
extra={
"props": {
"import_id_or_slug": slug,
},
"props": {"import_id_or_slug": slug, "app_token": str(app_token)},
},
)

family_document_import_id, family_import_id = get_slugged_objects(db, slug)
# Decode the app token and validate it.
token = AppTokenFactory()
token.decode_and_validate(db, request, app_token)

family_document_import_id, family_import_id = get_slugged_objects(
db, slug, token.allowed_corpora_ids
)
if family_document_import_id is None and family_import_id is None:
raise HTTPException(status_code=NOT_FOUND, detail=f"Nothing found for {slug}")

Expand Down
2 changes: 0 additions & 2 deletions app/api/api_v1/routers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,6 @@ def download_all_search_documents(
request: Request, app_token: Annotated[str, Header()], db=Depends(get_db)
) -> RedirectResponse:
"""Download a CSV containing details of all the documents in the corpus."""
token = AppTokenFactory()

_LOGGER.info(
"Whole data download request",
extra={
Expand Down
57 changes: 49 additions & 8 deletions app/repository/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import logging
import os
from datetime import datetime
from typing import Optional, Sequence, cast

Expand Down Expand Up @@ -34,29 +35,69 @@
LinkableFamily,
)
from app.repository.geography import get_geo_subquery
from app.repository.helpers import get_query_template
from app.repository.lookups import doc_type_from_family_document_metadata
from app.service.util import to_cdn_url

_LOGGER = logging.getLogger(__file__)


def get_slugged_objects(db: Session, slug: str) -> tuple[Optional[str], Optional[str]]:
def get_slugged_object_from_allowed_corpora_query(
template_query, slug_name: str, allowed_corpora_ids: list[str]
) -> str:
"""Create download whole database query, replacing variables.
:param str ingest_cycle_start: The current ingest cycle date.
:param list[str] allowed_corpora_ids: The corpora from which we
should allow the data to be dumped.
:return str: The SQL query to perform on the database session.
"""
Matches the slug name to a FamilyDocument or Family import_id
corpora_ids = "'" + "','".join(allowed_corpora_ids) + "'"
return template_query.replace("{slug_name}", slug_name).replace( # type: ignore
"{allowed_corpora_ids}", corpora_ids
) # type: ignore


def get_slugged_objects(
db: Session, slug: str, allowed_corpora: Optional[list[str]] = None
) -> tuple[Optional[str], Optional[str]]:
"""Match the slug name to a FamilyDocument or Family import ID.
This function also contains logic to only get the import ID for the
family or document if the slug given is associated with a family
that belongs to the list of allowed corpora.
:param Session db: connection to db
:param str slug: slug name to match
:return tuple[Optional[str], Optional[str]]: the FamilyDocument import id or
the Family import_id
:param Optional[list[str]] allowed_corpora: The corpora IDs to look
for the slugged object in.
:return tuple[Optional[str], Optional[str]]: the FamilyDocument
import id or the Family import_id.
"""
result = (
db.query(Slug.family_document_import_id, Slug.family_import_id).filter(
if allowed_corpora is not None:
query_template = get_query_template(
os.path.join("app", "repository", "sql", "slug_lookup.sql")
)
query = get_slugged_object_from_allowed_corpora_query(
query_template, slug, allowed_corpora
)
query = db.execute(query)
else:
query = db.query(Slug.family_document_import_id, Slug.family_import_id).filter(
Slug.name == slug
)
).one_or_none()

result = query.one_or_none()
if result is None:
return (None, None)
return result

DOC_INDEX = 0
doc_id = cast(str, result[DOC_INDEX]) if result[DOC_INDEX] is not None else None

FAM_INDEX = 1
fam_id = cast(str, result[FAM_INDEX]) if result[FAM_INDEX] is not None else None

return doc_id, fam_id


def get_family_document_and_context(
Expand Down
33 changes: 26 additions & 7 deletions app/repository/download.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,43 @@
"""Functions to support browsing the RDS document structure"""

import os
from functools import lru_cache
from logging import getLogger

import pandas as pd
from fastapi import Depends

from app.clients.db.session import get_db
from app.repository.helpers import get_query_template

_LOGGER = getLogger(__name__)


@lru_cache()
def _get_query_template():
with open(os.path.join("app", "repository", "sql", "download.sql"), "r") as file:
return file.read()
def create_query(
template_query, ingest_cycle_start: str, allowed_corpora_ids: list[str]
) -> str:
"""Create download whole database query, replacing variables.
:param str ingest_cycle_start: The current ingest cycle date.
:param list[str] allowed_corpora_ids: The corpora from which we
should allow the data to be dumped.
:return str: The SQL query to perform on the database session.
"""
corpora_ids = "'" + "','".join(allowed_corpora_ids) + "'"
return template_query.replace( # type: ignore
"{ingest_cycle_start}", ingest_cycle_start
).replace(
"{allowed_corpora_ids}", corpora_ids
) # type: ignore


def get_whole_database_dump(
ingest_cycle_start: str, allowed_corpora_ids: list[str], db=Depends(get_db)
):
query_template = get_query_template(
os.path.join("app", "repository", "sql", "download.sql")
)
query = create_query(query_template, ingest_cycle_start, allowed_corpora_ids)


def get_whole_database_dump(query, db=Depends(get_db)):
with db.connection() as conn:
df = pd.read_sql(query, conn.connection)
return df
14 changes: 14 additions & 0 deletions app/repository/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
Functions to support the documents endpoints
old functions (non DFC) are moved to the deprecated_documents.py file.
"""

from functools import lru_cache


@lru_cache()
def get_query_template(filepath: str) -> str:
"""Read query for non-deleted docs and their associated data."""
with open(filepath, "r") as file:
return file.read()
11 changes: 2 additions & 9 deletions app/repository/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
import os
from datetime import datetime, timezone
from functools import lru_cache
from typing import Sequence, cast

import pandas as pd
Expand All @@ -11,19 +10,13 @@

from app.clients.db.session import get_db
from app.models.document import DocumentParserInput
from app.repository.helpers import get_query_template

_LOGGER = logging.getLogger(__name__)

MetadataType = dict[str, list[str]]


@lru_cache()
def generate_pipeline_ingest_input_query():
"""Read query for non-deleted docs and their associated data."""
with open(os.path.join("app", "repository", "sql", "pipeline.sql"), "r") as file:
return file.read()


def get_pipeline_data(db=Depends(get_db)) -> pd.DataFrame:
"""Get non-deleted docs and their associated data from the db.
Expand All @@ -39,7 +32,7 @@ def get_pipeline_data(db=Depends(get_db)) -> pd.DataFrame:
in database.
"""
_LOGGER.info("Running pipeline query")
query = generate_pipeline_ingest_input_query()
query = get_query_template(os.path.join("app", "repository", "sql", "pipeline.sql"))
df = pd.read_sql(query, db.connection().connection)
return df

Expand Down
20 changes: 20 additions & 0 deletions app/repository/sql/slug_lookup.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
SELECT
slug.family_document_import_id, slug.family_import_id
FROM slug
LEFT JOIN family ON family.import_id = slug.family_import_id
LEFT JOIN family_corpus ON family_corpus.family_import_id = family.import_id
LEFT JOIN corpus ON corpus.import_id = family_corpus.corpus_import_id
WHERE slug.name = '{slug_name}'
AND corpus.import_id IN ({allowed_corpora_ids})

UNION

SELECT
slug.family_document_import_id, slug.family_import_id
FROM slug
LEFT JOIN family_document ON family_document.import_id = slug.family_document_import_id
LEFT JOIN family ON family.import_id = family_document.family_import_id
LEFT JOIN family_corpus ON family_corpus.family_import_id = family.import_id
LEFT JOIN corpus ON corpus.import_id = family_corpus.corpus_import_id
WHERE slug.name = '{slug_name}'
AND corpus.import_id IN ({allowed_corpora_ids});
31 changes: 7 additions & 24 deletions app/service/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,11 @@
from fastapi import Depends

from app.clients.db.session import get_db
from app.repository.download import _get_query_template, get_whole_database_dump
from app.repository.download import get_whole_database_dump

_LOGGER = getLogger(__name__)


def create_query(
template_query, ingest_cycle_start: str, allowed_corpora_ids: list[str]
) -> str:
"""Create download whole database query, replacing variables.
:param str ingest_cycle_start: The current ingest cycle date.
:param list[str] allowed_corpora_ids: The corpora from which we
should allow the data to be dumped.
:return str: The SQL query to perform on the database session.
"""
corpora_ids = "'" + "','".join(allowed_corpora_ids) + "'"
return template_query.replace( # type: ignore
"{ingest_cycle_start}", ingest_cycle_start
).replace(
"{allowed_corpora_ids}", corpora_ids
) # type: ignore


def replace_slug_with_qualified_url(
df: pd.DataFrame,
public_app_url: str,
Expand Down Expand Up @@ -61,8 +43,10 @@ def convert_dump_to_csv(df: pd.DataFrame):
return csv_buffer


def generate_data_dump_as_csv(query, db=Depends(get_db)):
df = get_whole_database_dump(query, db)
def generate_data_dump_as_csv(
ingest_cycle_start: str, allowed_corpora_ids: list[str], db=Depends(get_db)
):
df = get_whole_database_dump(ingest_cycle_start, allowed_corpora_ids, db)
csv = convert_dump_to_csv(df)
csv.seek(0)
return csv
Expand Down Expand Up @@ -90,9 +74,8 @@ def create_data_download_zip_archive(
ingest_cycle_start: str, allowed_corpora_ids: list[str], db=Depends(get_db)
):
readme_buffer = generate_data_dump_readme(ingest_cycle_start)
query_template = _get_query_template()
query = create_query(query_template, ingest_cycle_start, allowed_corpora_ids)
csv_buffer = generate_data_dump_as_csv(query, db)

csv_buffer = generate_data_dump_as_csv(ingest_cycle_start, allowed_corpora_ids, db)

zip_buffer = BytesIO()
with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "navigator_backend"
version = "1.19.10"
version = "1.19.11"
description = ""
authors = ["CPR-dev-team <[email protected]>"]
packages = [{ include = "app" }, { include = "tests" }]
Expand Down
26 changes: 25 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def valid_token(monkeypatch):
def mock_return(_, __, ___):
return True

corpora_ids = "CCLW.corpus.1.0,CCLW.corpus.2.0"
corpora_ids = "CCLW.corpus.1.0,CCLW.corpus.2.0,CCLW.corpus.i00000001.n0000"
subject = "CCLW"
audience = "localhost"
input_str = f"{corpora_ids};{subject};{audience}"
Expand All @@ -127,6 +127,30 @@ def mock_return(_, __, ___):
return af.create_configuration_token(input_str)


@pytest.fixture
def alternative_token(monkeypatch):
"""Generate a valid alternative config token using TOKEN_SECRET_KEY.
Need to generate the config token using the token secret key from
your local env file. For tests in CI, this will be the secret key in
the .env.example file, but for local development this secret key
might be different (e.g., the one for staging). This fixture works
around this.
"""

def mock_return(_, __, ___):
return True

corpora_ids = "UNFCCC.corpus.i00000001.n0000"
subject = "CPR"
audience = "localhost"
input_str = f"{corpora_ids};{subject};{audience}"

af = AppTokenFactory()
monkeypatch.setattr(custom_app.AppTokenFactory, "validate", mock_return)
return af.create_configuration_token(input_str)


@pytest.fixture
def create_test_db():
"""Create a test database and use it for the whole test session."""
Expand Down
24 changes: 24 additions & 0 deletions tests/non_search/routers/documents/setup_doc_fam_lookup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Optional

from fastapi import status

DOCUMENTS_ENDPOINT = "/api/v1/documents"
TEST_HOST = "http://localhost:3000/"


def _make_doc_fam_lookup_request(
client,
token,
slug: str,
expected_status_code: int = status.HTTP_200_OK,
origin: Optional[str] = TEST_HOST,
):
headers = (
{"app-token": token}
if origin is None
else {"app-token": token, "origin": origin}
)

response = client.get(f"{DOCUMENTS_ENDPOINT}/{slug}", headers=headers)
assert response.status_code == expected_status_code, response.text
return response.json()
Loading

0 comments on commit e4e9b9d

Please sign in to comment.