diff --git a/backend/api/fixtures/admins.yaml b/backend/api/fixtures/admins.yaml deleted file mode 100644 index a0ad3113..00000000 --- a/backend/api/fixtures/admins.yaml +++ /dev/null @@ -1,3 +0,0 @@ -- model: api.admin - pk: '2' - fields: {} diff --git a/backend/api/serializers/student_serializer.py b/backend/api/serializers/student_serializer.py index fddbd8d2..9cd1f245 100644 --- a/backend/api/serializers/student_serializer.py +++ b/backend/api/serializers/student_serializer.py @@ -1,5 +1,5 @@ from rest_framework import serializers -from ..models.student import Student +from api.models.student import Student class StudentSerializer(serializers.ModelSerializer): @@ -19,14 +19,4 @@ class StudentSerializer(serializers.ModelSerializer): class Meta: model = Student - fields = [ - "id", - "first_name", - "last_name", - "email", - "faculties", - "last_enrolled", - "create_time", - "courses", - "groups", - ] + fields = '__all__' diff --git a/backend/authentication/permissions.py b/backend/authentication/permissions.py index d852d767..b9ff5906 100644 --- a/backend/authentication/permissions.py +++ b/backend/authentication/permissions.py @@ -1,11 +1,9 @@ from rest_framework.permissions import BasePermission from rest_framework.request import Request from rest_framework.viewsets import ViewSet +from ypovoli import settings -class CASPermission(BasePermission): - def has_permission(self, request: Request, view: ViewSet): - """Check whether a user has permission in the CAS flow context.""" - return request.user.is_authenticated or view.action not in [ - 'logout', 'whoami' - ] +class IsDebug(BasePermission): + def has_permission(self, request: Request, view: ViewSet) -> bool: + return settings.DEBUG diff --git a/backend/authentication/serializers.py b/backend/authentication/serializers.py index 1ec9b9a0..60771277 100644 --- a/backend/authentication/serializers.py +++ b/backend/authentication/serializers.py @@ -1,4 +1,6 @@ +from typing import Tuple from django.contrib.auth.models import update_last_login +from django.contrib.auth import login from rest_framework.serializers import ( CharField, EmailField, @@ -11,7 +13,7 @@ from rest_framework_simplejwt.tokens import RefreshToken, AccessToken from rest_framework_simplejwt.settings import api_settings from authentication.signals import user_created, user_login -from authentication.models import User +from authentication.models import User, Faculty from authentication.cas.client import client @@ -20,20 +22,53 @@ class CASTokenObtainSerializer(Serializer): This serializer takes the CAS ticket and tries to validate it. Upon successful validation, create a new user if it doesn't exist. """ - - token = RefreshToken ticket = CharField(required=True, min_length=49, max_length=49) def validate(self, data): """Validate a ticket using CAS client""" - response = client.perform_service_validate(ticket=data["ticket"]) + # Validate the ticket and get CAS attributes. + attributes = self._validate_ticket(data["ticket"]) + + # Fetch a user model from the CAS attributes. + user, created = self._fetch_user_from_cas(attributes) + + # Update the user's last login. + if api_settings.UPDATE_LAST_LOGIN: + update_last_login(self, user) + + # Login and send authentication signals. + if "request" in self.context: + login(self.context["request"], user) + + user_login.send( + sender=self, user=user + ) + + if created: + user_created.send( + sender=self, attributes=attributes, user=user + ) + + # Return access tokens for the now logged-in user. + return { + "access": str( + AccessToken.for_user(user) + ), + "refresh": str( + RefreshToken.for_user(user) + ), + } + + def _validate_ticket(self, ticket: str) -> dict: + """Validate a CAS ticket using the CAS client""" + response = client.perform_service_validate(ticket=ticket) if response.error: raise ValidationError(response.error) - # Validation success: create user if it doesn't exist yet. - attributes = response.data.get("attributes", dict) + return response.data.get("attributes", dict) + def _fetch_user_from_cas(self, attributes: dict) -> Tuple[User, bool]: if attributes.get("lastenrolled"): attributes["lastenrolled"] = int(attributes.get("lastenrolled").split()[0]) @@ -51,22 +86,7 @@ def validate(self, data): if not user.is_valid(): raise ValidationError(user.errors) - user, created = user.get_or_create(user.validated_data) - - # Update the user's last login. - if api_settings.UPDATE_LAST_LOGIN: - update_last_login(self, user) - - user_login.send(sender=self, user=user) - - # Send signal upon creation. - if created: - user_created.send(sender=self, attributes=attributes, user=user) - - return { - "access": str(AccessToken.for_user(user)), - "refresh": str(RefreshToken.for_user(user)), - } + return user.get_or_create(user.validated_data) class UserSerializer(ModelSerializer): @@ -91,6 +111,6 @@ class Meta: model = User fields = "__all__" - def get_or_create(self, validated_data: dict) -> User: + def get_or_create(self, validated_data: dict) -> Tuple[User, bool]: """Create or fetch the user based on the validated data.""" return User.objects.get_or_create(**validated_data) diff --git a/backend/authentication/services/users.py b/backend/authentication/services/users.py deleted file mode 100644 index e69de29b..00000000 diff --git a/backend/authentication/tests/test_authentication_views.py b/backend/authentication/tests/test_authentication_views.py index 960e689d..8e7a4155 100644 --- a/backend/authentication/tests/test_authentication_views.py +++ b/backend/authentication/tests/test_authentication_views.py @@ -88,12 +88,3 @@ def test_login_view_returns_login_url(self): server_url=settings.CAS_ENDPOINT, service_url=settings.CAS_RESPONSE ) self.assertEqual(response["Location"], url) - - -class TestTokenEchoView(APITestCase): - def test_token_echo_echoes_token(self): - """TokenEchoView should echo the User's current token""" - ticket = "This is a ticket." - response = self.client.get(reverse("cas-echo"), data={"ticket": ticket}) - content = response.rendered_content.decode("utf-8").strip('"') - self.assertEqual(content, ticket) diff --git a/backend/authentication/views.py b/backend/authentication/views.py index c5f85e29..f029defd 100644 --- a/backend/authentication/views.py +++ b/backend/authentication/views.py @@ -1,10 +1,13 @@ from django.shortcuts import redirect +from django.contrib.auth import logout from rest_framework.decorators import action from rest_framework.viewsets import ViewSet from rest_framework.request import Request from rest_framework.response import Response +from rest_framework.exceptions import AuthenticationFailed from rest_framework.permissions import AllowAny, IsAuthenticated -from authentication.serializers import UserSerializer +from authentication.permissions import IsDebug +from authentication.serializers import UserSerializer, CASTokenObtainSerializer from authentication.cas.client import client from ypovoli import settings @@ -20,12 +23,16 @@ def login(self, _: Request) -> Response: return redirect(client.get_login_url()) @action(detail=False, methods=['GET']) - def logout(self, _: Request) -> Response: + def logout(self, request: Request) -> Response: """Attempt to log out. Redirect to our single CAS endpoint. Normally would only allow POST requests to a logout endpoint. Since the CAS logout location handles the actual logout, we should accept GET requests. """ - return redirect(client.get_logout_url(service_url=settings.API_ENDPOINT)) + logout(request) + + return redirect( + client.get_logout_url(service_url=settings.API_ENDPOINT) + ) @action(detail=False, methods=['GET'], url_path='whoami', url_name='whoami') def who_am_i(self, request: Request) -> Response: @@ -41,7 +48,14 @@ def who_am_i(self, request: Request) -> Response: user_serializer.data ) - @action(detail=False, methods=['GET'], permission_classes=[AllowAny]) + @action(detail=False, methods=['GET'], permission_classes=[IsDebug]) def echo(self, request: Request) -> Response: """Echo the obtained CAS token for development and testing.""" - return Response(request.query_params.get('ticket')) + token_serializer = CASTokenObtainSerializer(data=request.query_params, context={ + 'request': request + }) + + if token_serializer.is_valid(): + return Response(token_serializer.validated_data) + + raise AuthenticationFailed(token_serializer.errors) diff --git a/backend/ypovoli/settings.py b/backend/ypovoli/settings.py index e5f197c1..32355200 100644 --- a/backend/ypovoli/settings.py +++ b/backend/ypovoli/settings.py @@ -64,7 +64,8 @@ "rest_framework.renderers.JSONRenderer", ], "DEFAULT_AUTHENTICATION_CLASSES": [ - "rest_framework_simplejwt.authentication.JWTAuthentication" + "rest_framework_simplejwt.authentication.JWTAuthentication", + "rest_framework.authentication.SessionAuthentication" ], 'DEFAULT_PERMISSION_CLASSES': [ 'rest_framework.permissions.IsAuthenticated' @@ -79,9 +80,7 @@ } AUTH_USER_MODEL = "authentication.User" - ROOT_URLCONF = "ypovoli.urls" - WSGI_APPLICATION = "ypovoli.wsgi.application" # Application endpoints @@ -110,7 +109,7 @@ TIME_ZONE = "UTC" USE_I18N = True USE_L10N = False -USE_TZ = False +USE_TZ = True # Static files (CSS, JavaScript, Images) # https://docs.djangoproject.com/en/4.0/howto/static-files/ diff --git a/backend/ypovoli/urls.py b/backend/ypovoli/urls.py index 4dc43b32..25e30a72 100644 --- a/backend/ypovoli/urls.py +++ b/backend/ypovoli/urls.py @@ -25,7 +25,7 @@ default_version="v1", ), public=True, - permission_classes=(permissions.AllowAny,), + permission_classes=[permissions.AllowAny,], ) @@ -36,12 +36,6 @@ path("auth/", include("authentication.urls")), path("notifications/", include("notifications.urls")), # Swagger documentation. - path( - "swagger/", - schema_view.with_ui("swagger", cache_timeout=0), - name="schema-swagger-ui", - ), - path( - "swagger/", schema_view.without_ui(cache_timeout=0), name="schema-json" - ), + path("swagger/", schema_view.with_ui("swagger", cache_timeout=0), name="schema-swagger-ui"), + path("swagger/", schema_view.without_ui(cache_timeout=0), name="schema-json"), ]