Skip to content

Commit

Permalink
Refresh logic revision
Browse files Browse the repository at this point in the history
  • Loading branch information
na-stewart committed Jun 22, 2024
1 parent 628902e commit b217153
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 29 deletions.
6 changes: 2 additions & 4 deletions sanic_security/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,7 @@ async def fulfill_second_factor(request: Request) -> AuthenticationSession:

async def authenticate(request: Request) -> AuthenticationSession:
"""
Validates client's authentication session and account. If auto refresh is enabled, session property is_refresh will
only be true for the first time the refreshed session is returned.
Validates client's authentication session and account.
Args:
request (Request): Sanic request parameter.
Expand All @@ -222,9 +221,8 @@ async def authenticate(request: Request) -> AuthenticationSession:
if not authentication_session.anonymous:
authentication_session.bearer.validate()
except ExpiredError as e:
if security_config.AUTHENTICATION_REFRESH_AUTO: #
if security_config.AUTHENTICATION_REFRESH_AUTO:
authentication_session = await authentication_session.refresh(request)
logger.debug("Authentication session has been auto-refreshed.")
else:
raise e
return authentication_session
Expand Down
2 changes: 1 addition & 1 deletion sanic_security/configuration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from os import environ

from sanic.utils import str_to_bool

from sanic.log import logger

"""
Copyright (c) 2020-present Nicholas Aidan Stewart
Expand Down
13 changes: 7 additions & 6 deletions sanic_security/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,14 +524,14 @@ class AuthenticationSession(Session):
Used to authenticate and identify a client.
Attributes:
is_refresh (bool): Determines if current session was created during previous session refresh.
requires_second_factor (bool): Determines if session requires a second factor.
refresh_expiration_date (bool): Date and time the session can no longer be refreshed.
is_refresh (bool): Will only be true for the first time session is created during refresh.
"""

is_refresh: bool = False
requires_second_factor: bool = fields.BooleanField(default=False)
refresh_expiration_date: datetime.datetime = fields.DatetimeField(null=True)
is_refresh: bool

def validate(self) -> None:
"""
Expand All @@ -557,6 +557,9 @@ async def refresh(self, request: Request):
DeactivatedError
SecondFactorRequiredError
NotExpiredError
Returns:
session
"""
try:
self.validate()
Expand All @@ -568,9 +571,7 @@ async def refresh(self, request: Request):
):
self.active = False
await self.save(update_fields=["active"])
session = await self.new(request, self.bearer)
session.is_refresh = True
return session
return await self.new(request, self.bearer, is_refresh=True)
else:
raise e

Expand All @@ -583,7 +584,7 @@ async def new(cls, request: Request, account: Account = None, **kwargs):
expiration_date=get_expiration_date(
security_config.AUTHENTICATION_SESSION_EXPIRATION
),
refresh_date=get_expiration_date(
refresh_expiration_date=get_expiration_date(
security_config.AUTHENTICATION_REFRESH_EXPIRATION
),
)
Expand Down
30 changes: 16 additions & 14 deletions sanic_security/test/server.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import traceback

from argon2 import PasswordHasher
Expand Down Expand Up @@ -151,18 +152,6 @@ async def on_logout(request):
return response


@app.post("api/test/auth/refresh")
@requires_authentication
async def on_authentication_refresh(request):
"""
Refreshes current authentication session. Requires data persistence and date change to
expire previous session.
"""
authentication_session = await request.ctx.authentication_session.refresh(request)
response = json("Refresh successful!", authentication_session.json)
return response


@app.post("api/test/auth")
@requires_authentication
async def on_authenticate(request):
Expand All @@ -178,13 +167,26 @@ async def on_authenticate(request):
if not authentication_session.anonymous
else None
),
"auto-refreshed": authentication_session.is_refresh
"refresh": authentication_session.is_refresh
},
)
request.ctx.authentication_session.encode(response)
if authentication_session.is_refresh:
request.ctx.authentication_session.encode(response)
return response


@app.post("api/test/auth/expire")
@requires_authentication
async def on_authentication_expire(request):
"""
Expire client's session.
"""
authentication_session = request.ctx.authentication_session
authentication_session.expiration_date = datetime.datetime.now(datetime.UTC)
await authentication_session.save(update_fields=["expiration_date"])
return json("Authentication expired!", authentication_session.json)


@app.post("api/test/auth/associated")
@requires_authentication
async def on_get_associated_authentication_sessions(request):
Expand Down
9 changes: 6 additions & 3 deletions sanic_security/test/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,10 @@ def test_authentication_refresh(self):
auth=("[email protected]", "password"),
)
assert login_response.status_code == 200, login_response.text
refresh_response = self.client.post("http://127.0.0.1:8000/api/test/auth/refresh")
assert refresh_response.status_code == 200, refresh_response.text
# Authenticate and check is_refresh
expire_response = self.client.post("http://127.0.0.1:8000/api/test/auth/expire")
assert expire_response.status_code == 200, expire_response.text
authenticate_response = self.client.post(
"http://127.0.0.1:8000/api/test/auth",
)
assert authenticate_response.status_code == 200, authenticate_response.text

2 changes: 1 addition & 1 deletion sanic_security/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_code() -> str:
return "".join(random.choices(string.digits + string.ascii_uppercase, k=6))


def json(message: str, data, status_code: int = 200) -> HTTPResponse:
def json(message: str, data, status_code: int = 200) -> HTTPResponse: # May be causing fixture error bc of json property
"""
A preformatted Sanic json response.
Expand Down

0 comments on commit b217153

Please sign in to comment.