Skip to content

Commit

Permalink
Support for TIME(p) and TIMESTAMP(p) to SQLAlchemy
Browse files Browse the repository at this point in the history
  • Loading branch information
hovaesco committed May 31, 2022
1 parent 771eec3 commit bd14188
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 18 deletions.
11 changes: 10 additions & 1 deletion tests/unit/sqlalchemy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pytest
from sqlalchemy.sql.sqltypes import ARRAY

from trino.sqlalchemy.datatype import MAP, ROW, SQLType
from trino.sqlalchemy.datatype import MAP, ROW, SQLType, TIMESTAMP, TIME


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -40,6 +40,15 @@ def _assert_sqltype(this: SQLType, that: SQLType):
for (this_attr, that_attr) in zip(this.attr_types, that.attr_types):
assert this_attr[0] == that_attr[0]
_assert_sqltype(this_attr[1], that_attr[1])

elif isinstance(this, TIME):
assert this.precision == that.precision
assert this.timezone == that.timezone

elif isinstance(this, TIMESTAMP):
assert this.precision == that.precision
assert this.timezone == that.timezone

else:
assert str(this) == str(that)

Expand Down
30 changes: 18 additions & 12 deletions tests/unit/sqlalchemy/test_datatype_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
ARRAY,
INTEGER,
DECIMAL,
DATE,
TIME,
TIMESTAMP,
DATE
)
from sqlalchemy.sql.type_api import TypeEngine

from trino.sqlalchemy import datatype
from trino.sqlalchemy.datatype import MAP, ROW
from trino.sqlalchemy.datatype import (
MAP,
ROW,
TIME,
TIMESTAMP
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -65,8 +68,7 @@ def test_parse_cases(type_str: str, sql_type: TypeEngine, assert_sqltype):
"CHAR(10)": CHAR(10),
"VARCHAR(10)": VARCHAR(10),
"DECIMAL(20)": DECIMAL(20),
"DECIMAL(20, 3)": DECIMAL(20, 3),
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
"DECIMAL(20, 3)": DECIMAL(20, 3)
}


Expand Down Expand Up @@ -142,8 +144,8 @@ def test_parse_map(type_str: str, sql_type: ARRAY, assert_sqltype):
),
"row(min timestamp(6) with time zone, max timestamp(6) with time zone)": ROW(
attr_types=[
("min", TIMESTAMP(timezone=True)),
("max", TIMESTAMP(timezone=True)),
("min", TIMESTAMP(6, timezone=True)),
("max", TIMESTAMP(6, timezone=True)),
]
),
'row("first name" varchar, "last name" varchar)': ROW(
Expand Down Expand Up @@ -173,12 +175,16 @@ def test_parse_row(type_str: str, sql_type: ARRAY, assert_sqltype):


parse_datetime_testcases = {
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
"date": DATE(),
"time": TIME(),
"time(3)": TIME(3, timezone=False),
"time(6)": TIME(6),
"time(12) with time zone": TIME(12, timezone=True),
"time with time zone": TIME(timezone=True),
"timestamp": TIMESTAMP(),
"timestamp with time zone": TIMESTAMP(timezone=True),
"timestamp(3)": TIMESTAMP(3, timezone=False),
"timestamp(6)": TIMESTAMP(6),
"timestamp(12) with time zone": TIMESTAMP(12, timezone=True),
"timestamp with time zone": TIMESTAMP(timezone=True)
}


Expand All @@ -187,6 +193,6 @@ def test_parse_row(type_str: str, sql_type: ARRAY, assert_sqltype):
parse_datetime_testcases.items(),
ids=parse_datetime_testcases.keys(),
)
def test_parse_datetime(type_str: str, sql_type: ARRAY, assert_sqltype):
def test_parse_datetime(type_str: str, sql_type: TypeEngine, assert_sqltype):
actual_type = datatype.parse_sqltype(type_str)
assert_sqltype(actual_type, sql_type)
20 changes: 20 additions & 0 deletions trino/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,26 @@ def visit_BLOB(self, type_, **kw):
def visit_DATETIME(self, type_, **kw):
return self.visit_TIMESTAMP(type_, **kw)

def visit_TIMESTAMP(self, type_, **kw):
return "TIMESTAMP%s%s" % (
"(%d)" % type_.precision
if getattr(type_, "precision", None) is not None
else "",
" WITH TIME ZONE"
if getattr(type_, "timezone", False)
else ""
)

def visit_TIME(self, type_, **kw):
return "TIME%s %s" % (
"(%d)" % type_.precision
if getattr(type_, "precision", None) is not None
else "",
" WITH TIME ZONE"
if getattr(type_, "timezone", False)
else ""
)


class TrinoIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = RESERVED_WORDS
30 changes: 25 additions & 5 deletions trino/sqlalchemy/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,22 @@ def python_type(self):
return list


class TIME(sqltypes.TIME):
__visit_name__ = "TIME"

def __init__(self, precision=None, timezone=False):
super(TIME, self).__init__(timezone=timezone)
self.precision = precision


class TIMESTAMP(sqltypes.TIMESTAMP):
__visit_name__ = "TIMESTAMP"

def __init__(self, precision=None, timezone=False):
super(TIMESTAMP, self).__init__(timezone=timezone)
self.precision = precision


# https://trino.io/docs/current/language/types.html
_type_map = {
# === Boolean ===
Expand All @@ -77,8 +93,10 @@ def python_type(self):
"json": sqltypes.JSON,
# === Date and time ===
"date": sqltypes.DATE,
"time": sqltypes.TIME,
"timestamp": sqltypes.TIMESTAMP,
"time": TIME,
"time with time zone": TIME,
"timestamp": TIMESTAMP,
"timestamp with time zone": TIMESTAMP,
# 'interval year to month':
# 'interval day to second':
#
Expand Down Expand Up @@ -193,7 +211,9 @@ def parse_sqltype(type_str: str) -> TypeEngine:
type_class = _type_map[type_name]
type_args = [int(o.strip()) for o in type_opts.split(",")] if type_opts else []
if type_name in ("time", "timestamp"):
type_kwargs = dict(timezone=type_str.endswith("with time zone"))
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
return type_class(**type_kwargs)
if type_str.endswith("with time zone"):
type_kwargs = dict(timezone=True)
if type_opts is not None:
type_kwargs["precision"] = int(type_opts)
return type_class(**type_kwargs)
return type_class(*type_args)

0 comments on commit bd14188

Please sign in to comment.