Skip to content

Commit

Permalink
Merge branch 'master' into ignore-nulls
Browse files Browse the repository at this point in the history
  • Loading branch information
laserkaplan authored Dec 26, 2023
2 parents f6ec9ef + 9df2cf2 commit 4e87bc3
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 21 deletions.
10 changes: 4 additions & 6 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,10 +1343,9 @@ def test_cancel_query(trino_connection):
cur.fetchone()
cur.cancel() # would raise an exception if cancel fails

# verify that it doesn't fail in the absence of a previously running query
cur = trino_connection.cursor()
with pytest.raises(Exception) as cancel_error:
cur.cancel()
assert "Cancel query failed; no running query" in str(cancel_error.value)
cur.cancel()


def test_close_cursor(trino_connection):
Expand All @@ -1355,10 +1354,9 @@ def test_close_cursor(trino_connection):
cur.fetchone()
cur.close() # would raise an exception if cancel fails

# verify that it doesn't fail in the absence of a previously running query
cur = trino_connection.cursor()
with pytest.raises(Exception) as cancel_error:
cur.close()
assert "Cancel query failed; no running query" in str(cancel_error.value)
cur.close()


def test_session_properties(run_trino):
Expand Down
66 changes: 66 additions & 0 deletions tests/integration/test_sqlalchemy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ def test_json_column(trino_connection, json_object):
ins = table_with_json.insert()
conn.execute(ins, {"id": 1, "json_column": json_object})
query = sqla.select(table_with_json)
assert isinstance(table_with_json.c.json_column.type, JSON)
result = conn.execute(query)
rows = result.fetchall()
assert len(rows) == 1
Expand All @@ -410,6 +411,71 @@ def test_json_column(trino_connection, json_object):
metadata.drop_all(engine)


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True)
def test_json_column_operations(trino_connection):
engine, conn = trino_connection

metadata = sqla.MetaData()

json_object = {
"a": {"c": 1},
100: {"z": 200},
"b": 2,
10: 20,
"foo-bar": {"z": 200}
}

try:
table_with_json = sqla.Table(
'table_with_json',
metadata,
sqla.Column('json_column', JSON),
schema="default"
)
metadata.create_all(engine)
ins = table_with_json.insert()
conn.execute(ins, {"json_column": json_object})

# JSONPathType
query = sqla.select(table_with_json.c.json_column["a", "c"])
conn.execute(query)
result = conn.execute(query)
assert result.fetchall()[0][0] == 1

query = sqla.select(table_with_json.c.json_column[100, "z"])
conn.execute(query)
result = conn.execute(query)
assert result.fetchall()[0][0] == 200

query = sqla.select(table_with_json.c.json_column["foo-bar", "z"])
conn.execute(query)
result = conn.execute(query)
assert result.fetchall()[0][0] == 200

# JSONIndexType
query = sqla.select(table_with_json.c.json_column["b"])
conn.execute(query)
result = conn.execute(query)
assert result.fetchall()[0][0] == 2

query = sqla.select(table_with_json.c.json_column[10])
conn.execute(query)
result = conn.execute(query)
assert result.fetchall()[0][0] == 20

query = sqla.select(table_with_json.c.json_column["foo-bar"])
conn.execute(query)
result = conn.execute(query)
assert result.fetchall()[0][0] == {'z': 200}

finally:
metadata.drop_all(engine)


@pytest.mark.parametrize('trino_connection', ['system'], indirect=True)
def test_get_catalog_names(trino_connection):
engine, conn = trino_connection
Expand Down
4 changes: 1 addition & 3 deletions trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,9 +694,7 @@ def fetchall(self) -> List[List[Any]]:

def cancel(self):
if self._query is None:
raise trino.exceptions.OperationalError(
"Cancel query failed; no running query"
)
return
self._query.cancel()

def close(self):
Expand Down
15 changes: 14 additions & 1 deletion trino/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import compiler
from sqlalchemy.sql import compiler, sqltypes
from sqlalchemy.sql.base import DialectKWArgs
from sqlalchemy.sql.functions import GenericFunction

Expand Down Expand Up @@ -162,6 +162,19 @@ def compile_ignore_nulls(element, compiler, **kwargs):
compiled += ' IGNORE NULLS'
return compiled

def visit_json_getitem_op_binary(self, binary, operator, **kw):
return self._render_json_extract_from_binary(binary, operator, **kw)

def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
return self._render_json_extract_from_binary(binary, operator, **kw)

def _render_json_extract_from_binary(self, binary, operator, **kw):
if binary.type._type_affinity is sqltypes.JSON:
return "JSON_EXTRACT(%s, %s)" % (
self.process(binary.left, **kw),
self.process(binary.right, **kw),
)


class TrinoDDLCompiler(compiler.DDLCompiler):
pass
Expand Down
62 changes: 52 additions & 10 deletions trino/sqlalchemy/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union

import sqlalchemy
from sqlalchemy import util
from sqlalchemy import func, util
from sqlalchemy.sql import sqltypes
from sqlalchemy.sql.type_api import TypeDecorator, TypeEngine
from sqlalchemy.types import String
from sqlalchemy.types import JSON

SQLType = Union[TypeEngine, Type[TypeEngine]]

Expand Down Expand Up @@ -75,16 +74,59 @@ def __init__(self, precision=None, timezone=False):


class JSON(TypeDecorator):
impl = String
impl = JSON

def process_bind_param(self, value, dialect):
return json.dumps(value)
def bind_expression(self, bindvalue):
return func.JSON_PARSE(bindvalue)

def process_result_value(self, value, dialect):
return json.loads(value)

def get_col_spec(self, **kw):
return 'JSON'
class _FormatTypeMixin:
def _format_value(self, value):
raise NotImplementedError()

def bind_processor(self, dialect):
super_proc = self.string_bind_processor(dialect)

def process(value):
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return value

return process

def literal_processor(self, dialect):
super_proc = self.string_literal_processor(dialect)

def process(value):
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return value

return process


class _JSONFormatter:
@staticmethod
def format_index(value):
return "$[\"%s\"]" % value

@staticmethod
def format_path(value):
return "$%s" % (
"".join(["[\"%s\"]" % elem for elem in value])
)


class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
def _format_value(self, value):
return _JSONFormatter.format_index(value)


class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
def _format_value(self, value):
return _JSONFormatter.format_path(value)


# https://trino.io/docs/current/language/types.html
Expand Down
24 changes: 23 additions & 1 deletion trino/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sqlalchemy.engine.base import Connection
from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext
from sqlalchemy.engine.url import URL
from sqlalchemy.sql import sqltypes

from trino import dbapi as trino_dbapi
from trino import logging
Expand All @@ -31,10 +32,25 @@
from trino.dbapi import Cursor
from trino.sqlalchemy import compiler, datatype, error

from .datatype import JSONIndexType, JSONPathType

logger = logging.get_logger(__name__)

colspecs = {
sqltypes.JSON.JSONIndexType: JSONIndexType,
sqltypes.JSON.JSONPathType: JSONPathType,
}


class TrinoDialect(DefaultDialect):
def __init__(self,
json_serializer=None,
json_deserializer=None,
**kwargs):
DefaultDialect.__init__(self, **kwargs)
self._json_serializer = json_serializer
self._json_deserializer = json_deserializer

name = "trino"
driver = "rest"

Expand Down Expand Up @@ -70,6 +86,7 @@ class TrinoDialect(DefaultDialect):

# Support proper ordering of CTEs in regard to an INSERT statement
cte_follows_insert = True
colspecs = colspecs

@classmethod
def dbapi(cls):
Expand Down Expand Up @@ -280,7 +297,12 @@ def get_indexes(self, connection: Connection, table_name: str, schema: str = Non
if not self.has_table(connection, table_name, schema):
raise exc.NoSuchTableError(f"schema={schema}, table={table_name}")

partitioned_columns = self._get_columns(connection, f"{table_name}$partitions", schema, **kw)
partitioned_columns = None
try:
partitioned_columns = self._get_columns(connection, f"{table_name}$partitions", schema, **kw)
except Exception as e:
# e.g. it's not a Hive table or an unpartitioned Hive table
logger.debug("Couldn't fetch partition columns. schema: %s, table: %s, error: %s", schema, table_name, e)
if not partitioned_columns:
return []
partition_index = dict(
Expand Down

0 comments on commit 4e87bc3

Please sign in to comment.