Skip to content

Commit

Permalink
address misc comments 👍
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed May 15, 2024
1 parent 40a90fa commit 8fb7aa4
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 26 deletions.
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ pub enum Expr {
Column(ColumnName),
Columns(Arc<[ColumnName]>),
DtypeColumn(Vec<DataType>),
IndexColumn(Arc<[i32]>),
IndexColumn(Arc<[i64]>),
Literal(LiteralValue),
BinaryExpr {
left: Arc<Expr>,
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/functions/selectors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ pub fn dtype_cols<DT: AsRef<[DataType]>>(dtype: DT) -> Expr {
}

/// Select multiple columns by index.
pub fn index_cols<N: AsRef<[i32]>>(indices: N) -> Expr {
pub fn index_cols<N: AsRef<[i64]>>(indices: N) -> Expr {
let indices = indices.as_ref().to_vec();
Expr::IndexColumn(Arc::from(indices))
}
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena<AExpr>, state: &mut ConversionSta
Expr::Nth(i) => AExpr::Nth(i),
Expr::IndexColumn(idx) => {
if idx.len() == 1 {
AExpr::Nth(idx[0].into())
AExpr::Nth(idx[0])
} else {
panic!("no multi-value `index-columns` expected at this point")
}
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/logical_plan/expr_expansion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,10 @@ fn expand_indices(
expr: &Expr,
result: &mut Vec<Expr>,
schema: &Schema,
indices: &[i32],
indices: &[i64],
exclude: &PlHashSet<Arc<str>>,
) -> PolarsResult<()> {
let n_fields = schema.len() as i32;
let n_fields = schema.len() as i64;
for idx in indices {
let mut idx = *idx;
if idx < 0 {
Expand Down
25 changes: 8 additions & 17 deletions py-polars/polars/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from datetime import timezone
from functools import reduce
from operator import or_
from typing import TYPE_CHECKING, Any, Collection, Literal, Mapping, overload
from typing import TYPE_CHECKING, Any, Collection, Literal, Mapping, Sequence, overload

from polars import functions as F
from polars._utils.deprecation import deprecate_nonkeyword_arguments
Expand Down Expand Up @@ -596,7 +596,7 @@ def by_dtype(
)


def by_index(*indices: int | range | Collection[int | range]) -> SelectorType:
def by_index(*indices: int | range | Sequence[int | range]) -> SelectorType:
"""
Select all columns matching the given indices (or range objects).
Expand Down Expand Up @@ -677,22 +677,12 @@ def by_index(*indices: int | range | Collection[int | range]) -> SelectorType:
│ abc ┆ 0.5 ┆ 1.5 ┆ 2.5 ┆ … ┆ 46.5 ┆ 47.5 ┆ 48.5 ┆ 49.5 │
└─────┴─────┴─────┴─────┴───┴──────┴──────┴──────┴──────┘
"""
all_indices = []
all_indices: list[int] = []
for idx in indices:
if isinstance(idx, int):
all_indices.append(idx)
elif isinstance(idx, range):
all_indices.extend(idx)
elif isinstance(idx, Collection):
for i in idx:
if isinstance(i, int):
all_indices.append(i)
else:
msg = f"invalid index: {i!r}"
raise TypeError(msg)
if isinstance(idx, (range, Sequence)):
all_indices.extend(idx) # type: ignore[arg-type]
else:
msg = f"invalid index: {idx!r}"
raise TypeError(msg)
all_indices.append(idx)

return _selector_proxy_(
F.nth(all_indices), name="by_index", parameters={"*indices": indices}
Expand Down Expand Up @@ -762,7 +752,8 @@ def by_name(*names: str | Collection[str]) -> SelectorType:
raise TypeError(msg)
all_names.append(n)
else:
TypeError(f"Invalid name: {nm!r}")
msg = f"invalid name: {nm!r}"
raise TypeError(msg)

return _selector_proxy_(
F.col(all_names), name="by_name", parameters={"*names": all_names}
Expand Down
8 changes: 4 additions & 4 deletions py-polars/src/functions/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,13 +332,13 @@ pub fn dtype_cols(dtypes: Vec<Wrap<DataType>>) -> PyResult<PyExpr> {
}

#[pyfunction]
pub fn index_cols(indices: Vec<i32>) -> PyResult<PyExpr> {
Ok(if indices.len() == 1 {
dsl::nth(indices[0].into())
pub fn index_cols(indices: Vec<i64>) -> PyExpr {
if indices.len() == 1 {
dsl::nth(indices[0])
} else {
dsl::index_cols(indices)
}
.into())
.into()
}

#[pyfunction]
Expand Down
23 changes: 23 additions & 0 deletions py-polars/tests/unit/test_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def test_selector_by_dtype(df: pl.DataFrame) -> None:
assert df.select(cs.by_dtype()).schema == {}
assert df.select(cs.by_dtype([])).schema == {}

# expected errors
with pytest.raises(TypeError):
df.select(cs.by_dtype(999)) # type: ignore[arg-type]


def test_selector_by_index(df: pl.DataFrame) -> None:
# one or more +ve indexes
Expand Down Expand Up @@ -126,9 +130,13 @@ def test_selector_by_name(df: pl.DataFrame) -> None:
assert df.select(cs.by_name()).columns == []
assert df.select(cs.by_name([])).columns == []

# expected errors
with pytest.raises(ColumnNotFoundError):
df.select(cs.by_name("stroopwafel"))

with pytest.raises(TypeError):
df.select(cs.by_name(999)) # type: ignore[arg-type]


def test_selector_contains(df: pl.DataFrame) -> None:
assert df.select(cs.contains("b")).columns == ["abc", "bbb"]
Expand All @@ -147,6 +155,10 @@ def test_selector_contains(df: pl.DataFrame) -> None:
]
assert df.select(cs.contains(("ee", "x"))).columns == ["eee"]

# expected errors
with pytest.raises(TypeError):
df.select(cs.contains(999)) # type: ignore[arg-type]


def test_selector_datetime(df: pl.DataFrame) -> None:
assert df.select(cs.datetime()).schema == {"opp": pl.Datetime("ms")}
Expand Down Expand Up @@ -233,6 +245,10 @@ def test_selector_datetime(df: pl.DataFrame) -> None:
== df.select(~cs.datetime(["ms", "ns"], time_zone="*")).columns
)

# expected errors
with pytest.raises(TypeError):
df.select(cs.datetime(999)) # type: ignore[arg-type]


def test_select_decimal(df: pl.DataFrame) -> None:
assert df.select(cs.decimal()).columns == []
Expand Down Expand Up @@ -288,6 +304,10 @@ def test_selector_ends_with(df: pl.DataFrame) -> None:
"qqR",
]

# expected errors
with pytest.raises(TypeError):
df.select(cs.ends_with(999)) # type: ignore[arg-type]


def test_selector_first_last(df: pl.DataFrame) -> None:
assert df.select(cs.first()).columns == ["abc"]
Expand Down Expand Up @@ -388,6 +408,9 @@ def test_selector_startswith(df: pl.DataFrame) -> None:
"opp",
"qqR",
]
# expected errors
with pytest.raises(TypeError):
df.select(cs.starts_with(999)) # type: ignore[arg-type]


def test_selector_temporal(df: pl.DataFrame) -> None:
Expand Down

0 comments on commit 8fb7aa4

Please sign in to comment.