Skip to content

Commit

Permalink
Patch dbutils.notebook.entry_point... to return current local noteb…
Browse files Browse the repository at this point in the history
…ook path from env var (#618)

## Changes
* Users are using
`dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()`
to get the current notebook path in notebooks (referring
https://stackoverflow.com/questions/53523560/databricks-how-do-i-get-path-of-current-notebook).
* This does not work from our current dbutils proxy. This PR patches
this to read from a DATABRICKS_SOURCE_FILE env var when using local
dbutils.

## Proposal
* Allow users to add their own patches to make is easier for users using
dbutils from sdk to workaround such issues in the future.

## Tests
* Unit tests

- [x] `make test` run locally
- [x] `make fmt` applied
- [x] relevant integration tests applied

---------

Signed-off-by: Kartik Gupta <[email protected]>
Co-authored-by: Miles Yucht <[email protected]>
  • Loading branch information
kartikgupta-db and mgyucht authored May 1, 2024
1 parent 8d25659 commit 2fc049d
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 15 deletions.
84 changes: 81 additions & 3 deletions databricks/sdk/dbutils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import base64
import json
import logging
import os.path
import os
import threading
from collections import namedtuple
from typing import Callable, Dict, List
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional

from .core import ApiClient, Config, DatabricksError
from .mixins import compute as compute_ext
Expand Down Expand Up @@ -240,6 +241,76 @@ def __getattr__(self, util) -> '_ProxyUtil':
name=util)


@dataclass
class OverrideResult:
result: Any


def get_local_notebook_path():
value = os.getenv("DATABRICKS_SOURCE_FILE")
if value is None:
raise ValueError(
"Getting the current notebook path is only supported when running a notebook using the `Databricks Connect: Run as File` or `Databricks Connect: Debug as File` commands in the Databricks extension for VS Code. To bypass this error, set environment variable `DATABRICKS_SOURCE_FILE` to the desired notebook path."
)

return value


class _OverrideProxyUtil:

@classmethod
def new(cls, path: str):
if len(cls.__get_matching_overrides(path)) > 0:
return _OverrideProxyUtil(path)
return None

def __init__(self, name: str):
self._name = name

# These are the paths that we want to override and not send to remote dbutils. NOTE, for each of these paths, no prefixes
# are sent to remote either. This could lead to unintentional breakage.
# Our current proxy implementation (which sends everything to remote dbutils) uses `{util}.{method}(*args, **kwargs)` ONLY.
# This means, it is completely safe to override paths starting with `{util}.{attribute}.<other_parts>`, since none of the prefixes
# are being proxied to remote dbutils currently.
proxy_override_paths = {
'notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()':
get_local_notebook_path,
}

@classmethod
def __get_matching_overrides(cls, path: str):
return [x for x in cls.proxy_override_paths.keys() if x.startswith(path)]

def __run_override(self, path: str) -> Optional[OverrideResult]:
overrides = self.__get_matching_overrides(path)
if len(overrides) == 1 and overrides[0] == path:
return OverrideResult(self.proxy_override_paths[overrides[0]]())

if len(overrides) > 0:
return OverrideResult(_OverrideProxyUtil(name=path))

return None

def __call__(self, *args, **kwds) -> Any:
if len(args) != 0 or len(kwds) != 0:
raise TypeError(
f"Arguments are not supported for overridden method {self._name}. Invoke as: {self._name}()")

callable_path = f"{self._name}()"
result = self.__run_override(callable_path)
if result:
return result.result

raise TypeError(f"{self._name} is not callable")

def __getattr__(self, method: str) -> Any:
result = self.__run_override(f"{self._name}.{method}")
if result:
return result.result

raise AttributeError(f"module {self._name} has no attribute {method}")


class _ProxyUtil:
"""Enables temporary workaround to call remote in-REPL dbutils without having to re-implement them"""

Expand All @@ -250,7 +321,14 @@ def __init__(self, *, command_execution: compute.CommandExecutionAPI,
self._context_factory = context_factory
self._name = name

def __getattr__(self, method: str) -> '_ProxyCall':
def __call__(self):
raise NotImplementedError(f"dbutils.{self._name} is not callable")

def __getattr__(self, method: str) -> '_ProxyCall | _ProxyUtil | _OverrideProxyUtil':
override = _OverrideProxyUtil.new(f"{self._name}.{method}")
if override:
return override

return _ProxyCall(command_execution=self._commands,
cluster_id=self._cluster_id,
context_factory=self._context_factory,
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from databricks.sdk.core import Config
from databricks.sdk.credentials_provider import credentials_provider

from .integration.conftest import restorable_env # type: ignore


@credentials_provider('noop', [])
def noop_credentials(_: any):
Expand Down
12 changes: 12 additions & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,15 @@ def _load_debug_env_if_runs_from_ide(key) -> bool:

def _is_in_debug() -> bool:
return os.path.basename(sys.argv[0]) in ['_jb_pytest_runner.py', 'testlauncher.py', ]


@pytest.fixture(scope="function")
def restorable_env():
import os
current_env = os.environ.copy()
yield
for k, v in os.environ.items():
if k not in current_env:
del os.environ[k]
elif v != current_env[k]:
os.environ[k] = current_env[k]
12 changes: 0 additions & 12 deletions tests/integration/test_dbconnect.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,6 @@ def reload_modules(name: str):
print(f"Failed to reload {name}: {e}")


@pytest.fixture(scope="function")
def restorable_env():
import os
current_env = os.environ.copy()
yield
for k, v in os.environ.items():
if k not in current_env:
del os.environ[k]
elif v != current_env[k]:
os.environ[k] = current_env[k]


@pytest.fixture(params=list(DBCONNECT_DBR_CLIENT.keys()))
def setup_dbconnect_test(request, env_or_skip, restorable_env):
dbr = request.param
Expand Down
8 changes: 8 additions & 0 deletions tests/test_dbutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,11 @@ def test_jobs_task_values_get_throws(dbutils):
except TypeError as e:
assert str(
e) == 'Must pass debugValue when calling get outside of a job context. debugValue cannot be None.'


def test_dbutils_proxy_overrides(dbutils, mocker, restorable_env):
import os
os.environ["DATABRICKS_SOURCE_FILE"] = "test_source_file"
mocker.patch('databricks.sdk.dbutils.RemoteDbUtils._cluster_id', return_value="test_cluster_id")
assert dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get(
) == "test_source_file"

0 comments on commit 2fc049d

Please sign in to comment.