Skip to content

Commit

Permalink
Add tests for just OAuthCallbackHandler without actually starting the…
Browse files Browse the repository at this point in the history
… server
  • Loading branch information
ashovlin committed Oct 16, 2024
1 parent 35876bb commit 6cbf403
Showing 1 changed file with 55 additions and 2 deletions.
57 changes: 55 additions & 2 deletions tests/unit/customizations/sso/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
from botocore.exceptions import PendingAuthorizationExpiredError
from botocore.session import Session

from awscli.compat import StringIO
from awscli.compat import BytesIO, StringIO
from awscli.customizations.sso.utils import OpenBrowserHandler
from awscli.customizations.sso.utils import PrintOnlyHandler
from awscli.customizations.sso.utils import do_sso_login
from awscli.customizations.sso.utils import open_browser_with_original_ld_path
from awscli.customizations.sso.utils import (
parse_sso_registration_scopes, AuthCodeFetcher
parse_sso_registration_scopes, AuthCodeFetcher, OAuthCallbackHandler
)
from awscli.testutils import mock
from awscli.testutils import unittest
Expand Down Expand Up @@ -209,6 +209,59 @@ def test_can_patch_env(self):
self.assertIsNone(captured_env.get('LD_LIBRARY_PATH'))


class MockRequest(object):
def __init__(self, request):
self._request = request

def makefile(self, *args, **kwargs):
return BytesIO(self._request)

def sendall(self, data):
pass


class TestOAuthCallbackHandler:
"""Tests for OAuthCallbackHandler, which handles
individual requests that we receive at the callback uri
"""
def test_expected_query_params(self):
fetcher = mock.Mock(AuthCodeFetcher)

OAuthCallbackHandler(
fetcher,
MockRequest(b'GET /?state=123&code=456'),
mock.MagicMock(),
mock.MagicMock(),
)
fetcher.set_auth_code_and_state.assert_called_once_with('456', '123')

def test_error(self):
fetcher = mock.Mock(AuthCodeFetcher)

OAuthCallbackHandler(
fetcher,
MockRequest(b'GET /?error=Error%20message'),
mock.MagicMock(),
mock.MagicMock(),
)

fetcher.set_auth_code_and_state.assert_called_once_with(None, None)

def test_missing_expected_query_params(self):
fetcher = mock.Mock(AuthCodeFetcher)

# We generally don't expect to be missing the expected query params,
# but if we do we expect the server to keep waiting for a valid callback
OAuthCallbackHandler(
fetcher,
MockRequest(b'GET /'),
mock.MagicMock(),
mock.MagicMock(),
)

fetcher.set_auth_code_and_state.assert_not_called()


class TestAuthCodeFetcher:
"""Tests for the AuthCodeFetcher class, which is the local
web server we use to handle the OAuth 2.0 callback
Expand Down

0 comments on commit 6cbf403

Please sign in to comment.