Skip to content

Commit

Permalink
Add support for MAP type to SQLAlchemy dialect
Browse files Browse the repository at this point in the history
Co-authored-by: dudu butbul <[email protected]>
Co-authored-by: Ashhar Hasan <[email protected]>
  • Loading branch information
hashhar and dudu-upstream committed Feb 16, 2024
1 parent 143b6d8 commit b8d3360
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 1 deletion.
54 changes: 53 additions & 1 deletion tests/integration/test_sqlalchemy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import math
import uuid
from decimal import Decimal

import pytest
import sqlalchemy as sqla
from sqlalchemy.sql import and_, not_, or_

from tests.integration.conftest import trino_version
from tests.unit.conftest import sqlalchemy_version
from trino.sqlalchemy.datatype import JSON
from trino.sqlalchemy.datatype import JSON, MAP


@pytest.fixture
Expand Down Expand Up @@ -476,6 +478,56 @@ def test_json_column_operations(trino_connection):
metadata.drop_all(engine)


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.parametrize(
'trino_connection,map_object,sqla_type',
[
('memory', None, MAP(sqla.sql.sqltypes.String, sqla.sql.sqltypes.Integer)),
('memory', {}, MAP(sqla.sql.sqltypes.String, sqla.sql.sqltypes.Integer)),
('memory', {True: False, False: True}, MAP(sqla.sql.sqltypes.Boolean, sqla.sql.sqltypes.Boolean)),
('memory', {1: 1, 2: None}, MAP(sqla.sql.sqltypes.Integer, sqla.sql.sqltypes.Integer)),
('memory', {1.4: 1.4, math.inf: math.inf}, MAP(sqla.sql.sqltypes.Float, sqla.sql.sqltypes.Float)),
('memory', {1.4: 1.4, math.inf: math.inf}, MAP(sqla.sql.sqltypes.REAL, sqla.sql.sqltypes.REAL)),
('memory',
{Decimal("1.2"): Decimal("1.2")},
MAP(sqla.sql.sqltypes.DECIMAL(2, 1), sqla.sql.sqltypes.DECIMAL(2, 1))),
('memory', {"hello": "world"}, MAP(sqla.sql.sqltypes.String, sqla.sql.sqltypes.String)),
('memory', {"a ": "a", "null": "n"}, MAP(sqla.sql.sqltypes.CHAR(4), sqla.sql.sqltypes.CHAR(1))),
('memory', {b'': b'eh?', b'\x00': None}, MAP(sqla.sql.sqltypes.BINARY, sqla.sql.sqltypes.BINARY)),
],
indirect=['trino_connection']
)
def test_map_column(trino_connection, map_object, sqla_type):
engine, conn = trino_connection

if not engine.dialect.has_schema(conn, "test"):
with engine.begin() as connection:
connection.execute(sqla.schema.CreateSchema("test"))
metadata = sqla.MetaData()

try:
table_with_map = sqla.Table(
'table_with_map',
metadata,
sqla.Column('id', sqla.Integer),
sqla.Column('map_column', sqla_type),
schema="test"
)
metadata.create_all(engine)
ins = table_with_map.insert()
conn.execute(ins, {"id": 1, "map_column": map_object})
query = sqla.select(table_with_map)
result = conn.execute(query)
rows = result.fetchall()
assert len(rows) == 1
assert rows[0] == (1, map_object)
finally:
metadata.drop_all(engine)


@pytest.mark.parametrize('trino_connection', ['system'], indirect=True)
def test_get_catalog_names(trino_connection):
engine, conn = trino_connection
Expand Down
7 changes: 7 additions & 0 deletions trino/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,13 @@ def visit_TIME(self, type_, **kw):
def visit_JSON(self, type_, **kw):
return 'JSON'

def visit_MAP(self, type_, **kw):
# the key and value types themselves need to be processed otherwise sqltypes.MAP(Float, Float) will get
# rendered as MAP(FLOAT, FLOAT) instead of MAP(REAL, REAL) or MAP(DOUBLE, DOUBLE)
key_type = self.process(type_.key_type, **kw)
value_type = self.process(type_.value_type, **kw)
return f'MAP({key_type}, {value_type})'


class TrinoIdentifierPreparer(compiler.IdentifierPreparer):
reserved_words = RESERVED_WORDS
Expand Down

0 comments on commit b8d3360

Please sign in to comment.