Skip to content

Commit

Permalink
Implement ooniauth API v2 (#825)
Browse files Browse the repository at this point in the history
* Implement ooniauth API v2

* Add support for passing admin_emails via configuration

* Get rid of the account_id hashing

* Move clickhouse functionality into clickhouse_utils

* Fix broken errors and assertions

* Add tests for AHrefParser

* Mark the single line which is not covered as no cov
  • Loading branch information
hellais authored Mar 15, 2024
1 parent 0c9bdd9 commit ca2857c
Show file tree
Hide file tree
Showing 13 changed files with 603 additions and 205 deletions.
67 changes: 67 additions & 0 deletions ooniapi/common/src/common/clickhouse_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import logging
from typing import Dict, List, Optional, Union
import clickhouse_driver
import clickhouse_driver.errors

from sqlalchemy.dialects import postgresql
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.selectable import Select

log = logging.getLogger(__name__)

Query = Union[str, TextClause, Select]


def _run_query(
db: clickhouse_driver.Client, query: Query, query_params: dict, query_prio=3
):
# settings = {"priority": query_prio, "max_execution_time": 28}
settings = {}
if isinstance(query, (Select, TextClause)):
query = str(query.compile(dialect=postgresql.dialect()))
try:
q = db.execute(query, query_params, with_column_types=True, settings=settings)
except clickhouse_driver.errors.ServerException as e:
log.info(e.message)
raise Exception("Database query error")

rows, coldata = q # type: ignore
colnames, coltypes = tuple(zip(*coldata))
return colnames, rows


def query_click(
db: clickhouse_driver.Client, query: Query, query_params: dict, query_prio=3
) -> List[Dict]:
colnames, rows = _run_query(db, query, query_params, query_prio=query_prio)
return [dict(zip(colnames, row)) for row in rows] # type: ignore


def query_click_one_row(
db: clickhouse_driver.Client, query: Query, query_params: dict, query_prio=3
) -> Optional[dict]:
colnames, rows = _run_query(db, query, query_params, query_prio=query_prio)
for row in rows:
return dict(zip(colnames, row)) # type: ignore

return None


def insert_click(db: clickhouse_driver.Client, query: Query, rows: list) -> int:
assert isinstance(rows, list)
settings = {"priority": 1, "max_execution_time": 300} # query_prio
return db.execute(query, rows, types_check=True, settings=settings) # type: ignore


def optimize_table(db: clickhouse_driver.Client, tblname: str) -> None:
settings = {"priority": 1, "max_execution_time": 300} # query_prio
sql = f"OPTIMIZE TABLE {tblname} FINAL"
db.execute(sql, {}, settings=settings)


def raw_query(
db: clickhouse_driver.Client, query: Query, query_params: dict, query_prio=1
):
settings = {"priority": query_prio, "max_execution_time": 300}
q = db.execute(query, query_params, with_column_types=True, settings=settings)
return q
6 changes: 5 additions & 1 deletion ooniapi/common/src/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@ class Settings(BaseSettings):
statsd_port: int = 8125
statsd_prefix: str = "ooniapi"
jwt_encryption_key: str = "CHANGEME"
account_id_hashing_key: str = "CHANGEME"
prometheus_metrics_password: str = "CHANGEME"
session_expiry_days: int = 10
login_expiry_days: int = 10

admin_emails: List[str] = [
"[email protected]",
"[email protected]",
]

aws_region: str = ""
aws_access_key_id: str = ""
aws_secret_access_key: str = ""
Expand Down
65 changes: 0 additions & 65 deletions ooniapi/common/src/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,9 @@
from io import StringIO
import logging
from typing import Any, Dict, List, Optional, Union
from fastapi import HTTPException, Header
from fastapi.responses import JSONResponse

import jwt
import clickhouse_driver
import clickhouse_driver.errors

from sqlalchemy.dialects import postgresql
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.selectable import Select


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,64 +60,6 @@ def convert_to_csv(r) -> str:
return result


Query = Union[str, TextClause, Select]


def _run_query(
db: clickhouse_driver.Client, query: Query, query_params: dict, query_prio=3
):
# settings = {"priority": query_prio, "max_execution_time": 28}
settings = {}
if isinstance(query, (Select, TextClause)):
query = str(query.compile(dialect=postgresql.dialect()))
try:
q = db.execute(query, query_params, with_column_types=True, settings=settings)
except clickhouse_driver.errors.ServerException as e:
log.info(e.message)
raise Exception("Database query error")

rows, coldata = q # type: ignore
colnames, coltypes = tuple(zip(*coldata))
return colnames, rows


def query_click(
db: clickhouse_driver.Client, query: Query, query_params: dict, query_prio=3
) -> List[Dict]:
colnames, rows = _run_query(db, query, query_params, query_prio=query_prio)
return [dict(zip(colnames, row)) for row in rows] # type: ignore


def query_click_one_row(
db: clickhouse_driver.Client, query: Query, query_params: dict, query_prio=3
) -> Optional[dict]:
colnames, rows = _run_query(db, query, query_params, query_prio=query_prio)
for row in rows:
return dict(zip(colnames, row)) # type: ignore

return None


def insert_click(db: clickhouse_driver.Client, query: Query, rows: list) -> int:
assert isinstance(rows, list)
settings = {"priority": 1, "max_execution_time": 300} # query_prio
return db.execute(query, rows, types_check=True, settings=settings) # type: ignore


def optimize_table(db: clickhouse_driver.Client, tblname: str) -> None:
settings = {"priority": 1, "max_execution_time": 300} # query_prio
sql = f"OPTIMIZE TABLE {tblname} FINAL"
db.execute(sql, {}, settings=settings)


def raw_query(
db: clickhouse_driver.Client, query: Query, query_params: dict, query_prio=1
):
settings = {"priority": query_prio, "max_execution_time": 300}
q = db.execute(query, query_params, with_column_types=True, settings=settings)
return q


def decode_jwt(token: str, key: str, **kw) -> Dict[str, Any]:
tok = jwt.decode(token, key, algorithms=["HS256"], **kw)
return tok
Expand Down
26 changes: 12 additions & 14 deletions ooniapi/services/ooniauth/README.md
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
# ooniauth

[![PyPI - Version](https://img.shields.io/pypi/v/ooniauth.svg)](https://pypi.org/project/ooniauth)
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/ooniauth.svg)](https://pypi.org/project/ooniauth)
The OONI Auth service is designed to allow users to authenticate with their
email address to OONI services.

-----
The basic workflow is:

**Table of Contents**
1. Perform a login request by providing your email address
2. Check your email to retrieve the login link
3. Click on the login link to generate a session token that's valid for the
duration of the session

- [Installation](#installation)
- [License](#license)
You may also want to periodically refresh the session token so that it does not
expire.

## Installation
The tokens which are part of the system are:

```console
pip install ooniauth
```

## License

`ooniauth` is distributed under the terms of the [MIT](https://spdx.org/licenses/MIT.html) license.
- Login tokens, which are sent via email and are tied to an email address
- Session token, which are issued by the API and are tied to a login token
1 change: 0 additions & 1 deletion ooniapi/services/ooniauth/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ classifiers = [
]
dependencies = [
"fastapi ~= 0.108.0",
"clickhouse-driver ~= 0.2.6",
"sqlalchemy ~= 2.0.27",
"ujson ~= 5.9.0",
"python-dateutil ~= 2.8.2",
Expand Down
10 changes: 2 additions & 8 deletions ooniapi/services/ooniauth/src/ooniauth/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Annotated

from clickhouse_driver import Client as ClickhouseClient
import boto3

from fastapi import Depends
Expand All @@ -9,14 +8,9 @@
from .common.config import Settings


def get_clickhouse_client(
settings: Annotated[Settings, Depends(get_settings)]
) -> ClickhouseClient:
return ClickhouseClient.from_url(settings.clickhouse_url)


def get_ses_client(settings: Annotated[Settings, Depends(get_settings)]):
return boto3.client(
# TODO(art): add support for running integration tests of boto
return boto3.client( # no cov
"ses",
region_name=settings.aws_region,
aws_access_key_id=settings.aws_access_key_id,
Expand Down
6 changes: 2 additions & 4 deletions ooniapi/services/ooniauth/src/ooniauth/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from prometheus_fastapi_instrumentator import Instrumentator

from .routers import v1
from .routers import v1, v2

from .common.config import Settings
from .common.dependencies import get_settings
Expand Down Expand Up @@ -49,6 +49,7 @@ async def lifespan(app: FastAPI):
)

app.include_router(v1.router, prefix="/api")
app.include_router(v2.router, prefix="/api")


@app.get("/version")
Expand Down Expand Up @@ -82,9 +83,6 @@ async def health(
if settings.aws_secret_access_key == "" or settings.aws_access_key_id == "":
errors.append("bad_aws_credentials")

if settings.account_id_hashing_key == "CHANGEME":
errors.append("bad_prometheus_password")

if len(errors) > 0:
log.error(f"Health check errors: {errors}")
raise HTTPException(status_code=542, detail=f"health check failed")
Expand Down
49 changes: 12 additions & 37 deletions ooniapi/services/ooniauth/src/ooniauth/routers/v1.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
"""
OONIRun link management
https://github.com/ooni/spec/blob/master/backends/bk-005-ooni-run-v2.md
"""

from datetime import datetime, timedelta, timezone
from typing import Optional
from urllib.parse import urlparse, urlencode, urlunsplit
Expand All @@ -16,13 +10,14 @@
from pydantic import EmailStr
from typing_extensions import Annotated

from ..dependencies import get_clickhouse_client, get_ses_client
from ..dependencies import get_ses_client

from ..utils import (
create_session_token,
get_account_role,
hash_email_address,
send_login_email,
format_login_url,
VALID_REDIRECT_TO_FQDN,
)
from ..common.dependencies import get_settings, role_required
from ..common.config import Settings
Expand All @@ -38,10 +33,6 @@

router = APIRouter()

# @router.get("/api/v2/ooniauth/user-session")
# @router.post("/api/v2/ooniauth/user-session", response_model=SessionTokenCreate)
# redirect_to: ## Make this optional


class UserRegister(BaseModel):
email_address: EmailStr = Field(
Expand All @@ -56,26 +47,13 @@ def validate_redirect_to(cls, v):
u = urlparse(v)
if u.scheme != "https":
raise ValueError("Invalid URL")
valid_dnames = (
"explorer.ooni.org",
"explorer.test.ooni.org",
"run.ooni.io",
"run.test.ooni.org",
"test-lists.ooni.org",
"test-lists.test.ooni.org",
)
if u.netloc not in valid_dnames:

if u.netloc not in VALID_REDIRECT_TO_FQDN:
raise ValueError("Invalid URL", u.netloc)

return v


def format_login_url(redirect_to: str, registration_token: str) -> str:
login_fqdm = urlparse(redirect_to).netloc
e = urlencode(dict(token=registration_token))
return urlunsplit(("https", login_fqdm, "/login", e, ""))


class UserRegistrationResponse(BaseModel):
msg: str

Expand All @@ -89,17 +67,13 @@ async def user_register(
"""Auth Services: start email-based user registration"""
email_address = user_register.email_address.lower()

account_id = hash_email_address(
email_address=email_address, key=settings.account_id_hashing_key
)
now = datetime.now(timezone.utc)
expiration = now + timedelta(days=1)
# On the backend side the registration is stateless
payload = {
"nbf": now,
"exp": expiration,
"aud": "register",
"account_id": account_id,
"email_address": email_address,
"redirect_to": user_register.redirect_to,
}
Expand Down Expand Up @@ -138,7 +112,6 @@ async def user_login(
Query(alias="k", description="JWT token with aud=register"),
],
settings: Settings = Depends(get_settings),
db: Settings = Depends(get_clickhouse_client),
):
"""Auth Services: login using a registration/login link"""
try:
Expand All @@ -157,21 +130,23 @@ async def user_login(
log.info("user login successful")

# Store account role in token to prevent frequent DB lookups
role = get_account_role(db=db, account_id=dec["account_id"]) or "user"
email_address = dec["email_address"]
role = get_account_role(
admin_emails=settings.admin_emails, email_address=email_address
)
redirect_to = dec.get("redirect_to", "")
email = dec["email_address"]

token = create_session_token(
key=settings.jwt_encryption_key,
account_id=dec["account_id"],
email_address=email_address,
role=role,
session_expiry_days=settings.session_expiry_days,
login_expiry_days=settings.login_expiry_days,
)
return SessionTokenCreate(
bearer=token,
redirect_to=redirect_to,
email_address=email,
email_address=email_address,
)


Expand All @@ -198,7 +173,7 @@ async def user_refresh_token(

newtoken = create_session_token(
key=settings.jwt_encryption_key,
account_id=tok["account_id"],
email_address=tok["email_address"],
role=tok["role"],
session_expiry_days=settings.session_expiry_days,
login_expiry_days=settings.login_expiry_days,
Expand Down
Loading

0 comments on commit ca2857c

Please sign in to comment.