diff --git a/crates/polars-plan/src/dsl/expr.rs b/crates/polars-plan/src/dsl/expr.rs index a46602551884..a4a81d56582a 100644 --- a/crates/polars-plan/src/dsl/expr.rs +++ b/crates/polars-plan/src/dsl/expr.rs @@ -71,7 +71,7 @@ pub enum Expr { Column(ColumnName), Columns(Arc<[ColumnName]>), DtypeColumn(Vec), - IndexColumn(Arc<[i32]>), + IndexColumn(Arc<[i64]>), Literal(LiteralValue), BinaryExpr { left: Arc, diff --git a/crates/polars-plan/src/dsl/functions/selectors.rs b/crates/polars-plan/src/dsl/functions/selectors.rs index 68a868e93907..3a61dae987a2 100644 --- a/crates/polars-plan/src/dsl/functions/selectors.rs +++ b/crates/polars-plan/src/dsl/functions/selectors.rs @@ -58,7 +58,7 @@ pub fn dtype_cols>(dtype: DT) -> Expr { } /// Select multiple columns by index. -pub fn index_cols>(indices: N) -> Expr { +pub fn index_cols>(indices: N) -> Expr { let indices = indices.as_ref().to_vec(); Expr::IndexColumn(Arc::from(indices)) } diff --git a/crates/polars-plan/src/logical_plan/conversion/expr_to_ir.rs b/crates/polars-plan/src/logical_plan/conversion/expr_to_ir.rs index f411c449faef..f536a36d128d 100644 --- a/crates/polars-plan/src/logical_plan/conversion/expr_to_ir.rs +++ b/crates/polars-plan/src/logical_plan/conversion/expr_to_ir.rs @@ -349,7 +349,7 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta Expr::Nth(i) => AExpr::Nth(i), Expr::IndexColumn(idx) => { if idx.len() == 1 { - AExpr::Nth(idx[0].into()) + AExpr::Nth(idx[0]) } else { panic!("no multi-value `index-columns` expected at this point") } diff --git a/crates/polars-plan/src/logical_plan/expr_expansion.rs b/crates/polars-plan/src/logical_plan/expr_expansion.rs index ba2a1fd84684..f93925a2013d 100644 --- a/crates/polars-plan/src/logical_plan/expr_expansion.rs +++ b/crates/polars-plan/src/logical_plan/expr_expansion.rs @@ -304,10 +304,10 @@ fn expand_indices( expr: &Expr, result: &mut Vec, schema: &Schema, - indices: &[i32], + indices: &[i64], exclude: &PlHashSet>, ) -> PolarsResult<()> { - let n_fields = schema.len() as i32; + let n_fields = schema.len() as i64; for idx in indices { let mut idx = *idx; if idx < 0 { diff --git a/py-polars/polars/selectors.py b/py-polars/polars/selectors.py index 3c19e080ba03..7d1ecd03021a 100644 --- a/py-polars/polars/selectors.py +++ b/py-polars/polars/selectors.py @@ -4,7 +4,7 @@ from datetime import timezone from functools import reduce from operator import or_ -from typing import TYPE_CHECKING, Any, Collection, Literal, Mapping, overload +from typing import TYPE_CHECKING, Any, Collection, Literal, Mapping, Sequence, overload from polars import functions as F from polars._utils.deprecation import deprecate_nonkeyword_arguments @@ -596,7 +596,7 @@ def by_dtype( ) -def by_index(*indices: int | range | Collection[int | range]) -> SelectorType: +def by_index(*indices: int | range | Sequence[int | range]) -> SelectorType: """ Select all columns matching the given indices (or range objects). @@ -677,22 +677,12 @@ def by_index(*indices: int | range | Collection[int | range]) -> SelectorType: │ abc ┆ 0.5 ┆ 1.5 ┆ 2.5 ┆ … ┆ 46.5 ┆ 47.5 ┆ 48.5 ┆ 49.5 │ └─────┴─────┴─────┴─────┴───┴──────┴──────┴──────┴──────┘ """ - all_indices = [] + all_indices: list[int] = [] for idx in indices: - if isinstance(idx, int): - all_indices.append(idx) - elif isinstance(idx, range): - all_indices.extend(idx) - elif isinstance(idx, Collection): - for i in idx: - if isinstance(i, int): - all_indices.append(i) - else: - msg = f"invalid index: {i!r}" - raise TypeError(msg) + if isinstance(idx, (range, Sequence)): + all_indices.extend(idx) # type: ignore[arg-type] else: - msg = f"invalid index: {idx!r}" - raise TypeError(msg) + all_indices.append(idx) return _selector_proxy_( F.nth(all_indices), name="by_index", parameters={"*indices": indices} @@ -762,7 +752,8 @@ def by_name(*names: str | Collection[str]) -> SelectorType: raise TypeError(msg) all_names.append(n) else: - TypeError(f"Invalid name: {nm!r}") + msg = f"invalid name: {nm!r}" + raise TypeError(msg) return _selector_proxy_( F.col(all_names), name="by_name", parameters={"*names": all_names} diff --git a/py-polars/src/functions/lazy.rs b/py-polars/src/functions/lazy.rs index 9f211937f826..d7a87b9c635c 100644 --- a/py-polars/src/functions/lazy.rs +++ b/py-polars/src/functions/lazy.rs @@ -332,13 +332,13 @@ pub fn dtype_cols(dtypes: Vec>) -> PyResult { } #[pyfunction] -pub fn index_cols(indices: Vec) -> PyResult { - Ok(if indices.len() == 1 { - dsl::nth(indices[0].into()) +pub fn index_cols(indices: Vec) -> PyExpr { + if indices.len() == 1 { + dsl::nth(indices[0]) } else { dsl::index_cols(indices) } - .into()) + .into() } #[pyfunction] diff --git a/py-polars/tests/unit/test_selectors.py b/py-polars/tests/unit/test_selectors.py index d2a892426e0e..9d1a4704f503 100644 --- a/py-polars/tests/unit/test_selectors.py +++ b/py-polars/tests/unit/test_selectors.py @@ -71,6 +71,10 @@ def test_selector_by_dtype(df: pl.DataFrame) -> None: assert df.select(cs.by_dtype()).schema == {} assert df.select(cs.by_dtype([])).schema == {} + # expected errors + with pytest.raises(TypeError): + df.select(cs.by_dtype(999)) # type: ignore[arg-type] + def test_selector_by_index(df: pl.DataFrame) -> None: # one or more +ve indexes @@ -126,9 +130,13 @@ def test_selector_by_name(df: pl.DataFrame) -> None: assert df.select(cs.by_name()).columns == [] assert df.select(cs.by_name([])).columns == [] + # expected errors with pytest.raises(ColumnNotFoundError): df.select(cs.by_name("stroopwafel")) + with pytest.raises(TypeError): + df.select(cs.by_name(999)) # type: ignore[arg-type] + def test_selector_contains(df: pl.DataFrame) -> None: assert df.select(cs.contains("b")).columns == ["abc", "bbb"] @@ -147,6 +155,10 @@ def test_selector_contains(df: pl.DataFrame) -> None: ] assert df.select(cs.contains(("ee", "x"))).columns == ["eee"] + # expected errors + with pytest.raises(TypeError): + df.select(cs.contains(999)) # type: ignore[arg-type] + def test_selector_datetime(df: pl.DataFrame) -> None: assert df.select(cs.datetime()).schema == {"opp": pl.Datetime("ms")} @@ -233,6 +245,10 @@ def test_selector_datetime(df: pl.DataFrame) -> None: == df.select(~cs.datetime(["ms", "ns"], time_zone="*")).columns ) + # expected errors + with pytest.raises(TypeError): + df.select(cs.datetime(999)) # type: ignore[arg-type] + def test_select_decimal(df: pl.DataFrame) -> None: assert df.select(cs.decimal()).columns == [] @@ -288,6 +304,10 @@ def test_selector_ends_with(df: pl.DataFrame) -> None: "qqR", ] + # expected errors + with pytest.raises(TypeError): + df.select(cs.ends_with(999)) # type: ignore[arg-type] + def test_selector_first_last(df: pl.DataFrame) -> None: assert df.select(cs.first()).columns == ["abc"] @@ -388,6 +408,9 @@ def test_selector_startswith(df: pl.DataFrame) -> None: "opp", "qqR", ] + # expected errors + with pytest.raises(TypeError): + df.select(cs.starts_with(999)) # type: ignore[arg-type] def test_selector_temporal(df: pl.DataFrame) -> None: