Skip to content

Commit

Permalink
Added w.config.account_host to get the relevant account host from a…
Browse files Browse the repository at this point in the history
… workspace client

This simplifies downstream integration testing as well as configuration

introduce `DatabricksEnvironment`

..
  • Loading branch information
nfx committed Nov 9, 2023
1 parent a862adc commit cb830fc
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 13 deletions.
84 changes: 71 additions & 13 deletions databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import subprocess
import sys
import urllib.parse
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
from json import JSONDecodeError
from types import TracebackType
from typing import (Any, BinaryIO, Callable, Dict, Iterable, Iterator, List,
Expand Down Expand Up @@ -577,6 +579,56 @@ def __repr__(self) -> str:
return f"<ConfigAttribute '{self.name}' {self.transform.__name__}>"


class Cloud(Enum):
AWS = 'aws'
AZURE = 'azure'
GCP = 'gcp'


@dataclass
class DatabricksEnvironment:
cloud: Cloud

# domain suffixes are not very secret: https://crt.sh/?q=databricks
tld: str

azure_environment: AzureEnvironment = None

# The application (client) ID isn't a secret. See
# https://learn.microsoft.com/en-us/entra/identity-platform/developer-glossary#application-client-id
azure_application_id: str = None

def deployment(self, deployment_name: str) -> str:
return f'https://{deployment_name}{self.tld}'


_DATABRICKS_ENVIRONMENTS = [
DatabricksEnvironment(Cloud.AWS, '.dev.databricks.com'),
DatabricksEnvironment(Cloud.AWS, '.staging.cloud.databricks.com'),
DatabricksEnvironment(Cloud.AWS, '.cloud.databricks.com'),
DatabricksEnvironment(Cloud.AWS, '.cloud.databricks.us'),
DatabricksEnvironment(Cloud.AZURE,
'.dev.azuredatabricks.net',
azure_environment=ENVIRONMENTS['PUBLIC'],
azure_application_id='62a912ac-b58e-4c1d-89ea-b2dbfc7358fc'),
DatabricksEnvironment(Cloud.AZURE,
'.staging.azuredatabricks.net',
azure_environment=ENVIRONMENTS['PUBLIC'],
azure_application_id='4a67d088-db5c-48f1-9ff2-0aace800ae68'),
DatabricksEnvironment(Cloud.AZURE,
'.azuredatabricks.net',
azure_environment=ENVIRONMENTS['PUBLIC'],
azure_application_id=ARM_DATABRICKS_RESOURCE_ID),
DatabricksEnvironment(Cloud.AZURE,
'.databricks.azure.us',
azure_environment=ENVIRONMENTS['USGOVERNMENT'],
azure_application_id=ARM_DATABRICKS_RESOURCE_ID),
DatabricksEnvironment(Cloud.GCP, '.dev.gcp.databricks.com'),
DatabricksEnvironment(Cloud.GCP, '.staging.gcp.databricks.com'),
DatabricksEnvironment(Cloud.GCP, '.gcp.databricks.com'),
]


class Config:
host = ConfigAttribute(env='DATABRICKS_HOST')
account_id = ConfigAttribute(env='DATABRICKS_ACCOUNT_ID')
Expand Down Expand Up @@ -667,16 +719,12 @@ def as_dict(self) -> dict:
@property
def is_azure(self) -> bool:
has_resource_id = self.azure_workspace_resource_id is not None
has_host = self.host is not None
is_public_cloud = has_host and ".azuredatabricks.net" in self.host
is_china_cloud = has_host and ".databricks.azure.cn" in self.host
is_gov_cloud = has_host and ".databricks.azure.us" in self.host
is_valid_cloud = is_public_cloud or is_china_cloud or is_gov_cloud
return has_resource_id or (has_host and is_valid_cloud)
azure_environment = self.environment.azure_environment is not None
return has_resource_id or azure_environment is not None

@property
def is_gcp(self) -> bool:
return self.host and ".gcp.databricks.com" in self.host
return self.host and self.environment.cloud == Cloud.GCP

@property
def is_aws(self) -> bool:
Expand All @@ -688,20 +736,30 @@ def is_account_client(self) -> bool:
return False
return self.host.startswith("https://accounts.") or self.host.startswith("https://accounts-dod.")

@property
def account_host(self) -> str:
if self.is_account_client:
return self.host
return self.environment.deployment('accounts')

@property
def arm_environment(self) -> AzureEnvironment:
env = self.azure_environment if self.azure_environment else "PUBLIC"
try:
return ENVIRONMENTS[env]
except KeyError:
raise ValueError(f"Cannot find Azure {env} Environment")
return self.environment.azure_environment

@property
def effective_azure_login_app_id(self):
app_id = self.azure_login_app_id
if app_id:
return app_id
return ARM_DATABRICKS_RESOURCE_ID
return self.environment.azure_application_id

@property
def environment(self) -> DatabricksEnvironment:
hostname = self.hostname
for env in _DATABRICKS_ENVIRONMENTS:
if hostname.endswith(env.tld):
return env
raise ValueError(f"Cannot find DatabricksEnvironment for {hostname}")

@property
def hostname(self) -> str:
Expand Down
12 changes: 12 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,18 @@ def inner(h: BaseHTTPRequestHandler):
assert len(requests) == 2


@pytest.mark.parametrize(
"host,account_host",
[('https://accounts.cloud.databricks.com', 'https://accounts.cloud.databricks.com'),
('https://dbc-ldflSlsd.cloud.databricks.com', 'https://accounts.cloud.databricks.com'),
('https://abd-23424234234.12.azuredatabricks.net', 'https://accounts.azuredatabricks.net'),
('https://abd-23424234234.12.databricks.azure.us', 'https://accounts.databricks.azure.us'),
('https://23423423.gcp.databricks.com', 'https://accounts.gcp.databricks.com'), ])
def test_get_account_host(host, account_host):
cfg = Config(host=host, token=...)
assert account_host == cfg.account_host


def test_github_oidc_flow_works_with_azure(monkeypatch):

def inner(h: BaseHTTPRequestHandler):
Expand Down

0 comments on commit cb830fc

Please sign in to comment.