Skip to content

Commit

Permalink
Typing & error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
na-stewart committed Jan 7, 2025
1 parent 9ef81c8 commit 84bc7a6
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 39 deletions.
4 changes: 2 additions & 2 deletions sanic_security/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
import warnings

from argon2.exceptions import VerifyMismatchError
from argon2.exceptions import VerificationError
from sanic import Sanic
from sanic.log import logger
from sanic.request import Request
Expand Down Expand Up @@ -128,7 +128,7 @@ async def login(
f"Client {get_ip(request)} has logged in with authentication session {authentication_session.id}."
)
return authentication_session
except VerifyMismatchError:
except VerificationError:
logger.warning(
f"Client {get_ip(request)} has failed to log into account {account.id}."
)
Expand Down
30 changes: 24 additions & 6 deletions sanic_security/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,15 +398,15 @@ async def new(
cls,
request: Request,
account: Account,
**kwargs: Union[int, str, bool, float, list, dict],
**kwargs: dict[str, Union[int, str, bool, float, list, dict]],
):
"""
Creates session with pre-set values.
Args:
request (Request): Sanic request parameter.
account (Account): Account being associated to the session.
**kwargs (dict[str, Union[int, str, bool, float, list, dict]]): Extra arguments applied during session creation.
**kwargs (Union[int, str, bool, float, list, dict]): Extra arguments applied during session creation.
Returns:
session
Expand Down Expand Up @@ -525,7 +525,12 @@ async def check_code(self, code: str) -> None:
await self.deactivate()

@classmethod
async def new(cls, request: Request, account: Account, **kwargs):
async def new(
cls,
request: Request,
account: Account,
**kwargs: Union[int, str, bool, float, list, dict],
):
raise NotImplementedError

class Meta:
Expand All @@ -536,7 +541,12 @@ class TwoStepSession(VerificationSession):
"""Validates client using a code sent via email or text."""

@classmethod
async def new(cls, request: Request, account: Account, **kwargs):
async def new(
cls,
request: Request,
account: Account,
**kwargs: Union[int, str, bool, float, list, dict],
):
return await cls.create(
**kwargs,
ip=get_ip(request),
Expand All @@ -553,7 +563,11 @@ class CaptchaSession(VerificationSession):
"""Validates client with a captcha challenge via image or audio."""

@classmethod
async def new(cls, request: Request, **kwargs):
async def new(
cls,
request: Request,
**kwargs: Union[int, str, bool, float, list, dict],
):
return await cls.create(
**kwargs,
ip=get_ip(request),
Expand Down Expand Up @@ -641,7 +655,11 @@ async def refresh(self, request: Request):

@classmethod
async def new(
cls, request: Request, account: Account = None, is_refresh=False, **kwargs
cls,
request: Request,
account: Account = None,
is_refresh=False,
**kwargs: Union[int, str, bool, float, list, dict],
):
authentication_session = await cls.create(
**kwargs,
Expand Down
85 changes: 56 additions & 29 deletions sanic_security/oauth.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import time
from typing import Literal

import jwt
from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.exceptions import GetIdEmailError
from httpx_oauth.oauth2 import BaseOAuth2, RefreshTokenError, GetAccessTokenError
from jwt import DecodeError
from sanic import Request, HTTPResponse, redirect, Sanic

from sanic_security.configuration import config
Expand All @@ -28,13 +31,26 @@
SOFTWARE.
"""

oauth_clients: dict[str, BaseOAuth2]


async def oauth(
client: BaseOAuth2, redirect_uri: str = config.OAUTH_REDIRECT, **kwargs
async def oauth_login(
client: BaseOAuth2,
redirect_uri: str = config.OAUTH_REDIRECT,
scope: list[str] = None,
state: str = None,
code_challenge: str = None,
code_challenge_method: Literal["plain", "S256"] = None,
**extra_params: str,
) -> HTTPResponse:
return redirect(await client.get_authorization_url(redirect_uri, **kwargs))
return redirect(
await client.get_authorization_url(
redirect_uri,
state,
scope,
code_challenge,
code_challenge_method,
extra_params,
)
)


async def oauth_callback(
Expand All @@ -43,14 +59,20 @@ async def oauth_callback(
redirect_uri: str = config.OAUTH_REDIRECT,
code_verifier: str = None,
) -> dict:
token_info = await client.get_access_token(
request.args.get("code"),
redirect_uri,
code_verifier,
)
if "expires_at" not in token_info:
token_info["expires_at"] = time.time() + token_info["expires_in"]
return token_info
try:
token_info = await client.get_access_token(
request.args.get("code"),
redirect_uri,
code_verifier,
)
if "expires_at" not in token_info:
token_info["expires_at"] = time.time() + token_info["expires_in"]
client.get_id_email()
return token_info
except GetIdEmailError:
pass
except GetAccessTokenError:
pass


def oauth_encode(response: HTTPResponse, token_info: dict) -> None:
Expand All @@ -71,21 +93,26 @@ def oauth_encode(response: HTTPResponse, token_info: dict) -> None:
)


async def get_oauth(request: Request, client: BaseOAuth2) -> dict:
token_info = jwt.decode(
request.cookies.get(
f"{config.SESSION_PREFIX}_oauth",
),
config.PUBLIC_SECRET or config.SECRET,
config.SESSION_ENCODING_ALGORITHM,
)
if time.time() > token_info["expires_at"]:
token_info = await client.refresh_token(token_info["refresh_token"])
token_info["is_refresh"] = True
if "expires_at" not in token_info:
token_info["expires_at"] = time.time() + token_info["expires_in"]
request.ctx.oauth = token_info
return token_info
async def oauth_decode(request: Request, client: BaseOAuth2) -> dict:
try:
token_info = jwt.decode(
request.cookies.get(
f"{config.SESSION_PREFIX}_oauth",
),
config.PUBLIC_SECRET or config.SECRET,
config.SESSION_ENCODING_ALGORITHM,
)
if time.time() > token_info["expires_at"]:
token_info = await client.refresh_token(token_info["refresh_token"])
token_info["is_refresh"] = True
if "expires_at" not in token_info:
token_info["expires_at"] = time.time() + token_info["expires_in"]
request.ctx.oauth = token_info
return token_info
except RefreshTokenError:
pass
except DecodeError:
pass


def initialize_oauth(app: Sanic) -> None:
Expand Down
4 changes: 2 additions & 2 deletions sanic_security/test/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from sanic_security.configuration import config
from sanic_security.exceptions import SecurityError
from sanic_security.models import Account, CaptchaSession, AuthenticationSession
from sanic_security.oauth import oauth, oauth_callback, oauth_encode
from sanic_security.oauth import oauth_callback, oauth_encode, oauth_login
from sanic_security.utils import json
from sanic_security.verification import (
request_two_step_verification,
Expand Down Expand Up @@ -272,7 +272,7 @@ async def on_account_creation(request):

@app.get("api/test/oauth")
async def on_oauth_request(request):
return await oauth(
return await oauth_login(
fitbit_oauth, "http://localhost:8000/api/test/oauth/callback", scope=["profile"]
)

Expand Down

0 comments on commit 84bc7a6

Please sign in to comment.