Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues/523 saml auth create email address objects #544

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions openwisp_radius/saml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,16 @@ def get_url_or_path(url):
if parsed_url.netloc:
return f'{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path}'
return parsed_url.path


def get_email_from_ava(ava):
email_keys = (
'email',
'mail',
'uid',
)
for key in email_keys:
email = ava.get(key, None)
if email is not None:
return email[0]
return None
35 changes: 34 additions & 1 deletion openwisp_radius/saml/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from urllib.parse import parse_qs, urlparse

import swapper
from allauth.account.models import EmailAddress
from allauth.utils import ValidationError
from django.conf import settings
from django.contrib.auth import logout
from django.core.exceptions import ObjectDoesNotExist, PermissionDenied
Expand All @@ -16,7 +18,7 @@
from .. import settings as app_settings
from ..api.views import RadiusTokenMixin
from ..utils import get_organization_radius_settings, load_model
from .utils import get_url_or_path
from .utils import get_email_from_ava, get_url_or_path

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -66,6 +68,37 @@ def post_login_hook(self, request, user, session_info):
try:
user.registered_user
except ObjectDoesNotExist:
email = None
uid_is_email = 'email' in getattr(
settings, 'SAML_ATTRIBUTE_MAPPING', {}
).get('uid', ())
if uid_is_email:
email = session_info['name_id'].text
if email is None:
email = get_email_from_ava(session_info['ava'])
if email:
user.email = email
try:
user.full_clean()
user.save()
EmailAddress.objects.create(
user=user, email=email, verified=True, primary=True
)
except ValidationError:
assertion_email = get_email_from_ava(session_info['ava'])
if assertion_email and assertion_email != email:
user.email = assertion_email
try:
user.full_clean()
user.save()
EmailAddress.objects.create(
user=user,
email=assertion_email,
verified=True,
primary=True,
)
except ValidationError:
raise ValidationError('Email Verification Failed')
registered_user = RegisteredUser(
user=user, method='saml', is_verified=app_settings.SAML_IS_VERIFIED
)
Expand Down
26 changes: 24 additions & 2 deletions openwisp_radius/tests/test_saml/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from urllib.parse import parse_qs, urlparse

import swapper
from allauth.account.models import EmailAddress
from django.contrib.auth import SESSION_KEY, get_user_model
from django.core.validators import ValidationError
from django.test import TestCase, override_settings
from django.urls import reverse
from djangosaml2.tests import auth_response, conf
Expand Down Expand Up @@ -58,17 +60,19 @@ class TestAssertionConsumerServiceView(TestSamlMixin, TestCase):
def _get_relay_state(self, redirect_url, org_slug):
return f'{redirect_url}?org={org_slug}'

def _get_saml_response_for_acs_view(self, relay_state):
def _get_saml_response_for_acs_view(self, relay_state, uid='[email protected]'):
response = self.client.get(self.login_url, {'RelayState': relay_state})
saml2_req = saml2_from_httpredirect_request(response.url)
session_id = get_session_id_from_saml2(saml2_req)
self.add_outstanding_query(session_id, relay_state)
return auth_response(session_id, '[email protected]'), relay_state
return auth_response(session_id, uid), relay_state

def _post_successful_auth_assertions(self, query_params, org_slug):
self.assertEqual(User.objects.count(), 1)
user_id = self.client.session[SESSION_KEY]
user = User.objects.get(id=user_id)
email = EmailAddress.objects.filter(user=user)
self.assertEqual(email.count(), 1)
self.assertEqual(user.username, '[email protected]')
self.assertEqual(OrganizationUser.objects.count(), 1)
org_user = OrganizationUser.objects.get(user_id=user_id)
Expand Down Expand Up @@ -100,6 +104,24 @@ def test_organization_slug_present(self):
query_params = parse_qs(urlparse(response.url).query)
self._post_successful_auth_assertions(query_params, org_slug)

@capture_any_output()
def test_invalid_email_raise_validation_error(self):
invalid_email = 'invalid_email@example'
relay_state = self._get_relay_state(
redirect_url='https://captive-portal.example.com', org_slug='default'
)
saml_response, relay_state = self._get_saml_response_for_acs_view(
relay_state, uid=invalid_email
)
with self.assertRaises(ValidationError):
self.client.post(
reverse('radius:saml2_acs'),
{
'SAMLResponse': self.b64_for_post(saml_response),
'RelayState': relay_state,
},
)

@capture_any_output()
def test_relay_state_relative_path(self):
expected_redirect_path = '/captive/portal/page'
Expand Down