Skip to content

Commit

Permalink
Add workloadId and workloadClass request headers (mlflow#13104)
Browse files Browse the repository at this point in the history
Signed-off-by: Arpit Jasapara <[email protected]>
Signed-off-by: mlflow-automation <[email protected]>
Co-authored-by: mlflow-automation <[email protected]>
  • Loading branch information
arpitjasa-db and mlflow-automation authored Sep 10, 2024
1 parent 22f6bc1 commit 4cb8009
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,11 @@ def request_headers(self):
command_run_id = databricks_utils.get_command_run_id()
if command_run_id is not None:
request_headers["command_run_id"] = command_run_id
workload_id = databricks_utils.get_workload_id()
workload_class = databricks_utils.get_workload_class()
if workload_id is not None:
request_headers["workload_id"] = workload_id
if workload_class is not None:
request_headers["workload_class"] = workload_class

return request_headers
16 changes: 16 additions & 0 deletions mlflow/utils/databricks_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,22 @@ def get_command_run_id():
return None


@_use_repl_context_if_available("workloadId")
def get_workload_id():
try:
return _get_command_context().workloadId().get()
except Exception:
return _get_context_tag("workloadId")


@_use_repl_context_if_available("workloadClass")
def get_workload_class():
try:
return _get_command_context().workloadClass().get()
except Exception:
return _get_context_tag("workloadClass")


@_use_repl_context_if_available("apiUrl")
def get_webapp_url():
"""Should only be called if is_in_databricks_notebook or is_in_databricks_jobs is true"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ def test_databricks_request_header_provider_request_headers(
"mlflow.utils.databricks_utils.get_cluster_id"
) as cluster_id_mock, mock.patch(
"mlflow.utils.databricks_utils.get_command_run_id"
) as command_run_id_mock:
) as command_run_id_mock, mock.patch(
"mlflow.utils.databricks_utils.get_workload_id"
) as workload_id_mock, mock.patch(
"mlflow.utils.databricks_utils.get_workload_class"
) as workload_class_mock:
request_headers = DatabricksRequestHeaderProvider().request_headers()

if is_in_databricks_notebook:
Expand All @@ -81,3 +85,13 @@ def test_databricks_request_header_provider_request_headers(
assert request_headers["command_run_id"] == command_run_id_mock.return_value
else:
assert "command_run_id" not in request_headers

if workload_id_mock.return_value is not None:
assert request_headers["workload_id"] == workload_id_mock.return_value
else:
assert "workload_id" not in request_headers

if workload_class_mock.return_value is not None:
assert request_headers["workload_class"] == workload_class_mock.return_value
else:
assert "workload_class" not in request_headers

0 comments on commit 4cb8009

Please sign in to comment.