Skip to content

Commit

Permalink
Config endpoint returns only orgs that correspond to allowed corpora (#…
Browse files Browse the repository at this point in the history
…426)

* Config endpoint returns only orgs that correspond to allowed corpora

* Bump patch version

* Bump patch version again

* Add missing param

* Fix unit test

* Undo

* Return all orgs if allowed corpora list is empty

* Remove commented out code
  • Loading branch information
annaCPR authored Dec 4, 2024
1 parent 5fa106f commit 85d86de
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 20 deletions.
14 changes: 11 additions & 3 deletions app/api/api_v1/routers/lookups/config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
from fastapi import Depends, Request
from typing import Annotated

from fastapi import Depends, Header, Request

from app.api.api_v1.routers.lookups.router import lookups_router
from app.clients.db.session import get_db
from app.models.metadata import ApplicationConfig
from app.repository.lookups import get_config
from app.service.custom_app import AppTokenFactory


@lookups_router.get("/config", response_model=ApplicationConfig)
def lookup_config(request: Request, db=Depends(get_db)):
def lookup_config(
request: Request, app_token: Annotated[str, Header()], db=Depends(get_db)
):
"""Get the config for the metadata."""
return get_config(db)
token = AppTokenFactory()
token.decode_and_validate(db, request, app_token)

return get_config(db, token.allowed_corpora_ids)
6 changes: 3 additions & 3 deletions app/repository/lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,20 @@
from sqlalchemy.orm import Session

from app.models.metadata import ApplicationConfig
from app.repository.organisation import get_all_organisations, get_organisation_config
from app.repository.organisation import get_organisation_config, get_organisations
from app.service.pipeline import IMPORT_ID_MATCHER
from app.service.util import tree_table_to_json

_LOGGER = logging.getLogger(__name__)


def get_config(db: Session) -> ApplicationConfig:
def get_config(db: Session, allowed_corpora: list[str]) -> ApplicationConfig:
# First get the CCLW stats
return ApplicationConfig(
geographies=tree_table_to_json(table=Geography, db=db),
organisations={
cast(str, org.name): get_organisation_config(db, org)
for org in get_all_organisations(db)
for org in get_organisations(db, allowed_corpora)
},
languages={lang.language_code: lang.name for lang in db.query(Language).all()},
document_variants=[
Expand Down
9 changes: 7 additions & 2 deletions app/repository/organisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,10 @@ def get_organisation_config(db: Session, org: Organisation) -> OrganisationConfi
)


def get_all_organisations(db: Session) -> list[Organisation]:
return db.query(Organisation).all()
def get_organisations(db: Session, allowed_corpora: list[str]) -> list[Organisation]:
query = db.query(Organisation).join(
Corpus, Corpus.organisation_id == Organisation.id
)
if allowed_corpora != []:
query = query.filter(Corpus.import_id.in_(allowed_corpora))
return query.all()
1 change: 1 addition & 0 deletions app/service/custom_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,4 +260,5 @@ def decode_and_validate(

# First corpora validation is app token against DB. At least one of the app token
# corpora IDs must be present in the DB to continue the search request.
any_exist = False if not self.allowed_corpora_ids else True
self.validate(db, any_exist)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "navigator_backend"
version = "1.19.17"
version = "1.19.18"
description = ""
authors = ["CPR-dev-team <[email protected]>"]
packages = [{ include = "app" }, { include = "tests" }]
Expand Down
26 changes: 26 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,32 @@ def mock_return(_, __, ___):
return af.create_configuration_token(input_str)


@pytest.fixture
def app_token_factory(monkeypatch):
"""Generate a valid config token using TOKEN_SECRET_KEY and given corpora ids.
Need to generate the config token using the token secret key from
your local env file. For tests in CI, this will be the secret key in
the .env.example file, but for local development this secret key
might be different (e.g., the one for staging). This fixture works
around this.
"""

def mock_return(_, __, ___):
return True

def _app_token(allowed_corpora_ids):
subject = "CCLW"
audience = "localhost"
input_str = f"{allowed_corpora_ids};{subject};{audience}"

af = AppTokenFactory()
monkeypatch.setattr(custom_app.AppTokenFactory, "validate", mock_return)
return af.create_configuration_token(input_str)

return _app_token


@pytest.fixture
def create_test_db():
"""Create a test database and use it for the whole test session."""
Expand Down
155 changes: 146 additions & 9 deletions tests/non_search/routers/lookups/test_config.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import os
from datetime import datetime
from http.client import OK
from typing import Any
from unittest.mock import MagicMock

import jwt
import pytest
from dateutil.relativedelta import relativedelta
from db_client.models.dfce.family import (
Family,
FamilyCategory,
FamilyCorpus,
FamilyGeography,
)
from db_client.models.organisation import Corpus, Organisation
from db_client.models.organisation import Corpus, CorpusType, Organisation

from app.clients.db.session import SessionLocal
from app.service import security
from app.service.util import tree_table_to_json

LEN_ORG_CONFIG = 3
Expand Down Expand Up @@ -71,15 +76,13 @@ def _add_family(test_db, import_id: str, cat: FamilyCategory, corpus_import_id):
)


def test_config_endpoint_content(data_client, data_db):
def test_config_endpoint_content(data_client, data_db, valid_token):
"""Tests whether we get the expected content when the /config endpoint is called."""
# TODO: this test is fragile, we should look into validation according to the
# supporting data, rather than counts & fixed lists
url_under_test = "/api/v1/config"

response = data_client.get(
url_under_test,
)
response = data_client.get(url_under_test, headers={"app-token": valid_token})

response_json = response.json()

Expand Down Expand Up @@ -165,7 +168,7 @@ def test_config_endpoint_content(data_client, data_db):
)


def test_config_endpoint_cclw_stats(data_client, data_db):
def test_config_endpoint_cclw_stats(data_client, data_db, valid_token):
url_under_test = "/api/v1/config"

cclw = (
Expand All @@ -191,9 +194,7 @@ def test_config_endpoint_cclw_stats(data_client, data_db):
_add_family(data_db, "T.0.0.6", FamilyCategory.UNFCCC, unfccc.import_id)
data_db.flush()

response = data_client.get(
url_under_test,
)
response = data_client.get(url_under_test, headers={"app-token": valid_token})

response_json = response.json()

Expand All @@ -211,6 +212,142 @@ def test_config_endpoint_cclw_stats(data_client, data_db):
assert org_config["total"] == laws + policies + unfccc


@pytest.mark.parametrize(
"allowed_corpora_ids, expected_organisation, other_organisation",
[
("UNFCCC.corpus.i00000001.n0000", "UNFCCC", "CCLW"),
("CCLW.corpus.i00000001.n0000", "CCLW", "UNFCCC"),
],
)
def test_config_endpoint_returns_stats_for_allowed_corpora_only(
allowed_corpora_ids,
expected_organisation,
other_organisation,
app_token_factory,
data_client,
data_db,
):
app_token = app_token_factory(allowed_corpora_ids)
url_under_test = "/api/v1/config"

other_corpus = (
data_db.query(Corpus)
.join(Organisation, Organisation.id == Corpus.organisation_id)
.filter(Organisation.name == other_organisation)
.one()
)
expected_corpus = (
data_db.query(Corpus)
.join(Organisation, Organisation.id == Corpus.organisation_id)
.filter(Organisation.name == expected_organisation)
.one()
)
expected_corpus_type = (
data_db.query(CorpusType)
.join(Corpus, Corpus.corpus_type_name == CorpusType.name)
.filter(CorpusType.name == expected_corpus.corpus_type_name)
.one()
)

_add_family(data_db, "T.0.0.1", FamilyCategory.EXECUTIVE, other_corpus.import_id)
_add_family(
data_db, "T.0.0.2", FamilyCategory.LEGISLATIVE, expected_corpus.import_id
)
data_db.flush()

response = data_client.get(url_under_test, headers={"app-token": app_token})

response_json = response.json()

org_config = response_json["organisations"]
expected_org_config = {
expected_organisation: {
"corpora": [
{
"corpus_import_id": expected_corpus.import_id,
"title": expected_corpus.title,
"description": expected_corpus.description,
"corpus_type": expected_corpus.corpus_type_name,
"corpus_type_description": expected_corpus_type.description,
"taxonomy": expected_corpus_type.valid_metadata,
"text": expected_corpus.corpus_text,
"image_url": (
f"https://cdn.climatepolicyradar.org/{expected_corpus.corpus_image_url}"
if expected_corpus.corpus_image_url
else ""
),
}
],
"total": 1,
"count_by_category": {
"Executive": 0,
"Legislative": 1,
"MCF": 0,
"UNFCCC": 0,
},
}
}
assert org_config == expected_org_config


def test_config_endpoint_returns_stats_for_all_orgs_if_no_allowed_corpora_in_app_token(
data_client,
data_db,
):
issued_at = datetime.utcnow()
to_encode = {
"allowed_corpora_ids": [],
"exp": issued_at + relativedelta(years=10),
"iat": int(datetime.timestamp(issued_at.replace(microsecond=0))),
"iss": "Climate Policy Radar",
"sub": "CPR",
"aud": "localhost",
}
app_token = jwt.encode(
to_encode, os.environ["TOKEN_SECRET_KEY"], algorithm=security.ALGORITHM
)
url_under_test = "/api/v1/config"

cclw_corpus = (
data_db.query(Corpus)
.join(Organisation, Organisation.id == Corpus.organisation_id)
.filter(Organisation.name == "CCLW")
.one()
)

unfccc_corpus = (
data_db.query(Corpus)
.join(Organisation, Organisation.id == Corpus.organisation_id)
.filter(Organisation.name == "UNFCCC")
.one()
)

_add_family(data_db, "T.0.0.1", FamilyCategory.EXECUTIVE, cclw_corpus.import_id)
_add_family(data_db, "T.0.0.2", FamilyCategory.LEGISLATIVE, unfccc_corpus.import_id)
data_db.flush()

response = data_client.get(url_under_test, headers={"app-token": app_token})

response_json = response.json()
org_config = response_json["organisations"]

assert list(org_config.keys()) == ["CCLW", "UNFCCC"]
assert org_config["CCLW"]["total"] == 1
assert org_config["UNFCCC"]["total"] == 1
assert org_config["UNFCCC"]["count_by_category"] == {
"Executive": 0,
"Legislative": 1,
"MCF": 0,
"UNFCCC": 0,
}
assert org_config["CCLW"]["count_by_category"] == {
"Executive": 1,
"Legislative": 0,
"MCF": 0,
"UNFCCC": 0,
}


class _MockColumn:
def __init__(self, name):
self.name = name
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/app/core/test_organisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from sqlalchemy.orm import Session

from app.repository.organisation import get_all_organisations, get_corpora_for_org
from app.repository.organisation import get_corpora_for_org, get_organisations
from tests.non_search.setup_helpers import setup_new_corpus, setup_with_docs

CCLW_EXPECTED_NUM_CORPORA = 1
Expand Down Expand Up @@ -34,7 +34,9 @@


def test_expected_organisations_present(data_db: Session):
orgs = get_all_organisations(data_db)
orgs = get_organisations(
data_db, ["UNFCCC.corpus.i00000001.n0000", "CCLW.corpus.i00000001.n0000"]
)
assert len(orgs) == EXPECTED_NUM_ORGS

org_names = set([cast(str, org.name) for org in orgs])
Expand Down

0 comments on commit 85d86de

Please sign in to comment.