Skip to content

Commit

Permalink
Merge pull request #52 from SELab-2/session-auth
Browse files Browse the repository at this point in the history
feat: session authentication for smooth browser-specific experence
  • Loading branch information
francisvaut authored Mar 5, 2024
2 parents 76b5fde + a4a5a1e commit 5060acc
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 71 deletions.
3 changes: 0 additions & 3 deletions backend/api/fixtures/admins.yaml

This file was deleted.

14 changes: 2 additions & 12 deletions backend/api/serializers/student_serializer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from rest_framework import serializers
from ..models.student import Student
from api.models.student import Student


class StudentSerializer(serializers.ModelSerializer):
Expand All @@ -19,14 +19,4 @@ class StudentSerializer(serializers.ModelSerializer):

class Meta:
model = Student
fields = [
"id",
"first_name",
"last_name",
"email",
"faculties",
"last_enrolled",
"create_time",
"courses",
"groups",
]
fields = '__all__'
10 changes: 4 additions & 6 deletions backend/authentication/permissions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from rest_framework.permissions import BasePermission
from rest_framework.request import Request
from rest_framework.viewsets import ViewSet
from ypovoli import settings


class CASPermission(BasePermission):
def has_permission(self, request: Request, view: ViewSet):
"""Check whether a user has permission in the CAS flow context."""
return request.user.is_authenticated or view.action not in [
'logout', 'whoami'
]
class IsDebug(BasePermission):
def has_permission(self, request: Request, view: ViewSet) -> bool:
return settings.DEBUG
66 changes: 43 additions & 23 deletions backend/authentication/serializers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Tuple
from django.contrib.auth.models import update_last_login
from django.contrib.auth import login
from rest_framework.serializers import (
CharField,
EmailField,
Expand All @@ -11,7 +13,7 @@
from rest_framework_simplejwt.tokens import RefreshToken, AccessToken
from rest_framework_simplejwt.settings import api_settings
from authentication.signals import user_created, user_login
from authentication.models import User
from authentication.models import User, Faculty
from authentication.cas.client import client


Expand All @@ -20,20 +22,53 @@ class CASTokenObtainSerializer(Serializer):
This serializer takes the CAS ticket and tries to validate it.
Upon successful validation, create a new user if it doesn't exist.
"""

token = RefreshToken
ticket = CharField(required=True, min_length=49, max_length=49)

def validate(self, data):
"""Validate a ticket using CAS client"""
response = client.perform_service_validate(ticket=data["ticket"])
# Validate the ticket and get CAS attributes.
attributes = self._validate_ticket(data["ticket"])

# Fetch a user model from the CAS attributes.
user, created = self._fetch_user_from_cas(attributes)

# Update the user's last login.
if api_settings.UPDATE_LAST_LOGIN:
update_last_login(self, user)

# Login and send authentication signals.
if "request" in self.context:
login(self.context["request"], user)

user_login.send(
sender=self, user=user
)

if created:
user_created.send(
sender=self, attributes=attributes, user=user
)

# Return access tokens for the now logged-in user.
return {
"access": str(
AccessToken.for_user(user)
),
"refresh": str(
RefreshToken.for_user(user)
),
}

def _validate_ticket(self, ticket: str) -> dict:
"""Validate a CAS ticket using the CAS client"""
response = client.perform_service_validate(ticket=ticket)

if response.error:
raise ValidationError(response.error)

# Validation success: create user if it doesn't exist yet.
attributes = response.data.get("attributes", dict)
return response.data.get("attributes", dict)

def _fetch_user_from_cas(self, attributes: dict) -> Tuple[User, bool]:
if attributes.get("lastenrolled"):
attributes["lastenrolled"] = int(attributes.get("lastenrolled").split()[0])

Expand All @@ -51,22 +86,7 @@ def validate(self, data):
if not user.is_valid():
raise ValidationError(user.errors)

user, created = user.get_or_create(user.validated_data)

# Update the user's last login.
if api_settings.UPDATE_LAST_LOGIN:
update_last_login(self, user)

user_login.send(sender=self, user=user)

# Send signal upon creation.
if created:
user_created.send(sender=self, attributes=attributes, user=user)

return {
"access": str(AccessToken.for_user(user)),
"refresh": str(RefreshToken.for_user(user)),
}
return user.get_or_create(user.validated_data)


class UserSerializer(ModelSerializer):
Expand All @@ -91,6 +111,6 @@ class Meta:
model = User
fields = "__all__"

def get_or_create(self, validated_data: dict) -> User:
def get_or_create(self, validated_data: dict) -> Tuple[User, bool]:
"""Create or fetch the user based on the validated data."""
return User.objects.get_or_create(**validated_data)
Empty file.
9 changes: 0 additions & 9 deletions backend/authentication/tests/test_authentication_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,3 @@ def test_login_view_returns_login_url(self):
server_url=settings.CAS_ENDPOINT, service_url=settings.CAS_RESPONSE
)
self.assertEqual(response["Location"], url)


class TestTokenEchoView(APITestCase):
def test_token_echo_echoes_token(self):
"""TokenEchoView should echo the User's current token"""
ticket = "This is a ticket."
response = self.client.get(reverse("cas-echo"), data={"ticket": ticket})
content = response.rendered_content.decode("utf-8").strip('"')
self.assertEqual(content, ticket)
24 changes: 19 additions & 5 deletions backend/authentication/views.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from django.shortcuts import redirect
from django.contrib.auth import logout
from rest_framework.decorators import action
from rest_framework.viewsets import ViewSet
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.exceptions import AuthenticationFailed
from rest_framework.permissions import AllowAny, IsAuthenticated
from authentication.serializers import UserSerializer
from authentication.permissions import IsDebug
from authentication.serializers import UserSerializer, CASTokenObtainSerializer
from authentication.cas.client import client
from ypovoli import settings

Expand All @@ -20,12 +23,16 @@ def login(self, _: Request) -> Response:
return redirect(client.get_login_url())

@action(detail=False, methods=['GET'])
def logout(self, _: Request) -> Response:
def logout(self, request: Request) -> Response:
"""Attempt to log out. Redirect to our single CAS endpoint.
Normally would only allow POST requests to a logout endpoint.
Since the CAS logout location handles the actual logout, we should accept GET requests.
"""
return redirect(client.get_logout_url(service_url=settings.API_ENDPOINT))
logout(request)

return redirect(
client.get_logout_url(service_url=settings.API_ENDPOINT)
)

@action(detail=False, methods=['GET'], url_path='whoami', url_name='whoami')
def who_am_i(self, request: Request) -> Response:
Expand All @@ -41,7 +48,14 @@ def who_am_i(self, request: Request) -> Response:
user_serializer.data
)

@action(detail=False, methods=['GET'], permission_classes=[AllowAny])
@action(detail=False, methods=['GET'], permission_classes=[IsDebug])
def echo(self, request: Request) -> Response:
"""Echo the obtained CAS token for development and testing."""
return Response(request.query_params.get('ticket'))
token_serializer = CASTokenObtainSerializer(data=request.query_params, context={
'request': request
})

if token_serializer.is_valid():
return Response(token_serializer.validated_data)

raise AuthenticationFailed(token_serializer.errors)
7 changes: 3 additions & 4 deletions backend/ypovoli/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@
"rest_framework.renderers.JSONRenderer",
],
"DEFAULT_AUTHENTICATION_CLASSES": [
"rest_framework_simplejwt.authentication.JWTAuthentication"
"rest_framework_simplejwt.authentication.JWTAuthentication",
"rest_framework.authentication.SessionAuthentication"
],
'DEFAULT_PERMISSION_CLASSES': [
'rest_framework.permissions.IsAuthenticated'
Expand All @@ -79,9 +80,7 @@
}

AUTH_USER_MODEL = "authentication.User"

ROOT_URLCONF = "ypovoli.urls"

WSGI_APPLICATION = "ypovoli.wsgi.application"

# Application endpoints
Expand Down Expand Up @@ -110,7 +109,7 @@
TIME_ZONE = "UTC"
USE_I18N = True
USE_L10N = False
USE_TZ = False
USE_TZ = True

# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/4.0/howto/static-files/
Expand Down
12 changes: 3 additions & 9 deletions backend/ypovoli/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
default_version="v1",
),
public=True,
permission_classes=(permissions.AllowAny,),
permission_classes=[permissions.AllowAny,],
)


Expand All @@ -36,12 +36,6 @@
path("auth/", include("authentication.urls")),
path("notifications/", include("notifications.urls")),
# Swagger documentation.
path(
"swagger/",
schema_view.with_ui("swagger", cache_timeout=0),
name="schema-swagger-ui",
),
path(
"swagger<format>/", schema_view.without_ui(cache_timeout=0), name="schema-json"
),
path("swagger/", schema_view.with_ui("swagger", cache_timeout=0), name="schema-swagger-ui"),
path("swagger<format>/", schema_view.without_ui(cache_timeout=0), name="schema-json"),
]

0 comments on commit 5060acc

Please sign in to comment.