Skip to content

Commit

Permalink
[Fix] Decouple OAuth functionality from Config (#784)
Browse files Browse the repository at this point in the history
## Changes
### OAuth Refactoring
Currently, OAuthClient uses Config internally to resolve the OIDC
endpoints by passing the client ID and host to an internal Config
instance and calling its `oidc_endpoints` method. This has a few
drawbacks:
1. There is nearly a cyclical dependency: `Config` depends on methods in
`oauth.py`, and `OAuthClient` depends on `Config`. This currently
doesn't break because the `Config` import is done at runtime in the
`OAuthClient` constructor.
2. Databricks supports both in-house OAuth and Azure Entra ID OAuth.
Currently, the choice between these options depends on whether a user
specifies the azure_client_id or client_id parameter in the Config.
Because Config is used within OAuthClient, this means that OAuthClient
needs to expose a parameter to configure either client_id or
azure_client_id.

Rather than having these classes deeply coupled to one another, we can
allow users to fetch the OIDC endpoints for a given account/workspace as
a top-level functionality and provide this to `OAuthClient`. This breaks
the cyclic dependency and doesn't require `OAuthClient` to expose any
unnecessary parameters.

Further, I've also tried to remove the coupling of the other classes in
`oauth.py` to `OAuthClient`. Currently, `OAuthClient` serves both as the
mechanism to initialize OAuth and as a kind of configuration object,
capturing OAuth endpoint URLs, client ID/secret, redirect URL, and
scopes. Now, the parameters for each of these classes are explicit,
removing all unnecessarily coupling between them. One nice advantage is
that the Consent can be serialized/deserialized without any reference to
the `OAuthClient` anymore.

There is definitely more work to be done to simplify and clean up the
OAuth implementation, but this should at least unblock users who need to
use Azure Entra ID U2M OAuth in the SDK.

## Tests
The new OIDC endpoint methods are tested, and those tests also verify
that those endpoints are retried in case of rate limiting.

I ran the flask app example against an AWS workspace, and I ran the
external-browser demo example against AWS, Azure and GCP workspaces with
the default client ID and with a newly created OAuth app with and
without credentials.

- [ ] `make test` run locally
- [ ] `make fmt` applied
- [ ] relevant integration tests applied
  • Loading branch information
mgyucht authored Oct 21, 2024
1 parent 15257eb commit 32ba221
Show file tree
Hide file tree
Showing 7 changed files with 459 additions and 157 deletions.
20 changes: 20 additions & 0 deletions databricks/sdk/_base_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import urllib.parse
from datetime import timedelta
from types import TracebackType
from typing import (Any, BinaryIO, Callable, Dict, Iterable, Iterator, List,
Expand All @@ -17,6 +18,25 @@
logger = logging.getLogger('databricks.sdk')


def _fix_host_if_needed(host: Optional[str]) -> Optional[str]:
if not host:
return host

# Add a default scheme if it's missing
if '://' not in host:
host = 'https://' + host

o = urllib.parse.urlparse(host)
# remove trailing slash
path = o.path.rstrip('/')
# remove port if 443
netloc = o.netloc
if o.port == 443:
netloc = netloc.split(':')[0]

return urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment))


class _BaseClient:

def __init__(self,
Expand Down
44 changes: 10 additions & 34 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
import requests

from . import useragent
from ._base_client import _fix_host_if_needed
from .clock import Clock, RealClock
from .credentials_provider import CredentialsStrategy, DefaultCredentials
from .environments import (ALL_ENVS, AzureEnvironment, Cloud,
DatabricksEnvironment, get_environment_for_hostname)
from .oauth import OidcEndpoints, Token
from .oauth import (OidcEndpoints, Token, get_account_endpoints,
get_azure_entra_id_workspace_endpoints,
get_workspace_endpoints)

logger = logging.getLogger('databricks.sdk')

Expand Down Expand Up @@ -254,24 +257,10 @@ def oidc_endpoints(self) -> Optional[OidcEndpoints]:
if not self.host:
return None
if self.is_azure and self.azure_client_id:
# Retrieve authorize endpoint to retrieve token endpoint after
res = requests.get(f'{self.host}/oidc/oauth2/v2.0/authorize', allow_redirects=False)
real_auth_url = res.headers.get('location')
if not real_auth_url:
return None
return OidcEndpoints(authorization_endpoint=real_auth_url,
token_endpoint=real_auth_url.replace('/authorize', '/token'))
return get_azure_entra_id_workspace_endpoints(self.host)
if self.is_account_client and self.account_id:
prefix = f'{self.host}/oidc/accounts/{self.account_id}'
return OidcEndpoints(authorization_endpoint=f'{prefix}/v1/authorize',
token_endpoint=f'{prefix}/v1/token')
oidc = f'{self.host}/oidc/.well-known/oauth-authorization-server'
res = requests.get(oidc)
if res.status_code != 200:
return None
auth_metadata = res.json()
return OidcEndpoints(authorization_endpoint=auth_metadata.get('authorization_endpoint'),
token_endpoint=auth_metadata.get('token_endpoint'))
return get_account_endpoints(self.host, self.account_id)
return get_workspace_endpoints(self.host)

def debug_string(self) -> str:
""" Returns log-friendly representation of configured attributes """
Expand Down Expand Up @@ -346,22 +335,9 @@ def attributes(cls) -> Iterable[ConfigAttribute]:
return cls._attributes

def _fix_host_if_needed(self):
if not self.host:
return

# Add a default scheme if it's missing
if '://' not in self.host:
self.host = 'https://' + self.host

o = urllib.parse.urlparse(self.host)
# remove trailing slash
path = o.path.rstrip('/')
# remove port if 443
netloc = o.netloc
if o.port == 443:
netloc = netloc.split(':')[0]

self.host = urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment))
updated_host = _fix_host_if_needed(self.host)
if updated_host:
self.host = updated_host

def load_azure_tenant_id(self):
"""[Internal] Load the Azure tenant ID from the Azure Databricks login page.
Expand Down
31 changes: 18 additions & 13 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,30 +187,35 @@ def token() -> Token:
def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
if cfg.auth_type != 'external-browser':
return None
client_id, client_secret = None, None
if cfg.client_id:
client_id = cfg.client_id
elif cfg.is_aws:
client_secret = cfg.client_secret
elif cfg.azure_client_id:
client_id = cfg.azure_client
client_secret = cfg.azure_client_secret

if not client_id:
client_id = 'databricks-cli'
elif cfg.is_azure:
# Use Azure AD app for cases when Azure CLI is not available on the machine.
# App has to be registered as Single-page multi-tenant to support PKCE
# TODO: temporary app ID, change it later.
client_id = '6128a518-99a9-425b-8333-4cc94f04cacd'
else:
raise ValueError(f'local browser SSO is not supported')
oauth_client = OAuthClient(host=cfg.host,
client_id=client_id,
redirect_url='http://localhost:8020',
client_secret=cfg.client_secret)

# Load cached credentials from disk if they exist.
# Note that these are local to the Python SDK and not reused by other SDKs.
token_cache = TokenCache(oauth_client)
oidc_endpoints = cfg.oidc_endpoints
redirect_url = 'http://localhost:8020'
token_cache = TokenCache(host=cfg.host,
oidc_endpoints=oidc_endpoints,
client_id=client_id,
client_secret=client_secret,
redirect_url=redirect_url)
credentials = token_cache.load()
if credentials:
# Force a refresh in case the loaded credentials are expired.
credentials.token()
else:
oauth_client = OAuthClient(oidc_endpoints=oidc_endpoints,
client_id=client_id,
redirect_url=redirect_url,
client_secret=client_secret)
consent = oauth_client.initiate_consent()
if not consent:
return None
Expand Down
Loading

0 comments on commit 32ba221

Please sign in to comment.