diff --git a/src/databricks/sqlalchemy/__init__.py b/src/databricks/sqlalchemy/__init__.py index f5063204..08983da3 100644 --- a/src/databricks/sqlalchemy/__init__.py +++ b/src/databricks/sqlalchemy/__init__.py @@ -282,8 +282,8 @@ def has_table( DBR_GT_12_NOT_FOUND_STRING = "TABLE_OR_VIEW_NOT_FOUND" try: - res = connection.execute(sqlalchemy.text( - f"DESCRIBE TABLE {_catalog}.{_schema}.{table_name}") + res = connection.execute( + sqlalchemy.text(f"DESCRIBE TABLE {_catalog}.{_schema}.{table_name}") ) return True except DatabaseError as e: diff --git a/src/databricks/sqlalchemy/test_local/__init__.py b/src/databricks/sqlalchemy/test_local/__init__.py index f3dd3fbb..d6cebbfe 100644 --- a/src/databricks/sqlalchemy/test_local/__init__.py +++ b/src/databricks/sqlalchemy/test_local/__init__.py @@ -2,4 +2,4 @@ This module contains tests entirely maintained by Databricks. These tests do not rely on SQLAlchemy's custom test runner. -""" \ No newline at end of file +""" diff --git a/src/databricks/sqlalchemy/test_local/e2e/test_basic.py b/src/databricks/sqlalchemy/test_local/e2e/test_basic.py index 55780c7a..c8e173b1 100644 --- a/src/databricks/sqlalchemy/test_local/e2e/test_basic.py +++ b/src/databricks/sqlalchemy/test_local/e2e/test_basic.py @@ -1,7 +1,16 @@ import os, datetime, decimal import pytest from unittest import skipIf -from sqlalchemy import create_engine, select, insert, Column, MetaData, Table, Text, text +from sqlalchemy import ( + create_engine, + select, + insert, + Column, + MetaData, + Table, + Text, + text, +) from sqlalchemy.orm import Session, DeclarativeBase, Mapped, mapped_column from sqlalchemy.types import SMALLINT, Integer, BOOLEAN, String, DECIMAL, Date from sqlalchemy.engine import Engine @@ -71,7 +80,7 @@ def db_engine() -> Engine: def run_query(db_engine: Engine, query: Union[str, Text]): - + if not isinstance(query, Text): _query = text(query) else: @@ -125,14 +134,18 @@ def test_connect_args(db_engine): @pytest.mark.skipif(sqlalchemy_1_3(), reason="Pandas requires SQLAlchemy >= 1.4") -@pytest.mark.skip(reason="DBR is currently limited to 256 parameters per call to .execute(). Test cannot pass.") +@pytest.mark.skip( + reason="DBR is currently limited to 256 parameters per call to .execute(). Test cannot pass." +) def test_pandas_upload(db_engine, metadata_obj): import pandas as pd SCHEMA = os.environ.get("schema") try: - df = pd.read_excel("src/databricks/sqlalchemy/test_local/e2e/demo_data/MOCK_DATA.xlsx") + df = pd.read_excel( + "src/databricks/sqlalchemy/test_local/e2e/demo_data/MOCK_DATA.xlsx" + ) df.to_sql( "mock_data", db_engine, @@ -264,7 +277,6 @@ class SampleObject(Base): sample_object_1 = SampleObject(name="Bim Adewunmi", episodes=6, some_bool=True) sample_object_2 = SampleObject(name="Miki Meek", episodes=12, some_bool=False) - session = Session(db_engine) session.add(sample_object_1) @@ -370,9 +382,9 @@ def test_get_table_names_smoke_test(samples_engine: Engine): def test_has_table_across_schemas(db_engine: Engine, samples_engine: Engine): """For this test to pass these conditions must be met: - - Table samples.nyctaxi.trips must exist - - Table samples.tpch.customer must exist - - The `catalog` and `schema` environment variables must be set and valid + - Table samples.nyctaxi.trips must exist + - Table samples.tpch.customer must exist + - The `catalog` and `schema` environment variables must be set and valid """ with samples_engine.connect() as conn: diff --git a/src/databricks/sqlalchemy/test_local/test_types.py b/src/databricks/sqlalchemy/test_local/test_types.py index 47384652..2e4ea704 100644 --- a/src/databricks/sqlalchemy/test_local/test_types.py +++ b/src/databricks/sqlalchemy/test_local/test_types.py @@ -1,6 +1,43 @@ -import unittest import pytest -class DatabricksTypeTests(unittest.TestCase): +from databricks.sqlalchemy import DatabricksDialect +from sqlalchemy.types import ( + BigInteger, + Boolean, + Date, + DateTime, + Double, + Enum, + Float, + Integer, + Interval, + LargeBinary, + MatchType, + Numeric, + PickleType, + SchemaType, + SmallInteger, + String, + Text, + Time, + Unicode, + UnicodeText, + Uuid, +) + + +class TesteDatabricksTypeTests: def test_basic_example(self): assert False + + +class TestCamelCaseTypes: + """Per the sqlalchemy documentation here: https://docs.sqlalchemy.org/en/20/core/type_basics.html#generic-camelcase-types + + These are the default types that are expected to work across all dialects. These tests verify that they render as expected. + """ + + def test_bigint(self): + + target = BigInteger().compile(dialect=DatabricksDialect) + assert False