Skip to content

Commit

Permalink
Update all e2e tests to use SQLAlchemy 2's new declarative syntax
Browse files Browse the repository at this point in the history
Signed-off-by: Jesse Whitehouse <[email protected]>
  • Loading branch information
Jesse Whitehouse committed Sep 29, 2023
1 parent c7f9596 commit 5a2cb9e
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 36 deletions.
4 changes: 2 additions & 2 deletions src/databricks/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,8 @@ def has_table(
DBR_GT_12_NOT_FOUND_STRING = "TABLE_OR_VIEW_NOT_FOUND"

try:
res = connection.execute(
f"DESCRIBE TABLE {_catalog}.{_schema}.{table_name}"
res = connection.execute(sqlalchemy.text(
f"DESCRIBE TABLE {_catalog}.{_schema}.{table_name}")
)
return True
except DatabaseError as e:
Expand Down
92 changes: 58 additions & 34 deletions src/databricks/sqlalchemy/test_local/e2e/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import os, datetime, decimal
import pytest
from unittest import skipIf
from sqlalchemy import create_engine, select, insert, Column, MetaData, Table
from sqlalchemy.orm import Session
from sqlalchemy import create_engine, select, insert, Column, MetaData, Table, Text, text
from sqlalchemy.orm import Session, DeclarativeBase, Mapped, mapped_column
from sqlalchemy.types import SMALLINT, Integer, BOOLEAN, String, DECIMAL, Date
from sqlalchemy.engine import Engine

from typing import Tuple
from typing import Tuple, Union

try:
from sqlalchemy.orm import declarative_base
Expand Down Expand Up @@ -70,6 +70,16 @@ def db_engine() -> Engine:
return create_engine(conn_string, connect_args=connect_args)


def run_query(db_engine: Engine, query: Union[str, Text]):

if not isinstance(query, Text):
_query = text(query)
else:
_query = query
with db_engine.begin() as conn:
return conn.execute(_query).fetchall()


@pytest.fixture
def samples_engine() -> Engine:

Expand All @@ -81,22 +91,22 @@ def samples_engine() -> Engine:

@pytest.fixture()
def base(db_engine):
return declarative_base(bind=db_engine)
return declarative_base()


@pytest.fixture()
def session(db_engine):
return Session(bind=db_engine)
return Session(db_engine)


@pytest.fixture()
def metadata_obj(db_engine):
return MetaData(bind=db_engine)
return MetaData()


def test_can_connect(db_engine):
simple_query = "SELECT 1"
result = db_engine.execute(simple_query).fetchall()
result = run_query(db_engine, simple_query)
assert len(result) == 1


Expand Down Expand Up @@ -152,7 +162,7 @@ def test_create_table_not_null(db_engine, metadata_obj: MetaData):
Column("some_bool", BOOLEAN, nullable=False),
)

metadata_obj.create_all()
metadata_obj.create_all(db_engine)

columns = db_engine.dialect.get_columns(
connection=db_engine.connect(), table_name=table_name
Expand All @@ -164,7 +174,7 @@ def test_create_table_not_null(db_engine, metadata_obj: MetaData):
assert name_column_description.get("nullable") is True
assert some_bool_column_description.get("nullable") is False

metadata_obj.drop_all()
metadata_obj.drop_all(db_engine)


def test_bulk_insert_with_core(db_engine, metadata_obj, session):
Expand All @@ -187,10 +197,12 @@ def test_bulk_insert_with_core(db_engine, metadata_obj, session):
for i in range(num_to_insert)
]

metadata_obj.create_all()
db_engine.execute(insert(SampleTable).values(rows))
metadata_obj.create_all(db_engine)
with db_engine.begin() as conn:
conn.execute(insert(SampleTable).values(rows))

rows = db_engine.execute(version_agnostic_select(SampleTable)).fetchall()
with db_engine.begin() as conn:
rows = conn.execute(version_agnostic_select(SampleTable)).fetchall()

assert len(rows) == num_to_insert

Expand All @@ -207,7 +219,7 @@ def test_create_insert_drop_table_core(base, db_engine, metadata_obj: MetaData):
Column("dollars", DECIMAL(10, 2)),
)

metadata_obj.create_all()
metadata_obj.create_all(db_engine)

insert_stmt = insert(SampleTable).values(
name="Bim Adewunmi", episodes=6, some_bool=True, dollars=decimal.Decimal(125)
Expand All @@ -217,40 +229,47 @@ def test_create_insert_drop_table_core(base, db_engine, metadata_obj: MetaData):
conn.execute(insert_stmt)

select_stmt = version_agnostic_select(SampleTable)
resp = db_engine.execute(select_stmt)
with db_engine.begin() as conn:
resp = conn.execute(select_stmt)

result = resp.fetchall()

assert len(result) == 1

metadata_obj.drop_all()
metadata_obj.drop_all(db_engine)


# ORM tests are made following this tutorial
# https://docs.sqlalchemy.org/en/14/orm/quickstart.html


@skipIf(False, "Unity catalog must be supported")
def test_create_insert_drop_table_orm(base, session: Session):
def test_create_insert_drop_table_orm(db_engine):
"""ORM classes built on the declarative base class must have a primary key.
This is restricted to Unity Catalog.
"""

class SampleObject(base):
class Base(DeclarativeBase):
pass

class SampleObject(Base):

__tablename__ = "PySQLTest_{}".format(datetime.datetime.utcnow().strftime("%s"))

name = Column(String(255), primary_key=True)
episodes = Column(Integer)
some_bool = Column(BOOLEAN)
name: Mapped[str] = mapped_column(String(255), primary_key=True)
episodes: Mapped[int] = mapped_column(Integer)
some_bool: Mapped[bool] = mapped_column(BOOLEAN)

base.metadata.create_all()
Base.metadata.create_all(db_engine)

sample_object_1 = SampleObject(name="Bim Adewunmi", episodes=6, some_bool=True)
sample_object_2 = SampleObject(name="Miki Meek", episodes=12, some_bool=False)


session = Session(db_engine)
session.add(sample_object_1)
session.add(sample_object_2)
session.commit()
session.flush()

stmt = version_agnostic_select(SampleObject).where(
SampleObject.name.in_(["Bim Adewunmi", "Miki Meek"])
Expand All @@ -263,12 +282,15 @@ class SampleObject(base):

assert len(output) == 2

base.metadata.drop_all()
Base.metadata.drop_all(db_engine)


def test_dialect_type_mappings(base, db_engine, metadata_obj: MetaData):
def test_dialect_type_mappings(db_engine, metadata_obj: MetaData):
"""Confirms that we get back the same time we declared in a model and inserted using Core"""

class Base(DeclarativeBase):
pass

SampleTable = Table(
"PySQLTest_{}".format(datetime.datetime.utcnow().strftime("%s")),
metadata_obj,
Expand All @@ -285,7 +307,7 @@ def test_dialect_type_mappings(base, db_engine, metadata_obj: MetaData):
decimal_example = decimal.Decimal(125)
date_example = datetime.date(2013, 1, 1)

metadata_obj.create_all()
metadata_obj.create_all(db_engine)

insert_stmt = insert(SampleTable).values(
string_example=string_example,
Expand All @@ -299,18 +321,19 @@ def test_dialect_type_mappings(base, db_engine, metadata_obj: MetaData):
conn.execute(insert_stmt)

select_stmt = version_agnostic_select(SampleTable)
resp = db_engine.execute(select_stmt)
with db_engine.begin() as conn:
resp = conn.execute(select_stmt)

result = resp.fetchall()
this_row = result[0]

assert this_row["string_example"] == string_example
assert this_row["integer_example"] == integer_example
assert this_row["boolean_example"] == boolean_example
assert this_row["decimal_example"] == decimal_example
assert this_row["date_example"] == date_example
assert this_row.string_example == string_example
assert this_row.integer_example == integer_example
assert this_row.boolean_example == boolean_example
assert this_row.decimal_example == decimal_example
assert this_row.date_example == date_example

metadata_obj.drop_all()
metadata_obj.drop_all(db_engine)


def test_inspector_smoke_test(samples_engine: Engine):
Expand All @@ -337,6 +360,7 @@ def test_inspector_smoke_test(samples_engine: Engine):
assert len(views) == 0, "Views could not be fetched"


@pytest.mark.skip(reason="engine.table_names has been removed in sqlalchemy verison 2")
def test_get_table_names_smoke_test(samples_engine: Engine):

with samples_engine.connect() as conn:
Expand Down Expand Up @@ -367,7 +391,7 @@ def test_has_table_across_schemas(db_engine: Engine, samples_engine: Engine):

# Create a table in a different catalog
with db_engine.connect() as conn:
conn.execute("CREATE TABLE test_has_table (numbers_are_cool INT);")
conn.execute(text("CREATE TABLE test_has_table (numbers_are_cool INT);"))

try:
# Verify that this table is not found in the samples catalog
Expand All @@ -382,4 +406,4 @@ def test_has_table_across_schemas(db_engine: Engine, samples_engine: Engine):
catalog=other_catalog,
)
finally:
conn.execute("DROP TABLE test_has_table;")
conn.execute(text("DROP TABLE test_has_table;"))

0 comments on commit 5a2cb9e

Please sign in to comment.