From 32ba2214547e7c551689a4a0fc1d9ecbc2cc34ca Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 21 Oct 2024 09:47:42 +0200 Subject: [PATCH 1/2] [Fix] Decouple OAuth functionality from `Config` (#784) ## Changes ### OAuth Refactoring Currently, OAuthClient uses Config internally to resolve the OIDC endpoints by passing the client ID and host to an internal Config instance and calling its `oidc_endpoints` method. This has a few drawbacks: 1. There is nearly a cyclical dependency: `Config` depends on methods in `oauth.py`, and `OAuthClient` depends on `Config`. This currently doesn't break because the `Config` import is done at runtime in the `OAuthClient` constructor. 2. Databricks supports both in-house OAuth and Azure Entra ID OAuth. Currently, the choice between these options depends on whether a user specifies the azure_client_id or client_id parameter in the Config. Because Config is used within OAuthClient, this means that OAuthClient needs to expose a parameter to configure either client_id or azure_client_id. Rather than having these classes deeply coupled to one another, we can allow users to fetch the OIDC endpoints for a given account/workspace as a top-level functionality and provide this to `OAuthClient`. This breaks the cyclic dependency and doesn't require `OAuthClient` to expose any unnecessary parameters. Further, I've also tried to remove the coupling of the other classes in `oauth.py` to `OAuthClient`. Currently, `OAuthClient` serves both as the mechanism to initialize OAuth and as a kind of configuration object, capturing OAuth endpoint URLs, client ID/secret, redirect URL, and scopes. Now, the parameters for each of these classes are explicit, removing all unnecessarily coupling between them. One nice advantage is that the Consent can be serialized/deserialized without any reference to the `OAuthClient` anymore. There is definitely more work to be done to simplify and clean up the OAuth implementation, but this should at least unblock users who need to use Azure Entra ID U2M OAuth in the SDK. ## Tests The new OIDC endpoint methods are tested, and those tests also verify that those endpoints are retried in case of rate limiting. I ran the flask app example against an AWS workspace, and I ran the external-browser demo example against AWS, Azure and GCP workspaces with the default client ID and with a newly created OAuth app with and without credentials. - [ ] `make test` run locally - [ ] `make fmt` applied - [ ] relevant integration tests applied --- databricks/sdk/_base_client.py | 20 +++ databricks/sdk/config.py | 44 ++--- databricks/sdk/credentials_provider.py | 31 ++-- databricks/sdk/oauth.py | 230 +++++++++++++++++++------ examples/external_browser_auth.py | 72 ++++++++ examples/flask_app_with_oauth.py | 64 +++---- tests/test_oauth.py | 155 +++++++++++++---- 7 files changed, 459 insertions(+), 157 deletions(-) create mode 100644 examples/external_browser_auth.py diff --git a/databricks/sdk/_base_client.py b/databricks/sdk/_base_client.py index 62c2974ec..95ce39cbe 100644 --- a/databricks/sdk/_base_client.py +++ b/databricks/sdk/_base_client.py @@ -1,4 +1,5 @@ import logging +import urllib.parse from datetime import timedelta from types import TracebackType from typing import (Any, BinaryIO, Callable, Dict, Iterable, Iterator, List, @@ -17,6 +18,25 @@ logger = logging.getLogger('databricks.sdk') +def _fix_host_if_needed(host: Optional[str]) -> Optional[str]: + if not host: + return host + + # Add a default scheme if it's missing + if '://' not in host: + host = 'https://' + host + + o = urllib.parse.urlparse(host) + # remove trailing slash + path = o.path.rstrip('/') + # remove port if 443 + netloc = o.netloc + if o.port == 443: + netloc = netloc.split(':')[0] + + return urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment)) + + class _BaseClient: def __init__(self, diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 5cae1b2b4..b4efdf603 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -10,11 +10,14 @@ import requests from . import useragent +from ._base_client import _fix_host_if_needed from .clock import Clock, RealClock from .credentials_provider import CredentialsStrategy, DefaultCredentials from .environments import (ALL_ENVS, AzureEnvironment, Cloud, DatabricksEnvironment, get_environment_for_hostname) -from .oauth import OidcEndpoints, Token +from .oauth import (OidcEndpoints, Token, get_account_endpoints, + get_azure_entra_id_workspace_endpoints, + get_workspace_endpoints) logger = logging.getLogger('databricks.sdk') @@ -254,24 +257,10 @@ def oidc_endpoints(self) -> Optional[OidcEndpoints]: if not self.host: return None if self.is_azure and self.azure_client_id: - # Retrieve authorize endpoint to retrieve token endpoint after - res = requests.get(f'{self.host}/oidc/oauth2/v2.0/authorize', allow_redirects=False) - real_auth_url = res.headers.get('location') - if not real_auth_url: - return None - return OidcEndpoints(authorization_endpoint=real_auth_url, - token_endpoint=real_auth_url.replace('/authorize', '/token')) + return get_azure_entra_id_workspace_endpoints(self.host) if self.is_account_client and self.account_id: - prefix = f'{self.host}/oidc/accounts/{self.account_id}' - return OidcEndpoints(authorization_endpoint=f'{prefix}/v1/authorize', - token_endpoint=f'{prefix}/v1/token') - oidc = f'{self.host}/oidc/.well-known/oauth-authorization-server' - res = requests.get(oidc) - if res.status_code != 200: - return None - auth_metadata = res.json() - return OidcEndpoints(authorization_endpoint=auth_metadata.get('authorization_endpoint'), - token_endpoint=auth_metadata.get('token_endpoint')) + return get_account_endpoints(self.host, self.account_id) + return get_workspace_endpoints(self.host) def debug_string(self) -> str: """ Returns log-friendly representation of configured attributes """ @@ -346,22 +335,9 @@ def attributes(cls) -> Iterable[ConfigAttribute]: return cls._attributes def _fix_host_if_needed(self): - if not self.host: - return - - # Add a default scheme if it's missing - if '://' not in self.host: - self.host = 'https://' + self.host - - o = urllib.parse.urlparse(self.host) - # remove trailing slash - path = o.path.rstrip('/') - # remove port if 443 - netloc = o.netloc - if o.port == 443: - netloc = netloc.split(':')[0] - - self.host = urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment)) + updated_host = _fix_host_if_needed(self.host) + if updated_host: + self.host = updated_host def load_azure_tenant_id(self): """[Internal] Load the Azure tenant ID from the Azure Databricks login page. diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 232465dab..a79151b5a 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -187,30 +187,35 @@ def token() -> Token: def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]: if cfg.auth_type != 'external-browser': return None + client_id, client_secret = None, None if cfg.client_id: client_id = cfg.client_id - elif cfg.is_aws: + client_secret = cfg.client_secret + elif cfg.azure_client_id: + client_id = cfg.azure_client + client_secret = cfg.azure_client_secret + + if not client_id: client_id = 'databricks-cli' - elif cfg.is_azure: - # Use Azure AD app for cases when Azure CLI is not available on the machine. - # App has to be registered as Single-page multi-tenant to support PKCE - # TODO: temporary app ID, change it later. - client_id = '6128a518-99a9-425b-8333-4cc94f04cacd' - else: - raise ValueError(f'local browser SSO is not supported') - oauth_client = OAuthClient(host=cfg.host, - client_id=client_id, - redirect_url='http://localhost:8020', - client_secret=cfg.client_secret) # Load cached credentials from disk if they exist. # Note that these are local to the Python SDK and not reused by other SDKs. - token_cache = TokenCache(oauth_client) + oidc_endpoints = cfg.oidc_endpoints + redirect_url = 'http://localhost:8020' + token_cache = TokenCache(host=cfg.host, + oidc_endpoints=oidc_endpoints, + client_id=client_id, + client_secret=client_secret, + redirect_url=redirect_url) credentials = token_cache.load() if credentials: # Force a refresh in case the loaded credentials are expired. credentials.token() else: + oauth_client = OAuthClient(oidc_endpoints=oidc_endpoints, + client_id=client_id, + redirect_url=redirect_url, + client_secret=client_secret) consent = oauth_client.initiate_consent() if not consent: return None diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index e9a3afb90..6cac45afc 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -17,6 +17,8 @@ import requests import requests.auth +from ._base_client import _BaseClient, _fix_host_if_needed + # Error code for PKCE flow in Azure Active Directory, that gets additional retry. # See https://stackoverflow.com/a/75466778/277035 for more info NO_ORIGIN_FOR_SPA_CLIENT_ERROR = 'AADSTS9002327' @@ -46,8 +48,24 @@ def __call__(self, r): @dataclass class OidcEndpoints: + """ + The endpoints used for OAuth-based authentication in Databricks. + """ + authorization_endpoint: str # ../v1/authorize + """The authorization endpoint for the OAuth flow. The user-agent should be directed to this endpoint in order for + the user to login and authorize the client for user-to-machine (U2M) flows.""" + token_endpoint: str # ../v1/token + """The token endpoint for the OAuth flow.""" + + @staticmethod + def from_dict(d: dict) -> 'OidcEndpoints': + return OidcEndpoints(authorization_endpoint=d.get('authorization_endpoint'), + token_endpoint=d.get('token_endpoint')) + + def as_dict(self) -> dict: + return {'authorization_endpoint': self.authorization_endpoint, 'token_endpoint': self.token_endpoint} @dataclass @@ -220,18 +238,76 @@ def do_GET(self): self.wfile.write(b'You can close this tab.') +def get_account_endpoints(host: str, account_id: str, client: _BaseClient = _BaseClient()) -> OidcEndpoints: + """ + Get the OIDC endpoints for a given account. + :param host: The Databricks account host. + :param account_id: The account ID. + :return: The account's OIDC endpoints. + """ + host = _fix_host_if_needed(host) + oidc = f'{host}/oidc/accounts/{account_id}/.well-known/oauth-authorization-server' + resp = client.do('GET', oidc) + return OidcEndpoints.from_dict(resp) + + +def get_workspace_endpoints(host: str, client: _BaseClient = _BaseClient()) -> OidcEndpoints: + """ + Get the OIDC endpoints for a given workspace. + :param host: The Databricks workspace host. + :return: The workspace's OIDC endpoints. + """ + host = _fix_host_if_needed(host) + oidc = f'{host}/oidc/.well-known/oauth-authorization-server' + resp = client.do('GET', oidc) + return OidcEndpoints.from_dict(resp) + + +def get_azure_entra_id_workspace_endpoints(host: str) -> Optional[OidcEndpoints]: + """ + Get the Azure Entra ID endpoints for a given workspace. Can only be used when authenticating to Azure Databricks + using an application registered in Azure Entra ID. + :param host: The Databricks workspace host. + :return: The OIDC endpoints for the workspace's Azure Entra ID tenant. + """ + # In Azure, this workspace endpoint redirects to the Entra ID authorization endpoint + host = _fix_host_if_needed(host) + res = requests.get(f'{host}/oidc/oauth2/v2.0/authorize', allow_redirects=False) + real_auth_url = res.headers.get('location') + if not real_auth_url: + return None + return OidcEndpoints(authorization_endpoint=real_auth_url, + token_endpoint=real_auth_url.replace('/authorize', '/token')) + + class SessionCredentials(Refreshable): - def __init__(self, client: 'OAuthClient', token: Token): - self._client = client + def __init__(self, + token: Token, + token_endpoint: str, + client_id: str, + client_secret: str = None, + redirect_url: str = None): + self._token_endpoint = token_endpoint + self._client_id = client_id + self._client_secret = client_secret + self._redirect_url = redirect_url super().__init__(token) def as_dict(self) -> dict: return {'token': self._token.as_dict()} @staticmethod - def from_dict(client: 'OAuthClient', raw: dict) -> 'SessionCredentials': - return SessionCredentials(client=client, token=Token.from_dict(raw['token'])) + def from_dict(raw: dict, + token_endpoint: str, + client_id: str, + client_secret: str = None, + redirect_url: str = None) -> 'SessionCredentials': + return SessionCredentials(token=Token.from_dict(raw['token']), + token_endpoint=token_endpoint, + client_id=client_id, + client_secret=client_secret, + redirect_url=redirect_url) def auth_type(self): """Implementing CredentialsProvider protocol""" @@ -252,13 +328,13 @@ def refresh(self) -> Token: raise ValueError('oauth2: token expired and refresh token is not set') params = {'grant_type': 'refresh_token', 'refresh_token': refresh_token} headers = {} - if 'microsoft' in self._client.token_url: + if 'microsoft' in self._token_endpoint: # Tokens issued for the 'Single-Page Application' client-type may # only be redeemed via cross-origin requests - headers = {'Origin': self._client.redirect_url} - return retrieve_token(client_id=self._client.client_id, - client_secret=self._client.client_secret, - token_url=self._client.token_url, + headers = {'Origin': self._redirect_url} + return retrieve_token(client_id=self._client_id, + client_secret=self._client_secret, + token_url=self._token_endpoint, params=params, use_params=True, headers=headers) @@ -266,27 +342,53 @@ def refresh(self) -> Token: class Consent: - def __init__(self, client: 'OAuthClient', state: str, verifier: str, auth_url: str = None) -> None: - self.auth_url = auth_url - + def __init__(self, + state: str, + verifier: str, + authorization_url: str, + redirect_url: str, + token_endpoint: str, + client_id: str, + client_secret: str = None) -> None: self._verifier = verifier self._state = state - self._client = client + self._authorization_url = authorization_url + self._redirect_url = redirect_url + self._token_endpoint = token_endpoint + self._client_id = client_id + self._client_secret = client_secret def as_dict(self) -> dict: - return {'state': self._state, 'verifier': self._verifier} + return { + 'state': self._state, + 'verifier': self._verifier, + 'authorization_url': self._authorization_url, + 'redirect_url': self._redirect_url, + 'token_endpoint': self._token_endpoint, + 'client_id': self._client_id, + } + + @property + def authorization_url(self) -> str: + return self._authorization_url @staticmethod - def from_dict(client: 'OAuthClient', raw: dict) -> 'Consent': - return Consent(client, raw['state'], raw['verifier']) + def from_dict(raw: dict, client_secret: str = None) -> 'Consent': + return Consent(raw['state'], + raw['verifier'], + authorization_url=raw['authorization_url'], + redirect_url=raw['redirect_url'], + token_endpoint=raw['token_endpoint'], + client_id=raw['client_id'], + client_secret=client_secret) def launch_external_browser(self) -> SessionCredentials: - redirect_url = urllib.parse.urlparse(self._client.redirect_url) + redirect_url = urllib.parse.urlparse(self._redirect_url) if redirect_url.hostname not in ('localhost', '127.0.0.1'): raise ValueError(f'cannot listen on {redirect_url.hostname}') feedback = [] - logger.info(f'Opening {self.auth_url} in a browser') - webbrowser.open_new(self.auth_url) + logger.info(f'Opening {self._authorization_url} in a browser') + webbrowser.open_new(self._authorization_url) port = redirect_url.port handler_factory = functools.partial(_OAuthCallback, feedback) with HTTPServer(("localhost", port), handler_factory) as httpd: @@ -308,7 +410,7 @@ def exchange(self, code: str, state: str) -> SessionCredentials: if self._state != state: raise ValueError('state mismatch') params = { - 'redirect_uri': self._client.redirect_url, + 'redirect_uri': self._redirect_url, 'grant_type': 'authorization_code', 'code_verifier': self._verifier, 'code': code @@ -316,19 +418,20 @@ def exchange(self, code: str, state: str) -> SessionCredentials: headers = {} while True: try: - token = retrieve_token(client_id=self._client.client_id, - client_secret=self._client.client_secret, - token_url=self._client.token_url, + token = retrieve_token(client_id=self._client_id, + client_secret=self._client_secret, + token_url=self._token_endpoint, params=params, headers=headers, use_params=True) - return SessionCredentials(self._client, token) + return SessionCredentials(token, self._token_endpoint, self._client_id, self._client_secret, + self._redirect_url) except ValueError as e: if NO_ORIGIN_FOR_SPA_CLIENT_ERROR in str(e): # Retry in cases of 'Single-Page Application' client-type with # 'Origin' header equal to client's redirect URL. - headers['Origin'] = self._client.redirect_url - msg = f'Retrying OAuth token exchange with {self._client.redirect_url} origin' + headers['Origin'] = self._redirect_url + msg = f'Retrying OAuth token exchange with {self._redirect_url} origin' logger.debug(msg) continue raise e @@ -354,13 +457,28 @@ class OAuthClient: """ def __init__(self, - host: str, - client_id: str, + oidc_endpoints: OidcEndpoints, redirect_url: str, - *, + client_id: str, scopes: List[str] = None, client_secret: str = None): - # TODO: is it a circular dependency?.. + + if not scopes: + scopes = ['all-apis'] + + self.redirect_url = redirect_url + self._client_id = client_id + self._client_secret = client_secret + self._oidc_endpoints = oidc_endpoints + self._scopes = scopes + + @staticmethod + def from_host(host: str, + client_id: str, + redirect_url: str, + *, + scopes: List[str] = None, + client_secret: str = None) -> 'OAuthClient': from .core import Config from .credentials_provider import credentials_strategy @@ -374,18 +492,7 @@ def noop_credentials(_: any): oidc = config.oidc_endpoints if not oidc: raise ValueError(f'{host} does not support OAuth') - - self.host = host - self.redirect_url = redirect_url - self.client_id = client_id - self.client_secret = client_secret - self.token_url = oidc.token_endpoint - self.is_aws = config.is_aws - self.is_azure = config.is_azure - self.is_gcp = config.is_gcp - - self._auth_url = oidc.authorization_endpoint - self._scopes = scopes + return OAuthClient(oidc, redirect_url, client_id, scopes, client_secret) def initiate_consent(self) -> Consent: state = secrets.token_urlsafe(16) @@ -397,18 +504,24 @@ def initiate_consent(self) -> Consent: params = { 'response_type': 'code', - 'client_id': self.client_id, + 'client_id': self._client_id, 'redirect_uri': self.redirect_url, 'scope': ' '.join(self._scopes), 'state': state, 'code_challenge': challenge, 'code_challenge_method': 'S256' } - url = f'{self._auth_url}?{urllib.parse.urlencode(params)}' - return Consent(self, state, verifier, auth_url=url) + auth_url = f'{self._oidc_endpoints.authorization_endpoint}?{urllib.parse.urlencode(params)}' + return Consent(state, + verifier, + authorization_url=auth_url, + redirect_url=self.redirect_url, + token_endpoint=self._oidc_endpoints.token_endpoint, + client_id=self._client_id, + client_secret=self._client_secret) def __repr__(self) -> str: - return f'' + return f'' @dataclass @@ -448,17 +561,28 @@ def refresh(self) -> Token: use_header=self.use_header) -class TokenCache(): +class TokenCache: BASE_PATH = "~/.config/databricks-sdk-py/oauth" - def __init__(self, client: OAuthClient) -> None: - self.client = client + def __init__(self, + host: str, + oidc_endpoints: OidcEndpoints, + client_id: str, + redirect_url: str = None, + client_secret: str = None, + scopes: List[str] = None) -> None: + self._host = host + self._client_id = client_id + self._oidc_endpoints = oidc_endpoints + self._redirect_url = redirect_url + self._client_secret = client_secret + self._scopes = scopes or [] @property def filename(self) -> str: # Include host, client_id, and scopes in the cache filename to make it unique. hash = hashlib.sha256() - for chunk in [self.client.host, self.client.client_id, ",".join(self.client._scopes), ]: + for chunk in [self._host, self._client_id, ",".join(self._scopes), ]: hash.update(chunk.encode('utf-8')) return os.path.expanduser(os.path.join(self.__class__.BASE_PATH, hash.hexdigest() + ".json")) @@ -472,7 +596,11 @@ def load(self) -> Optional[SessionCredentials]: try: with open(self.filename, 'r') as f: raw = json.load(f) - return SessionCredentials.from_dict(self.client, raw) + return SessionCredentials.from_dict(raw, + token_endpoint=self._oidc_endpoints.token_endpoint, + client_id=self._client_id, + client_secret=self._client_secret, + redirect_url=self._redirect_url) except Exception: return None diff --git a/examples/external_browser_auth.py b/examples/external_browser_auth.py new file mode 100644 index 000000000..061ff60c7 --- /dev/null +++ b/examples/external_browser_auth.py @@ -0,0 +1,72 @@ +from databricks.sdk import WorkspaceClient +import argparse +import logging + +logging.basicConfig(level=logging.DEBUG) + + +def register_custom_app(confidential: bool) -> tuple[str, str]: + """Creates new Custom OAuth App in Databricks Account""" + logging.info("No OAuth custom app client/secret provided, creating new app") + + from databricks.sdk import AccountClient + + account_client = AccountClient() + + custom_app = account_client.custom_app_integration.create( + name="external-browser-demo", + redirect_urls=[ + f"http://localhost:8020", + ], + confidential=confidential, + scopes=["all-apis"], + ) + logging.info(f"Created new custom app: " + f"--client_id {custom_app.client_id} " + f"{'--client_secret ' + custom_app.client_secret if confidential else ''}") + + return custom_app.client_id, custom_app.client_secret + + +def delete_custom_app(client_id: str): + """Creates new Custom OAuth App in Databricks Account""" + logging.info(f"Deleting custom app {client_id}") + from databricks.sdk import AccountClient + account_client = AccountClient() + account_client.custom_app_integration.delete(client_id) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", help="Databricks host", required=True) + parser.add_argument("--client_id", help="Databricks client_id", default=None) + parser.add_argument("--azure_client_id", help="Databricks azure_client_id", default=None) + parser.add_argument("--client_secret", help="Databricks client_secret", default=None) + parser.add_argument("--azure_client_secret", help="Databricks azure_client_secret", default=None) + parser.add_argument("--register-custom-app", action="store_true", help="Register a new custom app") + parser.add_argument("--register-custom-app-confidential", action="store_true", help="Register a new custom app") + namespace = parser.parse_args() + if namespace.register_custom_app and (namespace.client_id is not None or namespace.azure_client_id is not None): + raise ValueError("Cannot register custom app and provide --client_id/--azure_client_id at the same time") + if not namespace.register_custom_app and namespace.client_id is None and namespace.azure_client_secret is None: + raise ValueError("Must provide --client_id/--azure_client_id or register a custom app") + if namespace.register_custom_app: + client_id, client_secret = register_custom_app(namespace.register_custom_app_confidential) + else: + client_id, client_secret = namespace.client_id, namespace.client_secret + + w = WorkspaceClient( + host=namespace.host, + client_id=client_id, + client_secret=client_secret, + azure_client_id=namespace.azure_client_id, + azure_client_secret=namespace.azure_client_secret, + auth_type="external-browser", + ) + me = w.current_user.me() + print(me) + + if namespace.register_custom_app: + delete_custom_app(client_id) + + diff --git a/examples/flask_app_with_oauth.py b/examples/flask_app_with_oauth.py index 4128de5ca..7c18eadc7 100755 --- a/examples/flask_app_with_oauth.py +++ b/examples/flask_app_with_oauth.py @@ -31,20 +31,21 @@ import logging import sys -from databricks.sdk.oauth import OAuthClient +from databricks.sdk.oauth import OAuthClient, get_workspace_endpoints +from databricks.sdk.service.compute import ListClustersFilterBy, State APP_NAME = "flask-demo" all_clusters_template = """""" -def create_flask_app(oauth_client: OAuthClient): +def create_flask_app(workspace_host: str, client_id: str, client_secret: str): """The create_flask_app function creates a Flask app that is enabled with OAuth. It initializes the app and web session secret keys with a randomly generated token. It defines two routes for @@ -64,7 +65,7 @@ def callback(): the callback parameters, and redirects the user to the index page.""" from databricks.sdk.oauth import Consent - consent = Consent.from_dict(oauth_client, session["consent"]) + consent = Consent.from_dict(session["consent"], client_secret=client_secret) session["creds"] = consent.exchange_callback_parameters(request.args).as_dict() return redirect(url_for("index")) @@ -72,21 +73,34 @@ def callback(): def index(): """The index page checks if the user has already authenticated and retrieves the user's credentials using the Databricks SDK WorkspaceClient. It then renders the template with the clusters' list.""" + oidc_endpoints = get_workspace_endpoints(workspace_host) + port = request.environ.get("SERVER_PORT") + redirect_url=f"http://localhost:{port}/callback" if "creds" not in session: + oauth_client = OAuthClient(oidc_endpoints=oidc_endpoints, + client_id=client_id, + client_secret=client_secret, + redirect_url=redirect_url) consent = oauth_client.initiate_consent() session["consent"] = consent.as_dict() - return redirect(consent.auth_url) + return redirect(consent.authorization_url) from databricks.sdk import WorkspaceClient from databricks.sdk.oauth import SessionCredentials - credentials_provider = SessionCredentials.from_dict(oauth_client, session["creds"]) - workspace_client = WorkspaceClient(host=oauth_client.host, + credentials_strategy = SessionCredentials.from_dict(session["creds"], + token_endpoint=oidc_endpoints.token_endpoint, + client_id=client_id, + client_secret=client_secret, + redirect_url=redirect_url) + workspace_client = WorkspaceClient(host=workspace_host, product=APP_NAME, - credentials_provider=credentials_provider, + credentials_strategy=credentials_strategy, ) - - return render_template_string(all_clusters_template, w=workspace_client) + clusters = workspace_client.clusters.list( + filter_by=ListClustersFilterBy(cluster_states=[State.RUNNING, State.PENDING]) + ) + return render_template_string(all_clusters_template, workspace_host=workspace_host, clusters=clusters) return app @@ -100,7 +114,11 @@ def register_custom_app(args: argparse.Namespace) -> tuple[str, str]: account_client = AccountClient(profile=args.profile) custom_app = account_client.custom_app_integration.create( - name=APP_NAME, redirect_urls=[f"http://localhost:{args.port}/callback"], confidential=True, + name=APP_NAME, + redirect_urls=[ + f"http://localhost:{args.port}/callback", + ], + confidential=True, scopes=["all-apis"], ) logging.info(f"Created new custom app: " @@ -110,22 +128,6 @@ def register_custom_app(args: argparse.Namespace) -> tuple[str, str]: return custom_app.client_id, custom_app.client_secret -def init_oauth_config(args) -> OAuthClient: - """Creates Databricks SDK configuration for OAuth""" - oauth_client = OAuthClient(host=args.host, - client_id=args.client_id, - client_secret=args.client_secret, - redirect_url=f"http://localhost:{args.port}/callback", - scopes=["all-apis"], - ) - if not oauth_client.client_id: - client_id, client_secret = register_custom_app(args) - oauth_client.client_id = client_id - oauth_client.client_secret = client_secret - - return oauth_client - - def parse_arguments() -> argparse.Namespace: """Parses arguments for this demo""" parser = argparse.ArgumentParser(prog=APP_NAME, description=__doc__.strip()) @@ -145,8 +147,10 @@ def parse_arguments() -> argparse.Namespace: logging.getLogger("databricks.sdk").setLevel(logging.DEBUG) args = parse_arguments() - oauth_cfg = init_oauth_config(args) - app = create_flask_app(oauth_cfg) + client_id, client_secret = args.client_id, args.client_secret + if not client_id: + client_id, client_secret = register_custom_app(args) + app = create_flask_app(args.host, client_id, client_secret) app.run( host="localhost", diff --git a/tests/test_oauth.py b/tests/test_oauth.py index ce2d514ff..a637a5508 100644 --- a/tests/test_oauth.py +++ b/tests/test_oauth.py @@ -1,29 +1,126 @@ -from databricks.sdk.core import Config -from databricks.sdk.oauth import OAuthClient, OidcEndpoints, TokenCache - - -def test_token_cache_unique_filename_by_host(mocker): - mocker.patch.object(Config, "oidc_endpoints", - OidcEndpoints("http://localhost:1234", "http://localhost:1234")) - common_args = dict(client_id="abc", redirect_url="http://localhost:8020") - c1 = OAuthClient(host="http://localhost:", **common_args) - c2 = OAuthClient(host="https://bar.cloud.databricks.com", **common_args) - assert TokenCache(c1).filename != TokenCache(c2).filename - - -def test_token_cache_unique_filename_by_client_id(mocker): - mocker.patch.object(Config, "oidc_endpoints", - OidcEndpoints("http://localhost:1234", "http://localhost:1234")) - common_args = dict(host="http://localhost:", redirect_url="http://localhost:8020") - c1 = OAuthClient(client_id="abc", **common_args) - c2 = OAuthClient(client_id="def", **common_args) - assert TokenCache(c1).filename != TokenCache(c2).filename - - -def test_token_cache_unique_filename_by_scopes(mocker): - mocker.patch.object(Config, "oidc_endpoints", - OidcEndpoints("http://localhost:1234", "http://localhost:1234")) - common_args = dict(host="http://localhost:", client_id="abc", redirect_url="http://localhost:8020") - c1 = OAuthClient(scopes=["foo"], **common_args) - c2 = OAuthClient(scopes=["bar"], **common_args) - assert TokenCache(c1).filename != TokenCache(c2).filename +from databricks.sdk._base_client import _BaseClient +from databricks.sdk.oauth import (OidcEndpoints, TokenCache, + get_account_endpoints, + get_workspace_endpoints) + +from .clock import FakeClock + + +def test_token_cache_unique_filename_by_host(): + common_args = dict(client_id="abc", + redirect_url="http://localhost:8020", + oidc_endpoints=OidcEndpoints("http://localhost:1234", "http://localhost:1234")) + assert TokenCache(host="http://localhost:", + **common_args).filename != TokenCache("https://bar.cloud.databricks.com", + **common_args).filename + + +def test_token_cache_unique_filename_by_client_id(): + common_args = dict(host="http://localhost:", + redirect_url="http://localhost:8020", + oidc_endpoints=OidcEndpoints("http://localhost:1234", "http://localhost:1234")) + assert TokenCache(client_id="abc", **common_args).filename != TokenCache(client_id="def", + **common_args).filename + + +def test_token_cache_unique_filename_by_scopes(): + common_args = dict(host="http://localhost:", + client_id="abc", + redirect_url="http://localhost:8020", + oidc_endpoints=OidcEndpoints("http://localhost:1234", "http://localhost:1234")) + assert TokenCache(scopes=["foo"], **common_args).filename != TokenCache(scopes=["bar"], + **common_args).filename + + +def test_account_oidc_endpoints(requests_mock): + requests_mock.get( + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/.well-known/oauth-authorization-server", + json={ + "authorization_endpoint": + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/authorize", + "token_endpoint": "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/token" + }) + client = _BaseClient(clock=FakeClock()) + endpoints = get_account_endpoints("accounts.cloud.databricks.com", "abc-123", client=client) + assert endpoints == OidcEndpoints( + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/authorize", + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/token") + + +def test_account_oidc_endpoints_retry_on_429(requests_mock): + # It doesn't seem possible to use requests_mock to return different responses for the same request, e.g. when + # simulating a transient failure. Instead, the nth_request matcher increments a test-wide counter and only matches + # the nth request. + request_count = 0 + + def nth_request(n): + + def observe_request(_request): + nonlocal request_count + is_match = request_count == n + if is_match: + request_count += 1 + return is_match + + return observe_request + + requests_mock.get( + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/.well-known/oauth-authorization-server", + additional_matcher=nth_request(0), + status_code=429) + requests_mock.get( + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/.well-known/oauth-authorization-server", + additional_matcher=nth_request(1), + json={ + "authorization_endpoint": + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/authorize", + "token_endpoint": "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/token" + }) + client = _BaseClient(clock=FakeClock()) + endpoints = get_account_endpoints("accounts.cloud.databricks.com", "abc-123", client=client) + assert endpoints == OidcEndpoints( + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/authorize", + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/token") + + +def test_workspace_oidc_endpoints(requests_mock): + requests_mock.get("https://my-workspace.cloud.databricks.com/oidc/.well-known/oauth-authorization-server", + json={ + "authorization_endpoint": + "https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", + "token_endpoint": "https://my-workspace.cloud.databricks.com/oidc/oauth/token" + }) + client = _BaseClient(clock=FakeClock()) + endpoints = get_workspace_endpoints("my-workspace.cloud.databricks.com", client=client) + assert endpoints == OidcEndpoints("https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", + "https://my-workspace.cloud.databricks.com/oidc/oauth/token") + + +def test_workspace_oidc_endpoints_retry_on_429(requests_mock): + request_count = 0 + + def nth_request(n): + + def observe_request(_request): + nonlocal request_count + is_match = request_count == n + if is_match: + request_count += 1 + return is_match + + return observe_request + + requests_mock.get("https://my-workspace.cloud.databricks.com/oidc/.well-known/oauth-authorization-server", + additional_matcher=nth_request(0), + status_code=429) + requests_mock.get("https://my-workspace.cloud.databricks.com/oidc/.well-known/oauth-authorization-server", + additional_matcher=nth_request(1), + json={ + "authorization_endpoint": + "https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", + "token_endpoint": "https://my-workspace.cloud.databricks.com/oidc/oauth/token" + }) + client = _BaseClient(clock=FakeClock()) + endpoints = get_workspace_endpoints("my-workspace.cloud.databricks.com", client=client) + assert endpoints == OidcEndpoints("https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", + "https://my-workspace.cloud.databricks.com/oidc/oauth/token") From d3b85cb867137657a875ceb18192e06456e39952 Mon Sep 17 00:00:00 2001 From: Omer Lachish <289488+rauchy@users.noreply.github.com> Date: Tue, 22 Oct 2024 15:33:14 +0200 Subject: [PATCH 2/2] [Release] Release v0.36.0 (#798) ### Breaking Changes * `external_browser` now uses the `databricks-cli` app instead of the third-party "6128a518-99a9-425b-8333-4cc94f04cacd" application when performing the U2M login flow for Azure workspaces when a client ID is not otherwise specified. This matches the AWS behavior. * The signatures of several OAuth-related constructors have changed to support U2M OAuth with Azure Entra ID application registrations. See https://github.com/databricks/databricks-sdk-py/blob/main/examples/flask_app_with_oauth.py for examples of how to use these classes. * `OAuthClient()`: renamed to `OAuthClient.from_host()` * `SessionCredentials()` and `SessionCredentials.from_dict()`: now accepts `token_endpoint`, `client_id`, `client_secret`, and `refresh_url` as parameters, rather than accepting the `OAuthClient`. * `TokenCache()`: now accepts `host`, `token_endpoint`, `client_id`, `client_secret`, and `refresh_url` as parameters, rather than accepting the `OAuthClient`. ### Bug Fixes * Decouple OAuth functionality from `Config` ([#784](https://github.com/databricks/databricks-sdk-py/pull/784)). ### Release * Release v0.35.0 ([#793](https://github.com/databricks/databricks-sdk-py/pull/793)). Co-authored-by: Omer Lachish --- CHANGELOG.md | 20 ++++++++++++++++++++ databricks/sdk/version.py | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 344e975d9..458921ee0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,25 @@ # Version changelog +## [Release] Release v0.36.0 + +### Breaking Changes +* `external_browser` now uses the `databricks-cli` app instead of the third-party "6128a518-99a9-425b-8333-4cc94f04cacd" application when performing the U2M login flow for Azure workspaces when a client ID is not otherwise specified. This matches the AWS behavior. +* The signatures of several OAuth-related constructors have changed to support U2M OAuth with Azure Entra ID application registrations. See https://github.com/databricks/databricks-sdk-py/blob/main/examples/flask_app_with_oauth.py for examples of how to use these classes. + * `OAuthClient()`: renamed to `OAuthClient.from_host()` + * `SessionCredentials()` and `SessionCredentials.from_dict()`: now accepts `token_endpoint`, `client_id`, `client_secret`, and `refresh_url` as parameters, rather than accepting the `OAuthClient`. + * `TokenCache()`: now accepts `host`, `token_endpoint`, `client_id`, `client_secret`, and `refresh_url` as parameters, rather than accepting the `OAuthClient`. + +### Bug Fixes + + * Decouple OAuth functionality from `Config` ([#784](https://github.com/databricks/databricks-sdk-py/pull/784)). + + +### Release + + * Release v0.35.0 ([#793](https://github.com/databricks/databricks-sdk-py/pull/793)). + + + ## [Release] Release v0.35.0 ### New Features and Improvements diff --git a/databricks/sdk/version.py b/databricks/sdk/version.py index 2670d0523..aae5aca67 100644 --- a/databricks/sdk/version.py +++ b/databricks/sdk/version.py @@ -1 +1 @@ -__version__ = '0.35.0' +__version__ = '0.36.0'