diff --git a/mlflow/tracking/request_header/databricks_request_header_provider.py b/mlflow/tracking/request_header/databricks_request_header_provider.py index 4607c64d0cc6d..420e4aeca3be5 100644 --- a/mlflow/tracking/request_header/databricks_request_header_provider.py +++ b/mlflow/tracking/request_header/databricks_request_header_provider.py @@ -28,5 +28,11 @@ def request_headers(self): command_run_id = databricks_utils.get_command_run_id() if command_run_id is not None: request_headers["command_run_id"] = command_run_id + workload_id = databricks_utils.get_workload_id() + workload_class = databricks_utils.get_workload_class() + if workload_id is not None: + request_headers["workload_id"] = workload_id + if workload_class is not None: + request_headers["workload_class"] = workload_class return request_headers diff --git a/mlflow/utils/databricks_utils.py b/mlflow/utils/databricks_utils.py index 8f1985612dd91..fd5f0263aff4e 100644 --- a/mlflow/utils/databricks_utils.py +++ b/mlflow/utils/databricks_utils.py @@ -402,6 +402,22 @@ def get_command_run_id(): return None +@_use_repl_context_if_available("workloadId") +def get_workload_id(): + try: + return _get_command_context().workloadId().get() + except Exception: + return _get_context_tag("workloadId") + + +@_use_repl_context_if_available("workloadClass") +def get_workload_class(): + try: + return _get_command_context().workloadClass().get() + except Exception: + return _get_context_tag("workloadClass") + + @_use_repl_context_if_available("apiUrl") def get_webapp_url(): """Should only be called if is_in_databricks_notebook or is_in_databricks_jobs is true""" diff --git a/tests/tracking/request_header/test_databricks_request_header_provider.py b/tests/tracking/request_header/test_databricks_request_header_provider.py index c7018cb9673b5..2c98afd7d7e4f 100644 --- a/tests/tracking/request_header/test_databricks_request_header_provider.py +++ b/tests/tracking/request_header/test_databricks_request_header_provider.py @@ -55,7 +55,11 @@ def test_databricks_request_header_provider_request_headers( "mlflow.utils.databricks_utils.get_cluster_id" ) as cluster_id_mock, mock.patch( "mlflow.utils.databricks_utils.get_command_run_id" - ) as command_run_id_mock: + ) as command_run_id_mock, mock.patch( + "mlflow.utils.databricks_utils.get_workload_id" + ) as workload_id_mock, mock.patch( + "mlflow.utils.databricks_utils.get_workload_class" + ) as workload_class_mock: request_headers = DatabricksRequestHeaderProvider().request_headers() if is_in_databricks_notebook: @@ -81,3 +85,13 @@ def test_databricks_request_header_provider_request_headers( assert request_headers["command_run_id"] == command_run_id_mock.return_value else: assert "command_run_id" not in request_headers + + if workload_id_mock.return_value is not None: + assert request_headers["workload_id"] == workload_id_mock.return_value + else: + assert "workload_id" not in request_headers + + if workload_class_mock.return_value is not None: + assert request_headers["workload_class"] == workload_class_mock.return_value + else: + assert "workload_class" not in request_headers