Skip to content

Commit

Permalink
fix: improved user creation process
Browse files Browse the repository at this point in the history
  • Loading branch information
EwoutV committed Mar 29, 2024
1 parent bea24c0 commit 39321ab
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 20 deletions.
2 changes: 1 addition & 1 deletion backend/api/models/course.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def clone(self, clone_assistants=True) -> Self:
)

if clone_assistants:
# Add all the assistants of the current course to the follow up course
# Add all the assistants of the current course to the follow-up course
for assistant in self.assistants.all():
course.assistants.add(assistant)

Expand Down
49 changes: 30 additions & 19 deletions backend/authentication/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class CASTokenObtainSerializer(Serializer):
This serializer takes the CAS ticket and tries to validate it.
Upon successful validation, create a new user if it doesn't exist.
"""

ticket = CharField(required=True, min_length=49, max_length=49)

def validate(self, data):
Expand Down Expand Up @@ -64,24 +63,41 @@ def _validate_ticket(self, ticket: str) -> dict:
return response.data.get("attributes", dict)

def _fetch_user_from_cas(self, attributes: dict) -> Tuple[User, bool]:
# Convert the lastenrolled attribute
if attributes.get("lastenrolled"):
attributes["lastenrolled"] = int(attributes.get("lastenrolled").split()[0])

user = UserSerializer(
data={
"id": attributes.get("ugentID"),
"username": attributes.get("uid"),
"email": attributes.get("mail"),
"first_name": attributes.get("givenname"),
"last_name": attributes.get("surname"),
"last_enrolled": attributes.get("lastenrolled"),
}
)
# Map the CAS data onto the user data
data = {
"id": attributes.get("ugentID"),
"username": attributes.get("uid"),
"email": attributes.get("mail"),
"first_name": attributes.get("givenname"),
"last_name": attributes.get("surname"),
"last_enrolled": attributes.get("lastenrolled"),
}

try:
# Fetch the user if it already exists
user = UserSerializer(User.objects.get(id=data["id"]), data=data)

# Validate the serializer
if not user.is_valid():
raise ValidationError(user.errors)

if not user.is_valid():
raise ValidationError(user.errors)
# Save the new user
return user.save(), False
except User.DoesNotExist:
# Create a new user
user = UserSerializer(data=data)

# Validate the serializer
if not user.is_valid():
raise ValidationError(user.errors)

# Save the new user
return user.save(), True

return user.get_or_create(user.validated_data)


class UserSerializer(ModelSerializer):
Expand All @@ -103,14 +119,9 @@ def get_roles(self, user: User):
"""Get the roles for the user"""
return user.roles

def get_or_create(self, validated_data: dict) -> Tuple[User, bool]:
"""Create or fetch the user based on the validated data."""
return User.objects.get_or_create(**validated_data)

class Meta:
model = User
fields = "__all__"
read_only_fields = ["id", "username", "email"]


class UserIDSerializer(Serializer):
Expand Down

0 comments on commit 39321ab

Please sign in to comment.