From dbd58fb4e15006908a9b6f152b1334adc73c89d6 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 28 Sep 2022 16:05:19 -0700 Subject: [PATCH] Follow "More flexible cluster configuration". (#194) ### Description Follows "More flexible cluster configuration" at dbt-labs/dbt-spark#467. - Reuse `dbt-spark`'s implementation - Remove the dependency on `databricks-cli` - Internal refactorings Co-authored-by: allisonwang-db --- CHANGELOG.md | 10 ++ dbt/adapters/databricks/api_client.py | 87 -------------- dbt/adapters/databricks/connections.py | 108 +++++++++--------- dbt/adapters/databricks/impl.py | 14 ++- dbt/adapters/databricks/python_submissions.py | 100 ++++++++-------- mypy.ini | 3 - requirements.txt | 1 - setup.py | 1 - 8 files changed, 119 insertions(+), 205 deletions(-) delete mode 100644 dbt/adapters/databricks/api_client.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 90500399c..4ce93543e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,16 @@ ### Features - Support python model through run command API, currently supported materializations are table and incremental. ([dbt-labs/dbt-spark#377](https://github.com/dbt-labs/dbt-spark/pull/377), [#126](https://github.com/databricks/dbt-databricks/pull/126)) - Enable Pandas and Pandas-on-Spark DataFrames for dbt python models ([dbt-labs/dbt-spark#469](https://github.com/dbt-labs/dbt-spark/pull/469), [#181](https://github.com/databricks/dbt-databricks/pull/181)) +- Support job cluster in notebook submission method ([dbt-labs/dbt-spark#467](https://github.com/dbt-labs/dbt-spark/pull/467), [#194](https://github.com/databricks/dbt-databricks/pull/194)) + - In `all_purpose_cluster` submission method, a config `http_path` can be specified in Python model config to switch the cluster where Python model runs. + ```py + def model(dbt, _): + dbt.config( + materialized='table', + http_path='...' + ) + ... + ``` - Use builtin timestampadd and timestampdiff functions for dateadd/datediff macros if available ([#185](https://github.com/databricks/dbt-databricks/pull/185)) - Implement testing for a test for various Python models ([#189](https://github.com/databricks/dbt-databricks/pull/189)) - Implement testing for `type_boolean` in Databricks ([dbt-labs/dbt-spark#471](https://github.com/dbt-labs/dbt-spark/pull/471), [#188](https://github.com/databricks/dbt-databricks/pull/188)) diff --git a/dbt/adapters/databricks/api_client.py b/dbt/adapters/databricks/api_client.py deleted file mode 100644 index 373a870ec..000000000 --- a/dbt/adapters/databricks/api_client.py +++ /dev/null @@ -1,87 +0,0 @@ -from requests.exceptions import HTTPError -from typing import Any, Dict - -from databricks_cli.sdk.api_client import ApiClient - - -class Api12Client: - def __init__(self, host: str, token: str, command_name: str = ""): - self._api_client = ApiClient( - host=f"https://{host}", - token=token, - api_version="1.2", - command_name=command_name, - ) - self.Context = Context(self._api_client) - self.Command = Command(self._api_client) - - def close(self) -> None: - self._api_client.close() - - -class Context: - def __init__(self, client: ApiClient): - self.client = client - - def create(self, cluster_id: str) -> str: - # https://docs.databricks.com/dev-tools/api/1.2/index.html#create-an-execution-context - try: - response = self.client.perform_query( - method="POST", - path="/contexts/create", - data=dict(clusterId=cluster_id, language="python"), - ) - return response["id"] - except HTTPError as e: - raise HTTPError( - f"Error creating an execution context\n {e.response.content!r}", response=e.response - ) from e - - def destroy(self, cluster_id: str, context_id: str) -> str: - # https://docs.databricks.com/dev-tools/api/1.2/index.html#delete-an-execution-context - try: - response = self.client.perform_query( - method="POST", - path="/contexts/destroy", - data=dict(clusterId=cluster_id, contextId=context_id), - ) - return response["id"] - except HTTPError as e: - raise HTTPError( - f"Error deleting an execution context\n {e.response.content!r}", response=e.response - ) from e - - -class Command: - def __init__(self, client: ApiClient): - self.client = client - - def execute(self, cluster_id: str, context_id: str, command: str) -> str: - # https://docs.databricks.com/dev-tools/api/1.2/index.html#run-a-command - try: - response = self.client.perform_query( - method="POST", - path="/commands/execute", - data=dict( - clusterId=cluster_id, - contextId=context_id, - language="python", - command=command, - ), - ) - return response["id"] - except HTTPError as e: - raise HTTPError( - f"Error creating a command\n {e.response.content!r}", response=e.response - ) from e - - def status(self, cluster_id: str, context_id: str, command_id: str) -> Dict[str, Any]: - # https://docs.databricks.com/dev-tools/api/1.2/index.html#get-information-about-a-command - try: - return self.client.perform_query( - method="GET", - path="/commands/status", - data=dict(clusterId=cluster_id, contextId=context_id, commandId=command_id), - ) - except HTTPError as e: - return e.response.json() diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 0b40cd05e..8dede1cf2 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -129,6 +129,50 @@ def __post_init__(self) -> None: ) self.connection_parameters = connection_parameters + def validate_creds(self) -> None: + for key in ["host", "http_path", "token"]: + if not getattr(self, key): + raise dbt.exceptions.DbtProfileError( + "The config '{}' is required to connect to Databricks".format(key) + ) + + @classmethod + def get_invocation_env(cls) -> Optional[str]: + invocation_env = os.environ.get(DBT_DATABRICKS_INVOCATION_ENV) + if invocation_env: + # Thrift doesn't allow nested () so we need to ensure + # that the passed user agent is valid. + if not DBT_DATABRICKS_INVOCATION_ENV_REGEX.search(invocation_env): + raise dbt.exceptions.ValidationException( + f"Invalid invocation environment: {invocation_env}" + ) + return invocation_env + + @classmethod + def get_all_http_headers(cls, user_http_session_headers: Dict[str, str]) -> Dict[str, str]: + http_session_headers_str: Optional[str] = os.environ.get( + DBT_DATABRICKS_HTTP_SESSION_HEADERS + ) + + http_session_headers_dict: Dict[str, str] = ( + {k: json.dumps(v) for k, v in json.loads(http_session_headers_str).items()} + if http_session_headers_str is not None + else {} + ) + + intersect_http_header_keys = ( + user_http_session_headers.keys() & http_session_headers_dict.keys() + ) + + if len(intersect_http_header_keys) > 0: + raise dbt.exceptions.ValidationException( + f"Intersection with reserved http_headers in keys: {intersect_http_header_keys}" + ) + + http_session_headers_dict.update(user_http_session_headers) + + return http_session_headers_dict + @property def type(self) -> str: return "databricks" @@ -165,14 +209,18 @@ def _connection_keys(self, *, with_aliases: bool = False) -> Tuple[str, ...]: connection_keys.append("session_properties") return tuple(connection_keys) - @property - def cluster_id(self) -> Optional[str]: - m = EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX.match(self.http_path) # type: ignore[arg-type] + @classmethod + def extract_cluster_id(cls, http_path: str) -> Optional[str]: + m = EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX.match(http_path) if m: return m.group(1).strip() else: return None + @property + def cluster_id(self) -> Optional[str]: + return self.extract_cluster_id(self.http_path) # type: ignore[arg-type] + class DatabricksSQLConnectionWrapper: """Wrap a Databricks SQL connector in a way that no-ops transactions""" @@ -437,49 +485,6 @@ def list_schemas(self, database: str, schema: Optional[str] = None) -> Table: lambda cursor: cursor.schemas(catalog_name=database, schema_name=schema), ) - @classmethod - def validate_creds(cls, creds: DatabricksCredentials, required: List[str]) -> None: - for key in required: - if not getattr(creds, key): - raise dbt.exceptions.DbtProfileError( - "The config '{}' is required to connect to Databricks".format(key) - ) - - @classmethod - def validate_invocation_env(cls, invocation_env: str) -> None: - # Thrift doesn't allow nested () so we need to ensure that the passed user agent is valid - if not DBT_DATABRICKS_INVOCATION_ENV_REGEX.search(invocation_env): - raise dbt.exceptions.ValidationException( - f"Invalid invocation environment: {invocation_env}" - ) - - @classmethod - def get_all_http_headers( - cls, user_http_session_headers: Dict[str, str] - ) -> List[Tuple[str, str]]: - http_session_headers_str: Optional[str] = os.environ.get( - DBT_DATABRICKS_HTTP_SESSION_HEADERS - ) - - http_session_headers_dict: Dict[str, str] = ( - {k: json.dumps(v) for k, v in json.loads(http_session_headers_str).items()} - if http_session_headers_str is not None - else {} - ) - - intersect_http_header_keys = ( - user_http_session_headers.keys() & http_session_headers_dict.keys() - ) - - if len(intersect_http_header_keys) > 0: - raise dbt.exceptions.ValidationException( - f"Intersection with reserved http_headers in keys: {intersect_http_header_keys}" - ) - - http_session_headers_dict.update(user_http_session_headers) - - return list(http_session_headers_dict.items()) - @classmethod def open(cls, connection: Connection) -> Connection: if connection.state == ConnectionState.OPEN: @@ -487,19 +492,18 @@ def open(cls, connection: Connection) -> Connection: return connection creds: DatabricksCredentials = connection.credentials - cls.validate_creds(creds, ["host", "http_path", "token"]) + creds.validate_creds() user_agent_entry = f"dbt-databricks/{__version__}" - invocation_env = os.environ.get(DBT_DATABRICKS_INVOCATION_ENV) - if invocation_env is not None and len(invocation_env) > 0: - cls.validate_invocation_env(invocation_env) + invocation_env = creds.get_invocation_env() + if invocation_env: user_agent_entry = f"{user_agent_entry}; {invocation_env}" connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr] - http_headers: List[Tuple[str, str]] = cls.get_all_http_headers( - connection_parameters.pop("http_headers", {}) + http_headers: List[Tuple[str, str]] = list( + creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items() ) exc: Optional[Exception] = None diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index 5d7f4d04c..5ff1d7dcf 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -26,7 +26,10 @@ from dbt.adapters.databricks.column import DatabricksColumn from dbt.adapters.databricks.connections import DatabricksConnectionManager -from dbt.adapters.databricks.python_submissions import CommandApiPythonJobHelper +from dbt.adapters.databricks.python_submissions import ( + DbtDatabricksAllPurposeClusterPythonJobHelper, + DbtDatabricksJobClusterPythonJobHelper, +) from dbt.adapters.databricks.relation import DatabricksRelation from dbt.adapters.databricks.utils import undefined_proof @@ -264,13 +267,12 @@ def run_sql_for_tests( def valid_incremental_strategies(self) -> List[str]: return ["append", "merge", "insert_overwrite"] - @property - def default_python_submission_method(self) -> str: - return "commands" - @property def python_submission_helpers(self) -> Dict[str, Type[PythonJobHelper]]: - return {"commands": CommandApiPythonJobHelper} + return { + "job_cluster": DbtDatabricksJobClusterPythonJobHelper, + "all_purpose_cluster": DbtDatabricksAllPurposeClusterPythonJobHelper, + } @contextmanager def _catalog(self, catalog: Optional[str]) -> Iterator[None]: diff --git a/dbt/adapters/databricks/python_submissions.py b/dbt/adapters/databricks/python_submissions.py index 94ecf1fff..1365298ff 100644 --- a/dbt/adapters/databricks/python_submissions.py +++ b/dbt/adapters/databricks/python_submissions.py @@ -1,18 +1,16 @@ -import os -from typing import Dict, cast -import uuid +from typing import Dict, Optional -from requests import HTTPError - -import dbt.exceptions -from dbt.adapters.spark.python_submissions import BaseDatabricksHelper +from dbt.adapters.spark.python_submissions import ( + AllPurposeClusterPythonJobHelper, + BaseDatabricksHelper, + JobClusterPythonJobHelper, +) from dbt.adapters.databricks.__version__ import version -from dbt.adapters.databricks.api_client import Api12Client -from dbt.adapters.databricks.connections import DatabricksCredentials, DBT_DATABRICKS_INVOCATION_ENV +from dbt.adapters.databricks.connections import DatabricksCredentials -class CommandApiPythonJobHelper(BaseDatabricksHelper): +class DbtDatabricksBasePythonJobHelper(BaseDatabricksHelper): credentials: DatabricksCredentials # type: ignore[assignment] def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> None: @@ -20,56 +18,48 @@ def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> No parsed_model=parsed_model, credentials=credentials # type: ignore[arg-type] ) - command_name = f"dbt-databricks_{version}" - invocation_env = os.environ.get(DBT_DATABRICKS_INVOCATION_ENV) - if invocation_env is not None and len(invocation_env) > 0: - command_name = f"{command_name}-{invocation_env}" - command_name += "-" + str(uuid.uuid1()) + user_agent = f"dbt-databricks/{version}" + + invocation_env = credentials.get_invocation_env() + if invocation_env: + user_agent = f"{user_agent} ({invocation_env})" + + connection_parameters = credentials.connection_parameters.copy() # type: ignore[union-attr] - self.api_client = Api12Client( - host=cast(str, credentials.host), - token=cast(str, credentials.token), - command_name=command_name, + http_headers: Dict[str, str] = credentials.get_all_http_headers( + connection_parameters.pop("http_headers", {}) ) - def check_credentials(self) -> None: - if not self.cluster_id: - raise ValueError("Databricks cluster is required for commands submission method.") + self.auth_header.update({"User-Agent": user_agent, **http_headers}) - def submit(self, compiled_code: str) -> None: - cluster_id = self.cluster_id + @property + def cluster_id(self) -> Optional[str]: # type: ignore[override] + return self.parsed_model["config"].get( + "cluster_id", + self.credentials.extract_cluster_id( + self.parsed_model["config"].get("http_path", self.credentials.http_path) + ), + ) - try: - # Create an execution context - context_id = self.api_client.Context.create(cluster_id) - try: - # Run a command - command_id = self.api_client.Command.execute( - cluster_id=cluster_id, - context_id=context_id, - command=compiled_code, - ) +class DbtDatabricksJobClusterPythonJobHelper( + DbtDatabricksBasePythonJobHelper, JobClusterPythonJobHelper +): + def check_credentials(self) -> None: + self.credentials.validate_creds() + if not self.parsed_model["config"].get("job_cluster_config", None): + raise ValueError( + "`job_cluster_config` is required for the `job_cluster` submission method." + ) - # poll until job finish - response = self.polling( - status_func=self.api_client.Command.status, - status_func_kwargs=dict( - cluster_id=cluster_id, context_id=context_id, command_id=command_id - ), - get_state_func=lambda response: response["status"], - terminal_states=("Cancelled", "Error", "Finished"), - expected_end_state="Finished", - get_state_msg_func=lambda response: response.json()["results"]["data"], - ) - if response["results"]["resultType"] == "error": - raise dbt.exceptions.RuntimeException( - f"Python model failed with traceback as:\n" - f"{response['results']['cause']}" - ) - finally: - # Delete the execution context - self.api_client.Context.destroy(cluster_id=cluster_id, context_id=context_id) - except HTTPError as e: - raise dbt.exceptions.RuntimeException(str(e)) from e +class DbtDatabricksAllPurposeClusterPythonJobHelper( + DbtDatabricksBasePythonJobHelper, AllPurposeClusterPythonJobHelper +): + def check_credentials(self) -> None: + self.credentials.validate_creds() + if not self.cluster_id: + raise ValueError( + "Databricks `http_path` or `cluster_id` of an all-purpose cluster is required " + "for the `all_purpose_cluster` submission method." + ) diff --git a/mypy.ini b/mypy.ini index dde351666..9e4f4e493 100644 --- a/mypy.ini +++ b/mypy.ini @@ -10,9 +10,6 @@ disallow_untyped_defs = False [mypy-databricks.*] ignore_missing_imports = True -[mypy-databricks_cli.*] -ignore_missing_imports = True - [mypy-agate.*] ignore_missing_imports = True diff --git a/requirements.txt b/requirements.txt index 559b397fc..69633989a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,2 @@ databricks-sql-connector>=2.0.5 -databricks-cli>=0.17.0 dbt-spark~=1.3.0b2 diff --git a/setup.py b/setup.py index 7b2ca9356..ba790aae8 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,6 @@ def _get_plugin_version(): install_requires=[ "dbt-spark~={}".format(dbt_spark_version), "databricks-sql-connector>=2.0.5", - "databricks-cli>=0.17.0", ], zip_safe=False, classifiers=[