From a2c2ba48e5e78e96fadf91a11fc4aaa8cd95914b Mon Sep 17 00:00:00 2001 From: Alex Shovlin Date: Thu, 26 Sep 2024 10:08:10 -0400 Subject: [PATCH 01/11] Add the OAuth Authorization Code Flow with PKCE This adds support for the OAuth2.0 authorization code flow with PKCE to the aws sso login command. It is the new default behavior, but users can fall back to the device code flow using the new --use-device-code option. --- .changes/next-release/feature-sso-81096.json | 5 + awscli/botocore/exceptions.py | 12 + awscli/botocore/utils.py | 333 ++++++++++++++++--- awscli/customizations/sso/index.html | 176 ++++++++++ awscli/customizations/sso/login.py | 1 + awscli/customizations/sso/utils.py | 139 +++++++- tests/functional/sso/__init__.py | 27 +- tests/functional/sso/test_login.py | 290 ++++++++++++++-- tests/unit/customizations/sso/test_utils.py | 57 +++- 9 files changed, 952 insertions(+), 88 deletions(-) create mode 100644 .changes/next-release/feature-sso-81096.json create mode 100644 awscli/customizations/sso/index.html diff --git a/.changes/next-release/feature-sso-81096.json b/.changes/next-release/feature-sso-81096.json new file mode 100644 index 000000000000..e4419d6bbd19 --- /dev/null +++ b/.changes/next-release/feature-sso-81096.json @@ -0,0 +1,5 @@ +{ + "type": "feature", + "category": "sso", + "description": "Add support and default to the OAuth 2.0 Authorization Code Flow with PKCE for aws sso login." +} diff --git a/awscli/botocore/exceptions.py b/awscli/botocore/exceptions.py index 115d2442a572..b92fa6bac0b2 100644 --- a/awscli/botocore/exceptions.py +++ b/awscli/botocore/exceptions.py @@ -682,6 +682,10 @@ class SSOTokenLoadError(SSOError): fmt = "Error loading SSO Token: {error_msg}" +class AuthorizationCodeLoadError(SSOError): + fmt = "Error loading authorization code: {error_msg}" + + class UnauthorizedSSOTokenError(SSOError): fmt = ( "The SSO session associated with this profile has expired or is " @@ -690,6 +694,14 @@ class UnauthorizedSSOTokenError(SSOError): ) +class AuthCodeFetcherError(SSOError): + fmt = ( + "Unable to initialize the OAuth 2.0 authorization callback handler: " + "{error_msg} \n You may use --use-device-code to fall back to the " + "device code flow which does not require the callback handler." + ) + + class CapacityNotAvailableError(BotoCoreError): fmt = ( 'Insufficient request capacity available.' diff --git a/awscli/botocore/utils.py b/awscli/botocore/utils.py index 5a7614211a4d..cf7c90630a37 100644 --- a/awscli/botocore/utils.py +++ b/awscli/botocore/utils.py @@ -23,7 +23,10 @@ import random import re import socket +import string import time +import urllib +import uuid import warnings import weakref from datetime import datetime as _DatetimeClass @@ -48,6 +51,7 @@ zip_longest, ) from botocore.exceptions import ( + AuthorizationCodeLoadError, ClientError, ConfigNotFound, ConnectionClosedError, @@ -3057,20 +3061,16 @@ def __call__(self): return token_file.read() -class SSOTokenFetcher(object): - # The device flow RFC defines the slow down delay to be an additional - # 5 seconds: - # https://tools.ietf.org/html/draft-ietf-oauth-device-flow-15#section-3.5 - _SLOW_DOWN_DELAY = 5 - # The default interval of 5 is also defined in the RFC (see above link) - _DEFAULT_INTERVAL = 5 +class BaseSSOTokenFetcher(object): + """Base class for SSO token fetchers, for functionality + shared between the device and authorization code grant flows. + """ _EXPIRY_WINDOW = 15 * 60 _CLIENT_REGISTRATION_TYPE = 'public' - _GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:device_code' def __init__( self, sso_region, client_creator, cache=None, - on_pending_authorization=None, time_fetcher=None, sleep=None, + on_pending_authorization=None, time_fetcher=None ): self._sso_region = sso_region self._client_creator = client_creator @@ -3080,10 +3080,6 @@ def __init__( time_fetcher = self._utc_now self._time_fetcher = time_fetcher - if sleep is None: - sleep = time.sleep - self._sleep = sleep - if cache is None: cache = {} self._cache = cache @@ -3101,6 +3097,15 @@ def _is_expired(self, response): seconds = total_seconds(end_time - self._time_fetcher()) return seconds < self._EXPIRY_WINDOW + def _is_registration_for_auth_code(self, registration): + if ('grantTypes' in registration and + 'authorization_code' in registration['grantTypes']): + return True + + # Else assume that it's device flow, + # since the CLI didn't cache grantTypes previously + return False + @CachedProperty def _client(self): config = botocore.config.Config( @@ -3109,13 +3114,61 @@ def _client(self): ) return self._client_creator('sso-oidc', config=config) - def _register_client(self, session_name, scopes): + def _generate_client_name(self, session_name): if session_name is None: # Use a timestamp for the session name for legacy configuration timestamp = datetime2timestamp(self._time_fetcher()) session_name = int(timestamp) + return f'botocore-client-{str(session_name)}' + + def _registration_cache_key(self, start_url, session_name, scopes): + # Registration is unique based on the following properties to ensure + # modifications to the registration do not affect the permissions of + # tokens derived for other start URLs. + args = { + 'tool': 'botocore', + 'startUrl': start_url, + 'region': self._sso_region, + 'scopes': scopes, + 'session_name': session_name, + } + cache_args = json.dumps(args, sort_keys=True).encode('utf-8') + return hashlib.sha1(cache_args).hexdigest() + + def _token_cache_key(self, start_url, session_name): + input_str = start_url + if session_name is not None: + input_str = session_name + return hashlib.sha1(input_str.encode('utf-8')).hexdigest() + + +class SSOTokenFetcher(BaseSSOTokenFetcher): + """Performs the device grant OAuth2.0 flow""" + # The device flow RFC defines the slow-down delay to be an additional + # 5 seconds: + # https://tools.ietf.org/html/draft-ietf-oauth-device-flow-15#section-3.5 + _SLOW_DOWN_DELAY = 5 + # The default interval of 5 is also defined in the RFC (see above link) + _DEFAULT_INTERVAL = 5 + _EXPIRY_WINDOW = 15 * 60 + _GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:device_code' + + def __init__( + self, sso_region, client_creator, cache=None, + on_pending_authorization=None, time_fetcher=None, sleep=None + ): + super().__init__( + sso_region, client_creator, cache, on_pending_authorization, + time_fetcher + ) + + if sleep is None: + sleep = time.sleep + self._sleep = sleep + + def _register_client(self, session_name, scopes): register_kwargs = { - 'clientName': f'botocore-client-{session_name}', + 'clientName': self._generate_client_name(session_name), 'clientType': self._CLIENT_REGISTRATION_TYPE, } if scopes: @@ -3132,20 +3185,6 @@ def _register_client(self, session_name, scopes): registration['scopes'] = scopes return registration - def _registration_cache_key(self, start_url, session_name, scopes): - # Registration is unique based on the following properties to ensure - # modifications to the registration do not affect the permissions of - # tokens derived for other start URLs. - args = { - 'tool': 'botocore', - 'startUrl': start_url, - 'region': self._sso_region, - 'scopes': scopes, - 'session_name': session_name, - } - cache_args = json.dumps(args, sort_keys=True).encode('utf-8') - return hashlib.sha1(cache_args).hexdigest() - def _registration( self, start_url, @@ -3160,7 +3199,8 @@ def _registration( ) if not force_refresh and cache_key in self._cache: registration = self._cache[cache_key] - if not self._is_expired(registration): + if (not self._is_expired(registration) and + not self._is_registration_for_auth_code(registration)): return registration registration = self._register_client( @@ -3172,7 +3212,7 @@ def _registration( def _authorize_client(self, start_url, registration): # NOTE: The authorization response is not cached. These responses are - # short lived (currently only 10 minutes) and can only be exchanged for + # short-lived (currently only 10 minutes) and can only be exchanged for # a token once. Having multiple clients share this is problematic. response = self._client.start_device_authorization( clientId=registration['clientId'], @@ -3261,12 +3301,6 @@ def _create_token_attempt( raise PendingAuthorizationExpiredError() return interval, None - def _token_cache_key(self, start_url, session_name): - input_str = start_url - if session_name is not None: - input_str = session_name - return hashlib.sha1(input_str.encode('utf-8')).hexdigest() - def _token( self, start_url, @@ -3305,6 +3339,231 @@ def fetch_token( ) +class SSOTokenFetcherAuth(BaseSSOTokenFetcher): + """Performs the authorization code grant with PKCE OAuth2.0 flow""" + _AUTH_GRANT_TYPES = ['authorization_code', 'refresh_token'] + _AUTH_GRANT_DEFAULT_SCOPE = 'sso:account:access' + + def __init__( + self, sso_region, client_creator, auth_code_fetcher, cache=None, + on_pending_authorization=None, time_fetcher=None + ): + super().__init__(sso_region, client_creator, cache, + on_pending_authorization, time_fetcher + ) + + self._auth_code_fetcher = auth_code_fetcher + + # Generate the PKCE pair + self.code_verifier = ''.join( + random.SystemRandom().choice( + string.ascii_letters + string.digits + '-._~' + ) + for _ in range(64) + ) + self.code_challenge = base64.urlsafe_b64encode( + hashlib.sha256(self.code_verifier.encode()).digest()).decode() + + def _register_client(self, session_name, scopes, redirect_uri, issuer_url): + register_kwargs = { + 'clientName': self._generate_client_name(session_name), + 'clientType': self._CLIENT_REGISTRATION_TYPE, + 'grantTypes': self._AUTH_GRANT_TYPES, + 'redirectUris': [redirect_uri], + 'issuerUrl': issuer_url + } + + if scopes: + register_kwargs['scopes'] = scopes + else: + register_kwargs['scopes'] = [self._AUTH_GRANT_DEFAULT_SCOPE] + + response = self._client.register_client(**register_kwargs) + + expires_at = response['clientSecretExpiresAt'] + expires_at = datetime.datetime.fromtimestamp(expires_at, tzutc()) + registration = { + 'clientId': response['clientId'], + 'clientSecret': response['clientSecret'], + 'expiresAt': expires_at, + 'scopes': register_kwargs['scopes'], + 'grantTypes': register_kwargs['grantTypes'] + } + + return registration + + def _registration( + self, + start_url, + session_name, + scopes, + force_refresh=False, + ): + cache_key = self._registration_cache_key( + start_url, + session_name, + scopes, + ) + if not force_refresh and cache_key in self._cache: + registration = self._cache[cache_key] + if (not self._is_expired(registration) and + self._is_registration_for_auth_code(registration)): + return registration + + registration = self._register_client( + session_name, + scopes, + self._auth_code_fetcher.redirect_uri_without_port(), + start_url + ) + self._cache[cache_key] = registration + return registration + + def _extract_resolved_endpoint(self, params, **kwargs): + """Event handler for before-call that will extract the resolved endpoint + for a given request without actually running it + """ + # This will contain any path and query params specific to + # the operation/input, so extract just the scheme and hostname + if params['url']: + parsed = urlparse(params['url']) + self._base_endpoint = f'{parsed.scheme}://{parsed.netloc}' + + # Return a tuple containing the "response" to short-circuit the request + return botocore.awsrequest.AWSResponse(None, 200, {}, None), {} + + def _get_base_authorization_uri(self): + """Simulates an SSO-OIDC request so that we can extract the "base" + endpoint for the current client to use for the un-modeled Authorize + operation + """ + self._client.meta.events.register('before-call', + self._extract_resolved_endpoint) + self._client.register_client( + clientName='temp', + clientType='public' + ) + self._client.meta.events.unregister('before-call', + self._extract_resolved_endpoint) + + return self._base_endpoint + + def _get_authorization_uri( + self, + client_id, + registration_scopes, + expected_state): + + query_params = { + 'response_type': 'code', + 'client_id': client_id, + 'redirect_uri': self._auth_code_fetcher.redirect_uri_with_port(), + 'state': expected_state, + 'code_challenge_method': 'S256' + # Don't want to encode code_challenge again, so we append below + } + + # For the query param, scopes must be space separated before encoding + if registration_scopes: + query_params['scope'] = " ".join(registration_scopes) + + return ( + f'{self._get_base_authorization_uri()}/authorize?' + f'{urllib.parse.urlencode(query_params)}' + f'&code_challenge={self.code_challenge[:-1]}' # trim final '=' + ) + + def _get_new_token(self, start_url, session_name, registration_scopes): + registration = self._registration( + start_url, + session_name, + registration_scopes, + ) + + expected_state = uuid.uuid4() + + authorization_uri = self._get_authorization_uri( + registration['clientId'], + registration_scopes, + expected_state) + + # Even though there's just one uri, this matches the inputs + # for the device code flow so that we can reuse the browser handlers + authorization_args = { + 'verificationUri': authorization_uri, + 'verificationUriComplete': authorization_uri, + 'userCode': None + } + + # Open/display the link, then block until the redirect uri is hit and + # the auth code is retrieved + self._on_pending_authorization(**authorization_args) + auth_code, state = self._auth_code_fetcher.get_auth_code_and_state() + + if auth_code is None: + raise AuthorizationCodeLoadError( + error_msg='Failed to retrieve an authorization code.' + ) + + if state != expected_state: + raise AuthorizationCodeLoadError( + error_msg='State parameter does not match expected value.' + ) + + return self._create_token_(start_url, registration, auth_code) + + def _create_token_(self, start_url, registration, auth_code): + try: + response = self._client.create_token( + grantType='authorization_code', + clientId=registration['clientId'], + clientSecret=registration['clientSecret'], + redirectUri=self._auth_code_fetcher.redirect_uri_with_port(), + codeVerifier=self.code_verifier, + code=auth_code + ) + expires_in = datetime.timedelta(seconds=response['expiresIn']) + token = { + 'startUrl': start_url, + 'region': self._sso_region, + 'accessToken': response['accessToken'], + 'expiresAt': self._time_fetcher() + expires_in, + # Cache the registration alongside the token + 'clientId': registration['clientId'], + 'clientSecret': registration['clientSecret'], + 'registrationExpiresAt': registration['expiresAt'], + } + if 'refreshToken' in response: + token['refreshToken'] = response['refreshToken'] + return token + except self._client.exceptions.InvalidGrantException as error: + print(error) + except self._client.exceptions.ExpiredTokenException: + raise PendingAuthorizationExpiredError() + + def fetch_token( + self, + start_url, + force_refresh=False, + registration_scopes=None, + session_name=None, + ): + cache_key = self._token_cache_key(start_url, session_name) + # Only obey the token cache if we are not forcing a refresh. + if not force_refresh and cache_key in self._cache: + token = self._cache[cache_key] + if not self._is_expired(token): + return token + + token = self._get_new_token( + start_url, + session_name, + registration_scopes + ) + self._cache[cache_key] = token + return token + + class SSOTokenLoader(object): def __init__(self, cache=None): if cache is None: diff --git a/awscli/customizations/sso/index.html b/awscli/customizations/sso/index.html new file mode 100644 index 000000000000..f1572f4b04f9 --- /dev/null +++ b/awscli/customizations/sso/index.html @@ -0,0 +1,176 @@ + + + + + AWS Authentication + + + + + +
+
+ + + + + +
+
+ +
+
+ +
+

Request approved

+

+
+
+

+
+ + + +
+
+ + + diff --git a/awscli/customizations/sso/login.py b/awscli/customizations/sso/login.py index 4d7790b62c76..acd58e287dd4 100644 --- a/awscli/customizations/sso/login.py +++ b/awscli/customizations/sso/login.py @@ -52,6 +52,7 @@ def _run_main(self, parsed_args, parsed_globals): force_refresh=True, session_name=sso_config.get('session_name'), registration_scopes=sso_config.get('registration_scopes'), + fallback_to_device_flow=parsed_args.use_device_code ) success_msg = 'Successfully logged into Start URL: %s\n' uni_print(success_msg % sso_config['sso_start_url']) diff --git a/awscli/customizations/sso/utils.py b/awscli/customizations/sso/utils.py index ae9a83e9e8b8..d7d813bf11cc 100644 --- a/awscli/customizations/sso/utils.py +++ b/awscli/customizations/sso/utils.py @@ -11,19 +11,22 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import datetime -import os -import logging import json +import logging +import os import webbrowser +from functools import partial +from http.server import HTTPServer, BaseHTTPRequestHandler +from urllib.parse import urlparse, parse_qs -from botocore.utils import SSOTokenFetcher -from botocore.utils import original_ld_library_path from botocore.credentials import JSONFileCache +from botocore.exceptions import AuthCodeFetcherError +from botocore.utils import SSOTokenFetcher, SSOTokenFetcherAuth +from botocore.utils import original_ld_library_path from awscli.customizations.commands import BasicCommand -from awscli.customizations.utils import uni_print -from awscli.customizations.assumerole import CACHE_DIR as AWS_CREDS_CACHE_DIR from awscli.customizations.exceptions import ConfigurationError +from awscli.customizations.utils import uni_print LOG = logging.getLogger(__name__) @@ -37,9 +40,18 @@ 'action': 'store_true', 'default': False, 'help_text': ( - 'Disables automatically opening the verfication URL in the ' + 'Disables automatically opening the verification URL in the ' 'default browser.' ) + }, + { + 'name': 'use-device-code', + 'action': 'store_true', + 'default': False, + 'help_text': ( + 'Uses the Device Code authorization grant and login flow ' + 'instead of the Authorization Code flow.' + ) } ] @@ -56,24 +68,38 @@ def _sso_json_dumps(obj): def do_sso_login(session, sso_region, start_url, token_cache=None, on_pending_authorization=None, force_refresh=False, - registration_scopes=None, session_name=None): + registration_scopes=None, session_name=None, + fallback_to_device_flow=False): if token_cache is None: token_cache = JSONFileCache(SSO_TOKEN_DIR, dumps_func=_sso_json_dumps) if on_pending_authorization is None: on_pending_authorization = OpenBrowserHandler( open_browser=open_browser_with_original_ld_path ) - token_fetcher = SSOTokenFetcher( - sso_region=sso_region, - client_creator=session.create_client, - cache=token_cache, - on_pending_authorization=on_pending_authorization - ) + + # For the auth flow, we need a non-legacy sso-session and check that the + # user hasn't opted into falling back to the device code flow + if session_name and not fallback_to_device_flow: + token_fetcher = SSOTokenFetcherAuth( + sso_region=sso_region, + client_creator=session.create_client, + auth_code_fetcher=AuthCodeFetcher(), + cache=token_cache, + on_pending_authorization=on_pending_authorization + ) + else: + token_fetcher = SSOTokenFetcher( + sso_region=sso_region, + client_creator=session.create_client, + cache=token_cache, + on_pending_authorization=on_pending_authorization + ) + return token_fetcher.fetch_token( start_url=start_url, session_name=session_name, force_refresh=force_refresh, - registration_scopes=registration_scopes, + registration_scopes=registration_scopes ) @@ -110,13 +136,19 @@ def __call__( f'Browser will not be automatically opened.\n' f'Please visit the following URL:\n' f'\n{verificationUri}\n' + + ) + + user_code_msg = ( f'\nThen enter the code:\n' f'\n{userCode}\n' f'\nAlternatively, you may visit the following URL which will ' f'autofill the code upon loading:' - f'\n{verificationUriComplete}\n' - ) + f'\n{verificationUriComplete}\n') + uni_print(opening_msg, self._outfile) + if userCode: + uni_print(user_code_msg, self._outfile) class OpenBrowserHandler(BaseAuthorizationhandler): @@ -135,10 +167,16 @@ def __call__( f'to use a different device to authorize this request, open the ' f'following URL:\n' f'\n{verificationUri}\n' + ) + + user_code_msg = ( f'\nThen enter the code:\n' f'\n{userCode}\n' ) uni_print(opening_msg, self._outfile) + if userCode: + uni_print(user_code_msg, self._outfile) + if self._open_browser: try: return self._open_browser(verificationUriComplete) @@ -146,6 +184,73 @@ def __call__( LOG.debug('Failed to open browser:', exc_info=True) +class AuthCodeFetcher: + """Manages the local web server that will be used + to retrieve the authorization code from the OAuth callback + """ + def __init__(self): + self._auth_code = None + self._state = None + self._is_done = False + + # We do this so that the request handler can have a reference to this + # AuthCodeFetcher so that it can pass back the state and auth code + try: + handler = partial(OAuthCallbackHandler, self) + self.http_server = HTTPServer(('', 0), handler) + except Exception as e: + raise AuthCodeFetcherError(error_msg=e) + + def redirect_uri_without_port(self): + return 'http://127.0.0.1/oauth/callback' + + def redirect_uri_with_port(self): + return f'http://127.0.0.1:{self.http_server.server_port}/oauth/callback' + + def get_auth_code_and_state(self): + """Blocks until the expected redirect request with either the + authorization code/state or and error is handled + """ + while not self._is_done: + self.http_server.handle_request() + self.http_server.server_close() + + return self._auth_code, self._state + + +class OAuthCallbackHandler(BaseHTTPRequestHandler): + """HTTP handler to handle OAuth callback requests, extracting + the auth code and state parameters, and displaying a page directing + the user to return to the CLI. + """ + def __init__(self, auth_code_fetcher, *args, **kwargs): + self._auth_code_fetcher = auth_code_fetcher + super().__init__(*args, **kwargs) + + def log_message(self, format, *args): + # Suppress built-in logging, otherwise it prints + # each request to console + pass + + def do_GET(self): + self.send_response(200) + self.end_headers() + with open( + os.path.join(os.path.dirname(__file__), 'index.html'), + 'rb') as file: + self.wfile.write(file.read()) + + query_params = parse_qs(urlparse(self.path).query) + + if 'error' in query_params: + self._auth_code_fetcher._is_done = True + return + elif 'code' in query_params and 'state' in query_params: + self._auth_code_fetcher._is_done = True + self._auth_code_fetcher._auth_code = query_params['code'][0] + self._auth_code_fetcher._state = query_params['state'][0] + + class InvalidSSOConfigError(ConfigurationError): pass diff --git a/tests/functional/sso/__init__.py b/tests/functional/sso/__init__.py index 98036f9173b8..ace3e5e63b03 100644 --- a/tests/functional/sso/__init__.py +++ b/tests/functional/sso/__init__.py @@ -13,7 +13,7 @@ import time from awscli.clidriver import AWSCLIEntryPoint -from awscli.customizations.sso.utils import OpenBrowserHandler +from awscli.customizations.sso.utils import OpenBrowserHandler, AuthCodeFetcher from awscli.testutils import create_clidriver from awscli.testutils import FileCreator from awscli.testutils import BaseAWSCommandParamsTest @@ -45,6 +45,29 @@ def setUp(self): self.open_browser_mock, ) self.open_browser_patch.start() + + self.fetcher_mock = mock.Mock(spec=AuthCodeFetcher) + self.fetcher_mock.return_value.redirect_uri_without_port.return_value = ( + 'http://127.0.0.1/oauth/callback' + ) + self.fetcher_mock.return_value.redirect_uri_with_port.return_value = ( + 'http://127.0.0.1:55555/oauth/callback' + ) + self.fetcher_mock.return_value.get_auth_code_and_state.return_value = ( + "abc", "00000000-0000-0000-0000-000000000000" + ) + self.auth_code_fetcher_patch = mock.patch( + 'awscli.customizations.sso.utils.AuthCodeFetcher', + self.fetcher_mock, + ) + self.auth_code_fetcher_patch.start() + + self.uuid_mock = mock.Mock( + return_value="00000000-0000-0000-0000-000000000000" + ) + self.uuid_patch = mock.patch('uuid.uuid4', self.uuid_mock) + self.uuid_patch.start() + self.expires_in = 28800 self.expiration_time = time.time() + 1000 @@ -52,6 +75,8 @@ def tearDown(self): super(BaseSSOTest, self).tearDown() self.files.remove_all() self.open_browser_patch.stop() + self.auth_code_fetcher_patch.stop() + self.uuid_patch.stop() self.token_cache_dir_patch.stop() def assert_used_expected_sso_region(self, expected_region): diff --git a/tests/functional/sso/test_login.py b/tests/functional/sso/test_login.py index e24c47591716..f8684908db21 100644 --- a/tests/functional/sso/test_login.py +++ b/tests/functional/sso/test_login.py @@ -23,8 +23,8 @@ class TestLoginCommand(BaseSSOTest): r'\d\d\d\d-\d\d-\d\dT\d\d:\d\d:\d\dZ' ) - def add_oidc_workflow_responses(self, access_token, - include_register_response=True): + def add_oidc_device_responses(self, access_token, + include_register_response=True): responses = [ # StartDeviceAuthorization response { @@ -53,12 +53,50 @@ def add_oidc_workflow_responses(self, access_token, 0, { 'clientSecretExpiresAt': self.expiration_time, - 'clientId': 'foo-client-id', - 'clientSecret': 'foo-client-secret', + 'clientId': 'device-client-id', + 'clientSecret': 'device-client-secret', } ) self.parsed_responses = responses + def add_oidc_auth_code_responses(self, access_token, + include_register_response=True): + responses = [ + # CreateToken responses + { + 'expiresIn': self.expires_in, + 'tokenType': 'Bearer', + 'accessToken': access_token, + } + ] + if include_register_response: + responses.insert( + 0, + { + 'clientSecretExpiresAt': self.expiration_time, + 'clientId': 'auth-client-id', + 'clientSecret': 'auth-client-secret', + } + ) + self.parsed_responses = responses + + def assert_cache_contains_registration( + self, + start_url, + session_name, + scopes, + expected_client_id): + cached_files = os.listdir(self.token_cache_dir) + + cached_registration_filename = self._get_cached_registration_filename( + start_url, session_name, scopes) + + self.assertIn(cached_registration_filename, cached_files) + self.assertEqual( + self._get_cached_response(cached_registration_filename)['clientId'], + expected_client_id + ) + def assert_cache_contains_token( self, start_url, @@ -78,6 +116,17 @@ def assert_cache_contains_token( expected_token ) + def _get_cached_registration_filename(self, start_url, session_name, scopes): + args = { + 'tool': 'botocore', + 'startUrl': start_url, + 'region': self.sso_region, + 'scopes': scopes, + 'session_name': session_name, + } + cache_args = json.dumps(args, sort_keys=True).encode('utf-8') + return hashlib.sha1(cache_args).hexdigest() + '.json' + def _get_cached_token_filename(self, start_url, session_name): to_hash = start_url if session_name: @@ -101,8 +150,19 @@ def assert_cache_token_expiration_time_format_is_correct(self): ) ) - def test_login(self): - self.add_oidc_workflow_responses(self.access_token) + def test_login_explicit_device(self): + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') + self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_token( + start_url=self.start_url, + expected_token=self.access_token + ) + + def test_login_implicit_device(self): + # This is a legacy profile via setUp, so we expect + # it to fall back to device flow automatically + self.add_oidc_device_responses(self.access_token) self.run_cmd('sso login') self.assert_used_expected_sso_region(expected_region=self.sso_region) self.assert_cache_contains_token( @@ -110,9 +170,9 @@ def test_login(self): expected_token=self.access_token ) - def test_login_no_browser(self): - self.add_oidc_workflow_responses(self.access_token) - stdout, _, _ = self.run_cmd('sso login --no-browser') + def test_login_device_no_browser(self): + self.add_oidc_device_responses(self.access_token) + stdout, _, _ = self.run_cmd('sso login --use-device-code --no-browser') self.assertIn('Browser will not be automatically opened.', stdout) self.open_browser_mock.assert_not_called() self.assert_used_expected_sso_region(expected_region=self.sso_region) @@ -121,20 +181,86 @@ def test_login_no_browser(self): expected_token=self.access_token ) - def test_login_forces_refresh(self): - self.add_oidc_workflow_responses(self.access_token) + def test_login_auth_no_browser(self): + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_auth_code_responses(self.access_token) + stdout, _, _ = self.run_cmd('sso login --no-browser') + self.assertIn('Browser will not be automatically opened.', stdout) + self.open_browser_mock.assert_not_called() + self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='auth-client-id' + ) + self.assert_cache_contains_token( + start_url=self.start_url, + expected_token=self.access_token, + session_name='test-session' + ) + + def test_login_device_forces_refresh(self): + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') + # The register response from the first login should have been + # cached. + self.add_oidc_device_responses( + 'new.token', include_register_response=False) + self.run_cmd('sso login --use-device-code') + self.assert_cache_contains_token( + start_url=self.start_url, + expected_token='new.token', + ) + + def test_login_auth_forces_refresh(self): + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_auth_code_responses(self.access_token) self.run_cmd('sso login') # The register response from the first login should have been # cached. - self.add_oidc_workflow_responses( + self.add_oidc_auth_code_responses( 'new.token', include_register_response=False) self.run_cmd('sso login') self.assert_cache_contains_token( start_url=self.start_url, - expected_token='new.token' + expected_token='new.token', + session_name='test-session' + ) + + def test_login_auth_after_device_forces_refresh(self): + self.set_config_file_content( + content=self.get_sso_session_config('test-session')) + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') + # The register response from the first login should have been + # cached. + self.add_oidc_auth_code_responses('new.token') + self.run_cmd('sso login') + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='auth-client-id' + ) + self.assert_cache_contains_token( + start_url=self.start_url, + expected_token='new.token', + session_name='test-session' + ) + + def test_login_device_no_sso_configuration(self): + self.set_config_file_content(content='') + _, stderr, _ = self.run_cmd('sso login --use-device-code', + expected_rc=253) + self.assertIn( + 'Missing the following required SSO configuration', + stderr ) - def test_login_no_sso_configuration(self): + def test_login_auth_no_sso_configuration(self): self.set_config_file_content(content='') _, stderr, _ = self.run_cmd('sso login', expected_rc=253) self.assertIn( @@ -142,14 +268,14 @@ def test_login_no_sso_configuration(self): stderr ) - def test_login_minimal_sso_configuration(self): + def test_login_device_minimal_sso_configuration(self): content = ( '[default]\n' 'sso_start_url={start_url}\n' 'sso_region={sso_region}\n' ).format(start_url=self.start_url, sso_region=self.sso_region) self.set_config_file_content(content=content) - self.add_oidc_workflow_responses(self.access_token) + self.add_oidc_device_responses(self.access_token) self.run_cmd('sso login') self.assert_used_expected_sso_region(expected_region=self.sso_region) self.assert_cache_contains_token( @@ -157,13 +283,15 @@ def test_login_minimal_sso_configuration(self): expected_token=self.access_token ) - def test_login_partially_missing_sso_configuration(self): + def test_login_device_partially_missing_sso_configuration(self): content = ( '[default]\n' 'sso_start_url=%s\n' % self.start_url ) self.set_config_file_content(content=content) - _, stderr, _ = self.run_cmd('sso login', expected_rc=253) + _, stderr, _ = self.run_cmd( + 'sso login --use-device-code', expected_rc=253 + ) self.assertIn( 'Missing the following required SSO configuration', stderr @@ -174,8 +302,8 @@ def test_login_partially_missing_sso_configuration(self): self.assertNotIn('sso_role_name', stderr) def test_token_cache_datetime_format(self): - self.add_oidc_workflow_responses(self.access_token) - self.run_cmd('sso login') + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') self.assert_used_expected_sso_region(expected_region=self.sso_region) self.assert_cache_contains_token( start_url=self.start_url, @@ -183,38 +311,115 @@ def test_token_cache_datetime_format(self): ) self.assert_cache_token_expiration_time_format_is_correct() - def test_login_sso_session(self): + def test_login_device_sso_session(self): + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') + self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='device-client-id' + ) + self.assert_cache_contains_token( + start_url=self.start_url, + session_name='test-session', + expected_token=self.access_token, + ) + + def test_login_auth_sso_session(self): content = self.get_sso_session_config('test-session') self.set_config_file_content(content=content) - self.add_oidc_workflow_responses(self.access_token) + self.add_oidc_auth_code_responses(self.access_token) self.run_cmd('sso login') self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='auth-client-id' + ) + self.assert_cache_contains_token( + start_url=self.start_url, + session_name='test-session', + expected_token=self.access_token, + ) + + def test_login_device_sso_with_explicit_sso_session_arg(self): + content = self.get_sso_session_config( + 'test-session', include_profile=False) + self.set_config_file_content(content=content) + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --sso-session test-session --use-device-code') + self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='device-client-id' + ) self.assert_cache_contains_token( start_url=self.start_url, session_name='test-session', expected_token=self.access_token, ) - def test_login_sso_with_explicit_sso_session_arg(self): + def test_login_auth_sso_with_explicit_sso_session_arg(self): content = self.get_sso_session_config( 'test-session', include_profile=False) self.set_config_file_content(content=content) - self.add_oidc_workflow_responses(self.access_token) + self.add_oidc_auth_code_responses(self.access_token) self.run_cmd('sso login --sso-session test-session') self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='auth-client-id' + ) + self.assert_cache_contains_token( + start_url=self.start_url, + session_name='test-session', + expected_token=self.access_token, + ) + + def test_login_device_sso_session_with_scopes(self): + self.registration_scopes = ['sso:foo', 'sso:bar'] + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_device_responses(self.access_token) + self.run_cmd('sso login --use-device-code') + self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='device-client-id' + ) self.assert_cache_contains_token( start_url=self.start_url, session_name='test-session', expected_token=self.access_token, ) + operation, params = self.operations_called[0] + self.assertEqual(operation.name, 'RegisterClient') + self.assertEqual(params.get('scopes'), self.registration_scopes) - def test_login_sso_session_with_scopes(self): + def test_login_auth_sso_session_with_scopes(self): self.registration_scopes = ['sso:foo', 'sso:bar'] content = self.get_sso_session_config('test-session') self.set_config_file_content(content=content) - self.add_oidc_workflow_responses(self.access_token) + self.add_oidc_auth_code_responses(self.access_token) self.run_cmd('sso login') self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_cache_contains_registration( + start_url=self.start_url, + session_name='test-session', + scopes=self.registration_scopes, + expected_client_id='auth-client-id' + ) self.assert_cache_contains_token( start_url=self.start_url, session_name='test-session', @@ -245,3 +450,36 @@ def test_login_sso_session_missing(self): self.set_config_file_content(content=content) _, stderr, _ = self.run_cmd('sso login', expected_rc=253) self.assertIn('sso-session does not exist: "test"', stderr) + + def test_login_auth_sso_no_authorization_code_throws_error(self): + self.fetcher_mock.return_value.get_auth_code_and_state.return_value = ( + None, None + ) + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_auth_code_responses(self.access_token) + + _, stderr, _ = self.run_cmd( + 'sso login', expected_rc=255 + ) + self.assertIn( + 'Failed to retrieve an authorization code.', + stderr + ) + + def test_login_auth_sso_state_mismatch_throws_error(self): + self.fetcher_mock.return_value.get_auth_code_and_state.return_value = ( + "abc", '00000000-0000-0000-0000-000000000001' + ) + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_auth_code_responses(self.access_token) + + _, stderr, _ = self.run_cmd( + 'sso login', expected_rc=255 + ) + self.assertIn( + 'State parameter does not match expected value.', + stderr + ) + diff --git a/tests/unit/customizations/sso/test_utils.py b/tests/unit/customizations/sso/test_utils.py index 47f5762b3d8c..7a24acee3baf 100644 --- a/tests/unit/customizations/sso/test_utils.py +++ b/tests/unit/customizations/sso/test_utils.py @@ -11,22 +11,23 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import os +import threading import webbrowser import pytest - -from awscli.testutils import mock -from awscli.testutils import unittest - +import urllib3 from botocore.session import Session -from botocore.exceptions import ClientError from awscli.compat import StringIO -from awscli.customizations.sso.utils import parse_sso_registration_scopes -from awscli.customizations.sso.utils import do_sso_login from awscli.customizations.sso.utils import OpenBrowserHandler from awscli.customizations.sso.utils import PrintOnlyHandler +from awscli.customizations.sso.utils import do_sso_login from awscli.customizations.sso.utils import open_browser_with_original_ld_path +from awscli.customizations.sso.utils import ( + parse_sso_registration_scopes, AuthCodeFetcher +) +from awscli.testutils import mock +from awscli.testutils import unittest @pytest.mark.parametrize( @@ -205,3 +206,45 @@ def test_can_patch_env(self): os.environ) open_browser_with_original_ld_path('http://example.com') self.assertIsNone(captured_env.get('LD_LIBRARY_PATH')) + + +class TestAuthCodeFetcher(unittest.TestCase): + """Tests for the AuthCodeFetcher class, which is the local + web server we use to handle the OAuth 2.0 callback + """ + + def setUp(self): + self.fetcher = AuthCodeFetcher() + self.url = f'http://127.0.0.1:{self.fetcher.http_server.server_address[1]}/' + + # Start the server on a background thread so that + # the test thread can make the request + self.server_thread = threading.Thread( + target=self.fetcher.get_auth_code_and_state + ) + self.server_thread.setDaemon(True) + self.server_thread.start() + + def test_expected_auth_code(self): + expected_code = '1234' + expected_state = '4567' + url = self.url + f'?code={expected_code}&state={expected_state}' + + http = urllib3.PoolManager() + response = http.request("GET", url) + + self.assertEqual(response.status, 200) + self.assertEqual(self.fetcher._auth_code, expected_code) + self.assertEqual(self.fetcher._state, expected_state) + + def test_error(self): + expected_code = 'Failed' + url = self.url + f'?error={expected_code}' + + http = urllib3.PoolManager() + response = http.request("GET", url) + + self.assertEqual(response.status, 200) + self.assertEqual(self.fetcher._auth_code, None) + self.assertEqual(self.fetcher._state, None) + From 6a63b68dff82f4938796f17b21d708bbefe5ff3f Mon Sep 17 00:00:00 2001 From: Alex Shovlin Date: Thu, 26 Sep 2024 10:39:37 -0400 Subject: [PATCH 02/11] Add back import that PyCharm thought was unused --- awscli/customizations/sso/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/awscli/customizations/sso/utils.py b/awscli/customizations/sso/utils.py index d7d813bf11cc..1f999d241022 100644 --- a/awscli/customizations/sso/utils.py +++ b/awscli/customizations/sso/utils.py @@ -24,6 +24,7 @@ from botocore.utils import SSOTokenFetcher, SSOTokenFetcherAuth from botocore.utils import original_ld_library_path +from awscli.customizations.assumerole import CACHE_DIR as AWS_CREDS_CACHE_DIR from awscli.customizations.commands import BasicCommand from awscli.customizations.exceptions import ConfigurationError from awscli.customizations.utils import uni_print From 9f93faf203005fa9e710d47d8e31b49bf7d7bc6d Mon Sep 17 00:00:00 2001 From: Alex Shovlin Date: Thu, 26 Sep 2024 11:04:43 -0400 Subject: [PATCH 03/11] Fix setDaemon warning on 3.10+ --- tests/unit/customizations/sso/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/customizations/sso/test_utils.py b/tests/unit/customizations/sso/test_utils.py index 7a24acee3baf..44ec2a7628ac 100644 --- a/tests/unit/customizations/sso/test_utils.py +++ b/tests/unit/customizations/sso/test_utils.py @@ -222,7 +222,7 @@ def setUp(self): self.server_thread = threading.Thread( target=self.fetcher.get_auth_code_and_state ) - self.server_thread.setDaemon(True) + self.server_thread.daemon = True self.server_thread.start() def test_expected_auth_code(self): From 494d0f72778d05ac8b51c59d8303c359dda8ef07 Mon Sep 17 00:00:00 2001 From: Alex Shovlin Date: Fri, 4 Oct 2024 10:28:20 -0400 Subject: [PATCH 04/11] Apply style/formatting suggestions from code review Co-authored-by: Nate Prewitt --- awscli/botocore/utils.py | 92 +++++++++++++++++++----------- awscli/customizations/sso/utils.py | 8 ++- 2 files changed, 65 insertions(+), 35 deletions(-) diff --git a/awscli/botocore/utils.py b/awscli/botocore/utils.py index cf7c90630a37..ca5b2598ab74 100644 --- a/awscli/botocore/utils.py +++ b/awscli/botocore/utils.py @@ -3098,8 +3098,10 @@ def _is_expired(self, response): return seconds < self._EXPIRY_WINDOW def _is_registration_for_auth_code(self, registration): - if ('grantTypes' in registration and - 'authorization_code' in registration['grantTypes']): + if ( + 'grantTypes' in registration + and 'authorization_code' in registration['grantTypes'] + ): return True # Else assume that it's device flow, @@ -3154,12 +3156,20 @@ class SSOTokenFetcher(BaseSSOTokenFetcher): _GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:device_code' def __init__( - self, sso_region, client_creator, cache=None, - on_pending_authorization=None, time_fetcher=None, sleep=None + self, + sso_region, + client_creator, + cache=None, + on_pending_authorization=None, + time_fetcher=None, + sleep=None, ): super().__init__( - sso_region, client_creator, cache, on_pending_authorization, - time_fetcher + sso_region, + client_creator, + cache, + on_pending_authorization, + time_fetcher, ) if sleep is None: @@ -3199,8 +3209,10 @@ def _registration( ) if not force_refresh and cache_key in self._cache: registration = self._cache[cache_key] - if (not self._is_expired(registration) and - not self._is_registration_for_auth_code(registration)): + if ( + not self._is_expired(registration) + and not self._is_registration_for_auth_code(registration) + ): return registration registration = self._register_client( @@ -3341,16 +3353,25 @@ def fetch_token( class SSOTokenFetcherAuth(BaseSSOTokenFetcher): """Performs the authorization code grant with PKCE OAuth2.0 flow""" - _AUTH_GRANT_TYPES = ['authorization_code', 'refresh_token'] + _AUTH_GRANT_TYPES = ('authorization_code', 'refresh_token') _AUTH_GRANT_DEFAULT_SCOPE = 'sso:account:access' def __init__( - self, sso_region, client_creator, auth_code_fetcher, cache=None, - on_pending_authorization=None, time_fetcher=None + self, + sso_region, + client_creator, + auth_code_fetcher, + cache=None, + on_pending_authorization=None, + time_fetcher=None, ): - super().__init__(sso_region, client_creator, cache, - on_pending_authorization, time_fetcher - ) + super().__init__( + sso_region, + client_creator, + cache, + on_pending_authorization, + time_fetcher, + ) self._auth_code_fetcher = auth_code_fetcher @@ -3362,7 +3383,8 @@ def __init__( for _ in range(64) ) self.code_challenge = base64.urlsafe_b64encode( - hashlib.sha256(self.code_verifier.encode()).digest()).decode() + hashlib.sha256(self.code_verifier.encode()).digest() + ).decode() def _register_client(self, session_name, scopes, redirect_uri, issuer_url): register_kwargs = { @@ -3393,11 +3415,11 @@ def _register_client(self, session_name, scopes, redirect_uri, issuer_url): return registration def _registration( - self, - start_url, - session_name, - scopes, - force_refresh=False, + self, + start_url, + session_name, + scopes, + force_refresh=False, ): cache_key = self._registration_cache_key( start_url, @@ -3406,8 +3428,10 @@ def _registration( ) if not force_refresh and cache_key in self._cache: registration = self._cache[cache_key] - if (not self._is_expired(registration) and - self._is_registration_for_auth_code(registration)): + if ( + not self._is_expired(registration) + and self._is_registration_for_auth_code(registration) + ): return registration registration = self._register_client( @@ -3437,22 +3461,25 @@ def _get_base_authorization_uri(self): endpoint for the current client to use for the un-modeled Authorize operation """ - self._client.meta.events.register('before-call', - self._extract_resolved_endpoint) + self._client.meta.events.register( + 'before-call', self._extract_resolved_endpoint + ) self._client.register_client( clientName='temp', clientType='public' ) - self._client.meta.events.unregister('before-call', - self._extract_resolved_endpoint) + self._client.meta.events.unregister( + 'before-call', self._extract_resolved_endpoint + ) return self._base_endpoint def _get_authorization_uri( - self, - client_id, - registration_scopes, - expected_state): + self, + client_id, + registration_scopes, + expected_state + ): query_params = { 'response_type': 'code', @@ -3485,9 +3512,10 @@ def _get_new_token(self, start_url, session_name, registration_scopes): authorization_uri = self._get_authorization_uri( registration['clientId'], registration_scopes, - expected_state) + expected_state + ) - # Even though there's just one uri, this matches the inputs + # Even though there's just one URI, this matches the inputs # for the device code flow so that we can reuse the browser handlers authorization_args = { 'verificationUri': authorization_uri, diff --git a/awscli/customizations/sso/utils.py b/awscli/customizations/sso/utils.py index 1f999d241022..de7e918866c2 100644 --- a/awscli/customizations/sso/utils.py +++ b/awscli/customizations/sso/utils.py @@ -145,7 +145,8 @@ def __call__( f'\n{userCode}\n' f'\nAlternatively, you may visit the following URL which will ' f'autofill the code upon loading:' - f'\n{verificationUriComplete}\n') + f'\n{verificationUriComplete}\n' + ) uni_print(opening_msg, self._outfile) if userCode: @@ -237,8 +238,9 @@ def do_GET(self): self.send_response(200) self.end_headers() with open( - os.path.join(os.path.dirname(__file__), 'index.html'), - 'rb') as file: + os.path.join(os.path.dirname(__file__), 'index.html'), + 'rb', + ) as file: self.wfile.write(file.read()) query_params = parse_qs(urlparse(self.path).query) From 241750a3b49ad83632540a7fcfc48b50bdf93a40 Mon Sep 17 00:00:00 2001 From: Alex Shovlin Date: Tue, 15 Oct 2024 11:48:17 -0400 Subject: [PATCH 05/11] PR feedback: more styling, preferring botocore.compat imports, improve the callback handling codepath, fix bug with state verification --- .changes/next-release/feature-sso-81096.json | 4 +- awscli/botocore/exceptions.py | 2 +- awscli/botocore/utils.py | 30 ++++--- awscli/customizations/sso/login.py | 2 +- awscli/customizations/sso/utils.py | 88 +++++++++++--------- tests/functional/sso/__init__.py | 3 +- tests/unit/customizations/sso/test_utils.py | 19 +++-- 7 files changed, 82 insertions(+), 66 deletions(-) diff --git a/.changes/next-release/feature-sso-81096.json b/.changes/next-release/feature-sso-81096.json index e4419d6bbd19..53acbc72cc32 100644 --- a/.changes/next-release/feature-sso-81096.json +++ b/.changes/next-release/feature-sso-81096.json @@ -1,5 +1,5 @@ { "type": "feature", - "category": "sso", - "description": "Add support and default to the OAuth 2.0 Authorization Code Flow with PKCE for aws sso login." + "category": "``sso``", + "description": "Add support and default to the OAuth 2.0 Authorization Code Flow with PKCE for ``aws sso login``." } diff --git a/awscli/botocore/exceptions.py b/awscli/botocore/exceptions.py index b92fa6bac0b2..4633141c9dfe 100644 --- a/awscli/botocore/exceptions.py +++ b/awscli/botocore/exceptions.py @@ -674,7 +674,7 @@ class SSOError(BotoCoreError): class PendingAuthorizationExpiredError(SSOError): fmt = ( "The pending authorization to retrieve an SSO token has expired. The " - "device authorization flow to retrieve an SSO token must be restarted." + "login flow to retrieve an SSO token must be restarted." ) diff --git a/awscli/botocore/utils.py b/awscli/botocore/utils.py index ca5b2598ab74..d2a9290b5793 100644 --- a/awscli/botocore/utils.py +++ b/awscli/botocore/utils.py @@ -22,10 +22,10 @@ import os import random import re +import secrets import socket import string import time -import urllib import uuid import warnings import weakref @@ -45,6 +45,7 @@ json, quote, total_seconds, + urlencode, urlparse, urlsplit, urlunsplit, @@ -3176,6 +3177,15 @@ def __init__( sleep = time.sleep self._sleep = sleep + def fetch_token( + self, + start_url, + force_refresh, + registration_scopes, + session_name, + ): + raise NotImplementedError('Must implement fetch_token()') + def _register_client(self, session_name, scopes): register_kwargs = { 'clientName': self._generate_client_name(session_name), @@ -3377,7 +3387,7 @@ def __init__( # Generate the PKCE pair self.code_verifier = ''.join( - random.SystemRandom().choice( + secrets.choice( string.ascii_letters + string.digits + '-._~' ) for _ in range(64) @@ -3392,14 +3402,10 @@ def _register_client(self, session_name, scopes, redirect_uri, issuer_url): 'clientType': self._CLIENT_REGISTRATION_TYPE, 'grantTypes': self._AUTH_GRANT_TYPES, 'redirectUris': [redirect_uri], - 'issuerUrl': issuer_url + 'issuerUrl': issuer_url, + 'scopes': scopes or [self._AUTH_GRANT_DEFAULT_SCOPE], } - if scopes: - register_kwargs['scopes'] = scopes - else: - register_kwargs['scopes'] = [self._AUTH_GRANT_DEFAULT_SCOPE] - response = self._client.register_client(**register_kwargs) expires_at = response['clientSecretExpiresAt'] @@ -3496,7 +3502,7 @@ def _get_authorization_uri( return ( f'{self._get_base_authorization_uri()}/authorize?' - f'{urllib.parse.urlencode(query_params)}' + f'{urlencode(query_params)}' f'&code_challenge={self.code_challenge[:-1]}' # trim final '=' ) @@ -3533,7 +3539,9 @@ def _get_new_token(self, start_url, session_name, registration_scopes): error_msg='Failed to retrieve an authorization code.' ) - if state != expected_state: + # The state we get back from the redirect is just a string, so + # cast our original UUID before comparing + if state != str(expected_state): raise AuthorizationCodeLoadError( error_msg='State parameter does not match expected value.' ) @@ -3564,8 +3572,6 @@ def _create_token_(self, start_url, registration, auth_code): if 'refreshToken' in response: token['refreshToken'] = response['refreshToken'] return token - except self._client.exceptions.InvalidGrantException as error: - print(error) except self._client.exceptions.ExpiredTokenException: raise PendingAuthorizationExpiredError() diff --git a/awscli/customizations/sso/login.py b/awscli/customizations/sso/login.py index acd58e287dd4..2c7917b82660 100644 --- a/awscli/customizations/sso/login.py +++ b/awscli/customizations/sso/login.py @@ -52,7 +52,7 @@ def _run_main(self, parsed_args, parsed_globals): force_refresh=True, session_name=sso_config.get('session_name'), registration_scopes=sso_config.get('registration_scopes'), - fallback_to_device_flow=parsed_args.use_device_code + use_device_code=parsed_args.use_device_code, ) success_msg = 'Successfully logged into Start URL: %s\n' uni_print(success_msg % sso_config['sso_start_url']) diff --git a/awscli/customizations/sso/utils.py b/awscli/customizations/sso/utils.py index de7e918866c2..afa0cf94b53d 100644 --- a/awscli/customizations/sso/utils.py +++ b/awscli/customizations/sso/utils.py @@ -14,11 +14,12 @@ import json import logging import os +import socket import webbrowser from functools import partial from http.server import HTTPServer, BaseHTTPRequestHandler -from urllib.parse import urlparse, parse_qs +from botocore.compat import urlparse, parse_qs from botocore.credentials import JSONFileCache from botocore.exceptions import AuthCodeFetcherError from botocore.utils import SSOTokenFetcher, SSOTokenFetcherAuth @@ -70,7 +71,7 @@ def _sso_json_dumps(obj): def do_sso_login(session, sso_region, start_url, token_cache=None, on_pending_authorization=None, force_refresh=False, registration_scopes=None, session_name=None, - fallback_to_device_flow=False): + use_device_code=False): if token_cache is None: token_cache = JSONFileCache(SSO_TOKEN_DIR, dumps_func=_sso_json_dumps) if on_pending_authorization is None: @@ -80,27 +81,27 @@ def do_sso_login(session, sso_region, start_url, token_cache=None, # For the auth flow, we need a non-legacy sso-session and check that the # user hasn't opted into falling back to the device code flow - if session_name and not fallback_to_device_flow: + if session_name and not use_device_code: token_fetcher = SSOTokenFetcherAuth( sso_region=sso_region, client_creator=session.create_client, auth_code_fetcher=AuthCodeFetcher(), cache=token_cache, - on_pending_authorization=on_pending_authorization + on_pending_authorization=on_pending_authorization, ) else: token_fetcher = SSOTokenFetcher( sso_region=sso_region, client_creator=session.create_client, cache=token_cache, - on_pending_authorization=on_pending_authorization + on_pending_authorization=on_pending_authorization, ) return token_fetcher.fetch_token( start_url=start_url, session_name=session_name, force_refresh=force_refresh, - registration_scopes=registration_scopes + registration_scopes=registration_scopes, ) @@ -198,9 +199,9 @@ def __init__(self): # We do this so that the request handler can have a reference to this # AuthCodeFetcher so that it can pass back the state and auth code try: - handler = partial(OAuthCallbackHandler, self) + handler = partial(self.OAuthCallbackHandler, self) self.http_server = HTTPServer(('', 0), handler) - except Exception as e: + except socket.error as e: raise AuthCodeFetcherError(error_msg=e) def redirect_uri_without_port(self): @@ -219,39 +220,46 @@ def get_auth_code_and_state(self): return self._auth_code, self._state + def set_auth_code_and_state(self, auth_code, state): + self._auth_code = auth_code + self._state = state + self._is_done = True -class OAuthCallbackHandler(BaseHTTPRequestHandler): - """HTTP handler to handle OAuth callback requests, extracting - the auth code and state parameters, and displaying a page directing - the user to return to the CLI. - """ - def __init__(self, auth_code_fetcher, *args, **kwargs): - self._auth_code_fetcher = auth_code_fetcher - super().__init__(*args, **kwargs) - - def log_message(self, format, *args): - # Suppress built-in logging, otherwise it prints - # each request to console - pass - - def do_GET(self): - self.send_response(200) - self.end_headers() - with open( - os.path.join(os.path.dirname(__file__), 'index.html'), - 'rb', - ) as file: - self.wfile.write(file.read()) - - query_params = parse_qs(urlparse(self.path).query) - - if 'error' in query_params: - self._auth_code_fetcher._is_done = True - return - elif 'code' in query_params and 'state' in query_params: - self._auth_code_fetcher._is_done = True - self._auth_code_fetcher._auth_code = query_params['code'][0] - self._auth_code_fetcher._state = query_params['state'][0] + class OAuthCallbackHandler(BaseHTTPRequestHandler): + """HTTP handler to handle OAuth callback requests, extracting + the auth code and state parameters, and displaying a page directing + the user to return to the CLI. + """ + def __init__(self, auth_code_fetcher, *args, **kwargs): + self._auth_code_fetcher = auth_code_fetcher + super().__init__(*args, **kwargs) + + def log_message(self, format, *args): + # Suppress built-in logging, otherwise it prints + # each request to console + pass + + def do_GET(self): + self.send_response(200) + self.end_headers() + with open( + os.path.join(os.path.dirname(__file__), 'index.html'), + 'rb', + ) as file: + self.wfile.write(file.read()) + + query_params = parse_qs(urlparse(self.path).query) + + if 'error' in query_params: + self._auth_code_fetcher.set_auth_code_and_state( + None, + None, + ) + elif 'code' in query_params and 'state' in query_params: + self._auth_code_fetcher.set_auth_code_and_state( + query_params['code'][0], + query_params['state'][0], + ) class InvalidSSOConfigError(ConfigurationError): diff --git a/tests/functional/sso/__init__.py b/tests/functional/sso/__init__.py index ace3e5e63b03..ac8e9ef2ecc5 100644 --- a/tests/functional/sso/__init__.py +++ b/tests/functional/sso/__init__.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import time +import uuid from awscli.clidriver import AWSCLIEntryPoint from awscli.customizations.sso.utils import OpenBrowserHandler, AuthCodeFetcher @@ -63,7 +64,7 @@ def setUp(self): self.auth_code_fetcher_patch.start() self.uuid_mock = mock.Mock( - return_value="00000000-0000-0000-0000-000000000000" + return_value=uuid.UUID("00000000-0000-0000-0000-000000000000") ) self.uuid_patch = mock.patch('uuid.uuid4', self.uuid_mock) self.uuid_patch.start() diff --git a/tests/unit/customizations/sso/test_utils.py b/tests/unit/customizations/sso/test_utils.py index 44ec2a7628ac..818de315144c 100644 --- a/tests/unit/customizations/sso/test_utils.py +++ b/tests/unit/customizations/sso/test_utils.py @@ -208,12 +208,12 @@ def test_can_patch_env(self): self.assertIsNone(captured_env.get('LD_LIBRARY_PATH')) -class TestAuthCodeFetcher(unittest.TestCase): +class TestAuthCodeFetcher: """Tests for the AuthCodeFetcher class, which is the local web server we use to handle the OAuth 2.0 callback """ - def setUp(self): + def setup_method(self): self.fetcher = AuthCodeFetcher() self.url = f'http://127.0.0.1:{self.fetcher.http_server.server_address[1]}/' @@ -233,9 +233,10 @@ def test_expected_auth_code(self): http = urllib3.PoolManager() response = http.request("GET", url) - self.assertEqual(response.status, 200) - self.assertEqual(self.fetcher._auth_code, expected_code) - self.assertEqual(self.fetcher._state, expected_state) + actual_code, actual_state = self.fetcher.get_auth_code_and_state() + assert response.status == 200 + assert actual_code == expected_code + assert actual_state == expected_state def test_error(self): expected_code = 'Failed' @@ -244,7 +245,7 @@ def test_error(self): http = urllib3.PoolManager() response = http.request("GET", url) - self.assertEqual(response.status, 200) - self.assertEqual(self.fetcher._auth_code, None) - self.assertEqual(self.fetcher._state, None) - + actual_code, actual_state = self.fetcher.get_auth_code_and_state() + assert response.status == 200 + assert actual_code is None + assert actual_state is None From 6afacfe2f066d1d3e6ad25b24db727a16d1cc47c Mon Sep 17 00:00:00 2001 From: Alex Shovlin Date: Wed, 16 Oct 2024 15:46:25 -0400 Subject: [PATCH 06/11] Move handler back and method we're expecting to override --- awscli/botocore/utils.py | 18 ++++---- awscli/customizations/sso/utils.py | 73 +++++++++++++++--------------- 2 files changed, 46 insertions(+), 45 deletions(-) diff --git a/awscli/botocore/utils.py b/awscli/botocore/utils.py index d2a9290b5793..1e16b7c33d4d 100644 --- a/awscli/botocore/utils.py +++ b/awscli/botocore/utils.py @@ -3085,6 +3085,15 @@ def __init__( cache = {} self._cache = cache + def fetch_token( + self, + start_url, + force_refresh, + registration_scopes, + session_name, + ): + raise NotImplementedError('Must implement fetch_token()') + def _utc_now(self): return datetime.datetime.now(tzutc()) @@ -3177,15 +3186,6 @@ def __init__( sleep = time.sleep self._sleep = sleep - def fetch_token( - self, - start_url, - force_refresh, - registration_scopes, - session_name, - ): - raise NotImplementedError('Must implement fetch_token()') - def _register_client(self, session_name, scopes): register_kwargs = { 'clientName': self._generate_client_name(session_name), diff --git a/awscli/customizations/sso/utils.py b/awscli/customizations/sso/utils.py index afa0cf94b53d..21920a333f97 100644 --- a/awscli/customizations/sso/utils.py +++ b/awscli/customizations/sso/utils.py @@ -199,7 +199,7 @@ def __init__(self): # We do this so that the request handler can have a reference to this # AuthCodeFetcher so that it can pass back the state and auth code try: - handler = partial(self.OAuthCallbackHandler, self) + handler = partial(OAuthCallbackHandler, self) self.http_server = HTTPServer(('', 0), handler) except socket.error as e: raise AuthCodeFetcherError(error_msg=e) @@ -225,41 +225,42 @@ def set_auth_code_and_state(self, auth_code, state): self._state = state self._is_done = True - class OAuthCallbackHandler(BaseHTTPRequestHandler): - """HTTP handler to handle OAuth callback requests, extracting - the auth code and state parameters, and displaying a page directing - the user to return to the CLI. - """ - def __init__(self, auth_code_fetcher, *args, **kwargs): - self._auth_code_fetcher = auth_code_fetcher - super().__init__(*args, **kwargs) - - def log_message(self, format, *args): - # Suppress built-in logging, otherwise it prints - # each request to console - pass - - def do_GET(self): - self.send_response(200) - self.end_headers() - with open( - os.path.join(os.path.dirname(__file__), 'index.html'), - 'rb', - ) as file: - self.wfile.write(file.read()) - - query_params = parse_qs(urlparse(self.path).query) - - if 'error' in query_params: - self._auth_code_fetcher.set_auth_code_and_state( - None, - None, - ) - elif 'code' in query_params and 'state' in query_params: - self._auth_code_fetcher.set_auth_code_and_state( - query_params['code'][0], - query_params['state'][0], - ) + +class OAuthCallbackHandler(BaseHTTPRequestHandler): + """HTTP handler to handle OAuth callback requests, extracting + the auth code and state parameters, and displaying a page directing + the user to return to the CLI. + """ + def __init__(self, auth_code_fetcher, *args, **kwargs): + self._auth_code_fetcher = auth_code_fetcher + super().__init__(*args, **kwargs) + + def log_message(self, format, *args): + # Suppress built-in logging, otherwise it prints + # each request to console + pass + + def do_GET(self): + self.send_response(200) + self.end_headers() + with open( + os.path.join(os.path.dirname(__file__), 'index.html'), + 'rb', + ) as file: + self.wfile.write(file.read()) + + query_params = parse_qs(urlparse(self.path).query) + + if 'error' in query_params: + self._auth_code_fetcher.set_auth_code_and_state( + None, + None, + ) + elif 'code' in query_params and 'state' in query_params: + self._auth_code_fetcher.set_auth_code_and_state( + query_params['code'][0], + query_params['state'][0], + ) class InvalidSSOConfigError(ConfigurationError): From 6aa693f583d0bdf30fcd2924e9916e65f82e1506 Mon Sep 17 00:00:00 2001 From: Alex Shovlin Date: Wed, 16 Oct 2024 15:59:35 -0400 Subject: [PATCH 07/11] Add timeout for the OAuth redirect --- awscli/customizations/sso/utils.py | 18 ++++++++++++++++-- tests/unit/customizations/sso/test_utils.py | 19 ++++++++++++++++++- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/awscli/customizations/sso/utils.py b/awscli/customizations/sso/utils.py index 21920a333f97..18fa2be4bcd9 100644 --- a/awscli/customizations/sso/utils.py +++ b/awscli/customizations/sso/utils.py @@ -15,13 +15,17 @@ import logging import os import socket +import time import webbrowser from functools import partial from http.server import HTTPServer, BaseHTTPRequestHandler from botocore.compat import urlparse, parse_qs from botocore.credentials import JSONFileCache -from botocore.exceptions import AuthCodeFetcherError +from botocore.exceptions import ( + AuthCodeFetcherError, + PendingAuthorizationExpiredError, +) from botocore.utils import SSOTokenFetcher, SSOTokenFetcherAuth from botocore.utils import original_ld_library_path @@ -191,6 +195,11 @@ class AuthCodeFetcher: """Manages the local web server that will be used to retrieve the authorization code from the OAuth callback """ + # How many seconds handle_request should wait for an incoming request + _REQUEST_TIMEOUT = 10 + # How long we wait overall for the callback + _OVERALL_TIMEOUT = 60 * 10 + def __init__(self): self._auth_code = None self._state = None @@ -201,6 +210,7 @@ def __init__(self): try: handler = partial(OAuthCallbackHandler, self) self.http_server = HTTPServer(('', 0), handler) + self.http_server.timeout = self._REQUEST_TIMEOUT except socket.error as e: raise AuthCodeFetcherError(error_msg=e) @@ -214,10 +224,14 @@ def get_auth_code_and_state(self): """Blocks until the expected redirect request with either the authorization code/state or and error is handled """ - while not self._is_done: + start = time.time() + while not self._is_done and time.time() < start + self._OVERALL_TIMEOUT: self.http_server.handle_request() self.http_server.server_close() + if not self._is_done: + raise PendingAuthorizationExpiredError + return self._auth_code, self._state def set_auth_code_and_state(self, auth_code, state): diff --git a/tests/unit/customizations/sso/test_utils.py b/tests/unit/customizations/sso/test_utils.py index 818de315144c..3efd3cb3a573 100644 --- a/tests/unit/customizations/sso/test_utils.py +++ b/tests/unit/customizations/sso/test_utils.py @@ -13,9 +13,10 @@ import os import threading import webbrowser - import pytest import urllib3 + +from botocore.exceptions import PendingAuthorizationExpiredError from botocore.session import Session from awscli.compat import StringIO @@ -249,3 +250,19 @@ def test_error(self): assert response.status == 200 assert actual_code is None assert actual_state is None + + +@mock.patch( + 'awscli.customizations.sso.utils.AuthCodeFetcher._REQUEST_TIMEOUT', + 0.1 +) +@mock.patch( + 'awscli.customizations.sso.utils.AuthCodeFetcher._OVERALL_TIMEOUT', + 0.1 +) +def test_get_auth_code_and_state_timeout(): + """Tests the timeout case separately of TestAuthCodeFetcher, + since we need to override the constants + """ + with pytest.raises(PendingAuthorizationExpiredError): + AuthCodeFetcher().get_auth_code_and_state() From 0f50729d5783d56126baa64989017208dc7602e4 Mon Sep 17 00:00:00 2001 From: Alex Shovlin Date: Wed, 16 Oct 2024 16:11:25 -0400 Subject: [PATCH 08/11] Add tests for just OAuthCallbackHandler without actually starting the server --- tests/unit/customizations/sso/test_utils.py | 57 ++++++++++++++++++++- 1 file changed, 55 insertions(+), 2 deletions(-) diff --git a/tests/unit/customizations/sso/test_utils.py b/tests/unit/customizations/sso/test_utils.py index 3efd3cb3a573..3a583dc29b39 100644 --- a/tests/unit/customizations/sso/test_utils.py +++ b/tests/unit/customizations/sso/test_utils.py @@ -19,13 +19,13 @@ from botocore.exceptions import PendingAuthorizationExpiredError from botocore.session import Session -from awscli.compat import StringIO +from awscli.compat import BytesIO, StringIO from awscli.customizations.sso.utils import OpenBrowserHandler from awscli.customizations.sso.utils import PrintOnlyHandler from awscli.customizations.sso.utils import do_sso_login from awscli.customizations.sso.utils import open_browser_with_original_ld_path from awscli.customizations.sso.utils import ( - parse_sso_registration_scopes, AuthCodeFetcher + parse_sso_registration_scopes, AuthCodeFetcher, OAuthCallbackHandler ) from awscli.testutils import mock from awscli.testutils import unittest @@ -209,6 +209,59 @@ def test_can_patch_env(self): self.assertIsNone(captured_env.get('LD_LIBRARY_PATH')) +class MockRequest(object): + def __init__(self, request): + self._request = request + + def makefile(self, *args, **kwargs): + return BytesIO(self._request) + + def sendall(self, data): + pass + + +class TestOAuthCallbackHandler: + """Tests for OAuthCallbackHandler, which handles + individual requests that we receive at the callback uri + """ + def test_expected_query_params(self): + fetcher = mock.Mock(AuthCodeFetcher) + + OAuthCallbackHandler( + fetcher, + MockRequest(b'GET /?state=123&code=456'), + mock.MagicMock(), + mock.MagicMock(), + ) + fetcher.set_auth_code_and_state.assert_called_once_with('456', '123') + + def test_error(self): + fetcher = mock.Mock(AuthCodeFetcher) + + OAuthCallbackHandler( + fetcher, + MockRequest(b'GET /?error=Error%20message'), + mock.MagicMock(), + mock.MagicMock(), + ) + + fetcher.set_auth_code_and_state.assert_called_once_with(None, None) + + def test_missing_expected_query_params(self): + fetcher = mock.Mock(AuthCodeFetcher) + + # We generally don't expect to be missing the expected query params, + # but if we do we expect the server to keep waiting for a valid callback + OAuthCallbackHandler( + fetcher, + MockRequest(b'GET /'), + mock.MagicMock(), + mock.MagicMock(), + ) + + fetcher.set_auth_code_and_state.assert_not_called() + + class TestAuthCodeFetcher: """Tests for the AuthCodeFetcher class, which is the local web server we use to handle the OAuth 2.0 callback From f2dd8e3d1c5fe072eb981be6bd47c329ad71bff1 Mon Sep 17 00:00:00 2001 From: Alex Shovlin Date: Wed, 30 Oct 2024 17:14:47 -0400 Subject: [PATCH 09/11] Remove second redirect and override server's host header --- awscli/customizations/sso/index.html | 9 ++------- awscli/customizations/sso/utils.py | 5 +++++ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/awscli/customizations/sso/index.html b/awscli/customizations/sso/index.html index f1572f4b04f9..214a32fe1fa6 100644 --- a/awscli/customizations/sso/index.html +++ b/awscli/customizations/sso/index.html @@ -138,7 +138,7 @@

Request denied

-

You can close this window and re-start the authorization flow

+

You can close this window and re-start the authorization flow.

@@ -158,12 +158,7 @@

Request denied

).innerText = `${productName} has been given requested permissions` document.getElementById( 'footerText' - ).innerText = `You can close this window and start using the ${productName}` - - const redirectUri = params.get('redirectUri') - if (redirectUri) { - window.location.replace(redirectUri) - } + ).innerText = `You can close this window and start using the ${productName}.` function showErrorMessage(errorText) { document.getElementById('approved-auth').classList.add('hidden') diff --git a/awscli/customizations/sso/utils.py b/awscli/customizations/sso/utils.py index 18fa2be4bcd9..1e9aef74d120 100644 --- a/awscli/customizations/sso/utils.py +++ b/awscli/customizations/sso/utils.py @@ -29,6 +29,7 @@ from botocore.utils import SSOTokenFetcher, SSOTokenFetcherAuth from botocore.utils import original_ld_library_path +from awscli import __version__ as awscli_version from awscli.customizations.assumerole import CACHE_DIR as AWS_CREDS_CACHE_DIR from awscli.customizations.commands import BasicCommand from awscli.customizations.exceptions import ConfigurationError @@ -254,6 +255,10 @@ def log_message(self, format, *args): # each request to console pass + def version_string(self): + # Override the Host header in case helpful for debugging + return f'AWS CLI/{awscli_version}' + def do_GET(self): self.send_response(200) self.end_headers() From c96d05cbcb85be64fbdbf44944b7f432812dd14a Mon Sep 17 00:00:00 2001 From: Alex Shovlin Date: Mon, 11 Nov 2024 15:56:34 -0500 Subject: [PATCH 10/11] Fix key for scopes, fix defaulting for scopes, and add assertions on the authorize uris --- awscli/botocore/utils.py | 4 +++- tests/functional/sso/__init__.py | 33 ++++++++++++++++++++++++++++++ tests/functional/sso/test_login.py | 33 ++++++++++++++++++++++++++++++ 3 files changed, 69 insertions(+), 1 deletion(-) diff --git a/awscli/botocore/utils.py b/awscli/botocore/utils.py index 1e16b7c33d4d..368b597d3935 100644 --- a/awscli/botocore/utils.py +++ b/awscli/botocore/utils.py @@ -3498,7 +3498,9 @@ def _get_authorization_uri( # For the query param, scopes must be space separated before encoding if registration_scopes: - query_params['scope'] = " ".join(registration_scopes) + query_params['scopes'] = " ".join(registration_scopes) + else: + query_params['scopes'] = self._AUTH_GRANT_DEFAULT_SCOPE return ( f'{self._get_base_authorization_uri()}/authorize?' diff --git a/tests/functional/sso/__init__.py b/tests/functional/sso/__init__.py index ac8e9ef2ecc5..712cf5003f2c 100644 --- a/tests/functional/sso/__init__.py +++ b/tests/functional/sso/__init__.py @@ -83,6 +83,39 @@ def tearDown(self): def assert_used_expected_sso_region(self, expected_region): self.assertIn(expected_region, self.last_request_dict['url']) + def assert_device_browser_handler_called_with( + self, + userCode, + verificationUri, + verificationUriComplete, + ): + # assert_called_with is matching the __init__ parameters instead of + # __call__, so verify the arguments we're interested in this way + self.open_browser_mock.assert_called_once() + _, kwargs = self.open_browser_mock.return_value.call_args + self.assertEqual(userCode, kwargs['userCode']) + self.assertEqual(verificationUri, kwargs['verificationUri']) + self.assertEqual(verificationUriComplete, kwargs['verificationUriComplete']) + + def assert_auth_browser_handler_called_with(self, expected_scopes): + # The endpoint is subject to the endpoint rules, and the + # code_challenge is not fixed so assert against the rest of the url + expected_url = ( + 'authorize?' + 'response_type=code' + '&client_id=auth-client-id' + '&redirect_uri=http%3A%2F%2F127.0.0.1%3A55555%2Foauth%2Fcallback' + '&state=00000000-0000-0000-0000-000000000000' + '&code_challenge_method=S256' + '&scopes=' + expected_scopes + ) + + self.open_browser_mock.assert_called_once() + _, kwargs = self.open_browser_mock.return_value.call_args + self.assertEqual(None, kwargs['userCode']) + self.assertIn(expected_url, kwargs['verificationUri']) + self.assertIn(expected_url, kwargs['verificationUriComplete']) + def get_legacy_config(self): content = ( f'[default]\n' diff --git a/tests/functional/sso/test_login.py b/tests/functional/sso/test_login.py index f8684908db21..d2e79f014a10 100644 --- a/tests/functional/sso/test_login.py +++ b/tests/functional/sso/test_login.py @@ -154,6 +154,11 @@ def test_login_explicit_device(self): self.add_oidc_device_responses(self.access_token) self.run_cmd('sso login --use-device-code') self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_device_browser_handler_called_with( + 'foo', + 'https://sso.fake/device', + 'https://sso.verify', + ) self.assert_cache_contains_token( start_url=self.start_url, expected_token=self.access_token @@ -165,6 +170,11 @@ def test_login_implicit_device(self): self.add_oidc_device_responses(self.access_token) self.run_cmd('sso login') self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_device_browser_handler_called_with( + 'foo', + 'https://sso.fake/device', + 'https://sso.verify', + ) self.assert_cache_contains_token( start_url=self.start_url, expected_token=self.access_token @@ -278,6 +288,11 @@ def test_login_device_minimal_sso_configuration(self): self.add_oidc_device_responses(self.access_token) self.run_cmd('sso login') self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_device_browser_handler_called_with( + 'foo', + 'https://sso.fake/device', + 'https://sso.verify', + ) self.assert_cache_contains_token( start_url=self.start_url, expected_token=self.access_token @@ -317,6 +332,11 @@ def test_login_device_sso_session(self): self.add_oidc_device_responses(self.access_token) self.run_cmd('sso login --use-device-code') self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_device_browser_handler_called_with( + 'foo', + 'https://sso.fake/device', + 'https://sso.verify', + ) self.assert_cache_contains_registration( start_url=self.start_url, session_name='test-session', @@ -335,6 +355,7 @@ def test_login_auth_sso_session(self): self.add_oidc_auth_code_responses(self.access_token) self.run_cmd('sso login') self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_auth_browser_handler_called_with('sso%3Aaccount%3Aaccess') self.assert_cache_contains_registration( start_url=self.start_url, session_name='test-session', @@ -354,6 +375,11 @@ def test_login_device_sso_with_explicit_sso_session_arg(self): self.add_oidc_device_responses(self.access_token) self.run_cmd('sso login --sso-session test-session --use-device-code') self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_device_browser_handler_called_with( + 'foo', + 'https://sso.fake/device', + 'https://sso.verify', + ) self.assert_cache_contains_registration( start_url=self.start_url, session_name='test-session', @@ -373,6 +399,7 @@ def test_login_auth_sso_with_explicit_sso_session_arg(self): self.add_oidc_auth_code_responses(self.access_token) self.run_cmd('sso login --sso-session test-session') self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_auth_browser_handler_called_with('sso%3Aaccount%3Aaccess') self.assert_cache_contains_registration( start_url=self.start_url, session_name='test-session', @@ -392,6 +419,11 @@ def test_login_device_sso_session_with_scopes(self): self.add_oidc_device_responses(self.access_token) self.run_cmd('sso login --use-device-code') self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_device_browser_handler_called_with( + 'foo', + 'https://sso.fake/device', + 'https://sso.verify', + ) self.assert_cache_contains_registration( start_url=self.start_url, session_name='test-session', @@ -414,6 +446,7 @@ def test_login_auth_sso_session_with_scopes(self): self.add_oidc_auth_code_responses(self.access_token) self.run_cmd('sso login') self.assert_used_expected_sso_region(expected_region=self.sso_region) + self.assert_auth_browser_handler_called_with('sso%3Afoo+sso%3Abar') self.assert_cache_contains_registration( start_url=self.start_url, session_name='test-session', From 0fc2814fa432b45893596267b37e88625af50de9 Mon Sep 17 00:00:00 2001 From: Alex Shovlin Date: Tue, 12 Nov 2024 14:20:32 -0500 Subject: [PATCH 11/11] Switch from urlencode to our percent_encode (from + to %20) --- awscli/botocore/utils.py | 3 +-- tests/functional/sso/test_login.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/awscli/botocore/utils.py b/awscli/botocore/utils.py index 368b597d3935..affcf155baa6 100644 --- a/awscli/botocore/utils.py +++ b/awscli/botocore/utils.py @@ -45,7 +45,6 @@ json, quote, total_seconds, - urlencode, urlparse, urlsplit, urlunsplit, @@ -3504,7 +3503,7 @@ def _get_authorization_uri( return ( f'{self._get_base_authorization_uri()}/authorize?' - f'{urlencode(query_params)}' + f'{percent_encode_sequence(query_params)}' f'&code_challenge={self.code_challenge[:-1]}' # trim final '=' ) diff --git a/tests/functional/sso/test_login.py b/tests/functional/sso/test_login.py index d2e79f014a10..a2d6cc025079 100644 --- a/tests/functional/sso/test_login.py +++ b/tests/functional/sso/test_login.py @@ -446,7 +446,7 @@ def test_login_auth_sso_session_with_scopes(self): self.add_oidc_auth_code_responses(self.access_token) self.run_cmd('sso login') self.assert_used_expected_sso_region(expected_region=self.sso_region) - self.assert_auth_browser_handler_called_with('sso%3Afoo+sso%3Abar') + self.assert_auth_browser_handler_called_with('sso%3Afoo%20sso%3Abar') self.assert_cache_contains_registration( start_url=self.start_url, session_name='test-session',