diff --git a/tests/unit/sqlalchemy/test_compiler.py b/tests/unit/sqlalchemy/test_compiler.py index 1051bf3d..872f4a35 100644 --- a/tests/unit/sqlalchemy/test_compiler.py +++ b/tests/unit/sqlalchemy/test_compiler.py @@ -10,7 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest -from sqlalchemy import Column, Integer, MetaData, String, Table, insert, select +from sqlalchemy import Column, Integer, MetaData, String, Table, func, insert, select from sqlalchemy.schema import CreateTable from sqlalchemy.sql import column, table @@ -113,3 +113,41 @@ def test_table_clause(dialect): statement = select(table("user", column("id"), column("name"), column("description"))) query = statement.compile(dialect=dialect) assert str(query) == 'SELECT user.id, user.name, user.description \nFROM user' + + +@pytest.mark.skipif( + sqlalchemy_version() < "1.4", + reason="columns argument to select() must be a Python list or other iterable" +) +@pytest.mark.parametrize( + 'function,element', + [ + ('first_value', func.first_value), + ('last_value', func.last_value), + ('nth_value', func.nth_value), + ('lead', func.lead), + ('lag', func.lag), + ] +) +def test_ignore_nulls(dialect, function, element): + statement = select( + element( + table_without_catalog.c.id, + ignore_nulls=True, + ).over(partition_by=table_without_catalog.c.name).label('window') + ) + query = statement.compile(dialect=dialect) + assert str(query) == \ + f'SELECT {function}("table".id) IGNORE NULLS OVER (PARTITION BY "table".name) AS window '\ + f'\nFROM "table"' + + statement = select( + element( + table_without_catalog.c.id, + ignore_nulls=False, + ).over(partition_by=table_without_catalog.c.name).label('window') + ) + query = statement.compile(dialect=dialect) + assert str(query) == \ + f'SELECT {function}("table".id) OVER (PARTITION BY "table".name) AS window ' \ + f'\nFROM "table"' diff --git a/trino/sqlalchemy/compiler.py b/trino/sqlalchemy/compiler.py index fef6beb1..6612be64 100644 --- a/trino/sqlalchemy/compiler.py +++ b/trino/sqlalchemy/compiler.py @@ -9,8 +9,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql import compiler, sqltypes from sqlalchemy.sql.base import DialectKWArgs +from sqlalchemy.sql.functions import GenericFunction # https://trino.io/docs/current/language/reserved.html RESERVED_WORDS = { @@ -138,6 +140,41 @@ def _render_json_extract_from_binary(self, binary, operator, **kw): self.process(binary.right, **kw), ) + class GenericIgnoreNulls(GenericFunction): + ignore_nulls = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if kwargs.get('ignore_nulls'): + self.ignore_nulls = True + + class FirstValue(GenericIgnoreNulls): + name = 'first_value' + + class LastValue(GenericIgnoreNulls): + name = 'last_value' + + class NthValue(GenericIgnoreNulls): + name = 'nth_value' + + class Lead(GenericIgnoreNulls): + name = 'lead' + + class Lag(GenericIgnoreNulls): + name = 'lag' + + @staticmethod + @compiles(FirstValue) + @compiles(LastValue) + @compiles(NthValue) + @compiles(Lead) + @compiles(Lag) + def compile_ignore_nulls(element, compiler, **kwargs): + compiled = f'{element.name}({compiler.process(element.clauses)})' + if element.ignore_nulls: + compiled += ' IGNORE NULLS' + return compiled + class TrinoDDLCompiler(compiler.DDLCompiler): pass