From 18c56016e3dc42437cd175be49780073bc91df19 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Mon, 20 May 2024 02:39:32 +0400 Subject: [PATCH] feat(python): Add `to_jax` methods to support Jax Array export from `DataFrame` and `Series` (#16294) --- .../src/chunked_array/ops/aggregate/mod.rs | 2 +- .../source/reference/dataframe/export.rst | 1 + .../docs/source/reference/series/export.rst | 2 + py-polars/polars/dataframe/frame.py | 283 ++++++++++++++- py-polars/polars/series/series.py | 71 +++- py-polars/polars/type_aliases.py | 1 + py-polars/pyproject.toml | 4 +- py-polars/requirements-ci.txt | 2 + py-polars/tests/docs/run_doctest.py | 1 + .../tests/unit/dataframe/test_to_torch.py | 290 ---------------- py-polars/tests/unit/ml/__init__.py | 0 py-polars/tests/unit/ml/test_to_jax.py | 156 +++++++++ py-polars/tests/unit/ml/test_to_torch.py | 325 ++++++++++++++++++ 13 files changed, 827 insertions(+), 311 deletions(-) delete mode 100644 py-polars/tests/unit/dataframe/test_to_torch.py create mode 100644 py-polars/tests/unit/ml/__init__.py create mode 100644 py-polars/tests/unit/ml/test_to_jax.py create mode 100644 py-polars/tests/unit/ml/test_to_torch.py diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs index 7fe43517b365..5e4b09e5d18a 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs @@ -195,7 +195,7 @@ where } } -/// Booleans are casted to 1 or 0. +/// Booleans are cast to 1 or 0. impl BooleanChunked { pub fn sum(&self) -> Option { Some(if self.is_empty() { diff --git a/py-polars/docs/source/reference/dataframe/export.rst b/py-polars/docs/source/reference/dataframe/export.rst index 12cb378dc6ef..0347b7429da0 100644 --- a/py-polars/docs/source/reference/dataframe/export.rst +++ b/py-polars/docs/source/reference/dataframe/export.rst @@ -13,6 +13,7 @@ Export DataFrame data to other formats: DataFrame.to_dict DataFrame.to_dicts DataFrame.to_init_repr + DataFrame.to_jax DataFrame.to_numpy DataFrame.to_pandas DataFrame.to_struct diff --git a/py-polars/docs/source/reference/series/export.rst b/py-polars/docs/source/reference/series/export.rst index 6e19c4efa4f7..c1c7bacf8086 100644 --- a/py-polars/docs/source/reference/series/export.rst +++ b/py-polars/docs/source/reference/series/export.rst @@ -10,7 +10,9 @@ Export Series data to other formats: Series.to_arrow Series.to_frame + Series.to_jax Series.to_list Series.to_numpy Series.to_pandas Series.to_init_repr + Series.to_torch diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index f5649d28f6dc..085ed0d259e0 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -72,6 +72,7 @@ INTEGER_DTYPES, N_INFER_DEFAULT, Boolean, + Float32, Float64, Int32, Int64, @@ -102,7 +103,7 @@ from polars.functions import col, lit from polars.selectors import _expand_selector_dicts, _expand_selectors from polars.slice import PolarsSlice -from polars.type_aliases import DbWriteMode, TorchExportType +from polars.type_aliases import DbWriteMode, JaxExportType, TorchExportType with contextlib.suppress(ImportError): # Module not available when building docs from polars.polars import dtype_str_repr as _dtype_str_repr @@ -115,6 +116,7 @@ from typing import Literal import deltalake + import jax import torch from hvplot.plotting.core import hvPlotTabularPolars from xlsxwriter import Workbook @@ -1516,7 +1518,7 @@ def to_numpy( However, the C-like order might be more appropriate to use for downstream applications to prevent cloning data, e.g. when reshaping into a one-dimensional array. Note that this option only takes effect if - `structured` is set to `False` and the DataFrame dtypes allow for a + `structured` is set to `False` and the DataFrame dtypes allow a global dtype for all columns. allow_copy Allow memory to be copied to perform the conversion. If set to `False`, @@ -1609,6 +1611,208 @@ def raise_on_copy(msg: str) -> None: return out + @overload + def to_jax( + self, + return_type: Literal["array"] = ..., + *, + device: jax.Device | str | None = ..., + label: str | Expr | Sequence[str | Expr] | None = ..., + features: str | Expr | Sequence[str | Expr] | None = ..., + dtype: PolarsDataType | None = ..., + order: IndexOrder = ..., + ) -> jax.Array: ... + + @overload + def to_jax( + self, + return_type: Literal["dict"], + *, + device: jax.Device | str | None = ..., + label: str | Expr | Sequence[str | Expr] | None = ..., + features: str | Expr | Sequence[str | Expr] | None = ..., + dtype: PolarsDataType | None = ..., + order: IndexOrder = ..., + ) -> dict[str, jax.Array]: ... + + @unstable() + def to_jax( + self, + return_type: JaxExportType = "array", + *, + device: jax.Device | str | None = None, + label: str | Expr | Sequence[str | Expr] | None = None, + features: str | Expr | Sequence[str | Expr] | None = None, + dtype: PolarsDataType | None = None, + order: IndexOrder = "fortran", + ) -> jax.Array | dict[str, jax.Array]: + """ + Convert DataFrame to a Jax Array, or dict of Jax Arrays. + + .. versionadded:: 0.20.27 + + .. warning:: + This functionality is currently considered **unstable**. It may be + changed at any point without it being considered a breaking change. + + Parameters + ---------- + return_type : {"array", "dict"} + Set return type; a Jax Array, or dict of Jax Arrays. + device + Specify the jax `Device` on which the array will be created; can provide + a string (such as "cpu", "gpu", or "tpu") in which case the device is + retrieved as `jax.devices(string)[0]`. For more specific control you + can supply the instantiated `Device` directly. If None, arrays are + created on the default device. + label + One or more column names, expressions, or selectors that label the feature + data; results in a `{"label": ..., "features": ...}` dict being returned + when `return_type` is "dict" instead of a `{"col": array, }` dict. + features + One or more column names, expressions, or selectors that contain the feature + data; if omitted, all columns that are not designated as part of the label + are used. Only applies when `return_type` is "dict". + dtype + Unify the dtype of all returned arrays; this casts any column that is + not already of the required dtype before converting to Array. Note that + export will be single-precision (32bit) unless the Jax config/environment + directs otherwise (eg: "jax_enable_x64" was set True in the config object + at startup, or "JAX_ENABLE_X64" is set to "1" in the environment). + order : {"c", "fortran"} + The index order of the returned Jax array, either C-like (row-major) or + Fortran-like (column-major). + + See Also + -------- + to_dummies + to_numpy + to_torch + + Examples + -------- + >>> df = pl.DataFrame( + ... { + ... "lbl": [0, 1, 2, 3], + ... "feat1": [1, 0, 0, 1], + ... "feat2": [1.5, -0.5, 0.0, -2.25], + ... } + ... ) + + Standard return type (2D Array), on the standard device: + + >>> df.to_jax() + Array([[ 0. , 1. , 1.5 ], + [ 1. , 0. , -0.5 ], + [ 2. , 0. , 0. ], + [ 3. , 1. , -2.25]], dtype=float32) + + Create the Array on the default GPU device: + + >>> a = df.to_jax(device="gpu") # doctest: +SKIP + >>> a.device() # doctest: +SKIP + GpuDevice(id=0, process_index=0) + + Create the Array on a specific GPU device: + + >>> gpu_device = jax.devices("gpu")[1]) # doctest: +SKIP + >>> a = df.to_jax(device=gpu_device) # doctest: +SKIP + >>> a.device() # doctest: +SKIP + GpuDevice(id=1, process_index=0) + + As a dictionary of individual Arrays: + + >>> df.to_jax("dict") + {'lbl': Array([0, 1, 2, 3], dtype=int32), + 'feat1': Array([1, 0, 0, 1], dtype=int32), + 'feat2': Array([ 1.5 , -0.5 , 0. , -2.25], dtype=float32)} + + As a "label" and "features" dictionary; note that as "features" is not + declared, it defaults to all the columns that are not in "label": + + >>> df.to_jax("dict", label="lbl") + {'label': Array([[0], + [1], + [2], + [3]], dtype=int32), + 'features': Array([[ 1. , 1.5 ], + [ 0. , -0.5 ], + [ 0. , 0. ], + [ 1. , -2.25]], dtype=float32)} + + As a "label" and "features" dictionary where each is designated using + a col or selector expression (which can also be used to cast the data + if the label and features are better-represented with different dtypes): + + >>> import polars.selectors as cs + >>> df.to_jax( + ... return_type="dict", + ... features=cs.float(), + ... label=pl.col("lbl").cast(pl.UInt8), + ... ) + {'label': Array([[0], + [1], + [2], + [3]], dtype=uint8), + 'features': Array([[ 1.5 ], + [-0.5 ], + [ 0. ], + [-2.25]], dtype=float32)} + """ + if return_type != "dict" and (label is not None or features is not None): + msg = "`label` and `features` only apply when `return_type` is 'dict'" + raise ValueError(msg) + elif return_type == "dict" and label is None and features is not None: + msg = "`label` is required if setting `features` when `return_type='dict'" + raise ValueError(msg) + + jx = import_optional( + "jax", + install_message="Please see `https://jax.readthedocs.io/en/latest/installation.html` " + "for specific installation recommendations for the Jax package", + ) + enabled_double_precision = jx.config.jax_enable_x64 or bool( + int(os.environ.get("JAX_ENABLE_X64", "0")) + ) + if dtype: + frame = self.cast(dtype) + elif not enabled_double_precision: + # enforce single-precision unless environment/config directs otherwise + frame = self.cast({Float64: Float32, Int64: Int32, UInt64: UInt32}) + else: + frame = self + + if isinstance(device, str): + device = jx.devices(device)[0] + + with contextlib.nullcontext() if device is None else jx.default_device(device): + if return_type == "array": + return jx.numpy.asarray( + # note: jax arrays are immutable, so can avoid a copy (vs torch) + a=frame.to_numpy(writable=False, use_pyarrow=False, order=order), + order="K", + ) + elif return_type == "dict": + if label is not None: + # return a {"label": array(s), "features": array(s)} dict + label_frame = frame.select(label) + features_frame = ( + frame.select(features) + if features is not None + else frame.drop(*label_frame.columns) + ) + return { + "label": label_frame.to_jax(), + "features": features_frame.to_jax(), + } + else: + # return a {"col": array} dict + return {srs.name: srs.to_jax() for srs in frame} + else: + valid_jax_types = ", ".join(get_args(JaxExportType)) + msg = f"invalid `return_type`: {return_type!r}\nExpected one of: {valid_jax_types}" + raise ValueError(msg) + @overload def to_torch( self, @@ -1639,6 +1843,7 @@ def to_torch( dtype: PolarsDataType | None = ..., ) -> dict[str, torch.Tensor]: ... + @unstable() def to_torch( self, return_type: TorchExportType = "tensor", @@ -1648,31 +1853,39 @@ def to_torch( dtype: PolarsDataType | None = None, ) -> torch.Tensor | dict[str, torch.Tensor] | PolarsDataset: """ - Convert DataFrame to a 2D PyTorch tensor, Dataset, or dict of Tensors. + Convert DataFrame to a PyTorch Tensor, Dataset, or dict of Tensors. .. versionadded:: 0.20.23 + .. warning:: + This functionality is currently considered **unstable**. It may be + changed at any point without it being considered a breaking change. + Parameters ---------- return_type : {"tensor", "dataset", "dict"} - Set return type; a 2D PyTorch tensor, PolarsDataset (a frame-specialized + Set return type; a PyTorch Tensor, PolarsDataset (a frame-specialized TensorDataset), or dict of Tensors. label One or more column names, expressions, or selectors that label the feature data; when `return_type` is "dataset", the PolarsDataset will return `(features, label)` tensor tuples for each row. Otherwise, it returns - `(features,)` tensor tuples where the feature contains all the row data; - note that setting this parameter with any other result type will raise an - informative error. + `(features,)` tensor tuples where the feature contains all the row data. features One or more column names, expressions, or selectors that contain the feature data; if omitted, all columns that are not designated as part of the label - are used. This parameter is a no-op for return-types other than "dataset". + are used. dtype - Unify the dtype of all returned tensors; this casts any frame Series - that are not of the required dtype before converting to tensor. This - includes the label column *unless* the label is an expression (such - as `pl.col("label_column").cast(pl.Int16)`). + Unify the dtype of all returned tensors; this casts any column that is + not of the required dtype before converting to Tensor. This includes + the label column *unless* the label is an expression (such as + `pl.col("label_column").cast(pl.Int16)`). + + See Also + -------- + to_dummies + to_jax + to_numpy Examples -------- @@ -1699,6 +1912,19 @@ def to_torch( 'feat1': tensor([1, 0, 0, 1]), 'feat2': tensor([ 1.5000, -0.5000, 0.0000, -2.2500], dtype=torch.float64)} + As a "label" and "features" dictionary; note that as "features" is not + declared, it defaults to all the columns that are not in "label": + + >>> df.to_torch("dict", label="lbl", dtype=pl.Float32) + {'label': tensor([[0.], + [1.], + [2.], + [3.]]), + 'features': tensor([[ 1.0000, 1.5000], + [ 0.0000, -0.5000], + [ 0.0000, 0.0000], + [ 1.0000, -2.2500]])} + As a PolarsDataset, with f64 supertype: >>> ds = df.to_torch("dataset", dtype=pl.Float64) @@ -1711,7 +1937,7 @@ def to_torch( (tensor([[ 0.0000, 1.0000, 1.5000], [ 3.0000, 1.0000, -2.2500]], dtype=torch.float64),) - As a convenience the PolarsDataset can opt-in to half-precision data + As a convenience the PolarsDataset can opt in to half-precision data for experimentation (usually this would be set on the model/pipeline): >>> list(ds.half()) @@ -1735,7 +1961,7 @@ def to_torch( supported). >>> ds = df.to_torch( - ... "dataset", + ... return_type="dataset", ... dtype=pl.Float32, ... label=pl.col("lbl").cast(pl.Int16), ... ) @@ -1760,8 +1986,13 @@ def to_torch( ... batch_size=64, ... ) # doctest: +SKIP """ - if return_type != "dataset" and (label is not None or features is not None): - msg = "the `label` and `features` parameters can only be set when `return_type='dataset'`" + if return_type not in ("dataset", "dict") and ( + label is not None or features is not None + ): + msg = "`label` and `features` only apply when `return_type` is 'dataset' or 'dict'" + raise ValueError(msg) + elif return_type == "dict" and label is None and features is not None: + msg = "`label` is required if setting `features` when `return_type='dict'" raise ValueError(msg) torch = import_optional("torch") @@ -1774,10 +2005,28 @@ def to_torch( frame = self.cast(to_dtype) # type: ignore[arg-type] if return_type == "tensor": + # note: torch tensors are not immutable, so we must consider them writable return torch.from_numpy(frame.to_numpy(writable=True, use_pyarrow=False)) + elif return_type == "dict": - return {srs.name: srs.to_torch() for srs in frame} + if label is not None: + # return a {"label": tensor(s), "features": tensor(s)} dict + label_frame = frame.select(label) + features_frame = ( + frame.select(features) + if features is not None + else frame.drop(*label_frame.columns) + ) + return { + "label": label_frame.to_torch(), + "features": features_frame.to_torch(), + } + else: + # return a {"col": tensor} dict + return {srs.name: srs.to_torch() for srs in frame} + elif return_type == "dataset": + # return a torch Dataset object from polars.ml.torch import PolarsDataset return PolarsDataset(frame, label=label, features=features) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index a3aaa3d001cb..66f150957584 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -2,6 +2,8 @@ import contextlib import math +import os +from contextlib import nullcontext from datetime import date, datetime, time, timedelta from decimal import Decimal as PyDecimal from typing import ( @@ -66,6 +68,7 @@ Decimal, Duration, Enum, + Float32, Float64, Int8, Int16, @@ -117,6 +120,7 @@ if TYPE_CHECKING: import sys + import jax import torch from hvplot.plotting.core import hvPlotTabularPolars @@ -4472,15 +4476,80 @@ def to_numpy( return self._s.to_numpy(allow_copy=allow_copy, writable=writable) + @unstable() + def to_jax(self, device: jax.Device | str | None = None) -> jax.Array: + """ + Convert this Series to a Jax Array. + + .. versionadded:: 0.20.27 + + .. warning:: + This functionality is currently considered **unstable**. It may be + changed at any point without it being considered a breaking change. + + Parameters + ---------- + device + Specify the jax `Device` on which the array will be created; can provide + a string (such as "cpu", "gpu", or "tpu") in which case the device is + retrieved as `jax.devices(string)[0]`. For more specific control you + can supply the instantiated `Device` directly. If None, arrays are + created on the default device. + + Examples + -------- + >>> s = pl.Series("x", [10.5, 0.0, -10.0, 5.5]) + >>> s.to_jax() + Array([ 10.5, 0. , -10. , 5.5], dtype=float32) + """ + jx = import_optional( + "jax", + install_message="Please see `https://jax.readthedocs.io/en/latest/installation.html` " + "for specific installation recommendations for the Jax package", + ) + if isinstance(device, str): + device = jx.devices(device)[0] + if ( + jx.config.jax_enable_x64 + or bool(int(os.environ.get("JAX_ENABLE_X64", "0"))) + or self.dtype not in {Float64, Int64, UInt64} + ): + srs = self + else: + single_precision = {Float64: Float32, Int64: Int32, UInt64: UInt32} + srs = self.cast(single_precision[self.dtype]) # type: ignore[index] + + with nullcontext() if device is None else jx.default_device(device): + return jx.numpy.asarray( + # note: jax arrays are immutable, so can avoid a copy (vs torch) + a=srs.to_numpy(writable=False, use_pyarrow=False), + order="K", + ) + + @unstable() def to_torch(self) -> torch.Tensor: """ - Convert this Series to a PyTorch tensor. + Convert this Series to a PyTorch Tensor. + + .. versionadded:: 0.20.23 + + .. warning:: + This functionality is currently considered **unstable**. It may be + changed at any point without it being considered a breaking change. + + Notes + ----- + PyTorch tensors do not support UInt16, UInt32, or UInt64; these dtypes + will be automatically cast to Int32, Int64, and Int64, respectively. Examples -------- >>> s = pl.Series("x", [1, 0, 1, 2, 0], dtype=pl.UInt8) >>> s.to_torch() tensor([1, 0, 1, 2, 0], dtype=torch.uint8) + >>> s = pl.Series("x", [5.5, -10.0, 2.5], dtype=pl.Float32) + >>> s.to_torch() + tensor([ 5.5000, -10.0000, 2.5000]) """ torch = import_optional("torch") diff --git a/py-polars/polars/type_aliases.py b/py-polars/polars/type_aliases.py index b57dcee1f5a3..92daf0b2dd0c 100644 --- a/py-polars/polars/type_aliases.py +++ b/py-polars/polars/type_aliases.py @@ -163,6 +163,7 @@ DbWriteEngine: TypeAlias = Literal["sqlalchemy", "adbc"] DbWriteMode: TypeAlias = Literal["replace", "append", "fail"] EpochTimeUnit = Literal["ns", "us", "ms", "s", "d"] +JaxExportType: TypeAlias = Literal["array", "dict"] Orientation: TypeAlias = Literal["col", "row"] SearchSortedSide: TypeAlias = Literal["any", "left", "right"] TorchExportType: TypeAlias = Literal["tensor", "dataset", "dict"] diff --git a/py-polars/pyproject.toml b/py-polars/pyproject.toml index 75ee3f46ff93..9ac0df0cd121 100644 --- a/py-polars/pyproject.toml +++ b/py-polars/pyproject.toml @@ -58,11 +58,10 @@ pydantic = ["pydantic"] pyxlsb = ["pyxlsb >= 1.0"] sqlalchemy = ["sqlalchemy", "pandas"] timezone = ["backports.zoneinfo; python_version < '3.9'", "tzdata; platform_system == 'Windows'"] -torch = ["torch"] xlsx2csv = ["xlsx2csv >= 0.8.0"] xlsxwriter = ["xlsxwriter"] all = [ - "polars[adbc,async,cloudpickle,connectorx,deltalake,fastexcel,fsspec,gevent,numpy,pandas,plot,pyarrow,pydantic,iceberg,sqlalchemy,timezone,torch,xlsx2csv,xlsxwriter]", + "polars[adbc,async,cloudpickle,connectorx,deltalake,fastexcel,fsspec,gevent,numpy,pandas,plot,pyarrow,pydantic,iceberg,sqlalchemy,timezone,xlsx2csv,xlsxwriter]", ] [tool.maturin] @@ -92,6 +91,7 @@ module = [ "fsspec.*", "gevent", "hvplot.*", + "jax.*", "kuzu", "matplotlib.*", "moto.server", diff --git a/py-polars/requirements-ci.txt b/py-polars/requirements-ci.txt index 3086002307dd..fbb39463fced 100644 --- a/py-polars/requirements-ci.txt +++ b/py-polars/requirements-ci.txt @@ -4,4 +4,6 @@ # ------------------------------------------------------- --extra-index-url https://download.pytorch.org/whl/cpu torch +jax +jaxlib pyiceberg>=0.5.0 diff --git a/py-polars/tests/docs/run_doctest.py b/py-polars/tests/docs/run_doctest.py index 7da0150e3347..39b95a548ddc 100644 --- a/py-polars/tests/docs/run_doctest.py +++ b/py-polars/tests/docs/run_doctest.py @@ -58,6 +58,7 @@ # if the module is found in the environment those doctests will # run; if the module is not found, their doctests are skipped. OPTIONAL_MODULES_AND_METHODS: dict[str, set[str]] = { + "jax": {"to_jax"}, "torch": {"to_torch"}, } OPTIONAL_MODULES: set[str] = set() diff --git a/py-polars/tests/unit/dataframe/test_to_torch.py b/py-polars/tests/unit/dataframe/test_to_torch.py deleted file mode 100644 index be8de2d1f2d1..000000000000 --- a/py-polars/tests/unit/dataframe/test_to_torch.py +++ /dev/null @@ -1,290 +0,0 @@ -from __future__ import annotations - -from typing import Any - -import pytest - -import polars as pl -import polars.selectors as cs -from polars.dependencies import _lazy_import - -# don't import torch until an actual test is triggered (the decorator already -# ensures the tests aren't run locally, this will skip premature local import) -torch, _ = _lazy_import("torch") - - -@pytest.fixture() -def df() -> pl.DataFrame: - return pl.DataFrame( - { - "x": [1, 2, 2, 3], - "y": [True, False, True, False], - "z": [1.5, -0.5, 0.0, -2.0], - }, - schema_overrides={"x": pl.Int8, "z": pl.Float32}, - ) - - -@pytest.mark.ci_only() -class TestTorchIntegration: - """Test coverage for `to_torch` conversions and `polars.ml.torch` classes.""" - - def assert_tensor(self, actual: Any, expected: Any) -> None: - torch.testing.assert_close(actual, expected) - - def test_to_torch_series( - self, - ) -> None: - s = pl.Series("x", [1, 2, 3, 4], dtype=pl.Int8) - t = s.to_torch() - - assert list(t.shape) == [4] - self.assert_tensor(t, torch.tensor([1, 2, 3, 4], dtype=torch.int8)) - - # note: torch doesn't natively support uint16/32/64. - # confirm that we export to a suitable signed integer type - s = s.cast(pl.UInt16) - t = s.to_torch() - self.assert_tensor(t, torch.tensor([1, 2, 3, 4], dtype=torch.int32)) - - for dtype in (pl.UInt32, pl.UInt64): - t = s.cast(dtype).to_torch() - self.assert_tensor(t, torch.tensor([1, 2, 3, 4], dtype=torch.int64)) - - def test_to_torch_tensor(self, df: pl.DataFrame) -> None: - t1 = df.to_torch() - t2 = df.to_torch("tensor") - - assert list(t1.shape) == [4, 3] - assert (t1 == t2).all().item() is True - - def test_to_torch_dict(self, df: pl.DataFrame) -> None: - td = df.to_torch("dict") - - assert list(td.keys()) == ["x", "y", "z"] - - self.assert_tensor(td["x"], torch.tensor([1, 2, 2, 3], dtype=torch.int8)) - self.assert_tensor( - td["y"], torch.tensor([True, False, True, False], dtype=torch.bool) - ) - self.assert_tensor( - td["z"], torch.tensor([1.5, -0.5, 0.0, -2.0], dtype=torch.float32) - ) - - def test_to_torch_dataset(self, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset", dtype=pl.Float64) - - assert len(ds) == 4 - assert isinstance(ds, torch.utils.data.Dataset) - assert repr(ds).startswith(" None: - ds = df.to_torch("dataset", label="x", features=["z", "y"]) - self.assert_tensor( - torch.tensor( - [ - [1.5000, 1.0000], - [-0.5000, 0.0000], - [0.0000, 1.0000], - [-2.0000, 0.0000], - ] - ), - ds.features, - ) - self.assert_tensor(torch.tensor([1, 2, 2, 3], dtype=torch.int8), ds.labels) - - def test_to_torch_dataset_feature_subset(self, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset", label="x", features=["z"]) - self.assert_tensor( - torch.tensor([[1.5000], [-0.5000], [0.0000], [-2.0000]]), - ds.features, - ) - self.assert_tensor(torch.tensor([1, 2, 2, 3], dtype=torch.int8), ds.labels) - - def test_to_torch_dataset_index_slice(self, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset") - ts = ds[1:3] - - expected = ( - torch.tensor([[2.0000, 0.0000, -0.5000], [2.0000, 1.0000, 0.0000]]), - ) - self.assert_tensor(expected, ts) - - ts = ds[::2] - expected = (torch.tensor([[1.0000, 1.0000, 1.5000], [2.0, 1.0, 0.0]]),) - self.assert_tensor(expected, ts) - - @pytest.mark.parametrize( - "index", - [ - [0, 3], - range(0, 4, 3), - slice(0, 4, 3), - ], - ) - def test_to_torch_dataset_index_multi(self, index: Any, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset") - ts = ds[index] - - expected = (torch.tensor([[1.0, 1.0, 1.5], [3.0, 0.0, -2.0]]),) - self.assert_tensor(expected, ts) - assert ds.schema == {"features": torch.float32, "labels": None} - - def test_to_torch_dataset_index_range(self, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset") - ts = ds[range(3, 0, -1)] - - expected = ( - torch.tensor([[3.0, 0.0, -2.0], [2.0, 1.0, 0.0], [2.0, 0.0, -0.5]]), - ) - self.assert_tensor(expected, ts) - - def test_to_dataset_half_precision(self, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset", label="x") - assert ds.schema == {"features": torch.float32, "labels": torch.int8} - - dsf16 = ds.half() - assert dsf16.schema == {"features": torch.float16, "labels": torch.float16} - - # half precision across all data - ts = dsf16[:3:2] - expected = ( - torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float16), - torch.tensor([1.0, 2.0], dtype=torch.float16), - ) - self.assert_tensor(expected, ts) - - # only apply half precision to the feature data - dsf16 = ds.half(labels=False) - assert dsf16.schema == {"features": torch.float16, "labels": torch.int8} - - ts = dsf16[:3:2] - expected = ( - torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float16), - torch.tensor([1, 2], dtype=torch.int8), - ) - self.assert_tensor(expected, ts) - - # only apply half precision to the label data - dsf16 = ds.half(features=False) - assert dsf16.schema == {"features": torch.float32, "labels": torch.float16} - - ts = dsf16[:3:2] - expected = ( - torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float32), - torch.tensor([1.0, 2.0], dtype=torch.float16), - ) - self.assert_tensor(expected, ts) - - # no labels - dsf16 = df.to_torch("dataset").half() - assert dsf16.schema == {"features": torch.float16, "labels": None} - - ts = dsf16[:3:2] - expected = ( # type: ignore[assignment] - torch.tensor( - data=[[1.0000, 1.0000, 1.5000], [2.0000, 1.0000, 0.0000]], - dtype=torch.float16, - ), - ) - self.assert_tensor(expected, ts) - - @pytest.mark.parametrize( - ("label", "features"), - [ - ("x", None), - ("x", ["y", "z"]), - (cs.by_dtype(pl.INTEGER_DTYPES), ~cs.by_dtype(pl.INTEGER_DTYPES)), - ], - ) - def test_to_torch_labelled_dataset( - self, label: Any, features: Any, df: pl.DataFrame - ) -> None: - ds = df.to_torch("dataset", label=label, features=features) - ts = next(iter(torch.utils.data.DataLoader(ds, batch_size=2, shuffle=False))) - - expected = [ - torch.tensor([[1.0, 1.5], [0.0, -0.5]]), - torch.tensor([1, 2], dtype=torch.int8), - ] - assert len(ts) == len(expected) - for actual, exp in zip(ts, expected): - self.assert_tensor(exp, actual) - - def test_to_torch_labelled_dataset_expr(self, df: pl.DataFrame) -> None: - ds = df.to_torch( - "dataset", - dtype=pl.Float64, - label=(pl.col("x") * 8).cast(pl.Int16), - ) - dl = torch.utils.data.DataLoader(ds, batch_size=2, shuffle=False) - for data in (tuple(ds[:2]), tuple(next(iter(dl)))): - expected = ( - torch.tensor( - [[1.0000, 1.5000], [0.0000, -0.5000]], dtype=torch.float64 - ), - torch.tensor([8, 16], dtype=torch.int16), - ) - assert len(data) == len(expected) - for actual, exp in zip(data, expected): - self.assert_tensor(exp, actual) - - def test_to_torch_labelled_dataset_multi(self, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset", label=["x", "y"]) - dl = torch.utils.data.DataLoader(ds, batch_size=3, shuffle=False) - ts = list(dl) - - expected = [ - [ - torch.tensor([[1.5000], [-0.5000], [0.0000]]), - torch.tensor([[1, 1], [2, 0], [2, 1]], dtype=torch.int8), - ], - [ - torch.tensor([[-2.0]]), - torch.tensor([[3, 0]], dtype=torch.int8), - ], - ] - assert len(ts) == len(expected) - - for actual, exp in zip(ts, expected): - assert len(actual) == len(exp) - for a, e in zip(actual, exp): - self.assert_tensor(e, a) - - def test_misc_errors(self, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset") - - with pytest.raises( - ValueError, - match="invalid `return_type`: 'stroopwafel'", - ): - _res0 = df.to_torch("stroopwafel") # type: ignore[call-overload] - - with pytest.raises( - ValueError, - match="does not support u16, u32, or u64 dtypes", - ): - _res1 = df.to_torch(dtype=pl.UInt16) - - with pytest.raises( - IndexError, - match="tensors used as indices must be long, int", - ): - _res2 = ds[torch.tensor([0, 3], dtype=torch.complex64)] - - with pytest.raises( - ValueError, - match="`label` and `features` parameters .* when `return_type='dataset'`", - ): - _res3 = df.to_torch(label="stroopwafel") - - with pytest.raises( - ValueError, - match="`label` and `features` parameters .* when `return_type='dataset'`", - ): - _res4 = df.to_torch("dict", features=cs.float()) diff --git a/py-polars/tests/unit/ml/__init__.py b/py-polars/tests/unit/ml/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/ml/test_to_jax.py b/py-polars/tests/unit/ml/test_to_jax.py new file mode 100644 index 000000000000..5dc0c172f084 --- /dev/null +++ b/py-polars/tests/unit/ml/test_to_jax.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING, Any + +import pytest + +import polars as pl +import polars.selectors as cs +from polars.dependencies import _lazy_import + +# don't import jax until an actual test is triggered (the decorator already +# ensures the tests aren't run locally; this avoids premature local import) +jx, _ = _lazy_import("jax") +jxn, _ = _lazy_import("jax.numpy") + +pytestmark = pytest.mark.ci_only + +if TYPE_CHECKING: + from polars.datatypes import PolarsDataType + + +@pytest.fixture() +def df() -> pl.DataFrame: + return pl.DataFrame( + { + "x": [1, 2, 2, 3], + "y": [1, 0, 1, 0], + "z": [1.5, -0.5, 0.0, -2.0], + }, + schema_overrides={"x": pl.Int8, "z": pl.Float32}, + ) + + +def assert_array_equal(actual: Any, expected: Any, nans_equal: bool = True) -> None: + assert isinstance(actual, jx.Array) + jxn.array_equal(actual, expected, equal_nan=nans_equal) + + +@pytest.mark.parametrize( + ("dtype", "expected_jax_dtype"), + [ + (pl.Int8, "int8"), + (pl.Int16, "int16"), + (pl.Int32, "int32"), + (pl.Int64, "int32"), + (pl.UInt8, "uint8"), + (pl.UInt16, "uint16"), + (pl.UInt32, "uint32"), + (pl.UInt64, "uint32"), + ], +) +def test_to_jax_from_series( + dtype: PolarsDataType, + expected_jax_dtype: str, +) -> None: + s = pl.Series("x", [1, 2, 3, 4], dtype=dtype) + for dvc in (None, "cpu", jx.devices("cpu")[0]): + assert_array_equal( + s.to_jax(device=dvc), + jxn.array([1, 2, 3, 4], dtype=getattr(jxn, expected_jax_dtype)), + ) + + +def test_to_jax_array(df: pl.DataFrame) -> None: + a1 = df.to_jax() + a2 = df.to_jax("array") + a3 = df.to_jax("array", device="cpu") + a4 = df.to_jax("array", device=jx.devices("cpu")[0]) + + expected = jxn.array( + [ + [1.0, 1.0, 1.5], + [2.0, 0.0, -0.5], + [2.0, 1.0, 0.0], + [3.0, 0.0, -2.0], + ], + dtype=jxn.float32, + ) + for a in (a1, a2, a3, a4): + assert_array_equal(a, expected) + + +def test_to_jax_dict(df: pl.DataFrame) -> None: + arr_dict = df.to_jax("dict") + assert list(arr_dict.keys()) == ["x", "y", "z"] + + assert_array_equal(arr_dict["x"], jxn.array([1, 2, 2, 3], dtype=jxn.int8)) + assert_array_equal(arr_dict["y"], jxn.array([1, 0, 1, 0], dtype=jxn.int32)) + assert_array_equal( + arr_dict["z"], + jxn.array([1.5, -0.5, 0.0, -2.0], dtype=jxn.float32), + ) + + arr_dict = df.to_jax("dict", dtype=pl.Float32) + for a, expected_data in zip( + arr_dict.values(), + ([1.0, 2.0, 2.0, 3.0], [1.0, 0.0, 1.0, 0.0], [1.5, -0.5, 0.0, -2.0]), + ): + assert_array_equal(a, jxn.array(expected_data, dtype=jxn.float32)) + + +@pytest.mark.skipif( + sys.version_info < (3, 9), + reason="jax.numpy.bool requires Python >= 3.9", +) +def test_to_jax_feature_label_dict(df: pl.DataFrame) -> None: + df = pl.DataFrame( + { + "age": [25, 32, 45, 22, 34], + "income": [50000, 75000, 60000, 58000, 120000], + "education": ["bachelor", "master", "phd", "bachelor", "phd"], + "purchased": [False, True, True, False, True], + } + ).to_dummies("education", separator=":") + + lbl_feat_dict = df.to_jax(return_type="dict", label="purchased") + assert list(lbl_feat_dict.keys()) == ["label", "features"] + + assert_array_equal( + lbl_feat_dict["label"], + jxn.array([[False], [True], [True], [False], [True]], dtype=jxn.bool), + ) + assert_array_equal( + lbl_feat_dict["features"], + jxn.array( + [ + [25, 50000, 1, 0, 0], + [32, 75000, 0, 1, 0], + [45, 60000, 0, 0, 1], + [22, 58000, 1, 0, 0], + [34, 120000, 0, 0, 1], + ], + dtype=jxn.int32, + ), + ) + + +def test_misc_errors(df: pl.DataFrame) -> None: + with pytest.raises( + ValueError, + match="invalid `return_type`: 'stroopwafel'", + ): + _res0 = df.to_jax("stroopwafel") # type: ignore[call-overload] + + with pytest.raises( + ValueError, + match="`label` is required if setting `features` when `return_type='dict'", + ): + _res2 = df.to_jax("dict", features=cs.float()) + + with pytest.raises( + ValueError, + match="`label` and `features` only apply when `return_type` is 'dict'", + ): + _res3 = df.to_jax(label="stroopwafel") diff --git a/py-polars/tests/unit/ml/test_to_torch.py b/py-polars/tests/unit/ml/test_to_torch.py new file mode 100644 index 000000000000..7f1a4711c8ac --- /dev/null +++ b/py-polars/tests/unit/ml/test_to_torch.py @@ -0,0 +1,325 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import polars as pl +import polars.selectors as cs +from polars.dependencies import _lazy_import + +# don't import torch until an actual test is triggered (the decorator already +# ensures the tests aren't run locally; this avoids premature local import) +torch, _ = _lazy_import("torch") + +pytestmark = pytest.mark.ci_only + + +@pytest.fixture() +def df() -> pl.DataFrame: + return pl.DataFrame( + { + "x": [1, 2, 2, 3], + "y": [True, False, True, False], + "z": [1.5, -0.5, 0.0, -2.0], + }, + schema_overrides={"x": pl.Int8, "z": pl.Float32}, + ) + + +def assert_tensor_equal(actual: Any, expected: Any) -> None: + torch.testing.assert_close(actual, expected) + + +def test_to_torch_from_series() -> None: + s = pl.Series("x", [1, 2, 3, 4], dtype=pl.Int8) + t = s.to_torch() + + assert list(t.shape) == [4] + assert_tensor_equal(t, torch.tensor([1, 2, 3, 4], dtype=torch.int8)) + + # note: torch doesn't natively support uint16/32/64. + # confirm that we export to a suitable signed integer type + s = s.cast(pl.UInt16) + t = s.to_torch() + assert_tensor_equal(t, torch.tensor([1, 2, 3, 4], dtype=torch.int32)) + + for dtype in (pl.UInt32, pl.UInt64): + t = s.cast(dtype).to_torch() + assert_tensor_equal(t, torch.tensor([1, 2, 3, 4], dtype=torch.int64)) + + +def test_to_torch_tensor(df: pl.DataFrame) -> None: + t1 = df.to_torch() + t2 = df.to_torch("tensor") + + assert list(t1.shape) == [4, 3] + assert (t1 == t2).all().item() is True + + +def test_to_torch_dict(df: pl.DataFrame) -> None: + td = df.to_torch("dict") + + assert list(td.keys()) == ["x", "y", "z"] + + assert_tensor_equal(td["x"], torch.tensor([1, 2, 2, 3], dtype=torch.int8)) + assert_tensor_equal( + td["y"], torch.tensor([True, False, True, False], dtype=torch.bool) + ) + assert_tensor_equal( + td["z"], torch.tensor([1.5, -0.5, 0.0, -2.0], dtype=torch.float32) + ) + + +def test_to_torch_feature_label_dict(df: pl.DataFrame) -> None: + df = pl.DataFrame( + { + "age": [25, 32, 45, 22, 34], + "income": [50000, 75000, 60000, 58000, 120000], + "education": ["bachelor", "master", "phd", "bachelor", "phd"], + "purchased": [False, True, True, False, True], + }, + schema_overrides={"age": pl.Int32, "income": pl.Int32}, + ).to_dummies("education", separator=":") + + lbl_feat_dict = df.to_torch(return_type="dict", label="purchased") + assert list(lbl_feat_dict.keys()) == ["label", "features"] + + assert_tensor_equal( + lbl_feat_dict["label"], + torch.tensor([[False], [True], [True], [False], [True]], dtype=torch.bool), + ) + assert_tensor_equal( + lbl_feat_dict["features"], + torch.tensor( + [ + [25, 50000, 1, 0, 0], + [32, 75000, 0, 1, 0], + [45, 60000, 0, 0, 1], + [22, 58000, 1, 0, 0], + [34, 120000, 0, 0, 1], + ], + dtype=torch.int32, + ), + ) + + +def test_to_torch_dataset(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset", dtype=pl.Float64) + + assert len(ds) == 4 + assert isinstance(ds, torch.utils.data.Dataset) + assert repr(ds).startswith(" None: + ds = df.to_torch("dataset", label="x", features=["z", "y"]) + assert_tensor_equal( + torch.tensor( + [ + [1.5000, 1.0000], + [-0.5000, 0.0000], + [0.0000, 1.0000], + [-2.0000, 0.0000], + ] + ), + ds.features, + ) + assert_tensor_equal(torch.tensor([1, 2, 2, 3], dtype=torch.int8), ds.labels) + + +def test_to_torch_dataset_feature_subset(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset", label="x", features=["z"]) + assert_tensor_equal( + torch.tensor([[1.5000], [-0.5000], [0.0000], [-2.0000]]), + ds.features, + ) + assert_tensor_equal(torch.tensor([1, 2, 2, 3], dtype=torch.int8), ds.labels) + + +def test_to_torch_dataset_index_slice(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset") + ts = ds[1:3] + + expected = (torch.tensor([[2.0000, 0.0000, -0.5000], [2.0000, 1.0000, 0.0000]]),) + assert_tensor_equal(expected, ts) + + ts = ds[::2] + expected = (torch.tensor([[1.0000, 1.0000, 1.5000], [2.0, 1.0, 0.0]]),) + assert_tensor_equal(expected, ts) + + +@pytest.mark.parametrize( + "index", + [ + [0, 3], + range(0, 4, 3), + slice(0, 4, 3), + ], +) +def test_to_torch_dataset_index_multi(index: Any, df: pl.DataFrame) -> None: + ds = df.to_torch("dataset") + ts = ds[index] + + expected = (torch.tensor([[1.0, 1.0, 1.5], [3.0, 0.0, -2.0]]),) + assert_tensor_equal(expected, ts) + assert ds.schema == {"features": torch.float32, "labels": None} + + +def test_to_torch_dataset_index_range(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset") + ts = ds[range(3, 0, -1)] + + expected = (torch.tensor([[3.0, 0.0, -2.0], [2.0, 1.0, 0.0], [2.0, 0.0, -0.5]]),) + assert_tensor_equal(expected, ts) + + +def test_to_dataset_half_precision(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset", label="x") + assert ds.schema == {"features": torch.float32, "labels": torch.int8} + + dsf16 = ds.half() + assert dsf16.schema == {"features": torch.float16, "labels": torch.float16} + + # half precision across all data + ts = dsf16[:3:2] + expected = ( + torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float16), + torch.tensor([1.0, 2.0], dtype=torch.float16), + ) + assert_tensor_equal(expected, ts) + + # only apply half precision to the feature data + dsf16 = ds.half(labels=False) + assert dsf16.schema == {"features": torch.float16, "labels": torch.int8} + + ts = dsf16[:3:2] + expected = ( + torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float16), + torch.tensor([1, 2], dtype=torch.int8), + ) + assert_tensor_equal(expected, ts) + + # only apply half precision to the label data + dsf16 = ds.half(features=False) + assert dsf16.schema == {"features": torch.float32, "labels": torch.float16} + + ts = dsf16[:3:2] + expected = ( + torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float32), + torch.tensor([1.0, 2.0], dtype=torch.float16), + ) + assert_tensor_equal(expected, ts) + + # no labels + dsf16 = df.to_torch("dataset").half() + assert dsf16.schema == {"features": torch.float16, "labels": None} + + ts = dsf16[:3:2] + expected = ( # type: ignore[assignment] + torch.tensor( + data=[[1.0000, 1.0000, 1.5000], [2.0000, 1.0000, 0.0000]], + dtype=torch.float16, + ), + ) + assert_tensor_equal(expected, ts) + + +@pytest.mark.parametrize( + ("label", "features"), + [ + ("x", None), + ("x", ["y", "z"]), + (cs.by_dtype(pl.INTEGER_DTYPES), ~cs.by_dtype(pl.INTEGER_DTYPES)), + ], +) +def test_to_torch_labelled_dataset(label: Any, features: Any, df: pl.DataFrame) -> None: + ds = df.to_torch("dataset", label=label, features=features) + ts = next(iter(torch.utils.data.DataLoader(ds, batch_size=2, shuffle=False))) + + expected = [ + torch.tensor([[1.0, 1.5], [0.0, -0.5]]), + torch.tensor([1, 2], dtype=torch.int8), + ] + assert len(ts) == len(expected) + for actual, exp in zip(ts, expected): + assert_tensor_equal(exp, actual) + + +def test_to_torch_labelled_dataset_expr(df: pl.DataFrame) -> None: + ds = df.to_torch( + "dataset", + dtype=pl.Float64, + label=(pl.col("x") * 8).cast(pl.Int16), + ) + dl = torch.utils.data.DataLoader(ds, batch_size=2, shuffle=False) + for data in (tuple(ds[:2]), tuple(next(iter(dl)))): + expected = ( + torch.tensor([[1.0000, 1.5000], [0.0000, -0.5000]], dtype=torch.float64), + torch.tensor([8, 16], dtype=torch.int16), + ) + assert len(data) == len(expected) + for actual, exp in zip(data, expected): + assert_tensor_equal(exp, actual) + + +def test_to_torch_labelled_dataset_multi(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset", label=["x", "y"]) + dl = torch.utils.data.DataLoader(ds, batch_size=3, shuffle=False) + ts = list(dl) + + expected = [ + [ + torch.tensor([[1.5000], [-0.5000], [0.0000]]), + torch.tensor([[1, 1], [2, 0], [2, 1]], dtype=torch.int8), + ], + [ + torch.tensor([[-2.0]]), + torch.tensor([[3, 0]], dtype=torch.int8), + ], + ] + assert len(ts) == len(expected) + + for actual, exp in zip(ts, expected): + assert len(actual) == len(exp) + for a, e in zip(actual, exp): + assert_tensor_equal(e, a) + + +def test_misc_errors(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset") + + with pytest.raises( + ValueError, + match="invalid `return_type`: 'stroopwafel'", + ): + _res0 = df.to_torch("stroopwafel") # type: ignore[call-overload] + + with pytest.raises( + ValueError, + match="does not support u16, u32, or u64 dtypes", + ): + _res1 = df.to_torch(dtype=pl.UInt16) + + with pytest.raises( + IndexError, + match="tensors used as indices must be long, int", + ): + _res2 = ds[torch.tensor([0, 3], dtype=torch.complex64)] + + with pytest.raises( + ValueError, + match="`label` and `features` only apply when `return_type` is 'dataset' or 'dict'", + ): + _res3 = df.to_torch(label="stroopwafel") + + with pytest.raises( + ValueError, + match="`label` is required if setting `features` when `return_type='dict'", + ): + _res4 = df.to_torch("dict", features=cs.float())