From 3bf489f17bc7b1ce2d7e32f8b5ee2b8e3a4f9b84 Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Thu, 22 Feb 2024 18:26:56 +0100 Subject: [PATCH 01/26] Make a best effort attempt to initialise all Databricks globals --- databricks/sdk/runtime/__init__.py | 90 ++++++++++++++++++++--- setup.py | 3 +- tests/integration/test_runtime_globals.py | 12 +++ 3 files changed, 92 insertions(+), 13 deletions(-) create mode 100644 tests/integration/test_runtime_globals.py diff --git a/databricks/sdk/runtime/__init__.py b/databricks/sdk/runtime/__init__.py index c92bd1d50..627073a85 100644 --- a/databricks/sdk/runtime/__init__.py +++ b/databricks/sdk/runtime/__init__.py @@ -1,7 +1,10 @@ from __future__ import annotations import logging -from typing import Dict, Union +from types import FunctionType +from typing import Callable, Dict, Union + +from databricks.sdk.service import sql logger = logging.getLogger('databricks.sdk') is_local_implementation = True @@ -86,23 +89,86 @@ def inner() -> Dict[str, str]: _globals[var] = userNamespaceGlobals[var] is_local_implementation = False except ImportError: - from typing import cast - # OSS implementation is_local_implementation = True - from databricks.sdk.dbutils import RemoteDbUtils + try: + # We expect this to fail and only do this for providing types + from pyspark.sql.context import SQLContext + sqlContext: SQLContext = None # type: ignore + sql = sqlContext.sql + table = sqlContext.table + except Exception: + pass + + # The next few try-except blocks are for initialising globals in a best effort + # mannaer. We separate them to try to get as many of them working as possible + try: + from pyspark.sql.functions import udf # type: ignore + except ImportError: + pass + + try: + from databricks.connect import DatabricksSession # type: ignore + spark = DatabricksSession.builder.getOrCreate() + sc = spark.sparkContext + except Exception: + # We are ignoring all failures here because user might want to initialize + # spark session themselves and we don't want to interfere with that + pass + + try: + from IPython import display as IPDisplay + + def display(input=None, *args, **kwargs) -> None : # type: ignore + """ + Display plots or data. + Display plot: + - display() # no-op + - display(matplotlib.figure.Figure) + Display dataset: + - display(spark.DataFrame) + - display(list) # if list can be converted to DataFrame, e.g., list of named tuples + - display(pandas.DataFrame) + - display(koalas.DataFrame) + - display(pyspark.pandas.DataFrame) + Display any other value that has a _repr_html_() method + For Spark 2.0 and 2.1: + - display(DataFrame, streamName='optional', trigger=optional pyspark.sql.streaming.Trigger, + checkpointLocation='optional') + For Spark 2.2+: + - display(DataFrame, streamName='optional', trigger=optional interval like '1 second', + checkpointLocation='optional') + """ + return IPDisplay.display(input, *args, **kwargs) # type: ignore + + def displayHTML(html) -> None: # type: ignore + """ + Display HTML data. + Parameters + ---------- + data : URL or HTML string + If data is a URL, display the resource at that URL, the resource is loaded dynamically by the browser. + Otherwise data should be the HTML to be displayed. + See also: + IPython.display.HTML + IPython.display.display_html + """ + return IPDisplay.display_html(html, raw=True) # type: ignore + + except ImportError: + pass + + # We want to propagate the error in initialising dbutils because this is a core + # functionality of the sdk + from databricks.sdk.dbutils import RemoteDbUtils from . import dbutils_stub - + from typing import cast dbutils_type = Union[dbutils_stub.dbutils, RemoteDbUtils] - try: - from .stub import * - except (ImportError, NameError): - # this assumes that all environment variables are set - dbutils = RemoteDbUtils() - + dbutils = RemoteDbUtils() dbutils = cast(dbutils_type, dbutils) + getArgument = dbutils.widgets.getArgument -__all__ = ['dbutils'] if is_local_implementation else dbruntime_objects +__all__ = dbruntime_objects diff --git a/setup.py b/setup.py index 6af948fc7..0d3e8c008 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,8 @@ install_requires=["requests>=2.28.1,<3", "google-auth~=2.0"], extras_require={"dev": ["pytest", "pytest-cov", "pytest-xdist", "pytest-mock", "yapf", "pycodestyle", "autoflake", "isort", "wheel", - "ipython", "ipywidgets", "requests-mock", "pyfakefs"], + "ipython", "ipywidgets", "requests-mock", "pyfakefs", + "databricks-connect", "ipython"], "notebook": ["ipython>=8,<9", "ipywidgets>=8,<9"]}, author="Serge Smertin", author_email="serge.smertin@databricks.com", diff --git a/tests/integration/test_runtime_globals.py b/tests/integration/test_runtime_globals.py new file mode 100644 index 000000000..6636fbcf9 --- /dev/null +++ b/tests/integration/test_runtime_globals.py @@ -0,0 +1,12 @@ +def test_runtime_spark(w, env_or_skip): + env_or_skip("SPARK_CONNECT_CLUSTER_ID") + + from databricks.sdk.runtime import spark + assert spark.sql("SELECT 1").collect()[0][0] == 1 + +def test_runtime_display(w, env_or_skip): + from databricks.sdk.runtime import display, displayHTML + + # assert no errors + display("test") + displayHTML("test") From 1e5561991d4da8b41f8ecd61538beeb84ee6644b Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Thu, 22 Feb 2024 18:31:27 +0100 Subject: [PATCH 02/26] lint --- databricks/sdk/runtime/__init__.py | 21 +++++++++++---------- tests/integration/test_runtime_globals.py | 1 + 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/databricks/sdk/runtime/__init__.py b/databricks/sdk/runtime/__init__.py index 627073a85..28b0f3109 100644 --- a/databricks/sdk/runtime/__init__.py +++ b/databricks/sdk/runtime/__init__.py @@ -101,26 +101,26 @@ def inner() -> Dict[str, str]: except Exception: pass - # The next few try-except blocks are for initialising globals in a best effort + # The next few try-except blocks are for initialising globals in a best effort # mannaer. We separate them to try to get as many of them working as possible try: - from pyspark.sql.functions import udf # type: ignore + from pyspark.sql.functions import udf # type: ignore except ImportError: pass - + try: - from databricks.connect import DatabricksSession # type: ignore + from databricks.connect import DatabricksSession # type: ignore spark = DatabricksSession.builder.getOrCreate() sc = spark.sparkContext except Exception: # We are ignoring all failures here because user might want to initialize # spark session themselves and we don't want to interfere with that pass - - try: + + try: from IPython import display as IPDisplay - def display(input=None, *args, **kwargs) -> None : # type: ignore + def display(input=None, *args, **kwargs) -> None: # type: ignore """ Display plots or data. Display plot: @@ -155,16 +155,17 @@ def displayHTML(html) -> None: # type: ignore IPython.display.display_html """ return IPDisplay.display_html(html, raw=True) # type: ignore - + except ImportError: pass - # We want to propagate the error in initialising dbutils because this is a core # functionality of the sdk + from typing import cast + from databricks.sdk.dbutils import RemoteDbUtils + from . import dbutils_stub - from typing import cast dbutils_type = Union[dbutils_stub.dbutils, RemoteDbUtils] dbutils = RemoteDbUtils() diff --git a/tests/integration/test_runtime_globals.py b/tests/integration/test_runtime_globals.py index 6636fbcf9..01b26ee44 100644 --- a/tests/integration/test_runtime_globals.py +++ b/tests/integration/test_runtime_globals.py @@ -4,6 +4,7 @@ def test_runtime_spark(w, env_or_skip): from databricks.sdk.runtime import spark assert spark.sql("SELECT 1").collect()[0][0] == 1 + def test_runtime_display(w, env_or_skip): from databricks.sdk.runtime import display, displayHTML From ebe0fc0869b903e5c8f753e1f61b4709883ac6cc Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Thu, 22 Feb 2024 18:38:06 +0100 Subject: [PATCH 03/26] remove stubs file --- databricks/sdk/runtime/stub.py | 48 ---------------------------------- 1 file changed, 48 deletions(-) delete mode 100644 databricks/sdk/runtime/stub.py diff --git a/databricks/sdk/runtime/stub.py b/databricks/sdk/runtime/stub.py deleted file mode 100644 index 38d748bde..000000000 --- a/databricks/sdk/runtime/stub.py +++ /dev/null @@ -1,48 +0,0 @@ -from pyspark.sql.context import SQLContext -from pyspark.sql.functions import udf as U -from pyspark.sql.session import SparkSession - -udf = U -spark: SparkSession -sc = spark.sparkContext -sqlContext: SQLContext -sql = sqlContext.sql -table = sqlContext.table - - -def displayHTML(html): - """ - Display HTML data. - Parameters - ---------- - data : URL or HTML string - If data is a URL, display the resource at that URL, the resource is loaded dynamically by the browser. - Otherwise data should be the HTML to be displayed. - See also: - IPython.display.HTML - IPython.display.display_html - """ - ... - - -def display(input=None, *args, **kwargs): - """ - Display plots or data. - Display plot: - - display() # no-op - - display(matplotlib.figure.Figure) - Display dataset: - - display(spark.DataFrame) - - display(list) # if list can be converted to DataFrame, e.g., list of named tuples - - display(pandas.DataFrame) - - display(koalas.DataFrame) - - display(pyspark.pandas.DataFrame) - Display any other value that has a _repr_html_() method - For Spark 2.0 and 2.1: - - display(DataFrame, streamName='optional', trigger=optional pyspark.sql.streaming.Trigger, - checkpointLocation='optional') - For Spark 2.2+: - - display(DataFrame, streamName='optional', trigger=optional interval like '1 second', - checkpointLocation='optional') - """ - ... From 7296d83d471474be35afafabb812277437325dba Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Thu, 22 Feb 2024 19:15:19 +0100 Subject: [PATCH 04/26] remove uneeded imports --- databricks/sdk/runtime/__init__.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/databricks/sdk/runtime/__init__.py b/databricks/sdk/runtime/__init__.py index 28b0f3109..b03733b36 100644 --- a/databricks/sdk/runtime/__init__.py +++ b/databricks/sdk/runtime/__init__.py @@ -1,10 +1,7 @@ from __future__ import annotations import logging -from types import FunctionType -from typing import Callable, Dict, Union - -from databricks.sdk.service import sql +from typing import Dict, Union logger = logging.getLogger('databricks.sdk') is_local_implementation = True From dd5cb295fabebf8387eadbdb0791bc4401e0248d Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Thu, 22 Feb 2024 19:18:22 +0100 Subject: [PATCH 05/26] remove uneeded imports --- databricks/sdk/runtime/__init__.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/databricks/sdk/runtime/__init__.py b/databricks/sdk/runtime/__init__.py index b03733b36..308e635f9 100644 --- a/databricks/sdk/runtime/__init__.py +++ b/databricks/sdk/runtime/__init__.py @@ -102,17 +102,18 @@ def inner() -> Dict[str, str]: # mannaer. We separate them to try to get as many of them working as possible try: from pyspark.sql.functions import udf # type: ignore - except ImportError: - pass + except ImportError as e: + logging.debug(f"Failed to initialise udf global: {e}") + try: from databricks.connect import DatabricksSession # type: ignore spark = DatabricksSession.builder.getOrCreate() sc = spark.sparkContext - except Exception: + except Exception as e: # We are ignoring all failures here because user might want to initialize # spark session themselves and we don't want to interfere with that - pass + logging.debug(f"Failed to initialize spark session: {e}") try: from IPython import display as IPDisplay @@ -153,7 +154,8 @@ def displayHTML(html) -> None: # type: ignore """ return IPDisplay.display_html(html, raw=True) # type: ignore - except ImportError: + except ImportError as e: + logging.debug(f"Failed to initialise display globals: {e}") pass # We want to propagate the error in initialising dbutils because this is a core From 62cb41590ce37636eb6ef9406421619b771eb1fa Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Thu, 22 Feb 2024 19:21:28 +0100 Subject: [PATCH 06/26] rename --- Makefile | 2 +- .../{test_runtime_globals.py => test_local_globals.py} | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) rename tests/integration/{test_runtime_globals.py => test_local_globals.py} (74%) diff --git a/Makefile b/Makefile index 21c28d324..86c7fd2f1 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,7 @@ test: pytest -m 'not integration and not benchmark' --cov=databricks --cov-report html tests integration: - pytest -n auto -m 'integration and not benchmark' --cov=databricks --cov-report html tests + pytest -n auto -m 'integration and not benchmark' --cov=databricks --cov-report html -k 'test_local_globals.py' tests benchmark: pytest -m 'benchmark' tests diff --git a/tests/integration/test_runtime_globals.py b/tests/integration/test_local_globals.py similarity index 74% rename from tests/integration/test_runtime_globals.py rename to tests/integration/test_local_globals.py index 01b26ee44..a3c168c36 100644 --- a/tests/integration/test_runtime_globals.py +++ b/tests/integration/test_local_globals.py @@ -1,11 +1,11 @@ -def test_runtime_spark(w, env_or_skip): +def test_local_global_spark(w, env_or_skip): env_or_skip("SPARK_CONNECT_CLUSTER_ID") from databricks.sdk.runtime import spark assert spark.sql("SELECT 1").collect()[0][0] == 1 -def test_runtime_display(w, env_or_skip): +def test_local_global_display(w, env_or_skip): from databricks.sdk.runtime import display, displayHTML # assert no errors From 5c3a78c1565cc8a9bda53f562d3950eba2d8c699 Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Thu, 22 Feb 2024 19:22:05 +0100 Subject: [PATCH 07/26] lint --- databricks/sdk/runtime/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/databricks/sdk/runtime/__init__.py b/databricks/sdk/runtime/__init__.py index 308e635f9..5d836982a 100644 --- a/databricks/sdk/runtime/__init__.py +++ b/databricks/sdk/runtime/__init__.py @@ -104,7 +104,6 @@ def inner() -> Dict[str, str]: from pyspark.sql.functions import udf # type: ignore except ImportError as e: logging.debug(f"Failed to initialise udf global: {e}") - try: from databricks.connect import DatabricksSession # type: ignore @@ -156,7 +155,6 @@ def displayHTML(html) -> None: # type: ignore except ImportError as e: logging.debug(f"Failed to initialise display globals: {e}") - pass # We want to propagate the error in initialising dbutils because this is a core # functionality of the sdk From 7ce8fa358b692d7589252fdfd6077d5d8ab1a20e Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Thu, 22 Feb 2024 19:43:20 +0100 Subject: [PATCH 08/26] fix test --- tests/integration/test_local_globals.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/tests/integration/test_local_globals.py b/tests/integration/test_local_globals.py index a3c168c36..f859d1372 100644 --- a/tests/integration/test_local_globals.py +++ b/tests/integration/test_local_globals.py @@ -1,8 +1,25 @@ -def test_local_global_spark(w, env_or_skip): - env_or_skip("SPARK_CONNECT_CLUSTER_ID") +from contextlib import contextmanager + + +@contextmanager +def restorable_env(): + import os + current_env = os.environ.copy() - from databricks.sdk.runtime import spark - assert spark.sql("SELECT 1").collect()[0][0] == 1 + try: + yield + finally: + os.environ.clear() + os.environ.update(current_env) + + +def test_local_global_spark(w, env_or_skip): + cluster_id = env_or_skip("SPARK_CONNECT_CLUSTER_ID") + with restorable_env(): + import os + os.environ["DATABRICKS_CLUSTER_ID"] = cluster_id + from databricks.sdk.runtime import spark + assert spark.sql("SELECT 1").collect()[0][0] == 1 def test_local_global_display(w, env_or_skip): From f90d141d414bc75b485a0eedd049d401ffb6c650 Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Thu, 22 Feb 2024 19:45:06 +0100 Subject: [PATCH 09/26] revert make file change --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 86c7fd2f1..21c28d324 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,7 @@ test: pytest -m 'not integration and not benchmark' --cov=databricks --cov-report html tests integration: - pytest -n auto -m 'integration and not benchmark' --cov=databricks --cov-report html -k 'test_local_globals.py' tests + pytest -n auto -m 'integration and not benchmark' --cov=databricks --cov-report html tests benchmark: pytest -m 'benchmark' tests From 95508c1d6bf9bdbea7605b23539de6fdeafe05fd Mon Sep 17 00:00:00 2001 From: Kartik Gupta <88345179+kartikgupta-db@users.noreply.github.com> Date: Mon, 26 Feb 2024 13:39:17 +0100 Subject: [PATCH 10/26] Update databricks/sdk/runtime/__init__.py Co-authored-by: Miles Yucht Signed-off-by: Kartik Gupta <88345179+kartikgupta-db@users.noreply.github.com> --- databricks/sdk/runtime/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/databricks/sdk/runtime/__init__.py b/databricks/sdk/runtime/__init__.py index 5d836982a..76f8873d6 100644 --- a/databricks/sdk/runtime/__init__.py +++ b/databricks/sdk/runtime/__init__.py @@ -112,7 +112,7 @@ def inner() -> Dict[str, str]: except Exception as e: # We are ignoring all failures here because user might want to initialize # spark session themselves and we don't want to interfere with that - logging.debug(f"Failed to initialize spark session: {e}") + logging.debug(f"Failed to initialize globals 'spark' and 'sc', continuing. Cause: {e}") try: from IPython import display as IPDisplay From ebac19dc580cf7f32ee5cbcfbd360b1bd1d25c5d Mon Sep 17 00:00:00 2001 From: Kartik Gupta <88345179+kartikgupta-db@users.noreply.github.com> Date: Mon, 26 Feb 2024 13:48:28 +0100 Subject: [PATCH 11/26] Update databricks/sdk/runtime/__init__.py Co-authored-by: Miles Yucht Signed-off-by: Kartik Gupta <88345179+kartikgupta-db@users.noreply.github.com> --- databricks/sdk/runtime/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/databricks/sdk/runtime/__init__.py b/databricks/sdk/runtime/__init__.py index 76f8873d6..807a8ca65 100644 --- a/databricks/sdk/runtime/__init__.py +++ b/databricks/sdk/runtime/__init__.py @@ -154,7 +154,7 @@ def displayHTML(html) -> None: # type: ignore return IPDisplay.display_html(html, raw=True) # type: ignore except ImportError as e: - logging.debug(f"Failed to initialise display globals: {e}") + logging.debug(f"Failed to initialise globals 'display' and 'displayHTML', continuing. Cause: {e}") # We want to propagate the error in initialising dbutils because this is a core # functionality of the sdk From 7afdc6aff7e8313b11c4d5b5ff93880695e3db61 Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Mon, 26 Feb 2024 13:50:05 +0100 Subject: [PATCH 12/26] trigger imports only on function calls --- databricks/sdk/_widgets/__init__.py | 4 ++-- databricks/sdk/runtime/__init__.py | 15 +++++++++++---- databricks/sdk/runtime/dbutils_stub.py | 2 +- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/databricks/sdk/_widgets/__init__.py b/databricks/sdk/_widgets/__init__.py index 4fef42696..0cd033a55 100644 --- a/databricks/sdk/_widgets/__init__.py +++ b/databricks/sdk/_widgets/__init__.py @@ -13,11 +13,11 @@ def get(self, name: str): def _get(self, name: str) -> str: pass - def getArgument(self, name: str, default_value: typing.Optional[str] = None): + def getArgument(self, name: str, defaultValue: typing.Optional[str] = None): try: return self.get(name) except Exception: - return default_value + return defaultValue def remove(self, name: str): self._remove(name) diff --git a/databricks/sdk/runtime/__init__.py b/databricks/sdk/runtime/__init__.py index 5d836982a..449d13a13 100644 --- a/databricks/sdk/runtime/__init__.py +++ b/databricks/sdk/runtime/__init__.py @@ -115,8 +115,6 @@ def inner() -> Dict[str, str]: logging.debug(f"Failed to initialize spark session: {e}") try: - from IPython import display as IPDisplay - def display(input=None, *args, **kwargs) -> None: # type: ignore """ Display plots or data. @@ -137,6 +135,8 @@ def display(input=None, *args, **kwargs) -> None: # type: ignore - display(DataFrame, streamName='optional', trigger=optional interval like '1 second', checkpointLocation='optional') """ + # Import inside the function so that imports are only triggered on usage. + from IPython import display as IPDisplay return IPDisplay.display(input, *args, **kwargs) # type: ignore def displayHTML(html) -> None: # type: ignore @@ -151,6 +151,8 @@ def displayHTML(html) -> None: # type: ignore IPython.display.HTML IPython.display.display_html """ + # Import inside the function so that imports are only triggered on usage. + from IPython import display as IPDisplay return IPDisplay.display_html(html, raw=True) # type: ignore except ImportError as e: @@ -158,7 +160,7 @@ def displayHTML(html) -> None: # type: ignore # We want to propagate the error in initialising dbutils because this is a core # functionality of the sdk - from typing import cast + from typing import cast, Optional from databricks.sdk.dbutils import RemoteDbUtils @@ -167,6 +169,11 @@ def displayHTML(html) -> None: # type: ignore dbutils = RemoteDbUtils() dbutils = cast(dbutils_type, dbutils) - getArgument = dbutils.widgets.getArgument + + # We do this to prevent importing widgets implementation prematurely + # The widget import should prompt users to use the implementation + # which has ipywidget support. + def getArgument(name: str, defaultValue: Optional[str] = None): + return dbutils.widgets.getArgument(name, defaultValue) __all__ = dbruntime_objects diff --git a/databricks/sdk/runtime/dbutils_stub.py b/databricks/sdk/runtime/dbutils_stub.py index 12eff7b6d..8584b6080 100644 --- a/databricks/sdk/runtime/dbutils_stub.py +++ b/databricks/sdk/runtime/dbutils_stub.py @@ -288,7 +288,7 @@ def get(name: str) -> str: ... @staticmethod - def getArgument(name: str, defaultValue: typing.Optional[str] = None) -> str: + def getArgument(name: str, defaultValue: typing.Optional[str] = None) -> str | None: """Returns the current value of a widget with give name. :param name: Name of the argument to be accessed :param defaultValue: (Deprecated) default value From db36f520bb3f650635f6fc7867d62266a3303335 Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Mon, 26 Feb 2024 13:54:35 +0100 Subject: [PATCH 13/26] trigger imports only on function calls --- databricks/sdk/runtime/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/databricks/sdk/runtime/__init__.py b/databricks/sdk/runtime/__init__.py index 449d13a13..bb1de29e8 100644 --- a/databricks/sdk/runtime/__init__.py +++ b/databricks/sdk/runtime/__init__.py @@ -93,7 +93,6 @@ def inner() -> Dict[str, str]: # We expect this to fail and only do this for providing types from pyspark.sql.context import SQLContext sqlContext: SQLContext = None # type: ignore - sql = sqlContext.sql table = sqlContext.table except Exception: pass @@ -108,6 +107,7 @@ def inner() -> Dict[str, str]: try: from databricks.connect import DatabricksSession # type: ignore spark = DatabricksSession.builder.getOrCreate() + sql = spark.sql # type: ignore sc = spark.sparkContext except Exception as e: # We are ignoring all failures here because user might want to initialize From 4b9d8e54637a09b2b8a60f13ad745058a6b71c11 Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Mon, 26 Feb 2024 13:58:15 +0100 Subject: [PATCH 14/26] address feedback --- databricks/sdk/runtime/__init__.py | 96 +++++++++++++++--------------- 1 file changed, 49 insertions(+), 47 deletions(-) diff --git a/databricks/sdk/runtime/__init__.py b/databricks/sdk/runtime/__init__.py index dbf7b97a6..e5d474357 100644 --- a/databricks/sdk/runtime/__init__.py +++ b/databricks/sdk/runtime/__init__.py @@ -89,16 +89,16 @@ def inner() -> Dict[str, str]: # OSS implementation is_local_implementation = True + # The next few try-except blocks are for initialising globals in a best effort + # mannaer. We separate them to try to get as many of them working as possible try: # We expect this to fail and only do this for providing types from pyspark.sql.context import SQLContext sqlContext: SQLContext = None # type: ignore table = sqlContext.table - except Exception: - pass + except Exception as e: + logging.debug(f"Failed to initialize globals 'sqlContext' and 'table', continuing. Cause: {e}") - # The next few try-except blocks are for initialising globals in a best effort - # mannaer. We separate them to try to get as many of them working as possible try: from pyspark.sql.functions import udf # type: ignore except ImportError as e: @@ -108,55 +108,57 @@ def inner() -> Dict[str, str]: from databricks.connect import DatabricksSession # type: ignore spark = DatabricksSession.builder.getOrCreate() sql = spark.sql # type: ignore - sc = spark.sparkContext except Exception as e: # We are ignoring all failures here because user might want to initialize # spark session themselves and we don't want to interfere with that - logging.debug(f"Failed to initialize globals 'spark' and 'sc', continuing. Cause: {e}") + logging.debug(f"Failed to initialize globals 'spark' and 'sql', continuing. Cause: {e}") try: - def display(input=None, *args, **kwargs) -> None: # type: ignore - """ - Display plots or data. - Display plot: - - display() # no-op - - display(matplotlib.figure.Figure) - Display dataset: - - display(spark.DataFrame) - - display(list) # if list can be converted to DataFrame, e.g., list of named tuples - - display(pandas.DataFrame) - - display(koalas.DataFrame) - - display(pyspark.pandas.DataFrame) - Display any other value that has a _repr_html_() method - For Spark 2.0 and 2.1: - - display(DataFrame, streamName='optional', trigger=optional pyspark.sql.streaming.Trigger, - checkpointLocation='optional') - For Spark 2.2+: - - display(DataFrame, streamName='optional', trigger=optional interval like '1 second', - checkpointLocation='optional') - """ - # Import inside the function so that imports are only triggered on usage. - from IPython import display as IPDisplay - return IPDisplay.display(input, *args, **kwargs) # type: ignore - - def displayHTML(html) -> None: # type: ignore - """ - Display HTML data. - Parameters - ---------- - data : URL or HTML string - If data is a URL, display the resource at that URL, the resource is loaded dynamically by the browser. - Otherwise data should be the HTML to be displayed. - See also: - IPython.display.HTML - IPython.display.display_html - """ - # Import inside the function so that imports are only triggered on usage. - from IPython import display as IPDisplay - return IPDisplay.display_html(html, raw=True) # type: ignore + # We expect this to fail locally since dbconnect does not support sparkcontext. This is just for typing + sc = spark.sparkContext + except Exception as e: + logging.debug(f"Failed to initialize global 'sc', continuing. Cause: {e}") + + def display(input=None, *args, **kwargs) -> None: # type: ignore + """ + Display plots or data. + Display plot: + - display() # no-op + - display(matplotlib.figure.Figure) + Display dataset: + - display(spark.DataFrame) + - display(list) # if list can be converted to DataFrame, e.g., list of named tuples + - display(pandas.DataFrame) + - display(koalas.DataFrame) + - display(pyspark.pandas.DataFrame) + Display any other value that has a _repr_html_() method + For Spark 2.0 and 2.1: + - display(DataFrame, streamName='optional', trigger=optional pyspark.sql.streaming.Trigger, + checkpointLocation='optional') + For Spark 2.2+: + - display(DataFrame, streamName='optional', trigger=optional interval like '1 second', + checkpointLocation='optional') + """ + # Import inside the function so that imports are only triggered on usage. + from IPython import display as IPDisplay + return IPDisplay.display(input, *args, **kwargs) # type: ignore + + def displayHTML(html) -> None: # type: ignore + """ + Display HTML data. + Parameters + ---------- + data : URL or HTML string + If data is a URL, display the resource at that URL, the resource is loaded dynamically by the browser. + Otherwise data should be the HTML to be displayed. + See also: + IPython.display.HTML + IPython.display.display_html + """ + # Import inside the function so that imports are only triggered on usage. + from IPython import display as IPDisplay + return IPDisplay.display_html(html, raw=True) # type: ignore - except ImportError as e: - logging.debug(f"Failed to initialise globals 'display' and 'displayHTML', continuing. Cause: {e}") # We want to propagate the error in initialising dbutils because this is a core # functionality of the sdk From 70a609e3b335f670936a34e604e4426668b94081 Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Mon, 26 Feb 2024 17:18:30 +0100 Subject: [PATCH 15/26] Run integration test against multiple dbconnect install --- Makefile | 2 +- tests/integration/conftest.py | 2 +- tests/integration/test_dbconnect.py | 48 +++++++++++++++++++++++++ tests/integration/test_local_globals.py | 30 ---------------- 4 files changed, 50 insertions(+), 32 deletions(-) create mode 100644 tests/integration/test_dbconnect.py delete mode 100644 tests/integration/test_local_globals.py diff --git a/Makefile b/Makefile index 21c28d324..554d81536 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,7 @@ test: pytest -m 'not integration and not benchmark' --cov=databricks --cov-report html tests integration: - pytest -n auto -m 'integration and not benchmark' --cov=databricks --cov-report html tests + pytest -n auto -m 'integration and not benchmark' --cov=databricks --cov-report html -k 'test_dbconnect.py' tests benchmark: pytest -m 'benchmark' tests diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 5b0811048..0cae805aa 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -81,7 +81,7 @@ def ucws(env_or_skip) -> WorkspaceClient: @pytest.fixture(scope='session') def env_or_skip(): - def inner(var) -> str: + def inner(var: str) -> str: if var not in os.environ: pytest.skip(f'Environment variable {var} is missing') return os.environ[var] diff --git a/tests/integration/test_dbconnect.py b/tests/integration/test_dbconnect.py new file mode 100644 index 000000000..c90e3e100 --- /dev/null +++ b/tests/integration/test_dbconnect.py @@ -0,0 +1,48 @@ +from contextlib import contextmanager + +import pytest + +DBCONNECT_DBR_CLIENT = { + "13.3": "13.3.3", + "14.3": "14.3.1", +} + +@pytest.fixture(scope="function") +def restorable_env(): + import os + current_env = os.environ.copy() + yield + for k, v in os.environ.items(): + if k not in current_env: + del os.environ[k] + elif v != current_env[k]: + os.environ[k] = current_env[k] + +@pytest.fixture() +def setup_dbconnect_test(dbr: str, env_or_skip, restorable_env): + assert dbr in DBCONNECT_DBR_CLIENT, f"Unsupported Databricks Runtime version {dbr}. Please update DBCONNECT_DBR_CLIENT." + + import os + os.environ["DATABRICKS_CLUSTER_ID"] = env_or_skip(f"TEST_DBR_{dbr.replace('.', '_')}_DBCONNECT_CLUSTER_ID") + + import sys + import subprocess + lib = f"databricks-dbconnect=={DBCONNECT_DBR_CLIENT[dbr]}" + subprocess.check_call([sys.executable, "-m", "pip", "install", lib]) + + yield + + subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", lib]) + + +@pytest.mark.parametrize("dbr", DBCONNECT_DBR_CLIENT.keys(), indirect=True) +def test_dbconnect_initialisation(w, setup_dbconnect_test): + from databricks.connect import DatabricksSession + spark = DatabricksSession.builder.getOrCreate() + assert spark.sql("SELECT 1").collect()[0][0] == 1 + +@pytest.mark.parametrize("dbr", DBCONNECT_DBR_CLIENT.keys(), indirect=True) +def test_dbconnect_runtime_import(w, setup_dbconnect_test): + from databricks.sdk.runtime import * + assert spark.sql("SELECT 1").collect()[0][0] == 1 + diff --git a/tests/integration/test_local_globals.py b/tests/integration/test_local_globals.py deleted file mode 100644 index f859d1372..000000000 --- a/tests/integration/test_local_globals.py +++ /dev/null @@ -1,30 +0,0 @@ -from contextlib import contextmanager - - -@contextmanager -def restorable_env(): - import os - current_env = os.environ.copy() - - try: - yield - finally: - os.environ.clear() - os.environ.update(current_env) - - -def test_local_global_spark(w, env_or_skip): - cluster_id = env_or_skip("SPARK_CONNECT_CLUSTER_ID") - with restorable_env(): - import os - os.environ["DATABRICKS_CLUSTER_ID"] = cluster_id - from databricks.sdk.runtime import spark - assert spark.sql("SELECT 1").collect()[0][0] == 1 - - -def test_local_global_display(w, env_or_skip): - from databricks.sdk.runtime import display, displayHTML - - # assert no errors - display("test") - displayHTML("test") From 82e60fa24e5ab432436cbccbcd08738d3fdc4374 Mon Sep 17 00:00:00 2001 From: Kartik Gupta <88345179+kartikgupta-db@users.noreply.github.com> Date: Mon, 26 Feb 2024 17:19:03 +0100 Subject: [PATCH 16/26] Update setup.py Co-authored-by: Miles Yucht Signed-off-by: Kartik Gupta <88345179+kartikgupta-db@users.noreply.github.com> --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 0d3e8c008..cae84949c 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ extras_require={"dev": ["pytest", "pytest-cov", "pytest-xdist", "pytest-mock", "yapf", "pycodestyle", "autoflake", "isort", "wheel", "ipython", "ipywidgets", "requests-mock", "pyfakefs", - "databricks-connect", "ipython"], + "databricks-connect"], "notebook": ["ipython>=8,<9", "ipywidgets>=8,<9"]}, author="Serge Smertin", author_email="serge.smertin@databricks.com", From d3a484212b80e4cea292fd524b53638649a1081b Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Mon, 26 Feb 2024 17:22:26 +0100 Subject: [PATCH 17/26] Run integration test against multiple dbconnect install --- tests/integration/test_dbconnect.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/test_dbconnect.py b/tests/integration/test_dbconnect.py index c90e3e100..f2a25e5cf 100644 --- a/tests/integration/test_dbconnect.py +++ b/tests/integration/test_dbconnect.py @@ -43,6 +43,6 @@ def test_dbconnect_initialisation(w, setup_dbconnect_test): @pytest.mark.parametrize("dbr", DBCONNECT_DBR_CLIENT.keys(), indirect=True) def test_dbconnect_runtime_import(w, setup_dbconnect_test): - from databricks.sdk.runtime import * + from databricks.sdk.runtime import spark assert spark.sql("SELECT 1").collect()[0][0] == 1 From 69803733d65df0b71a531041521389a4f811f67e Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Mon, 26 Feb 2024 17:36:11 +0100 Subject: [PATCH 18/26] Run integration test against multiple dbconnect install --- tests/integration/test_dbconnect.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/integration/test_dbconnect.py b/tests/integration/test_dbconnect.py index f2a25e5cf..87a4fbc63 100644 --- a/tests/integration/test_dbconnect.py +++ b/tests/integration/test_dbconnect.py @@ -18,8 +18,9 @@ def restorable_env(): elif v != current_env[k]: os.environ[k] = current_env[k] -@pytest.fixture() -def setup_dbconnect_test(dbr: str, env_or_skip, restorable_env): +@pytest.fixture(params=list(DBCONNECT_DBR_CLIENT.keys())) +def setup_dbconnect_test(request, env_or_skip, restorable_env): + dbr = request.param assert dbr in DBCONNECT_DBR_CLIENT, f"Unsupported Databricks Runtime version {dbr}. Please update DBCONNECT_DBR_CLIENT." import os @@ -35,13 +36,11 @@ def setup_dbconnect_test(dbr: str, env_or_skip, restorable_env): subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", lib]) -@pytest.mark.parametrize("dbr", DBCONNECT_DBR_CLIENT.keys(), indirect=True) def test_dbconnect_initialisation(w, setup_dbconnect_test): from databricks.connect import DatabricksSession spark = DatabricksSession.builder.getOrCreate() assert spark.sql("SELECT 1").collect()[0][0] == 1 -@pytest.mark.parametrize("dbr", DBCONNECT_DBR_CLIENT.keys(), indirect=True) def test_dbconnect_runtime_import(w, setup_dbconnect_test): from databricks.sdk.runtime import spark assert spark.sql("SELECT 1").collect()[0][0] == 1 From b3b2154a1843a8301180a8d57f5543d5d35ab497 Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Mon, 26 Feb 2024 18:06:27 +0100 Subject: [PATCH 19/26] Run integration test against multiple dbconnect install --- tests/integration/test_dbconnect.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/integration/test_dbconnect.py b/tests/integration/test_dbconnect.py index 87a4fbc63..68563c8d9 100644 --- a/tests/integration/test_dbconnect.py +++ b/tests/integration/test_dbconnect.py @@ -28,12 +28,12 @@ def setup_dbconnect_test(request, env_or_skip, restorable_env): import sys import subprocess - lib = f"databricks-dbconnect=={DBCONNECT_DBR_CLIENT[dbr]}" - subprocess.check_call([sys.executable, "-m", "pip", "install", lib]) + lib = f"databricks-connect=={DBCONNECT_DBR_CLIENT[dbr]}" + subprocess.check_call([sys.executable, "-m", "pip", "install", lib], stdout=sys.stdout, stderr=sys.stderr) yield - subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", lib]) + subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", lib], stdout=sys.stdout, stderr=sys.stderr) def test_dbconnect_initialisation(w, setup_dbconnect_test): From 09370ebd057849530e13cc548e5e2309ee223896 Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Mon, 26 Feb 2024 18:48:16 +0100 Subject: [PATCH 20/26] Run integration test against multiple dbconnect install --- Makefile | 2 +- tests/integration/test_dbconnect.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index 554d81536..00b491d14 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,7 @@ test: pytest -m 'not integration and not benchmark' --cov=databricks --cov-report html tests integration: - pytest -n auto -m 'integration and not benchmark' --cov=databricks --cov-report html -k 'test_dbconnect.py' tests + pytest -n auto -m 'integration and not benchmark' --dist loadgroup --cov=databricks --cov-report html -k 'test_dbconnect.py' tests benchmark: pytest -m 'benchmark' tests diff --git a/tests/integration/test_dbconnect.py b/tests/integration/test_dbconnect.py index 68563c8d9..46f99856e 100644 --- a/tests/integration/test_dbconnect.py +++ b/tests/integration/test_dbconnect.py @@ -1,5 +1,4 @@ from contextlib import contextmanager - import pytest DBCONNECT_DBR_CLIENT = { @@ -29,18 +28,20 @@ def setup_dbconnect_test(request, env_or_skip, restorable_env): import sys import subprocess lib = f"databricks-connect=={DBCONNECT_DBR_CLIENT[dbr]}" - subprocess.check_call([sys.executable, "-m", "pip", "install", lib], stdout=sys.stdout, stderr=sys.stderr) + subprocess.check_call([sys.executable, "-m", "pip", "install", lib]) yield - subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", lib], stdout=sys.stdout, stderr=sys.stderr) + subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", "databricks-connect"]) +@pytest.mark.xdist_group(name="databricks-connect") def test_dbconnect_initialisation(w, setup_dbconnect_test): from databricks.connect import DatabricksSession spark = DatabricksSession.builder.getOrCreate() assert spark.sql("SELECT 1").collect()[0][0] == 1 +@pytest.mark.xdist_group(name="databricks-connect") def test_dbconnect_runtime_import(w, setup_dbconnect_test): from databricks.sdk.runtime import spark assert spark.sql("SELECT 1").collect()[0][0] == 1 From 0a585a9476637fcdd8e3ed791ccb1f6cbd2d115d Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Mon, 26 Feb 2024 18:48:54 +0100 Subject: [PATCH 21/26] fmt --- databricks/sdk/runtime/__init__.py | 4 ++-- tests/integration/test_dbconnect.py | 21 ++++++++++----------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/databricks/sdk/runtime/__init__.py b/databricks/sdk/runtime/__init__.py index e5d474357..19616e041 100644 --- a/databricks/sdk/runtime/__init__.py +++ b/databricks/sdk/runtime/__init__.py @@ -159,10 +159,9 @@ def displayHTML(html) -> None: # type: ignore from IPython import display as IPDisplay return IPDisplay.display_html(html, raw=True) # type: ignore - # We want to propagate the error in initialising dbutils because this is a core # functionality of the sdk - from typing import cast, Optional + from typing import Optional, cast from databricks.sdk.dbutils import RemoteDbUtils @@ -178,4 +177,5 @@ def displayHTML(html) -> None: # type: ignore def getArgument(name: str, defaultValue: Optional[str] = None): return dbutils.widgets.getArgument(name, defaultValue) + __all__ = dbruntime_objects diff --git a/tests/integration/test_dbconnect.py b/tests/integration/test_dbconnect.py index 46f99856e..613c8eaeb 100644 --- a/tests/integration/test_dbconnect.py +++ b/tests/integration/test_dbconnect.py @@ -1,10 +1,7 @@ -from contextlib import contextmanager import pytest -DBCONNECT_DBR_CLIENT = { - "13.3": "13.3.3", - "14.3": "14.3.1", -} +DBCONNECT_DBR_CLIENT = {"13.3": "13.3.3", "14.3": "14.3.1", } + @pytest.fixture(scope="function") def restorable_env(): @@ -15,7 +12,8 @@ def restorable_env(): if k not in current_env: del os.environ[k] elif v != current_env[k]: - os.environ[k] = current_env[k] + os.environ[k] = current_env[k] + @pytest.fixture(params=list(DBCONNECT_DBR_CLIENT.keys())) def setup_dbconnect_test(request, env_or_skip, restorable_env): @@ -23,17 +21,18 @@ def setup_dbconnect_test(request, env_or_skip, restorable_env): assert dbr in DBCONNECT_DBR_CLIENT, f"Unsupported Databricks Runtime version {dbr}. Please update DBCONNECT_DBR_CLIENT." import os - os.environ["DATABRICKS_CLUSTER_ID"] = env_or_skip(f"TEST_DBR_{dbr.replace('.', '_')}_DBCONNECT_CLUSTER_ID") - - import sys + os.environ["DATABRICKS_CLUSTER_ID"] = env_or_skip( + f"TEST_DBR_{dbr.replace('.', '_')}_DBCONNECT_CLUSTER_ID") + import subprocess + import sys lib = f"databricks-connect=={DBCONNECT_DBR_CLIENT[dbr]}" subprocess.check_call([sys.executable, "-m", "pip", "install", lib]) yield subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", "databricks-connect"]) - + @pytest.mark.xdist_group(name="databricks-connect") def test_dbconnect_initialisation(w, setup_dbconnect_test): @@ -41,8 +40,8 @@ def test_dbconnect_initialisation(w, setup_dbconnect_test): spark = DatabricksSession.builder.getOrCreate() assert spark.sql("SELECT 1").collect()[0][0] == 1 + @pytest.mark.xdist_group(name="databricks-connect") def test_dbconnect_runtime_import(w, setup_dbconnect_test): from databricks.sdk.runtime import spark assert spark.sql("SELECT 1").collect()[0][0] == 1 - From 18a022a8dfa3ffe71970a8b2a5441437f187352f Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Mon, 26 Feb 2024 18:59:11 +0100 Subject: [PATCH 22/26] fmt --- Makefile | 2 +- databricks/sdk/runtime/__init__.py | 4 +--- tests/integration/test_dbconnect.py | 5 +++++ 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index 00b491d14..f65f98731 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,7 @@ test: pytest -m 'not integration and not benchmark' --cov=databricks --cov-report html tests integration: - pytest -n auto -m 'integration and not benchmark' --dist loadgroup --cov=databricks --cov-report html -k 'test_dbconnect.py' tests + pytest -n auto -m 'integration and not benchmark' --dist loadgroup --cov=databricks --cov-report html tests benchmark: pytest -m 'benchmark' tests diff --git a/databricks/sdk/runtime/__init__.py b/databricks/sdk/runtime/__init__.py index 19616e041..b289e3532 100644 --- a/databricks/sdk/runtime/__init__.py +++ b/databricks/sdk/runtime/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import Dict, Union +from typing import Dict, Union, Optional, cast logger = logging.getLogger('databricks.sdk') is_local_implementation = True @@ -161,8 +161,6 @@ def displayHTML(html) -> None: # type: ignore # We want to propagate the error in initialising dbutils because this is a core # functionality of the sdk - from typing import Optional, cast - from databricks.sdk.dbutils import RemoteDbUtils from . import dbutils_stub diff --git a/tests/integration/test_dbconnect.py b/tests/integration/test_dbconnect.py index 613c8eaeb..6ccd6d793 100644 --- a/tests/integration/test_dbconnect.py +++ b/tests/integration/test_dbconnect.py @@ -45,3 +45,8 @@ def test_dbconnect_initialisation(w, setup_dbconnect_test): def test_dbconnect_runtime_import(w, setup_dbconnect_test): from databricks.sdk.runtime import spark assert spark.sql("SELECT 1").collect()[0][0] == 1 + +@pytest.mark.xdist_group(name="databricks-connect") +def test_dbconnect_runtime_import_no_error_if_doesnt_exist(w): + from databricks.sdk.runtime import spark + assert spark is None From 5fea7a92f28431eaa6f19db5d57a82c78ea21ce7 Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Mon, 26 Feb 2024 18:59:50 +0100 Subject: [PATCH 23/26] fmt --- databricks/sdk/runtime/__init__.py | 2 +- tests/integration/test_dbconnect.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/databricks/sdk/runtime/__init__.py b/databricks/sdk/runtime/__init__.py index b289e3532..3b12a5f52 100644 --- a/databricks/sdk/runtime/__init__.py +++ b/databricks/sdk/runtime/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import Dict, Union, Optional, cast +from typing import Dict, Optional, Union, cast logger = logging.getLogger('databricks.sdk') is_local_implementation = True diff --git a/tests/integration/test_dbconnect.py b/tests/integration/test_dbconnect.py index 6ccd6d793..f3f9d1c4a 100644 --- a/tests/integration/test_dbconnect.py +++ b/tests/integration/test_dbconnect.py @@ -46,6 +46,7 @@ def test_dbconnect_runtime_import(w, setup_dbconnect_test): from databricks.sdk.runtime import spark assert spark.sql("SELECT 1").collect()[0][0] == 1 + @pytest.mark.xdist_group(name="databricks-connect") def test_dbconnect_runtime_import_no_error_if_doesnt_exist(w): from databricks.sdk.runtime import spark From ad501a50af39bae34af9305d5c2b4cee5cc68c39 Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Mon, 26 Feb 2024 19:03:46 +0100 Subject: [PATCH 24/26] fix no error test --- databricks/sdk/runtime/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/databricks/sdk/runtime/__init__.py b/databricks/sdk/runtime/__init__.py index 3b12a5f52..9230c7a83 100644 --- a/databricks/sdk/runtime/__init__.py +++ b/databricks/sdk/runtime/__init__.py @@ -89,6 +89,9 @@ def inner() -> Dict[str, str]: # OSS implementation is_local_implementation = True + for var in dbruntime_objects: + globals()[var] = None + # The next few try-except blocks are for initialising globals in a best effort # mannaer. We separate them to try to get as many of them working as possible try: From 87989f15988f81090fc063e462171db5972886ec Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Mon, 26 Feb 2024 19:47:48 +0100 Subject: [PATCH 25/26] fix no error test --- tests/integration/test_dbconnect.py | 36 ++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/tests/integration/test_dbconnect.py b/tests/integration/test_dbconnect.py index f3f9d1c4a..ddfdec607 100644 --- a/tests/integration/test_dbconnect.py +++ b/tests/integration/test_dbconnect.py @@ -3,6 +3,29 @@ DBCONNECT_DBR_CLIENT = {"13.3": "13.3.3", "14.3": "14.3.1", } +@pytest.fixture(scope="function") +def reload_modules(): + """ + Returns a function that can be used to reload modules. This is useful when testing + Databricks Connect, since both the `databricks.connect` and `databricks.sdk.runtime` modules + are stateful, and we need to reload these modules to reset the state cache between test runs. + """ + import importlib + import sys + + def reload(name: str): + v = sys.modules.get(name) + if v is None: + return + try: + print(f"Reloading {name}") + importlib.reload(v) + except Exception as e: + print(f"Failed to reload {name}: {e}") + + return reload + + @pytest.fixture(scope="function") def restorable_env(): import os @@ -35,19 +58,26 @@ def setup_dbconnect_test(request, env_or_skip, restorable_env): @pytest.mark.xdist_group(name="databricks-connect") -def test_dbconnect_initialisation(w, setup_dbconnect_test): +def test_dbconnect_initialisation(w, setup_dbconnect_test, reload_modules): + reload_modules("databricks.connect") from databricks.connect import DatabricksSession + reload_modules("databricks.connect") + spark = DatabricksSession.builder.getOrCreate() assert spark.sql("SELECT 1").collect()[0][0] == 1 @pytest.mark.xdist_group(name="databricks-connect") -def test_dbconnect_runtime_import(w, setup_dbconnect_test): +def test_dbconnect_runtime_import(w, setup_dbconnect_test, reload_modules): + reload_modules("databricks.sdk.runtime") from databricks.sdk.runtime import spark + assert spark.sql("SELECT 1").collect()[0][0] == 1 @pytest.mark.xdist_group(name="databricks-connect") -def test_dbconnect_runtime_import_no_error_if_doesnt_exist(w): +def test_dbconnect_runtime_import_no_error_if_doesnt_exist(w, reload_modules): + reload_modules("databricks.sdk.runtime") from databricks.sdk.runtime import spark + assert spark is None From 70dba8c144e76ec2a8483ef604fb4bd307015daf Mon Sep 17 00:00:00 2001 From: kartikgupta-db Date: Wed, 28 Feb 2024 13:17:46 +0100 Subject: [PATCH 26/26] Address feedback --- .vscode/settings.json | 6 ++++- tests/integration/test_dbconnect.py | 36 +++++++++++++---------------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index a3a183836..c36b4db6c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -3,5 +3,9 @@ "tests" ], "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true + "python.testing.pytestEnabled": true, + "python.envFile": "${workspaceFolder}/.databricks/.databricks.env", + "databricks.python.envFile": "${workspaceFolder}/.env", + "jupyter.interactiveWindow.cellMarker.codeRegex": "^# COMMAND ----------|^# Databricks notebook source|^(#\\s*%%|#\\s*\\|#\\s*In\\[\\d*?\\]|#\\s*In\\[ \\])", + "jupyter.interactiveWindow.cellMarker.default": "# COMMAND ----------" } diff --git a/tests/integration/test_dbconnect.py b/tests/integration/test_dbconnect.py index ddfdec607..e08f451c6 100644 --- a/tests/integration/test_dbconnect.py +++ b/tests/integration/test_dbconnect.py @@ -3,27 +3,24 @@ DBCONNECT_DBR_CLIENT = {"13.3": "13.3.3", "14.3": "14.3.1", } -@pytest.fixture(scope="function") -def reload_modules(): +def reload_modules(name: str): """ - Returns a function that can be used to reload modules. This is useful when testing - Databricks Connect, since both the `databricks.connect` and `databricks.sdk.runtime` modules - are stateful, and we need to reload these modules to reset the state cache between test runs. + Reloads the specified module. This is useful when testing Databricks Connect, since both + the `databricks.connect` and `databricks.sdk.runtime` modules are stateful, and we need + to reload these modules to reset the state cache between test runs. """ + import importlib import sys - def reload(name: str): - v = sys.modules.get(name) - if v is None: - return - try: - print(f"Reloading {name}") - importlib.reload(v) - except Exception as e: - print(f"Failed to reload {name}: {e}") - - return reload + v = sys.modules.get(name) + if v is None: + return + try: + print(f"Reloading {name}") + importlib.reload(v) + except Exception as e: + print(f"Failed to reload {name}: {e}") @pytest.fixture(scope="function") @@ -58,17 +55,16 @@ def setup_dbconnect_test(request, env_or_skip, restorable_env): @pytest.mark.xdist_group(name="databricks-connect") -def test_dbconnect_initialisation(w, setup_dbconnect_test, reload_modules): +def test_dbconnect_initialisation(w, setup_dbconnect_test): reload_modules("databricks.connect") from databricks.connect import DatabricksSession - reload_modules("databricks.connect") spark = DatabricksSession.builder.getOrCreate() assert spark.sql("SELECT 1").collect()[0][0] == 1 @pytest.mark.xdist_group(name="databricks-connect") -def test_dbconnect_runtime_import(w, setup_dbconnect_test, reload_modules): +def test_dbconnect_runtime_import(w, setup_dbconnect_test): reload_modules("databricks.sdk.runtime") from databricks.sdk.runtime import spark @@ -76,7 +72,7 @@ def test_dbconnect_runtime_import(w, setup_dbconnect_test, reload_modules): @pytest.mark.xdist_group(name="databricks-connect") -def test_dbconnect_runtime_import_no_error_if_doesnt_exist(w, reload_modules): +def test_dbconnect_runtime_import_no_error_if_doesnt_exist(w): reload_modules("databricks.sdk.runtime") from databricks.sdk.runtime import spark