From c7ca7cd4017e1e09825194221cb791a68ac67dbc Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 2 Oct 2023 11:52:33 -0400 Subject: [PATCH] refactor(clickhouse): use more sqlglot constructs --- ibis/backends/base/sqlglot/__init__.py | 19 ++ ibis/backends/clickhouse/__init__.py | 105 +++++--- ibis/backends/clickhouse/compiler/values.py | 285 ++++++++++---------- 3 files changed, 230 insertions(+), 179 deletions(-) diff --git a/ibis/backends/base/sqlglot/__init__.py b/ibis/backends/base/sqlglot/__init__.py index 3381d584a24f..d16500d05662 100644 --- a/ibis/backends/base/sqlglot/__init__.py +++ b/ibis/backends/base/sqlglot/__init__.py @@ -52,11 +52,30 @@ def exists(self, query): def concat(self, *args): return sg.exp.Concat.from_arg_list(list(map(_to_sqlglot, args))) + def map(self, keys, values): + return sg.exp.Map(keys=keys, values=values) + + +class ColGen: + __slots__ = () + + def __getattr__(self, name: str) -> sg.exp.Column: + return sg.column(name) + + def __getitem__(self, key: str) -> sg.exp.Column: + return sg.column(key) + def lit(val): return sg.exp.Literal(this=str(val), is_string=isinstance(val, str)) +def interval(value, *, unit): + return sg.exp.Interval(this=_to_sqlglot(value), unit=sg.exp.var(unit)) + + +F = FuncGen() +C = ColGen() NULL = sg.exp.NULL FALSE = sg.exp.FALSE TRUE = sg.exp.TRUE diff --git a/ibis/backends/clickhouse/__init__.py b/ibis/backends/clickhouse/__init__.py index f6484cb953ad..a5a26df44745 100644 --- a/ibis/backends/clickhouse/__init__.py +++ b/ibis/backends/clickhouse/__init__.py @@ -23,6 +23,7 @@ import ibis.expr.types as ir from ibis import util from ibis.backends.base import BaseBackend, CanCreateDatabase +from ibis.backends.base.sqlglot import STAR, C, F, lit from ibis.backends.clickhouse.compiler import translate from ibis.backends.clickhouse.datatypes import ClickhouseType @@ -176,12 +177,14 @@ def version(self) -> str: @property def current_database(self) -> str: - with closing(self.raw_sql("SELECT currentDatabase()")) as result: + with closing(self.raw_sql(sg.select(F.currentDatabase()))) as result: [(db,)] = result.result_rows return db def list_databases(self, like: str | None = None) -> list[str]: - with closing(self.raw_sql("SELECT name FROM system.databases")) as result: + with closing( + self.raw_sql(sg.select(C.name).from_(sg.table("databases", db="system"))) + ) as result: results = result.result_columns if results: @@ -193,14 +196,14 @@ def list_databases(self, like: str | None = None) -> list[str]: def list_tables( self, like: str | None = None, database: str | None = None ) -> list[str]: - query = "SELECT name FROM system.tables WHERE" + query = sg.select(C.name).from_(sg.table("tables", db="system")) if database is None: - database = "currentDatabase()" + database = F.currentDatabase() else: - database = f"'{database}'" + database = lit(database) - query += f" database = {database} OR is_temporary" + query = query.where(C.database.eq(database).or_(C.is_temporary)) with closing(self.raw_sql(query)) as result: results = result.result_columns @@ -377,7 +380,10 @@ def execute( # in single column conversion and whole table conversion return expr.__pandas_result__(table.__pandas_result__(df)) - def compile(self, expr: ir.Expr, limit: str | None = None, params=None, **_: Any): + def _to_sqlglot( + self, expr: ir.Expr, limit: str | None = None, params=None, **_: Any + ): + """Compile an Ibis expression to a sqlglot object.""" table_expr = expr.as_table() if limit == "default": @@ -392,13 +398,21 @@ def compile(self, expr: ir.Expr, limit: str | None = None, params=None, **_: Any assert not isinstance(sql, sg.exp.Subquery) if isinstance(sql, sg.exp.Table): - sql = sg.select("*").from_(sql) + sql = sg.select(STAR).from_(sql) assert not isinstance(sql, sg.exp.Subquery) - return sql.sql(dialect="clickhouse", pretty=True) + return sql + + def compile( + self, expr: ir.Expr, limit: str | None = None, params=None, **kwargs: Any + ): + """Compile an Ibis expression to a ClickHouse SQL string.""" + return self._to_sqlglot(expr, limit=limit, params=params, **kwargs).sql( + dialect=self.name, pretty=True + ) def _to_sql(self, expr: ir.Expr, **kwargs) -> str: - return str(self.compile(expr, **kwargs)) + return self.compile(expr, **kwargs) def table(self, name: str, database: str | None = None) -> ir.Table: """Construct a table expression. @@ -444,7 +458,7 @@ def insert( def raw_sql( self, - query: str, + query: str | sg.exp.Expression, external_tables: Mapping[str, pd.DataFrame] | None = None, **kwargs, ) -> Any: @@ -467,6 +481,8 @@ def raw_sql( """ external_tables = toolz.valmap(_to_memtable, external_tables or {}) external_data = self._normalize_external_tables(external_tables) + with suppress(AttributeError): + query = query.sql(dialect=self.name, pretty=True) self._log(query) return self.con.query(query, external_data=external_data, **kwargs) @@ -501,8 +517,7 @@ def get_schema(self, table_name: str, database: str | None = None) -> sch.Schema sch.Schema Ibis schema """ - qualified_name = self._fully_qualified_name(table_name, database) - query = f"DESCRIBE {qualified_name}" + query = sg.exp.Describe(this=sg.table(table_name, db=database)) with closing(self.raw_sql(query)) as results: names, types, *_ = results.result_columns return sch.Schema(dict(zip(names, map(ClickhouseType.from_string, types)))) @@ -528,34 +543,39 @@ def has_operation(cls, operation: type[ops.Value]) -> bool: def create_database( self, name: str, *, force: bool = False, engine: str = "Atomic" ) -> None: - if_not_exists = "IF NOT EXISTS " * force - with closing( - self.raw_sql(f"CREATE DATABASE {if_not_exists}{name} ENGINE = {engine}") - ): + src = sg.exp.Create( + this=sg.to_identifier(name), + kind="DATABASE", + exists=force, + properties=sg.exp.Properties( + expressions=[sg.exp.EngineProperty(this=sg.to_identifier(engine))] + ), + ) + with closing(self.raw_sql(src)): pass def drop_database(self, name: str, *, force: bool = False) -> None: - if_exists = "IF EXISTS " * force - with closing(self.raw_sql(f"DROP DATABASE {if_exists}{name}")): + src = sg.exp.Drop(this=sg.to_identifier(name), kind="DATABASE", exists=force) + with closing(self.raw_sql(src)): pass def truncate_table(self, name: str, database: str | None = None) -> None: - ident = self._fully_qualified_name(name, database) + ident = sg.table(name, db=database).sql(self.name) with closing(self.raw_sql(f"TRUNCATE TABLE {ident}")): pass def drop_table( self, name: str, database: str | None = None, force: bool = False ) -> None: - ident = self._fully_qualified_name(name, database) - with closing(self.raw_sql(f"DROP TABLE {'IF EXISTS ' * force}{ident}")): + src = sg.exp.Drop(this=sg.table(name, db=database), kind="TABLE", exists=force) + with closing(self.raw_sql(src)): pass def read_parquet( self, path: str | Path, table_name: str | None = None, - engine: str = "File(Native)", + engine: str = "MergeTree", **kwargs: Any, ) -> ir.Table: import pyarrow.dataset as ds @@ -583,7 +603,7 @@ def read_csv( self, path: str | Path, table_name: str | None = None, - engine: str = "File(Native)", + engine: str = "MergeTree", **kwargs: Any, ) -> ir.Table: import pyarrow.dataset as ds @@ -611,7 +631,7 @@ def create_table( temp: bool = False, overwrite: bool = False, # backend specific arguments - engine: str = "File(Native)", + engine: str = "MergeTree", order_by: Iterable[str] | None = None, partition_by: Iterable[str] | None = None, sample_by: str | None = None, @@ -636,7 +656,9 @@ def create_table( Whether to overwrite the table engine The table engine to use. See [ClickHouse's `CREATE TABLE` documentation](https://clickhouse.com/docs/en/sql-reference/statements/create/table) - for specifics. + for specifics. Defaults to [`MergeTree`](https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree) + with `ORDER BY tuple()` because `MergeTree` is the most + feature-complete engine. order_by String column names to order by. Required for some table engines like `MergeTree`. partition_by @@ -681,6 +703,10 @@ def create_table( if order_by is not None: code += f" ORDER BY {', '.join(util.promote_list(order_by))}" + elif engine == "MergeTree": + # empty tuple to indicate no specific order when engine is + # MergeTree + code += " ORDER BY tuple()" if partition_by is not None: code += f" PARTITION BY {', '.join(util.promote_list(partition_by))}" @@ -713,21 +739,22 @@ def create_view( database: str | None = None, overwrite: bool = False, ) -> ir.Table: - qualname = self._fully_qualified_name(name, database) - replace = "OR REPLACE " * overwrite - query = self.compile(obj) - code = f"CREATE {replace}VIEW {qualname} AS {query}" + src = sg.exp.Create( + this=sg.table(name, db=database), + kind="VIEW", + replace=overwrite, + expression=self._to_sqlglot(obj), + ) external_tables = self._collect_in_memory_tables(obj) - with closing(self.raw_sql(code, external_tables=external_tables)): + with closing(self.raw_sql(src, external_tables=external_tables)): pass return self.table(name, database=database) def drop_view( self, name: str, *, database: str | None = None, force: bool = False ) -> None: - name = self._fully_qualified_name(name, database) - if_exists = "IF EXISTS " * force - with closing(self.raw_sql(f"DROP VIEW {if_exists}{name}")): + src = sg.exp.Drop(this=sg.table(name, db=database), kind="VIEW", exists=force) + with closing(self.raw_sql(src)): pass def _load_into_cache(self, name, expr): @@ -742,12 +769,9 @@ def _create_temp_view(self, table_name, source): f"{table_name} already exists as a non-temporary table or view" ) src = sg.exp.Create( - this=sg.table(table_name), # CREATE ... 'table_name' - kind="VIEW", # VIEW - replace=True, # OR REPLACE - expression=source, # AS ... + this=sg.table(table_name), kind="VIEW", replace=True, expression=source ) - self.raw_sql(src.sql(dialect=self.name, pretty=True)) + self.raw_sql(src) self._temp_views.add(table_name) self._register_temp_view_cleanup(table_name) @@ -756,6 +780,5 @@ def drop(self, name: str, query: str): self.raw_sql(query) self._temp_views.discard(name) - src = sg.exp.Drop(this=sg.table(name), kind="VIEW", exists=True) - query = src.sql(self.name, pretty=True) + query = sg.exp.Drop(this=sg.table(name), kind="VIEW", exists=True) atexit.register(drop, self, name=name, query=query) diff --git a/ibis/backends/clickhouse/compiler/values.py b/ibis/backends/clickhouse/compiler/values.py index 5dcecc1299f7..85b225399aba 100644 --- a/ibis/backends/clickhouse/compiler/values.py +++ b/ibis/backends/clickhouse/compiler/values.py @@ -13,7 +13,16 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis import util -from ibis.backends.base.sqlglot import NULL, STAR, AggGen, FuncGen, lit, make_cast +from ibis.backends.base.sqlglot import ( + NULL, + STAR, + AggGen, + C, + F, + interval, + lit, + make_cast, +) from ibis.backends.clickhouse.datatypes import ClickhouseType # TODO: This is a hack to get around the fact that sqlglot 17.8.6 is broken for @@ -28,14 +37,13 @@ def _aggregate(funcname, *args, where): has_filter = where is not None - func = f[funcname + "If" * has_filter] + func = F[funcname + "If" * has_filter] args += (where,) * has_filter return func(*args) -f = FuncGen() agg = AggGen(aggfunc=_aggregate) -if_ = f["if"] +if_ = F["if"] cast = make_cast(ClickhouseType) @@ -52,7 +60,7 @@ def _column(op, *, table, name, **_): @translate_val.register(ops.Alias) def _alias(op, *, arg, name, **_): - return sg.alias(arg, name) + return arg.as_(name) _interval_cast_suffixes = { @@ -71,17 +79,17 @@ def _alias(op, *, arg, name, **_): def _cast(op, *, arg, to, **_): if to.is_interval(): suffix = _interval_cast_suffixes[to.unit.short] - return f[f"toInterval{suffix}"](arg) + return F[f"toInterval{suffix}"](arg) result = cast(arg, to) if (timezone := getattr(to, "timezone", None)) is not None: - return f.toTimeZone(result, timezone) + return F.toTimeZone(result, timezone) return result @translate_val.register(ops.TryCast) def _try_cast(op, *, arg, to, **_): - return f.accurateCastOrNull(arg, ClickhouseType.to_string(to)) + return F.accurateCastOrNull(arg, ClickhouseType.to_string(to)) @translate_val.register(ops.Between) @@ -100,25 +108,25 @@ def _not(op, *, arg, **_): def _parenthesize(op, arg): - # function calls don't need parens if isinstance(op, (ops.Binary, ops.Unary)): return sg.exp.Paren(this=arg) else: + # function calls don't need parens return arg @translate_val.register(ops.ArrayIndex) def _array_index_op(op, *, arg, index, **_): - return f.arrayElement(arg, if_(index >= 0, index + 1, index)) + return arg[if_(index >= 0, index + 1, index)] @translate_val.register(ops.ArrayRepeat) def _array_repeat_op(op, *, arg, times, **_): return ( - sg.select(f.arrayFlatten(f.groupArray(sg.column("arr")))) + sg.select(F.arrayFlatten(F.groupArray(C.arr))) .from_( - sg.select(sg.alias(arg, "arr")) - .from_(sg.table(table="numbers", catalog="system")) + sg.select(arg.as_("arr")) + .from_(sg.table("numbers", db="system")) .limit(times) .subquery() ) @@ -134,33 +142,36 @@ def _array_slice_op(op, *, arg, start, stop, **_): if stop is not None: stop = _parenthesize(op.stop, stop) - neg_start = f.length(arg) + start - diff = lambda v: f.greatest(0, stop - v) - - length = if_(stop < 0, stop, if_(start < 0, diff(neg_start), diff(start))) - return f.arraySlice(arg, start_correct, length) + length = if_( + stop < 0, + stop, + if_( + start < 0, + F.greatest(0, stop - (F.length(arg) + start)), + F.greatest(0, stop - start), + ), + ) + return F.arraySlice(arg, start_correct, length) else: - return f.arraySlice(arg, start_correct) + return F.arraySlice(arg, start_correct) @translate_val.register(ops.CountStar) def _count_star(op, *, where, **_): if where is not None: - return f.countIf(where) + return F.countIf(where) return sg.exp.Count(this=STAR) -def _quantile(func): +def _quantile(func: str): def _compile(op, *, arg, quantile, where, **_): - funcname = func - args = [arg] - - if where is not None: - funcname += "If" - args.append(where) + if where is None: + return agg.quantile(arg, quantile, where=where) return sg.exp.ParameterizedAgg( - this=funcname, expressions=util.promote_list(quantile), params=args + this=f"{func}If", + expressions=util.promote_list(quantile), + params=[arg, where], ) return _compile @@ -202,8 +213,8 @@ def _arbitrary(op, *, arg, how, where, **_): def _substring(op, *, arg, start, length, **_): # Clickhouse is 1-indexed suffix = (length,) * (length is not None) - if_pos = f.substring(arg, start + 1, *suffix) - if_neg = f.substring(arg, f.length(arg) + start + 1, *suffix) + if_pos = F.substring(arg, start + 1, *suffix) + if_neg = F.substring(arg, F.length(arg) + start + 1, *suffix) return if_(start >= 0, if_pos, if_neg) @@ -213,9 +224,9 @@ def _string_find(op, *, arg, substr, start, end, **_): raise com.UnsupportedOperationError("String find doesn't support end argument") if start is not None: - return f.locate(arg, substr, start) - 1 + return F.locate(arg, substr, start) - 1 - return f.locate(arg, substr) - 1 + return F.locate(arg, substr) - 1 @translate_val.register(ops.RegexSearch) @@ -227,39 +238,39 @@ def _regex_search(op, *, arg, pattern, **_): def _regex_extract(op, *, arg, pattern, index, **_): arg = cast(arg, dt.String(nullable=False)) - pattern = f.concat("(", pattern, ")") + pattern = F.concat("(", pattern, ")") if index is None: index = 0 index += 1 - then = f.extractGroups(arg, pattern)[index] + then = F.extractGroups(arg, pattern)[index] - return if_(f.notEmpty(then), then, NULL) + return if_(F.notEmpty(then), then, NULL) @translate_val.register(ops.FindInSet) def _index_of(op, *, needle, values, **_): - return f.indexOf(f.array(*values), needle) - 1 + return F.indexOf(F.array(*values), needle) - 1 @translate_val.register(ops.Round) def _round(op, *, arg, digits, **_): if digits is not None: - return f.round(arg, digits) - return f.round(arg) + return F.round(arg, digits) + return F.round(arg) @translate_val.register(ops.Sign) def _sign(op, *, arg, **_): """Workaround for missing sign function.""" - return f.intDivOrZero(arg, f.abs(arg)) + return F.intDivOrZero(arg, F.abs(arg)) @translate_val.register(ops.Hash) def _hash(op, *, arg, **_): - return f.sipHash64(arg) + return F.sipHash64(arg) _SUPPORTED_ALGORITHMS = frozenset( @@ -283,28 +294,28 @@ def _hash_bytes(op, *, arg, how, **_): if how not in _SUPPORTED_ALGORITHMS: raise com.UnsupportedOperationError(f"Unsupported hash algorithm {how}") - return f[how](arg) + return F[how](arg) @translate_val.register(ops.Log) def _log(op, *, arg, base, **_): if base is None: - return f.ln(arg) + return F.ln(arg) elif str(base) in ("2", "10"): - return f[f"log{base}"](arg) + return F[f"log{base}"](arg) else: - return f.ln(arg) / f.ln(base) + return F.ln(arg) / F.ln(base) @translate_val.register(ops.IntervalFromInteger) def _interval_from_integer(op, *, arg, unit, **_): dtype = op.dtype - if dtype.unit.short in {"ms", "us", "ns"}: + if dtype.unit.short in ("ms", "us", "ns"): raise com.UnsupportedOperationError( "Clickhouse doesn't support subsecond interval resolutions" ) - return sg.exp.Interval(this=arg, unit=sg.exp.var(dtype.resolution.upper())) + return interval(arg, unit=dtype.resolution.upper()) @translate_val.register(ops.Literal) @@ -317,7 +328,7 @@ def _literal(op, *, value, dtype, **kw): return lit(bool(value)) elif dtype.is_inet(): v = str(value) - return f.toIPv6(v) if ":" in v else f.toIPv4(v) + return F.toIPv6(v) if ":" in v else F.toIPv4(v) elif dtype.is_string(): return lit(str(value).replace(r"\0", r"\\0")) elif dtype.is_macaddr(): @@ -330,13 +341,13 @@ def _literal(op, *, value, dtype, **kw): ) if 1 <= precision <= 9: - type_name = f.toDecimal32 + type_name = F.toDecimal32 elif 10 <= precision <= 18: - type_name = f.toDecimal64 + type_name = F.toDecimal64 elif 19 <= precision <= 38: - type_name = f.toDecimal128 + type_name = F.toDecimal128 else: - type_name = f.toDecimal256 + type_name = F.toDecimal256 return type_name(value, dtype.scale) elif dtype.is_numeric(): return lit(value) @@ -347,9 +358,7 @@ def _literal(op, *, value, dtype, **kw): "Clickhouse doesn't support subsecond interval resolutions" ) - return sg.exp.Interval( - this=lit(value), unit=sg.exp.var(dtype.resolution.upper()) - ) + return interval(value, unit=dtype.resolution.upper()) elif dtype.is_timestamp(): funcname = "toDateTime" fmt = "%Y-%m-%dT%H:%M:%S" @@ -368,16 +377,16 @@ def _literal(op, *, value, dtype, **kw): if (timezone := dtype.timezone) is not None: args.append(timezone) - return f[funcname](*args) + return F[funcname](*args) elif dtype.is_date(): - return f.toDate(value.strftime("%Y-%m-%d")) + return F.toDate(value.strftime("%Y-%m-%d")) elif dtype.is_array(): value_type = dtype.value_type values = [ _literal(ops.Literal(v, dtype=value_type), value=v, dtype=value_type, **kw) for v in value ] - return f.array(*values) + return F.array(*values) elif dtype.is_map(): value_type = dtype.value_type keys = [] @@ -391,13 +400,13 @@ def _literal(op, *, value, dtype, **kw): ) ) - return sg.exp.Map(keys=f.array(*keys), values=f.array(*values)) + return F.map(F.array(*keys), F.array(*values)) elif dtype.is_struct(): fields = [ _literal(ops.Literal(v, dtype=field_type), value=v, dtype=field_type, **kw) for field_type, v in zip(dtype.types, value.values()) ] - return f.tuple(*fields) + return F.tuple(*fields) else: raise NotImplementedError(f"Unsupported type: {dtype!r}") @@ -417,7 +426,7 @@ def _table_array_view(op, *, table, **_): def _timestamp_from_unix(op, *, arg, unit, **_): if (unit := unit.short) in {"ms", "us", "ns"}: raise com.UnsupportedOperationError(f"{unit!r} unit is not supported!") - return f.toDateTime(arg) + return F.toDateTime(arg) @translate_val.register(ops.DateTruncate) @@ -425,13 +434,13 @@ def _timestamp_from_unix(op, *, arg, unit, **_): @translate_val.register(ops.TimeTruncate) def _truncate(op, *, arg, unit, **_): converters = { - "Y": f.toStartOfYear, - "M": f.toStartOfMonth, - "W": f.toMonday, - "D": f.toDate, - "h": f.toStartOfHour, - "m": f.toStartOfMinute, - "s": f.toDateTime, + "Y": F.toStartOfYear, + "M": F.toStartOfMonth, + "W": F.toMonday, + "D": F.toDate, + "h": F.toStartOfHour, + "m": F.toStartOfMinute, + "s": F.toDateTime, } unit = unit.short @@ -443,36 +452,36 @@ def _truncate(op, *, arg, unit, **_): @translate_val.register(ops.DateFromYMD) def _date_from_ymd(op, *, year, month, day, **_): - return f.toDate( - f.concat( - f.toString(year), + return F.toDate( + F.concat( + F.toString(year), "-", - f.leftPad(f.toString(month), 2, "0"), + F.leftPad(F.toString(month), 2, "0"), "-", - f.leftPad(f.toString(day), 2, "0"), + F.leftPad(F.toString(day), 2, "0"), ) ) @translate_val.register(ops.TimestampFromYMDHMS) def _timestamp_from_ymdhms(op, *, year, month, day, hours, minutes, seconds, **_): - to_datetime = f.toDateTime( - f.concat( - f.toString(year), + to_datetime = F.toDateTime( + F.concat( + F.toString(year), "-", - f.leftPad(f.toString(month), 2, "0"), + F.leftPad(F.toString(month), 2, "0"), "-", - f.leftPad(f.toString(day), 2, "0"), + F.leftPad(F.toString(day), 2, "0"), " ", - f.leftPad(f.toString(hours), 2, "0"), + F.leftPad(F.toString(hours), 2, "0"), ":", - f.leftPad(f.toString(minutes), 2, "0"), + F.leftPad(F.toString(minutes), 2, "0"), ":", - f.leftPad(f.toString(seconds), 2, "0"), + F.leftPad(F.toString(seconds), 2, "0"), ) ) if timezone := op.dtype.timezone: - return f.toTimeZone(to_datetime, timezone) + return F.toTimeZone(to_datetime, timezone) return to_datetime @@ -482,22 +491,22 @@ def _exists_subquery(op, *, foreign_table, predicates, **_): # # this would work if clickhouse supported correlated subqueries subq = sg.select(1).from_(foreign_table).where(sg.condition(predicates)).subquery() - return f.exists(subq) + return F.exists(subq) @translate_val.register(ops.StringSplit) def _string_split(op, *, arg, delimiter, **_): - return f.splitByString(delimiter, cast(arg, dt.String(nullable=False))) + return F.splitByString(delimiter, cast(arg, dt.String(nullable=False))) @translate_val.register(ops.StringJoin) def _string_join(op, *, sep, arg, **_): - return f.arrayStringConcat(f.array(*arg), sep) + return F.arrayStringConcat(F.array(*arg), sep) @translate_val.register(ops.StringConcat) def _string_concat(op, *, arg, **_): - return f.concat(*arg) + return F.concat(*arg) @translate_val.register(ops.StringSQLLike) @@ -512,31 +521,31 @@ def _string_ilike(op, *, arg, pattern, **_): @translate_val.register(ops.Capitalize) def _string_capitalize(op, *, arg, **_): - return f.concat(f.upper(f.substr(arg, 1, 1)), f.lower(f.substr(arg, 2))) + return F.concat(F.upper(F.substr(arg, 1, 1)), F.lower(F.substr(arg, 2))) @translate_val.register(ops.GroupConcat) def _group_concat(op, *, arg, sep, where, **_): call = agg.groupArray(arg, where=where) - return if_(f.empty(call), NULL, f.arrayStringConcat(call, sep)) + return if_(F.empty(call), NULL, F.arrayStringConcat(call, sep)) @translate_val.register(ops.StrRight) def _string_right(op, *, arg, nchars, **_): nchars = _parenthesize(op.nchars, nchars) - return f.substring(arg, -nchars) + return F.substring(arg, -nchars) @translate_val.register(ops.Cot) def _cotangent(op, *, arg, **_): - return 1.0 / f.tan(arg) + return 1.0 / F.tan(arg) def _bit_agg(func: str): def _translate(op, *, arg, where, **_): if not (dtype := op.arg.dtype).is_unsigned_integer(): nbits = dtype.nbytes * 8 - arg = f[f"reinterpretAsUInt{nbits}"](arg) + arg = F[f"reinterpretAsUInt{nbits}"](arg) return agg[func](arg, where=where) return _translate @@ -544,23 +553,23 @@ def _translate(op, *, arg, where, **_): @translate_val.register(ops.ArrayColumn) def _array_column(op, *, cols, **_): - return f.array(*cols) + return F.array(*cols) @translate_val.register(ops.StructColumn) def _struct_column(op, *, values, **_): # ClickHouse struct types cannot be nullable # (non-nested fields can be nullable) - return cast(f.tuple(*values), op.dtype.copy(nullable=False)) + return cast(F.tuple(*values), op.dtype.copy(nullable=False)) @translate_val.register(ops.Clip) def _clip(op, *, arg, lower, upper, **_): if upper is not None: - arg = if_(f.isNull(arg), NULL, f.least(upper, arg)) + arg = if_(F.isNull(arg), NULL, F.least(upper, arg)) if lower is not None: - arg = if_(f.isNull(arg), NULL, f.greatest(lower, arg)) + arg = if_(F.isNull(arg), NULL, F.greatest(lower, arg)) return arg @@ -574,22 +583,22 @@ def _struct_field(op, *, arg, field: str, **_): @translate_val.register(ops.NthValue) def _nth_value(op, *, arg, nth, **_): - return f.nth_value(arg, _parenthesize(op.nth, nth) + 1) + return F.nth_value(arg, _parenthesize(op.nth, nth) + 1) @translate_val.register(ops.Repeat) def _repeat(op, *, arg, times, **_): - return f.repeat(arg, f.accurateCast(times, "UInt64")) + return F.repeat(arg, F.accurateCast(times, "UInt64")) @translate_val.register(ops.FloorDivide) def _floor_divide(op, *, left, right, **_): - return f.floor(left / right) + return F.floor(left / right) @translate_val.register(ops.StringContains) def _string_contains(op, haystack, needle, **_): - return f.locate(haystack, needle) > 0 + return F.locate(haystack, needle) > 0 @translate_val.register(ops.InValues) @@ -608,7 +617,7 @@ def _in_column(op, *, value, options, **_): @translate_val.register(ops.DayOfWeekIndex) def _day_of_week_index(op, *, arg, **_): weekdays = _NUM_WEEKDAYS - return (((f.toDayOfWeek(arg) - 1) % weekdays) + weekdays) % weekdays + return (((F.toDayOfWeek(arg) - 1) % weekdays) + weekdays) % weekdays @translate_val.register(ops.DayOfWeekName) @@ -624,8 +633,8 @@ def day_of_week_name(op, *, arg, **_): # We test against 20 in CI, so we implement day_of_week_name as follows num_weekdays = _NUM_WEEKDAYS weekdays = range(num_weekdays) - base = (((f.toDayOfWeek(arg) - 1) % num_weekdays) + num_weekdays) % num_weekdays - return f.nullIf( + base = (((F.toDayOfWeek(arg) - 1) % num_weekdays) + num_weekdays) % num_weekdays + return F.nullIf( sg.exp.Case( this=base, ifs=[if_(day, calendar.day_name[day]) for day in weekdays], @@ -639,23 +648,23 @@ def day_of_week_name(op, *, arg, **_): @translate_val.register(ops.Least) @translate_val.register(ops.Coalesce) def _vararg_func(op, *, arg, **_): - return f[op.__class__.__name__.lower()](*arg) + return F[op.__class__.__name__.lower()](*arg) @translate_val.register(ops.Map) def _map(op, *, keys, values, **_): # cast here to allow lookups of nullable columns - return cast(f.tuple(keys, values), op.dtype) + return cast(F.tuple(keys, values), op.dtype) @translate_val.register(ops.MapGet) def _map_get(op, *, arg, key, default, **_): - return if_(f.mapContains(arg, key), arg[key], default) + return if_(F.mapContains(arg, key), arg[key], default) @translate_val.register(ops.ArrayConcat) def _array_concat(op, *, arg, **_): - return f.arrayConcat(*arg) + return F.arrayConcat(*arg) def _binary_infix(func): @@ -684,7 +693,7 @@ def formatter(op, *, left, right, **_): # Boolean comparisons ops.And: operator.and_, ops.Or: operator.or_, - ops.Xor: f.xor, + ops.Xor: F.xor, ops.DateAdd: operator.add, ops.DateSub: operator.sub, ops.DateDiff: operator.sub, @@ -824,7 +833,7 @@ def _fmt(_, _name: str = _name, *, where, **kw): @translate_val.register(_op) def _fmt(_, _name: str = _name, **kw): - return f[_name](*kw.values()) + return F[_name](*kw.values()) del _fmt, _name, _op @@ -832,15 +841,15 @@ def _fmt(_, _name: str = _name, **kw): @translate_val.register(ops.ArrayDistinct) def _array_distinct(op, *, arg, **_): - null_element = if_(f.countEqual(arg, NULL) > 0, f.array(NULL), f.array()) - return f.arrayConcat(f.arrayDistinct(arg), null_element) + null_element = if_(F.countEqual(arg, NULL) > 0, F.array(NULL), F.array()) + return F.arrayConcat(F.arrayDistinct(arg), null_element) @translate_val.register(ops.ExtractMicrosecond) def _extract_microsecond(op, *, arg, **_): dtype = op.dtype return cast( - f.toUnixTimestamp64Micro(cast(arg, op.arg.dtype.copy(scale=6))) % 1_000_000, + F.toUnixTimestamp64Micro(cast(arg, op.arg.dtype.copy(scale=6))) % 1_000_000, dtype, ) @@ -849,7 +858,7 @@ def _extract_microsecond(op, *, arg, **_): def _extract_millisecond(op, *, arg, **_): dtype = op.dtype return cast( - f.toUnixTimestamp64Milli(cast(arg, op.arg.dtype.copy(scale=3))) % 1_000, dtype + F.toUnixTimestamp64Milli(cast(arg, op.arg.dtype.copy(scale=3))) % 1_000, dtype ) @@ -930,67 +939,67 @@ def formatter(op, *, arg, offset, default, **_): return formatter -shift_like(ops.Lag, f.lagInFrame) -shift_like(ops.Lead, f.leadInFrame) +shift_like(ops.Lag, F.lagInFrame) +shift_like(ops.Lead, F.leadInFrame) @translate_val.register(ops.RowNumber) def _row_number(op, **_): - return f.row_number() + return F.row_number() @translate_val.register(ops.DenseRank) def _dense_rank(op, **_): - return f.dense_rank() + return F.dense_rank() @translate_val.register(ops.MinRank) def _rank(op, **_): - return f.rank() + return F.rank() @translate_val.register(ops.ExtractProtocol) def _extract_protocol(op, *, arg, **_): - return f.nullIf(f.protocol(arg), "") + return F.nullIf(F.protocol(arg), "") @translate_val.register(ops.ExtractAuthority) def _extract_authority(op, *, arg, **_): - return f.nullIf(f.netloc(arg), "") + return F.nullIf(F.netloc(arg), "") @translate_val.register(ops.ExtractHost) def _extract_host(op, *, arg, **_): - return f.nullIf(f.domain(arg), "") + return F.nullIf(F.domain(arg), "") @translate_val.register(ops.ExtractFile) def _extract_file(op, *, arg, **_): - return f.nullIf(f.cutFragment(f.pathFull(arg)), "") + return F.nullIf(F.cutFragment(F.pathFull(arg)), "") @translate_val.register(ops.ExtractPath) def _extract_path(op, *, arg, **_): - return f.nullIf(f.path(arg), "") + return F.nullIf(F.path(arg), "") @translate_val.register(ops.ExtractQuery) def _extract_query(op, *, arg, key, **_): if key is not None: - input = f.extractURLParameter(arg, key) + input = F.extractURLParameter(arg, key) else: - input = f.queryString(arg) - return f.nullIf(input, "") + input = F.queryString(arg) + return F.nullIf(input, "") @translate_val.register(ops.ExtractFragment) def _extract_fragment(op, *, arg, **_): - return f.nullIf(f.fragment(arg), "") + return F.nullIf(F.fragment(arg), "") @translate_val.register(ops.ArrayStringJoin) def _array_string_join(op, *, arg, sep, **_): - return f.arrayStringConcat(arg, sep) + return F.arrayStringConcat(arg, sep) @translate_val.register(ops.Argument) @@ -1001,52 +1010,52 @@ def _argument(op, *, name, **_): @translate_val.register(ops.ArrayMap) def _array_map(op, *, arg, param, body, **_): func = sg.exp.Lambda(this=body, expressions=[param]) - return f.arrayMap(func, arg) + return F.arrayMap(func, arg) @translate_val.register(ops.ArrayFilter) def _array_filter(op, *, arg, param, body, **_): func = sg.exp.Lambda(this=body, expressions=[param]) - return f.arrayFilter(func, arg) + return F.arrayFilter(func, arg) @translate_val.register(ops.ArrayPosition) def _array_position(op, *, arg, other, **_): - return f.indexOf(arg, other) - 1 + return F.indexOf(arg, other) - 1 @translate_val.register(ops.ArrayRemove) def _array_remove(op, *, arg, other, **_): x = sg.to_identifier("x") body = x.neq(other) - return f.arrayFilter(sg.exp.Lambda(this=body, expressions=[x]), arg) + return F.arrayFilter(sg.exp.Lambda(this=body, expressions=[x]), arg) @translate_val.register(ops.ArrayUnion) def _array_union(op, *, left, right, **_): - arg = f.arrayConcat(left, right) - null_element = if_(f.countEqual(arg, NULL) > 0, f.array(NULL), f.array()) - return f.arrayConcat(f.arrayDistinct(arg), null_element) + arg = F.arrayConcat(left, right) + null_element = if_(F.countEqual(arg, NULL) > 0, F.array(NULL), F.array()) + return F.arrayConcat(F.arrayDistinct(arg), null_element) @translate_val.register(ops.ArrayZip) def _array_zip(op: ops.ArrayZip, *, arg, **_: Any) -> str: - return f.arrayZip(*arg) + return F.arrayZip(*arg) @translate_val.register(ops.CountDistinctStar) def _count_distinct_star(op: ops.CountDistinctStar, *, where, **_: Any) -> str: - columns = f.tuple(*map(sg.column, op.arg.schema.names)) + columns = F.tuple(*map(sg.column, op.arg.schema.names)) if where is not None: - return f.countDistinctIf(columns, where) + return F.countDistinctIf(columns, where) else: - return f.countDistinct(columns) + return F.countDistinct(columns) @translate_val.register(ops.ScalarUDF) def _scalar_udf(op, **kw) -> str: - return f[op.__full_name__](*kw.values()) + return F[op.__full_name__](*kw.values()) @translate_val.register(ops.AggUDF)