Skip to content

Commit

Permalink
Revert "Revert "Create a method to generate OAuth tokens (#644)" (#653)…
Browse files Browse the repository at this point in the history
…" (#655)

This reverts commit 9c5fae7.

## Changes
Introduce again the method to generate OAuth tokens, including fixes for
impacted auth types

## Tests

- [x] `make test` run locally
- [x] `make fmt` applied
- [x] relevant integration tests applied
  • Loading branch information
hectorcast-db authored May 24, 2024
1 parent 3a9b82a commit 2bafb95
Show file tree
Hide file tree
Showing 12 changed files with 211 additions and 79 deletions.
10 changes: 7 additions & 3 deletions .codegen/__init__.py.tmpl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import databricks.sdk.core as client
import databricks.sdk.dbutils as dbutils
from databricks.sdk.credentials_provider import CredentialsProvider
from databricks.sdk.credentials_provider import CredentialsStrategy

from databricks.sdk.mixins.files import DbfsExt
from databricks.sdk.mixins.compute import ClustersExt
Expand Down Expand Up @@ -46,10 +46,12 @@ class WorkspaceClient:
debug_headers: bool = None,
product="unknown",
product_version="0.0.0",
credentials_provider: CredentialsProvider = None,
credentials_strategy: CredentialsStrategy = None,
credentials_provider: CredentialsStrategy = None,
config: client.Config = None):
if not config:
config = client.Config({{range $args}}{{.}}={{.}}, {{end}}
credentials_strategy=credentials_strategy,
credentials_provider=credentials_provider,
debug_truncate_bytes=debug_truncate_bytes,
debug_headers=debug_headers,
Expand Down Expand Up @@ -101,10 +103,12 @@ class AccountClient:
debug_headers: bool = None,
product="unknown",
product_version="0.0.0",
credentials_provider: CredentialsProvider = None,
credentials_strategy: CredentialsStrategy = None,
credentials_provider: CredentialsStrategy = None,
config: client.Config = None):
if not config:
config = client.Config({{range $args}}{{.}}={{.}}, {{end}}
credentials_strategy=credentials_strategy,
credentials_provider=credentials_provider,
debug_truncate_bytes=debug_truncate_bytes,
debug_headers=debug_headers,
Expand Down
10 changes: 7 additions & 3 deletions databricks/sdk/__init__.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 20 additions & 7 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
import requests

from .clock import Clock, RealClock
from .credentials_provider import CredentialsProvider, DefaultCredentials
from .credentials_provider import CredentialsStrategy, DefaultCredentials
from .environments import (ALL_ENVS, AzureEnvironment, Cloud,
DatabricksEnvironment, get_environment_for_hostname)
from .oauth import OidcEndpoints
from .oauth import OidcEndpoints, Token
from .version import __version__

logger = logging.getLogger('databricks.sdk')
Expand Down Expand Up @@ -81,15 +81,25 @@ class Config:

def __init__(self,
*,
credentials_provider: CredentialsProvider = None,
# Deprecated. Use credentials_strategy instead.
credentials_provider: CredentialsStrategy = None,
credentials_strategy: CredentialsStrategy = None,
product="unknown",
product_version="0.0.0",
clock: Clock = None,
**kwargs):
self._header_factory = None
self._inner = {}
self._user_agent_other_info = []
self._credentials_provider = credentials_provider if credentials_provider else DefaultCredentials()
if credentials_strategy and credentials_provider:
raise ValueError(
"When providing `credentials_strategy` field, `credential_provider` cannot be specified.")
if credentials_provider:
logger.warning(
"parameter 'credentials_provider' is deprecated. Use 'credentials_strategy' instead.")
self._credentials_strategy = next(
s for s in [credentials_strategy, credentials_provider,
DefaultCredentials()] if s is not None)
if 'databricks_environment' in kwargs:
self.databricks_environment = kwargs['databricks_environment']
del kwargs['databricks_environment']
Expand All @@ -107,6 +117,9 @@ def __init__(self,
message = self.wrap_debug_info(str(e))
raise ValueError(message) from e

def oauth_token(self) -> Token:
return self._credentials_strategy.oauth_token(self)

def wrap_debug_info(self, message: str) -> str:
debug_string = self.debug_string()
if debug_string:
Expand Down Expand Up @@ -436,12 +449,12 @@ def _validate(self):

def init_auth(self):
try:
self._header_factory = self._credentials_provider(self)
self.auth_type = self._credentials_provider.auth_type()
self._header_factory = self._credentials_strategy(self)
self.auth_type = self._credentials_strategy.auth_type()
if not self._header_factory:
raise ValueError('not configured')
except ValueError as e:
raise ValueError(f'{self._credentials_provider.auth_type()} auth: {e}') from e
raise ValueError(f'{self._credentials_strategy.auth_type()} auth: {e}') from e

def __repr__(self):
return f'<{self.debug_string()}>'
Expand Down
22 changes: 22 additions & 0 deletions databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from json import JSONDecodeError
from types import TracebackType
from typing import Any, BinaryIO, Iterator, Type
from urllib.parse import urlencode

from requests.adapters import HTTPAdapter

Expand All @@ -13,12 +14,17 @@
from .credentials_provider import *
from .errors import DatabricksError, error_mapper
from .errors.private_link import _is_private_link_redirect
from .oauth import retrieve_token
from .retries import retried

__all__ = ['Config', 'DatabricksError']

logger = logging.getLogger('databricks.sdk')

URL_ENCODED_CONTENT_TYPE = "application/x-www-form-urlencoded"
JWT_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
OIDC_TOKEN_PATH = "/oidc/v1/token"


class ApiClient:
_cfg: Config
Expand Down Expand Up @@ -109,6 +115,22 @@ def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]:
flattened = dict(flatten_dict(with_fixed_bools))
return flattened

def get_oauth_token(self, auth_details: str) -> Token:
if not self._cfg.auth_type:
self._cfg.authenticate()
original_token = self._cfg.oauth_token()
headers = {"Content-Type": URL_ENCODED_CONTENT_TYPE}
params = urlencode({
"grant_type": JWT_BEARER_GRANT_TYPE,
"authorization_details": auth_details,
"assertion": original_token.access_token
})
return retrieve_token(client_id=self._cfg.client_id,
client_secret=self._cfg.client_secret,
token_url=self._cfg.host + OIDC_TOKEN_PATH,
params=params,
headers=headers)

def do(self,
method: str,
path: str,
Expand Down
Loading

0 comments on commit 2bafb95

Please sign in to comment.