Skip to content

Commit

Permalink
OAuth encoding revision
Browse files Browse the repository at this point in the history
  • Loading branch information
na-stewart committed Jan 9, 2025
1 parent adef7ff commit e0e5c5c
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 14 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ async def on_oauth_callback(request):
"Authorization successful.",
{"token_info": token_info, "auth_session": authentication_session.json},
)
oauth_encode(response, discord_oauth, token_info)
oauth_encode(response, token_info)
authentication_session.encode(response)
return response
```
Expand Down
18 changes: 6 additions & 12 deletions sanic_security/oauth.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import functools
import time
from typing import Literal, Union
from typing import Literal

import jwt
from httpx_oauth.oauth2 import BaseOAuth2, RefreshTokenError
from jwt import DecodeError
from sanic import Request, HTTPResponse, redirect, Sanic
from sanic import Request, HTTPResponse, Sanic
from sanic.log import logger
from tortoise.exceptions import IntegrityError, DoesNotExist

Expand Down Expand Up @@ -129,18 +129,17 @@ async def on_oauth_callback(


def oauth_encode(
response: HTTPResponse, client: Union[BaseOAuth2, str], token_info: dict
response: HTTPResponse, token_info: dict
) -> None:
"""
Transforms OAuth access token into JWT and then is stored in a cookie.
Args:
response (HTTPResponse): Sanic response used to store JWT into a cookie on the client.
client (Union[BaseOAuth2, str]): OAuth provider.
token_info (dict): OAuth access token.
"""
response.cookies.add_cookie(
f"{config.SESSION_PREFIX}_{(client if isinstance(client, str) else client.__class__.__name__)[:7].lower()}",
f"{config.SESSION_PREFIX}_oauth",
str(
jwt.encode(
token_info,
Expand Down Expand Up @@ -176,15 +175,14 @@ async def oauth_decode(request: Request, client: BaseOAuth2, refresh=False) -> d
try:
token_info = jwt.decode(
request.cookies.get(
f"{config.SESSION_PREFIX}_{client.__class__.__name__[:7].lower()}",
f"{config.SESSION_PREFIX}_oauth",
),
config.PUBLIC_SECRET or config.SECRET,
config.SESSION_ENCODING_ALGORITHM,
)
if time.time() > token_info["expires_at"] or refresh:
token_info = await client.refresh_token(token_info["refresh_token"])
token_info["is_refresh"] = True
token_info["client"] = client.__class__.__name__
if "expires_at" not in token_info:
token_info["expires_at"] = time.time() + token_info["expires_in"]
request.ctx.oauth = token_info
Expand Down Expand Up @@ -240,8 +238,4 @@ async def refresh_encoder_middleware(request, response):
if hasattr(request.ctx, "oauth") and getattr(
request.ctx.oauth, "is_refresh", False
):
oauth_encode(
response,
request.ctx.oauth["client"],
request.ctx.oauth,
)
oauth_encode(response, request.ctx.oauth)
2 changes: 1 addition & 1 deletion sanic_security/test/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ async def on_oauth_callback(request):
"OAuth successful.",
{"token_info": token_info, "auth_session": authentication_session.json},
)
oauth_encode(response, discord_oauth, token_info)
oauth_encode(response, token_info)
authentication_session.encode(response)
return response

Expand Down

0 comments on commit e0e5c5c

Please sign in to comment.