Skip to content

Commit

Permalink
Use host instead of FQDN for audience (#349)
Browse files Browse the repository at this point in the history
* Use host header for audience comparison

* Update custom app documentation to use hostname

* Make audience a string instead of URL

* Use host for audience & add additional corpora ID validation

* Bump to 1.16.4

* Make audiences match
  • Loading branch information
katybaulch authored Sep 23, 2024
1 parent be9bd8d commit 9f844ca
Show file tree
Hide file tree
Showing 10 changed files with 77 additions and 71 deletions.
16 changes: 11 additions & 5 deletions app/api/api_v1/routers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
process_result_into_csv,
process_vespa_search_response,
)
from app.db.crud.helpers.validate import validate_corpora_ids
from app.db.crud.helpers.validate import verify_any_corpora_ids_in_db
from app.db.session import get_db

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -113,6 +113,7 @@ def search_documents(
}
),
],
host: Annotated[str, Header()],
app_token: Annotated[str, Header()],
db=Depends(get_db),
) -> SearchResponse:
Expand All @@ -137,7 +138,7 @@ def search_documents(
results from the search database. See the request schema for more details.
"""
try:
allowed_corpora_ids = decode_config_token(app_token, PUBLIC_APP_URL)
allowed_corpora_ids = decode_config_token(app_token, host)
except PyJWTError as e:
_LOGGER.error(e)
raise HTTPException(
Expand All @@ -156,15 +157,20 @@ def search_documents(
},
)

if not validate_corpora_ids(db, allowed_corpora_ids):
msg = "Error validating corpora IDs."
# First corpora validation is app token against DB. At least one of the app token
# corpora IDs must be present in the DB to continue the search request.
if not verify_any_corpora_ids_in_db(db, allowed_corpora_ids):
msg = "Error verifying corpora IDs."
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=msg,
headers={"WWW-Authenticate": "Bearer"},
)

_LOGGER.info("Starting search...")
_LOGGER.info(
"Starting search...",
extra={"props": {"search_request": search_body.model_dump()}},
)
return _search_request(db=db, search_body=search_body)


Expand Down
4 changes: 2 additions & 2 deletions app/api/api_v1/schemas/custom_app.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from datetime import datetime

from pydantic import BaseModel, HttpUrl
from pydantic import BaseModel


class CustomAppConfigDTO(BaseModel):
"""A JSON representation of custom app configurable options."""

allowed_corpora_ids: list[str]
subject: str
audience: HttpUrl
audience: str
issuer: str
expiry: datetime
issued_at: int
20 changes: 3 additions & 17 deletions app/core/custom_app.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import logging
import os
from datetime import datetime
from typing import Optional, cast
from typing import Optional

import jwt
from dateutil.relativedelta import relativedelta
from pydantic import HttpUrl

from app.api.api_v1.schemas.custom_app import CustomAppConfigDTO
from app.core import security
Expand Down Expand Up @@ -69,7 +68,7 @@ def create_configuration_token(input: str, years: Optional[int] = None) -> str:
allowed_corpora_ids=_parse_and_sort_corpora_ids(corpora_ids),
subject=subject,
issuer=ISSUER,
audience=cast(HttpUrl, add_trailing_slash_to_url(audience)),
audience=audience,
expiry=expire,
issued_at=int(
datetime.timestamp(issued_at.replace(microsecond=0))
Expand Down Expand Up @@ -111,21 +110,8 @@ def decode_config_token(token: str, audience: Optional[str]) -> list[str]:
TOKEN_SECRET_KEY,
algorithms=[security.ALGORITHM],
issuer=ISSUER,
audience=(
add_trailing_slash_to_url(audience) if audience is not None else None
),
audience=audience,
)
corpora_ids: list = decoded_token.get("allowed_corpora_ids")

return corpora_ids


def add_trailing_slash_to_url(app_url: str) -> str:
"""Add trailing slash to end of a URL string.
:param str app_url : A URL in string format.
:return str: The URL with a trailing '/' added if it wasn't present.
"""
if not app_url.endswith("/"):
app_url += "/"
return app_url
30 changes: 16 additions & 14 deletions app/db/crud/helpers/validate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Optional, cast
from typing import cast

from db_client.models.dfce.family import Corpus
from sqlalchemy import distinct, select
Expand All @@ -8,23 +8,25 @@
_LOGGER = logging.getLogger(__name__)


def validate_corpora_ids(
db: Session, corpora_ids: list[str], allowed_corpora_ids: Optional[list[str]] = None
) -> bool:
"""Validate all given corpus IDs against a list of allowed corpora.
def verify_any_corpora_ids_in_db(db: Session, corpora_ids: list[str]) -> bool:
"""Validate given corpus IDs against the existing corpora in DB.
:param Session db: The DB session to connect to.
:param list[str] corpora_ids: The corpus import IDs we want to
validate.
:param Optional[list[str]] allowed_corpora_ids: The corpus import
IDs we want to validate against.
validate against the DB values.
:return bool: Return whether or not all the corpora are valid.
"""
if allowed_corpora_ids is None:
allowed_corpora_ids = cast(
list, db.scalars(select(distinct(Corpus.import_id))).all()
)
_LOGGER.info(allowed_corpora_ids) # TODO remove in part 2.
corpora_ids_from_db = cast(
list, db.scalars(select(distinct(Corpus.import_id))).all()
)

validate_success = any(corpus in corpora_ids_from_db for corpus in corpora_ids)
if validate_success:
not_in_db = set(corpora_ids).difference(corpora_ids_from_db)
if not_in_db != set():
_LOGGER.warning(
f"Some corpora in app token {not_in_db} "
"not available for searching against."
)

validate_success = all(corpus in allowed_corpora_ids for corpus in corpora_ids)
return validate_success
8 changes: 4 additions & 4 deletions docs/custom_apps.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ where:
custom app
- `THEME` is the name of the theme or organisation - it must not contain any
special characters or spaces
- `APP_DOMAIN` is the domain name of the custom app (NOTE: this is specific to
the AWS environment)
- `AUDIENCE` is the host name of the custom app, without HTTP scheme. This must
not end with a trailing slash OR start with a trailing slash.

## Example usage

Expand All @@ -29,8 +29,8 @@ want to generate the token for.
export TOKEN_SECRET_KEY=secret_key
export CORPORA_IDS=CCLW.corpus.i00000001.n0000,UNFCCC.corpus.i00000001.n0000
export THEME=CPR
export APP_DOMAIN=https://app.climatepolicyradar.org/
python -c "from app.core import config; from app.core.custom_app import create_configuration_token; print(create_configuration_token('$CORPORA_IDS;$THEME;$APP_DOMAIN'))"
export AUDIENCE=app.dev.climatepolicyradar.org
python -c "from app.core import config; from app.core.custom_app import create_configuration_token; print(create_configuration_token('$CORPORA_IDS;$THEME;$AUDIENCE'))"
```

### Decoding
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "navigator_backend"
version = "1.16.3"
version = "1.16.4"
description = ""
authors = ["CPR-dev-team <[email protected]>"]
packages = [{ include = "app" }, { include = "tests" }]
Expand Down
19 changes: 18 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,24 @@ def valid_token():
"""
corpora_ids = "CCLW.corpus.i00000001.n0000,UNFCCC.corpus.i00000001.n0000"
subject = "CCLW"
audience = "http://localhost:3000"
audience = "localhost:8888"
input_str = f"{corpora_ids};{subject};{audience}"
return create_configuration_token(input_str)


@pytest.fixture
def unfccc_token():
"""Generate valid config token using TOKEN_SECRET_KEY.
Need to generate the config token using the token secret key from
your local env file. For tests in CI, this will be the secret key in
the .env.example file, but for local development this secret key
might be different (e.g., the one for staging). This fixture works
around this.
"""
corpora_ids = "CCLW.corpus.i00000001.n0000,UNFCCC.corpus.i00000001.n0000"
subject = "CCLW"
audience = "localhost:8888"
input_str = f"{corpora_ids};{subject};{audience}"
return create_configuration_token(input_str)

Expand Down
9 changes: 7 additions & 2 deletions tests/search/vespa/setup_search_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from sqlalchemy.orm import Session

SEARCH_ENDPOINT = "/api/v1/searches"
TEST_HOST = "localhost:8888"


def _make_search_request(
Expand All @@ -37,7 +38,11 @@ def _make_search_request(
params: Mapping[str, Any],
expected_status_code: int = status.HTTP_200_OK,
):
response = client.post(SEARCH_ENDPOINT, json=params, headers={"app-token": token})
response = client.post(
SEARCH_ENDPOINT,
json=params,
headers={"app-token": token, "host": TEST_HOST},
)
assert response.status_code == expected_status_code, response.text
return response.json()

Expand Down Expand Up @@ -160,7 +165,7 @@ def _create_family_event(db: Session, family: VespaFixture):


def _generate_synthetic_metadata(
taxonomy: Mapping[str, dict]
taxonomy: Mapping[str, dict],
) -> Mapping[str, Sequence[str]]:
meta_value = {}
for k in taxonomy:
Expand Down
4 changes: 2 additions & 2 deletions tests/search/vespa/test_search_raises_on_token_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_search_with_invalid_corpus_id_in_token(
_populate_db_families(data_db)

with patch(
"app.api.api_v1.routers.search.validate_corpora_ids", return_value=False
"app.api.api_v1.routers.search.verify_any_corpora_ids_in_db", return_value=False
):
response = _make_search_request(
data_client,
Expand All @@ -33,7 +33,7 @@ def test_search_with_invalid_corpus_id_in_token(
expected_status_code=status.HTTP_400_BAD_REQUEST,
)

assert response["detail"] == "Error validating corpora IDs."
assert response["detail"] == "Error verifying corpora IDs."


@pytest.mark.search
Expand Down
36 changes: 13 additions & 23 deletions tests/unit/app/core/custom_app/test_create_configuration_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,32 +67,16 @@ def test_create_configuration_token_subject_contains_special_chars(input_str: st
assert has_expected_keys(data)


@pytest.mark.parametrize(
"input_str",
[
("kiwi;subject;not_a_url"),
("pear;subject;url_with_no_scheme.com"),
],
)
def test_create_configuration_token_audience_not_a_http_url(input_str: str):
with pytest.raises(ValueError):
token = create_configuration_token(input_str)
assert token is None

data = jwt.decode(token, TOKEN_SECRET_KEY, algorithms=[ALGORITHM])
assert has_expected_keys(data)


@pytest.mark.parametrize(
"input_str,expected_allowed_corpora,expected_subject,expected_audience",
[
(
"apple,banana,carrot;ORG1;http://ORG1.org",
"apple,banana,carrot;ORG1;ORG1.org",
["apple", "banana", "carrot"],
"ORG1",
"http://org1.org/",
"ORG1.org",
),
("cucumber;ORG2;https://ORG2.com/", ["cucumber"], "ORG2", "https://org2.com/"),
("cucumber;ORG2;ORG2.com", ["cucumber"], "ORG2", "ORG2.com"),
],
)
def test_create_configuration_token_default_expiry(
Expand Down Expand Up @@ -122,23 +106,26 @@ def test_create_configuration_token_default_expiry(
EXPIRE_AFTER_DEFAULT_YEARS, datetime.fromtimestamp(data["exp"])
) == datetime.fromtimestamp(data["iat"])

assert not data["aud"].endswith("/")
assert not data["aud"].startswith("http")


@pytest.mark.parametrize(
"input_str,expected_allowed_corpora,expiry_years,expected_subject,expected_audience",
[
(
"raspberry,strawberry,orange;ORG1;http://ORG1.org",
"raspberry,strawberry,orange;ORG1;ORG1.org",
["orange", "raspberry", "strawberry"],
EXPIRE_AFTER_1_YEAR,
"ORG1",
"http://org1.org/",
"ORG1.org",
),
(
"grapefruit;ORG2;https://ORG2.com/",
"grapefruit;ORG2;ORG2.com",
["grapefruit"],
EXPIRE_AFTER_5_YEARS,
"ORG2",
"https://org2.com/",
"ORG2.com",
),
],
)
Expand Down Expand Up @@ -168,3 +155,6 @@ def test_create_configuration_token_specific_expiry(
assert timedelta_years(
expiry_years, datetime.fromtimestamp(data["exp"])
) == datetime.fromtimestamp(data["iat"])

assert not data["aud"].endswith("/")
assert not data["aud"].startswith("http")

0 comments on commit 9f844ca

Please sign in to comment.