Skip to content

Commit

Permalink
Fix polars-sqlalchemy interactions
Browse files Browse the repository at this point in the history
This fixes read from the database through Polars and sqlalchemy after
recent fixes to the Polars Python package.
  • Loading branch information
daniel-thom committed Oct 17, 2024
1 parent 87ac971 commit 3b763a1
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 26 deletions.
Empty file.
8 changes: 8 additions & 0 deletions src/chronify/sqlalchemy/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import pandas as pd
import polars as pl
from sqlalchemy import Connection, Selectable


def read_database_query(query: Selectable | str, conn: Connection) -> pd.DataFrame:
"""Read a database query into a Pandas DataFrame."""
return pl.read_database(query, connection=conn).to_pandas()
22 changes: 8 additions & 14 deletions src/chronify/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Optional

import pandas as pd
import polars as pl
from loguru import logger
from sqlalchemy import Column, Engine, MetaData, Selectable, Table, create_engine, text

Expand All @@ -16,6 +15,7 @@
TableSchemaBase,
get_sqlalchemy_type_from_duckdb,
)
from chronify.sqlalchemy.functions import read_database_query
from chronify.time_configs import DatetimeRange, IndexTimeRange
from chronify.time_series_checker import TimeSeriesChecker
from chronify.utils.sql import make_temp_view_name
Expand Down Expand Up @@ -143,20 +143,14 @@ def ingest_from_csv(
conn.commit()
self.update_table_schema()

def read_table(self, name: str, query: Optional[Selectable | str] = None) -> pd.DataFrame:
"""Return the table as a pandas DataFrame, optionally applying a query."""
if query is None:
query_ = f"select * from {name}"
elif isinstance(query, Selectable) and self.engine.name == "duckdb":
# TODO: unsafe. Need duckdb_engine support.
# https://github.com/Mause/duckdb_engine/issues/1119
# https://github.com/pola-rs/polars/issues/19221
query_ = str(query.compile(compile_kwargs={"literal_binds": True}))
else:
query_ = query

def read_query(self, query: Selectable | str) -> pd.DataFrame:
"""Return the query result as a pandas DataFrame."""
with self._engine.begin() as conn:
return pl.read_database(query_, connection=conn).to_pandas()
return read_database_query(query, conn)

def read_table(self, name: str) -> pd.DataFrame:
"""Return the table as a pandas DataFrame."""
return self.read_query(f"select * from {name}")

def write_query_to_parquet(self, stmt: Selectable, file_path: Path | str) -> None:
"""Write the result of a query to a Parquet file."""
Expand Down
23 changes: 18 additions & 5 deletions src/chronify/time_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from typing import Any, Optional, Union, Literal
from zoneinfo import ZoneInfo

import pandas as pd
from pydantic import (
BaseModel,
Field,
ValidationInfo,
field_validator,
model_validator,
)
from sqlalchemy import CursorResult
from typing_extensions import Annotated

from chronify.time import (
Expand Down Expand Up @@ -162,9 +163,9 @@ def iter_timestamps(self) -> Generator[Any, None, None]:
Type of the time is dependent on the class.
"""

def convert_database_timestamps(self, cur: CursorResult) -> list[Any]:
@abc.abstractmethod
def convert_database_timestamps(self, df: pd.DataFrame) -> list[Any]:
"""Convert timestamps from the database."""
return [x[0] for x in cur]


class DatetimeRange(TimeBaseModel):
Expand All @@ -185,6 +186,13 @@ class DatetimeRange(TimeBaseModel):
interval_type: TimeIntervalType = TimeIntervalType.PERIOD_ENDING
measurement_type: MeasurementType = MeasurementType.TOTAL

@model_validator(mode="after")
def check_time_columns(self) -> "DatetimeRange":
if len(self.time_columns) != 1:
msg = f"{self.time_columns=} must have one column"
raise ValueError(msg)
return self

@field_validator("start")
@classmethod
def fix_time_zone(cls, start: datetime, info: ValidationInfo) -> datetime:
Expand All @@ -197,10 +205,15 @@ def fix_time_zone(cls, start: datetime, info: ValidationInfo) -> datetime:
return start.replace(tzinfo=zone_info)
return start

def convert_database_timestamps(self, cur: CursorResult) -> list[datetime]:
def convert_database_timestamps(self, df: pd.DataFrame) -> list[datetime]:
assert self.time_zone is not None
tzinfo = get_zone_info(self.time_zone)
return [x[0].astimezone(tzinfo) for x in cur]
time_column = self.get_time_column()
return df[time_column].apply(lambda x: x.astimezone(tzinfo)).to_list()

def get_time_column(self) -> str:
"""Return the time column."""
return self.time_columns[0]

def iter_timestamps(self) -> Generator[datetime, None, None]:
tz_info = self.start.tzinfo
Expand Down
12 changes: 5 additions & 7 deletions src/chronify/time_series_checker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from sqlalchemy import Connection, Engine, MetaData, Table, text
from sqlalchemy import Connection, Engine, MetaData, Table, select, text

from chronify.exceptions import InvalidTable
from chronify.models import TableSchema
from chronify.sqlalchemy.functions import read_database_query
from chronify.utils.sql import make_temp_view_name


Expand All @@ -13,21 +14,18 @@ def __init__(self, engine: Engine, metadata: MetaData) -> None:
self._metadata = metadata

def check_timestamps(self, schema: TableSchema) -> None:
# TODO: the conn.execute calls here are slow but not vulnerable to sql injection.
# Data extraction should never be large.
# Consider changing when we have a better implementation.
self._check_expected_timestamps(schema)
self._check_expected_timestamps_by_time_array(schema)

def _check_expected_timestamps(self, schema: TableSchema) -> None:
expected = schema.time_config.list_timestamps()
with self._engine.connect() as conn:
table = Table(schema.name, self._metadata)
stmt = table.select().distinct()
stmt = select(*(table.c[x] for x in schema.time_config.time_columns)).distinct()
for col in schema.time_config.time_columns:
stmt = stmt.where(table.c[col].is_not(None))
# This is slow, but the size should never be large.
actual = set(schema.time_config.convert_database_timestamps(conn.execute(stmt)))
df = read_database_query(stmt, conn)
actual = set(schema.time_config.convert_database_timestamps(df))
diff = actual.symmetric_difference(expected)
if diff:
msg = f"Actual timestamps do not match expected timestamps: {diff}"
Expand Down

0 comments on commit 3b763a1

Please sign in to comment.