diff --git a/src/chronify/sqlalchemy/__init__.py b/src/chronify/sqlalchemy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/chronify/sqlalchemy/functions.py b/src/chronify/sqlalchemy/functions.py new file mode 100644 index 0000000..8ab91b1 --- /dev/null +++ b/src/chronify/sqlalchemy/functions.py @@ -0,0 +1,8 @@ +import pandas as pd +import polars as pl +from sqlalchemy import Connection, Selectable + + +def read_database_query(query: Selectable | str, conn: Connection) -> pd.DataFrame: + """Read a database query into a Pandas DataFrame.""" + return pl.read_database(query, connection=conn).to_pandas() diff --git a/src/chronify/store.py b/src/chronify/store.py index 572543a..36e41d7 100644 --- a/src/chronify/store.py +++ b/src/chronify/store.py @@ -3,7 +3,6 @@ from typing import Optional import pandas as pd -import polars as pl from loguru import logger from sqlalchemy import Column, Engine, MetaData, Selectable, Table, create_engine, text @@ -16,6 +15,7 @@ TableSchemaBase, get_sqlalchemy_type_from_duckdb, ) +from chronify.sqlalchemy.functions import read_database_query from chronify.time_configs import DatetimeRange, IndexTimeRange from chronify.time_series_checker import TimeSeriesChecker from chronify.utils.sql import make_temp_view_name @@ -143,20 +143,14 @@ def ingest_from_csv( conn.commit() self.update_table_schema() - def read_table(self, name: str, query: Optional[Selectable | str] = None) -> pd.DataFrame: - """Return the table as a pandas DataFrame, optionally applying a query.""" - if query is None: - query_ = f"select * from {name}" - elif isinstance(query, Selectable) and self.engine.name == "duckdb": - # TODO: unsafe. Need duckdb_engine support. - # https://github.com/Mause/duckdb_engine/issues/1119 - # https://github.com/pola-rs/polars/issues/19221 - query_ = str(query.compile(compile_kwargs={"literal_binds": True})) - else: - query_ = query - + def read_query(self, query: Selectable | str) -> pd.DataFrame: + """Return the query result as a pandas DataFrame.""" with self._engine.begin() as conn: - return pl.read_database(query_, connection=conn).to_pandas() + return read_database_query(query, conn) + + def read_table(self, name: str) -> pd.DataFrame: + """Return the table as a pandas DataFrame.""" + return self.read_query(f"select * from {name}") def write_query_to_parquet(self, stmt: Selectable, file_path: Path | str) -> None: """Write the result of a query to a Parquet file.""" diff --git a/src/chronify/time_configs.py b/src/chronify/time_configs.py index cc28229..458ad1d 100644 --- a/src/chronify/time_configs.py +++ b/src/chronify/time_configs.py @@ -5,13 +5,14 @@ from typing import Any, Optional, Union, Literal from zoneinfo import ZoneInfo +import pandas as pd from pydantic import ( BaseModel, Field, ValidationInfo, field_validator, + model_validator, ) -from sqlalchemy import CursorResult from typing_extensions import Annotated from chronify.time import ( @@ -162,9 +163,9 @@ def iter_timestamps(self) -> Generator[Any, None, None]: Type of the time is dependent on the class. """ - def convert_database_timestamps(self, cur: CursorResult) -> list[Any]: + @abc.abstractmethod + def convert_database_timestamps(self, df: pd.DataFrame) -> list[Any]: """Convert timestamps from the database.""" - return [x[0] for x in cur] class DatetimeRange(TimeBaseModel): @@ -185,6 +186,13 @@ class DatetimeRange(TimeBaseModel): interval_type: TimeIntervalType = TimeIntervalType.PERIOD_ENDING measurement_type: MeasurementType = MeasurementType.TOTAL + @model_validator(mode="after") + def check_time_columns(self) -> "DatetimeRange": + if len(self.time_columns) != 1: + msg = f"{self.time_columns=} must have one column" + raise ValueError(msg) + return self + @field_validator("start") @classmethod def fix_time_zone(cls, start: datetime, info: ValidationInfo) -> datetime: @@ -197,10 +205,15 @@ def fix_time_zone(cls, start: datetime, info: ValidationInfo) -> datetime: return start.replace(tzinfo=zone_info) return start - def convert_database_timestamps(self, cur: CursorResult) -> list[datetime]: + def convert_database_timestamps(self, df: pd.DataFrame) -> list[datetime]: assert self.time_zone is not None tzinfo = get_zone_info(self.time_zone) - return [x[0].astimezone(tzinfo) for x in cur] + time_column = self.get_time_column() + return df[time_column].apply(lambda x: x.astimezone(tzinfo)).to_list() + + def get_time_column(self) -> str: + """Return the time column.""" + return self.time_columns[0] def iter_timestamps(self) -> Generator[datetime, None, None]: tz_info = self.start.tzinfo diff --git a/src/chronify/time_series_checker.py b/src/chronify/time_series_checker.py index 5ef5cf4..a2f918a 100644 --- a/src/chronify/time_series_checker.py +++ b/src/chronify/time_series_checker.py @@ -1,7 +1,8 @@ -from sqlalchemy import Connection, Engine, MetaData, Table, text +from sqlalchemy import Connection, Engine, MetaData, Table, select, text from chronify.exceptions import InvalidTable from chronify.models import TableSchema +from chronify.sqlalchemy.functions import read_database_query from chronify.utils.sql import make_temp_view_name @@ -13,9 +14,6 @@ def __init__(self, engine: Engine, metadata: MetaData) -> None: self._metadata = metadata def check_timestamps(self, schema: TableSchema) -> None: - # TODO: the conn.execute calls here are slow but not vulnerable to sql injection. - # Data extraction should never be large. - # Consider changing when we have a better implementation. self._check_expected_timestamps(schema) self._check_expected_timestamps_by_time_array(schema) @@ -23,11 +21,11 @@ def _check_expected_timestamps(self, schema: TableSchema) -> None: expected = schema.time_config.list_timestamps() with self._engine.connect() as conn: table = Table(schema.name, self._metadata) - stmt = table.select().distinct() + stmt = select(*(table.c[x] for x in schema.time_config.time_columns)).distinct() for col in schema.time_config.time_columns: stmt = stmt.where(table.c[col].is_not(None)) - # This is slow, but the size should never be large. - actual = set(schema.time_config.convert_database_timestamps(conn.execute(stmt))) + df = read_database_query(stmt, conn) + actual = set(schema.time_config.convert_database_timestamps(df)) diff = actual.symmetric_difference(expected) if diff: msg = f"Actual timestamps do not match expected timestamps: {diff}"