Skip to content

Commit

Permalink
Support for TIME(p) and TIMESTAMP(p) to SQLAlchemy
Browse files Browse the repository at this point in the history
  • Loading branch information
hovaesco committed May 31, 2022
1 parent 771eec3 commit 2eeb2dc
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 19 deletions.
11 changes: 10 additions & 1 deletion tests/unit/sqlalchemy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pytest
from sqlalchemy.sql.sqltypes import ARRAY

from trino.sqlalchemy.datatype import MAP, ROW, SQLType
from trino.sqlalchemy.datatype import MAP, ROW, SQLType, TIMESTAMP, TIME


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -40,6 +40,15 @@ def _assert_sqltype(this: SQLType, that: SQLType):
for (this_attr, that_attr) in zip(this.attr_types, that.attr_types):
assert this_attr[0] == that_attr[0]
_assert_sqltype(this_attr[1], that_attr[1])

elif isinstance(this, TIME):
assert this.precision == that.precision
assert this.timezone == that.timezone

elif isinstance(this, TIMESTAMP):
assert this.precision == that.precision
assert this.timezone == that.timezone

else:
assert str(this) == str(that)

Expand Down
30 changes: 18 additions & 12 deletions tests/unit/sqlalchemy/test_datatype_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
ARRAY,
INTEGER,
DECIMAL,
DATE,
TIME,
TIMESTAMP,
DATE
)
from sqlalchemy.sql.type_api import TypeEngine

from trino.sqlalchemy import datatype
from trino.sqlalchemy.datatype import MAP, ROW
from trino.sqlalchemy.datatype import (
MAP,
ROW,
TIME,
TIMESTAMP
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -65,8 +68,7 @@ def test_parse_cases(type_str: str, sql_type: TypeEngine, assert_sqltype):
"CHAR(10)": CHAR(10),
"VARCHAR(10)": VARCHAR(10),
"DECIMAL(20)": DECIMAL(20),
"DECIMAL(20, 3)": DECIMAL(20, 3),
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
"DECIMAL(20, 3)": DECIMAL(20, 3)
}


Expand Down Expand Up @@ -142,8 +144,8 @@ def test_parse_map(type_str: str, sql_type: ARRAY, assert_sqltype):
),
"row(min timestamp(6) with time zone, max timestamp(6) with time zone)": ROW(
attr_types=[
("min", TIMESTAMP(timezone=True)),
("max", TIMESTAMP(timezone=True)),
("min", TIMESTAMP(6, timezone=True)),
("max", TIMESTAMP(6, timezone=True)),
]
),
'row("first name" varchar, "last name" varchar)': ROW(
Expand Down Expand Up @@ -173,12 +175,16 @@ def test_parse_row(type_str: str, sql_type: ARRAY, assert_sqltype):


parse_datetime_testcases = {
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
"date": DATE(),
"time": TIME(),
"time(3)": TIME(3, timezone=False),
"time(6)": TIME(6),
"time(12) with time zone": TIME(12, timezone=True),
"time with time zone": TIME(timezone=True),
"timestamp": TIMESTAMP(),
"timestamp with time zone": TIMESTAMP(timezone=True),
"timestamp(3)": TIMESTAMP(3, timezone=False),
"timestamp(6)": TIMESTAMP(6),
"timestamp(12) with time zone": TIMESTAMP(12, timezone=True),
"timestamp with time zone": TIMESTAMP(timezone=True)
}


Expand All @@ -187,6 +193,6 @@ def test_parse_row(type_str: str, sql_type: ARRAY, assert_sqltype):
parse_datetime_testcases.items(),
ids=parse_datetime_testcases.keys(),
)
def test_parse_datetime(type_str: str, sql_type: ARRAY, assert_sqltype):
def test_parse_datetime(type_str: str, sql_type: TypeEngine, assert_sqltype):
actual_type = datatype.parse_sqltype(type_str)
assert_sqltype(actual_type, sql_type)
20 changes: 20 additions & 0 deletions trino/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,26 @@ def visit_BLOB(self, type_, **kw):
def visit_DATETIME(self, type_, **kw):
return self.visit_TIMESTAMP(type_, **kw)

def visit_TIMESTAMP(self, type_, **kw):
return "TIMESTAMP%s%s" % (
"(%d)" % type_.precision
if getattr(type_, "precision", None) is not None
else "",
" WITH TIME ZONE"
if getattr(type_, "timezone", False)
else ""
)

def visit_TIME(self, type_, **kw):
return "TIME%s %s" % (
"(%d)" % type_.precision
if getattr(type_, "precision", None) is not None
else "",
" WITH TIME ZONE"
if getattr(type_, "timezone", False)
else ""
)


class TrinoIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = RESERVED_WORDS
32 changes: 26 additions & 6 deletions trino/sqlalchemy/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Iterator, List, Optional, Tuple, Type, Union
from typing import Iterator, List, Optional, Tuple, Type, Union, Dict, Any

from sqlalchemy import util
from sqlalchemy.sql import sqltypes
Expand Down Expand Up @@ -55,6 +55,22 @@ def python_type(self):
return list


class TIME(sqltypes.TIME):
__visit_name__ = "TIME"

def __init__(self, precision=None, timezone=False):
super(TIME, self).__init__(timezone=timezone)
self.precision = precision


class TIMESTAMP(sqltypes.TIMESTAMP):
__visit_name__ = "TIMESTAMP"

def __init__(self, precision=None, timezone=False):
super(TIMESTAMP, self).__init__(timezone=timezone)
self.precision = precision


# https://trino.io/docs/current/language/types.html
_type_map = {
# === Boolean ===
Expand All @@ -77,8 +93,10 @@ def python_type(self):
"json": sqltypes.JSON,
# === Date and time ===
"date": sqltypes.DATE,
"time": sqltypes.TIME,
"timestamp": sqltypes.TIMESTAMP,
"time": TIME,
"time with time zone": TIME,
"timestamp": TIMESTAMP,
"timestamp with time zone": TIMESTAMP,
# 'interval year to month':
# 'interval day to second':
#
Expand Down Expand Up @@ -193,7 +211,9 @@ def parse_sqltype(type_str: str) -> TypeEngine:
type_class = _type_map[type_name]
type_args = [int(o.strip()) for o in type_opts.split(",")] if type_opts else []
if type_name in ("time", "timestamp"):
type_kwargs = dict(timezone=type_str.endswith("with time zone"))
# TODO: support parametric timestamps (https://github.com/trinodb/trino-python-client/issues/107)
return type_class(**type_kwargs)
if type_str.endswith("with time zone"):
type_kwargs: Dict[str, Any] = dict(timezone=True)
if type_opts is not None:
type_kwargs["precision"] = int(type_opts)
return type_class(**type_kwargs)
return type_class(*type_args)

0 comments on commit 2eeb2dc

Please sign in to comment.