Skip to content

Commit

Permalink
refactor(clickhouse): use more sqlglot constructs
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Oct 2, 2023
1 parent 5d300a9 commit c7ca7cd
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 179 deletions.
19 changes: 19 additions & 0 deletions ibis/backends/base/sqlglot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,30 @@ def exists(self, query):
def concat(self, *args):
return sg.exp.Concat.from_arg_list(list(map(_to_sqlglot, args)))

def map(self, keys, values):
return sg.exp.Map(keys=keys, values=values)


class ColGen:
__slots__ = ()

def __getattr__(self, name: str) -> sg.exp.Column:
return sg.column(name)

def __getitem__(self, key: str) -> sg.exp.Column:
return sg.column(key)


def lit(val):
return sg.exp.Literal(this=str(val), is_string=isinstance(val, str))


def interval(value, *, unit):
return sg.exp.Interval(this=_to_sqlglot(value), unit=sg.exp.var(unit))


F = FuncGen()
C = ColGen()
NULL = sg.exp.NULL
FALSE = sg.exp.FALSE
TRUE = sg.exp.TRUE
Expand Down
105 changes: 64 additions & 41 deletions ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import ibis.expr.types as ir
from ibis import util
from ibis.backends.base import BaseBackend, CanCreateDatabase
from ibis.backends.base.sqlglot import STAR, C, F, lit
from ibis.backends.clickhouse.compiler import translate
from ibis.backends.clickhouse.datatypes import ClickhouseType

Expand Down Expand Up @@ -176,12 +177,14 @@ def version(self) -> str:

@property
def current_database(self) -> str:
with closing(self.raw_sql("SELECT currentDatabase()")) as result:
with closing(self.raw_sql(sg.select(F.currentDatabase()))) as result:
[(db,)] = result.result_rows
return db

def list_databases(self, like: str | None = None) -> list[str]:
with closing(self.raw_sql("SELECT name FROM system.databases")) as result:
with closing(
self.raw_sql(sg.select(C.name).from_(sg.table("databases", db="system")))
) as result:
results = result.result_columns

if results:
Expand All @@ -193,14 +196,14 @@ def list_databases(self, like: str | None = None) -> list[str]:
def list_tables(
self, like: str | None = None, database: str | None = None
) -> list[str]:
query = "SELECT name FROM system.tables WHERE"
query = sg.select(C.name).from_(sg.table("tables", db="system"))

if database is None:
database = "currentDatabase()"
database = F.currentDatabase()
else:
database = f"'{database}'"
database = lit(database)

query += f" database = {database} OR is_temporary"
query = query.where(C.database.eq(database).or_(C.is_temporary))

with closing(self.raw_sql(query)) as result:
results = result.result_columns
Expand Down Expand Up @@ -377,7 +380,10 @@ def execute(
# in single column conversion and whole table conversion
return expr.__pandas_result__(table.__pandas_result__(df))

def compile(self, expr: ir.Expr, limit: str | None = None, params=None, **_: Any):
def _to_sqlglot(
self, expr: ir.Expr, limit: str | None = None, params=None, **_: Any
):
"""Compile an Ibis expression to a sqlglot object."""
table_expr = expr.as_table()

if limit == "default":
Expand All @@ -392,13 +398,21 @@ def compile(self, expr: ir.Expr, limit: str | None = None, params=None, **_: Any
assert not isinstance(sql, sg.exp.Subquery)

if isinstance(sql, sg.exp.Table):
sql = sg.select("*").from_(sql)
sql = sg.select(STAR).from_(sql)

assert not isinstance(sql, sg.exp.Subquery)
return sql.sql(dialect="clickhouse", pretty=True)
return sql

def compile(
self, expr: ir.Expr, limit: str | None = None, params=None, **kwargs: Any
):
"""Compile an Ibis expression to a ClickHouse SQL string."""
return self._to_sqlglot(expr, limit=limit, params=params, **kwargs).sql(
dialect=self.name, pretty=True
)

def _to_sql(self, expr: ir.Expr, **kwargs) -> str:
return str(self.compile(expr, **kwargs))
return self.compile(expr, **kwargs)

def table(self, name: str, database: str | None = None) -> ir.Table:
"""Construct a table expression.
Expand Down Expand Up @@ -444,7 +458,7 @@ def insert(

def raw_sql(
self,
query: str,
query: str | sg.exp.Expression,
external_tables: Mapping[str, pd.DataFrame] | None = None,
**kwargs,
) -> Any:
Expand All @@ -467,6 +481,8 @@ def raw_sql(
"""
external_tables = toolz.valmap(_to_memtable, external_tables or {})
external_data = self._normalize_external_tables(external_tables)
with suppress(AttributeError):
query = query.sql(dialect=self.name, pretty=True)
self._log(query)
return self.con.query(query, external_data=external_data, **kwargs)

Expand Down Expand Up @@ -501,8 +517,7 @@ def get_schema(self, table_name: str, database: str | None = None) -> sch.Schema
sch.Schema
Ibis schema
"""
qualified_name = self._fully_qualified_name(table_name, database)
query = f"DESCRIBE {qualified_name}"
query = sg.exp.Describe(this=sg.table(table_name, db=database))
with closing(self.raw_sql(query)) as results:
names, types, *_ = results.result_columns
return sch.Schema(dict(zip(names, map(ClickhouseType.from_string, types))))
Expand All @@ -528,34 +543,39 @@ def has_operation(cls, operation: type[ops.Value]) -> bool:
def create_database(
self, name: str, *, force: bool = False, engine: str = "Atomic"
) -> None:
if_not_exists = "IF NOT EXISTS " * force
with closing(
self.raw_sql(f"CREATE DATABASE {if_not_exists}{name} ENGINE = {engine}")
):
src = sg.exp.Create(
this=sg.to_identifier(name),
kind="DATABASE",
exists=force,
properties=sg.exp.Properties(
expressions=[sg.exp.EngineProperty(this=sg.to_identifier(engine))]
),
)
with closing(self.raw_sql(src)):
pass

def drop_database(self, name: str, *, force: bool = False) -> None:
if_exists = "IF EXISTS " * force
with closing(self.raw_sql(f"DROP DATABASE {if_exists}{name}")):
src = sg.exp.Drop(this=sg.to_identifier(name), kind="DATABASE", exists=force)
with closing(self.raw_sql(src)):
pass

def truncate_table(self, name: str, database: str | None = None) -> None:
ident = self._fully_qualified_name(name, database)
ident = sg.table(name, db=database).sql(self.name)
with closing(self.raw_sql(f"TRUNCATE TABLE {ident}")):
pass

def drop_table(
self, name: str, database: str | None = None, force: bool = False
) -> None:
ident = self._fully_qualified_name(name, database)
with closing(self.raw_sql(f"DROP TABLE {'IF EXISTS ' * force}{ident}")):
src = sg.exp.Drop(this=sg.table(name, db=database), kind="TABLE", exists=force)
with closing(self.raw_sql(src)):
pass

def read_parquet(
self,
path: str | Path,
table_name: str | None = None,
engine: str = "File(Native)",
engine: str = "MergeTree",
**kwargs: Any,
) -> ir.Table:
import pyarrow.dataset as ds
Expand Down Expand Up @@ -583,7 +603,7 @@ def read_csv(
self,
path: str | Path,
table_name: str | None = None,
engine: str = "File(Native)",
engine: str = "MergeTree",
**kwargs: Any,
) -> ir.Table:
import pyarrow.dataset as ds
Expand Down Expand Up @@ -611,7 +631,7 @@ def create_table(
temp: bool = False,
overwrite: bool = False,
# backend specific arguments
engine: str = "File(Native)",
engine: str = "MergeTree",
order_by: Iterable[str] | None = None,
partition_by: Iterable[str] | None = None,
sample_by: str | None = None,
Expand All @@ -636,7 +656,9 @@ def create_table(
Whether to overwrite the table
engine
The table engine to use. See [ClickHouse's `CREATE TABLE` documentation](https://clickhouse.com/docs/en/sql-reference/statements/create/table)
for specifics.
for specifics. Defaults to [`MergeTree`](https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree)
with `ORDER BY tuple()` because `MergeTree` is the most
feature-complete engine.
order_by
String column names to order by. Required for some table engines like `MergeTree`.
partition_by
Expand Down Expand Up @@ -681,6 +703,10 @@ def create_table(

if order_by is not None:
code += f" ORDER BY {', '.join(util.promote_list(order_by))}"
elif engine == "MergeTree":
# empty tuple to indicate no specific order when engine is
# MergeTree
code += " ORDER BY tuple()"

if partition_by is not None:
code += f" PARTITION BY {', '.join(util.promote_list(partition_by))}"
Expand Down Expand Up @@ -713,21 +739,22 @@ def create_view(
database: str | None = None,
overwrite: bool = False,
) -> ir.Table:
qualname = self._fully_qualified_name(name, database)
replace = "OR REPLACE " * overwrite
query = self.compile(obj)
code = f"CREATE {replace}VIEW {qualname} AS {query}"
src = sg.exp.Create(
this=sg.table(name, db=database),
kind="VIEW",
replace=overwrite,
expression=self._to_sqlglot(obj),
)
external_tables = self._collect_in_memory_tables(obj)
with closing(self.raw_sql(code, external_tables=external_tables)):
with closing(self.raw_sql(src, external_tables=external_tables)):
pass
return self.table(name, database=database)

def drop_view(
self, name: str, *, database: str | None = None, force: bool = False
) -> None:
name = self._fully_qualified_name(name, database)
if_exists = "IF EXISTS " * force
with closing(self.raw_sql(f"DROP VIEW {if_exists}{name}")):
src = sg.exp.Drop(this=sg.table(name, db=database), kind="VIEW", exists=force)
with closing(self.raw_sql(src)):
pass

def _load_into_cache(self, name, expr):
Expand All @@ -742,12 +769,9 @@ def _create_temp_view(self, table_name, source):
f"{table_name} already exists as a non-temporary table or view"
)
src = sg.exp.Create(
this=sg.table(table_name), # CREATE ... 'table_name'
kind="VIEW", # VIEW
replace=True, # OR REPLACE
expression=source, # AS ...
this=sg.table(table_name), kind="VIEW", replace=True, expression=source
)
self.raw_sql(src.sql(dialect=self.name, pretty=True))
self.raw_sql(src)
self._temp_views.add(table_name)
self._register_temp_view_cleanup(table_name)

Expand All @@ -756,6 +780,5 @@ def drop(self, name: str, query: str):
self.raw_sql(query)
self._temp_views.discard(name)

src = sg.exp.Drop(this=sg.table(name), kind="VIEW", exists=True)
query = src.sql(self.name, pretty=True)
query = sg.exp.Drop(this=sg.table(name), kind="VIEW", exists=True)
atexit.register(drop, self, name=name, query=query)
Loading

0 comments on commit c7ca7cd

Please sign in to comment.