diff --git a/sanic_security/authentication.py b/sanic_security/authentication.py index a6dbfac..1f1cb6c 100644 --- a/sanic_security/authentication.py +++ b/sanic_security/authentication.py @@ -2,7 +2,7 @@ import re import warnings -from argon2.exceptions import VerifyMismatchError +from argon2.exceptions import VerificationError from sanic import Sanic from sanic.log import logger from sanic.request import Request @@ -128,7 +128,7 @@ async def login( f"Client {get_ip(request)} has logged in with authentication session {authentication_session.id}." ) return authentication_session - except VerifyMismatchError: + except VerificationError: logger.warning( f"Client {get_ip(request)} has failed to log into account {account.id}." ) diff --git a/sanic_security/models.py b/sanic_security/models.py index 99fd654..26b9911 100644 --- a/sanic_security/models.py +++ b/sanic_security/models.py @@ -398,7 +398,7 @@ async def new( cls, request: Request, account: Account, - **kwargs: Union[int, str, bool, float, list, dict], + **kwargs: dict[str, Union[int, str, bool, float, list, dict]], ): """ Creates session with pre-set values. @@ -406,7 +406,7 @@ async def new( Args: request (Request): Sanic request parameter. account (Account): Account being associated to the session. - **kwargs (dict[str, Union[int, str, bool, float, list, dict]]): Extra arguments applied during session creation. + **kwargs (Union[int, str, bool, float, list, dict]): Extra arguments applied during session creation. Returns: session @@ -525,7 +525,12 @@ async def check_code(self, code: str) -> None: await self.deactivate() @classmethod - async def new(cls, request: Request, account: Account, **kwargs): + async def new( + cls, + request: Request, + account: Account, + **kwargs: Union[int, str, bool, float, list, dict], + ): raise NotImplementedError class Meta: @@ -536,7 +541,12 @@ class TwoStepSession(VerificationSession): """Validates client using a code sent via email or text.""" @classmethod - async def new(cls, request: Request, account: Account, **kwargs): + async def new( + cls, + request: Request, + account: Account, + **kwargs: Union[int, str, bool, float, list, dict], + ): return await cls.create( **kwargs, ip=get_ip(request), @@ -553,7 +563,11 @@ class CaptchaSession(VerificationSession): """Validates client with a captcha challenge via image or audio.""" @classmethod - async def new(cls, request: Request, **kwargs): + async def new( + cls, + request: Request, + **kwargs: Union[int, str, bool, float, list, dict], + ): return await cls.create( **kwargs, ip=get_ip(request), @@ -641,7 +655,11 @@ async def refresh(self, request: Request): @classmethod async def new( - cls, request: Request, account: Account = None, is_refresh=False, **kwargs + cls, + request: Request, + account: Account = None, + is_refresh=False, + **kwargs: Union[int, str, bool, float, list, dict], ): authentication_session = await cls.create( **kwargs, diff --git a/sanic_security/oauth.py b/sanic_security/oauth.py index 646ed8d..0facf9d 100644 --- a/sanic_security/oauth.py +++ b/sanic_security/oauth.py @@ -1,7 +1,10 @@ import time +from typing import Literal import jwt -from httpx_oauth.oauth2 import BaseOAuth2 +from httpx_oauth.exceptions import GetIdEmailError +from httpx_oauth.oauth2 import BaseOAuth2, RefreshTokenError, GetAccessTokenError +from jwt import DecodeError from sanic import Request, HTTPResponse, redirect, Sanic from sanic_security.configuration import config @@ -28,13 +31,26 @@ SOFTWARE. """ -oauth_clients: dict[str, BaseOAuth2] - -async def oauth( - client: BaseOAuth2, redirect_uri: str = config.OAUTH_REDIRECT, **kwargs +async def oauth_login( + client: BaseOAuth2, + redirect_uri: str = config.OAUTH_REDIRECT, + scope: list[str] = None, + state: str = None, + code_challenge: str = None, + code_challenge_method: Literal["plain", "S256"] = None, + **extra_params: str, ) -> HTTPResponse: - return redirect(await client.get_authorization_url(redirect_uri, **kwargs)) + return redirect( + await client.get_authorization_url( + redirect_uri, + state, + scope, + code_challenge, + code_challenge_method, + extra_params, + ) + ) async def oauth_callback( @@ -43,14 +59,20 @@ async def oauth_callback( redirect_uri: str = config.OAUTH_REDIRECT, code_verifier: str = None, ) -> dict: - token_info = await client.get_access_token( - request.args.get("code"), - redirect_uri, - code_verifier, - ) - if "expires_at" not in token_info: - token_info["expires_at"] = time.time() + token_info["expires_in"] - return token_info + try: + token_info = await client.get_access_token( + request.args.get("code"), + redirect_uri, + code_verifier, + ) + if "expires_at" not in token_info: + token_info["expires_at"] = time.time() + token_info["expires_in"] + client.get_id_email() + return token_info + except GetIdEmailError: + pass + except GetAccessTokenError: + pass def oauth_encode(response: HTTPResponse, token_info: dict) -> None: @@ -71,21 +93,26 @@ def oauth_encode(response: HTTPResponse, token_info: dict) -> None: ) -async def get_oauth(request: Request, client: BaseOAuth2) -> dict: - token_info = jwt.decode( - request.cookies.get( - f"{config.SESSION_PREFIX}_oauth", - ), - config.PUBLIC_SECRET or config.SECRET, - config.SESSION_ENCODING_ALGORITHM, - ) - if time.time() > token_info["expires_at"]: - token_info = await client.refresh_token(token_info["refresh_token"]) - token_info["is_refresh"] = True - if "expires_at" not in token_info: - token_info["expires_at"] = time.time() + token_info["expires_in"] - request.ctx.oauth = token_info - return token_info +async def oauth_decode(request: Request, client: BaseOAuth2) -> dict: + try: + token_info = jwt.decode( + request.cookies.get( + f"{config.SESSION_PREFIX}_oauth", + ), + config.PUBLIC_SECRET or config.SECRET, + config.SESSION_ENCODING_ALGORITHM, + ) + if time.time() > token_info["expires_at"]: + token_info = await client.refresh_token(token_info["refresh_token"]) + token_info["is_refresh"] = True + if "expires_at" not in token_info: + token_info["expires_at"] = time.time() + token_info["expires_in"] + request.ctx.oauth = token_info + return token_info + except RefreshTokenError: + pass + except DecodeError: + pass def initialize_oauth(app: Sanic) -> None: diff --git a/sanic_security/test/server.py b/sanic_security/test/server.py index b959bd0..33e88da 100644 --- a/sanic_security/test/server.py +++ b/sanic_security/test/server.py @@ -22,7 +22,7 @@ from sanic_security.configuration import config from sanic_security.exceptions import SecurityError from sanic_security.models import Account, CaptchaSession, AuthenticationSession -from sanic_security.oauth import oauth, oauth_callback, oauth_encode +from sanic_security.oauth import oauth_callback, oauth_encode, oauth_login from sanic_security.utils import json from sanic_security.verification import ( request_two_step_verification, @@ -272,7 +272,7 @@ async def on_account_creation(request): @app.get("api/test/oauth") async def on_oauth_request(request): - return await oauth( + return await oauth_login( fitbit_oauth, "http://localhost:8000/api/test/oauth/callback", scope=["profile"] )