Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add new index/range based selector cs.by_index, allow multiple indices for nth #16217

Merged
merged 6 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/polars-lazy/src/physical_plan/planner/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -639,8 +639,8 @@ fn create_physical_expr_inner(
Wildcard => {
polars_bail!(ComputeError: "wildcard column selection not supported at this point")
},
Nth(_) => {
polars_bail!(ComputeError: "nth column selection not supported at this point")
Nth(n) => {
polars_bail!(ComputeError: "nth column selection not supported at this point (n={})", n)
},
}
}
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-plan/src/dsl/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ pub enum Expr {
Column(ColumnName),
Columns(Arc<[ColumnName]>),
DtypeColumn(Vec<DataType>),
IndexColumn(Arc<[i64]>),
Literal(LiteralValue),
BinaryExpr {
left: Arc<Expr>,
Expand Down Expand Up @@ -172,6 +173,7 @@ impl Hash for Expr {
Expr::Column(name) => name.hash(state),
Expr::Columns(names) => names.hash(state),
Expr::DtypeColumn(dtypes) => dtypes.hash(state),
Expr::IndexColumn(indices) => indices.hash(state),
Expr::Literal(lv) => std::mem::discriminant(lv).hash(state),
Expr::Selector(s) => s.hash(state),
Expr::Nth(v) => v.hash(state),
Expand Down
6 changes: 6 additions & 0 deletions crates/polars-plan/src/dsl/functions/selectors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,9 @@ pub fn dtype_cols<DT: AsRef<[DataType]>>(dtype: DT) -> Expr {
let dtypes = dtype.as_ref().to_vec();
Expr::DtypeColumn(dtypes)
}

/// Select multiple columns by index.
pub fn index_cols<N: AsRef<[i64]>>(indices: N) -> Expr {
let indices = indices.as_ref().to_vec();
alexander-beedie marked this conversation as resolved.
Show resolved Hide resolved
Expr::IndexColumn(Arc::from(indices))
}
1 change: 1 addition & 0 deletions crates/polars-plan/src/dsl/meta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ impl MetaNameSpace {
pub fn has_multiple_outputs(&self) -> bool {
self.0.into_iter().any(|e| match e {
Expr::Selector(_) | Expr::Wildcard | Expr::Columns(_) | Expr::DtypeColumn(_) => true,
Expr::IndexColumn(idxs) => idxs.len() > 1,
Expr::Column(name) => is_regex_projection(name),
_ => false,
})
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/logical_plan/aexpr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ impl AExpr {
Wildcard => {
polars_bail!(ComputeError: "wildcard column selection not supported at this point")
},
Nth(_) => {
polars_bail!(ComputeError: "nth column selection not supported at this point")
Nth(n) => {
polars_bail!(ComputeError: "nth column selection not supported at this point (n={})", n)
},
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ fn expand_filter(predicate: Expr, input: Node, lp_arena: &Arena<IR>) -> PolarsRe
| Expr::RenameAlias { .. }
| Expr::Columns(_)
| Expr::DtypeColumn(_)
| Expr::IndexColumn(_)
| Expr::Nth(_) => true,
_ => false,
}) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,15 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena<AExpr>, state: &mut ConversionSta
AExpr::Len
},
Expr::Nth(i) => AExpr::Nth(i),
Expr::IndexColumn(idx) => {
if idx.len() == 1 {
AExpr::Nth(idx[0])
} else {
panic!("no multi-value `index-columns` expected at this point")
}
},
Expr::Wildcard => AExpr::Wildcard,
Expr::SubPlan { .. } => panic!("no SQLSubquery expected at this point"),
Expr::SubPlan { .. } => panic!("no SQL subquery expected at this point"),
Expr::KeepName(_) => panic!("no `name.keep` expected at this point"),
Expr::Exclude(_, _) => panic!("no `exclude` expected at this point"),
Expr::RenameAlias { .. } => panic!("no `rename_alias` expected at this point"),
Expand Down
83 changes: 66 additions & 17 deletions crates/polars-plan/src/logical_plan/expr_expansion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@ fn replace_nth(expr: Expr, schema: &Schema) -> Expr {
if let Expr::Nth(i) = e {
match i.negative_to_usize(schema.len()) {
None => {
let name = if i == 0 { "first" } else { "last" };
let name = match i {
0 => "first",
-1 => "last",
_ => "nth",
};
Expr::Column(ColumnName::from(name))
},
Some(idx) => {
Expand Down Expand Up @@ -184,16 +188,6 @@ fn expand_columns(
Ok(())
}

/// This replaces the dtypes Expr with a Column Expr. It also removes the Exclude Expr from the
/// expression chain.
fn replace_dtype_with_column(expr: Expr, column_name: Arc<str>) -> Expr {
expr.map_expr(|e| match e {
Expr::DtypeColumn(_) => Expr::Column(column_name.clone()),
Expr::Exclude(input, _) => Arc::unwrap_or_clone(input),
e => e,
})
}

#[cfg(feature = "dtype-struct")]
fn struct_index_to_field(expr: Expr, schema: &Schema) -> PolarsResult<Expr> {
expr.try_map_expr(|e| match e {
Expand Down Expand Up @@ -228,6 +222,21 @@ fn struct_index_to_field(expr: Expr, schema: &Schema) -> PolarsResult<Expr> {
})
}

/// This replaces the dtype or index expanded Expr with a Column Expr.
/// ()It also removes the Exclude Expr from the expression chain).
fn replace_dtype_or_index_with_column(
expr: Expr,
column_name: &ColumnName,
replace_dtype: bool,
) -> Expr {
expr.map_expr(|e| match e {
Expr::DtypeColumn(_) if replace_dtype => Expr::Column(column_name.clone()),
Expr::IndexColumn(_) if !replace_dtype => Expr::Column(column_name.clone()),
Expr::Exclude(input, _) => Arc::unwrap_or_clone(input),
e => e,
})
}

/// This replaces the columns Expr with a Column Expr. It also removes the Exclude Expr from the
/// expression chain.
pub(super) fn replace_columns_with_column(
Expand Down Expand Up @@ -282,13 +291,47 @@ fn expand_dtypes(
}) {
let name = field.name();
let new_expr = expr.clone();
let new_expr = replace_dtype_with_column(new_expr, ColumnName::from(name.as_str()));
let new_expr =
replace_dtype_or_index_with_column(new_expr, &ColumnName::from(name.as_str()), true);
let new_expr = rewrite_special_aliases(new_expr)?;
result.push(new_expr)
}
Ok(())
}

/// replace `IndexColumn` with `col("foo")..col("bar")`
fn expand_indices(
expr: &Expr,
result: &mut Vec<Expr>,
schema: &Schema,
indices: &[i64],
exclude: &PlHashSet<Arc<str>>,
) -> PolarsResult<()> {
let n_fields = schema.len() as i64;
for idx in indices {
let mut idx = *idx;
if idx < 0 {
idx += n_fields;
if idx < 0 {
polars_bail!(ComputeError: "invalid column index {}", idx)
}
}
if let Some((name, _)) = schema.get_at_index(idx as usize) {
if !exclude.contains(name.as_str()) {
let new_expr = expr.clone();
let new_expr = replace_dtype_or_index_with_column(
new_expr,
&ColumnName::from(name.as_str()),
false,
);
let new_expr = rewrite_special_aliases(new_expr)?;
result.push(new_expr);
}
}
}
Ok(())
}

// schema is not used if regex not activated
#[allow(unused_variables)]
fn prepare_excluded(
Expand Down Expand Up @@ -400,6 +443,7 @@ fn find_flags(expr: &Expr) -> ExpansionFlags {
for expr in expr {
match expr {
Expr::Columns(_) | Expr::DtypeColumn(_) => multiple_columns = true,
Expr::IndexColumn(idx) => multiple_columns = idx.len() > 1,
Expr::Nth(_) => has_nth = true,
Expr::Wildcard => has_wildcard = true,
Expr::Selector(_) => has_selector = true,
Expand Down Expand Up @@ -474,20 +518,25 @@ fn replace_and_add_to_results(
// has multiple column names
// the expanded columns are added to the result
if flags.multiple_columns {
if let Some(e) = expr
.into_iter()
.find(|e| matches!(e, Expr::Columns(_) | Expr::DtypeColumn(_)))
{
if let Some(e) = expr.into_iter().find(|e| {
matches!(
e,
Expr::Columns(_) | Expr::DtypeColumn(_) | Expr::IndexColumn(_)
)
}) {
match &e {
Expr::Columns(names) => {
let exclude = prepare_excluded(&expr, schema, keys, flags.has_exclude)?;
expand_columns(&expr, result, names, schema, &exclude)?;
},
Expr::DtypeColumn(dtypes) => {
// keep track of column excluded from the dtypes
let exclude = prepare_excluded(&expr, schema, keys, flags.has_exclude)?;
expand_dtypes(&expr, result, schema, dtypes, &exclude)?
},
Expr::IndexColumn(indices) => {
let exclude = prepare_excluded(&expr, schema, keys, flags.has_exclude)?;
expand_indices(&expr, result, schema, indices, &exclude)?
},
_ => {},
}
}
Expand Down
1 change: 1 addition & 0 deletions crates/polars-plan/src/logical_plan/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ impl Debug for Expr {
RenameAlias { expr, .. } => write!(f, ".rename_alias({expr:?})"),
Columns(names) => write!(f, "cols({names:?})"),
DtypeColumn(dt) => write!(f, "dtype_columns({dt:?})"),
IndexColumn(idxs) => write!(f, "index_columns({idxs:?})"),
Selector(_) => write!(f, "SELECTOR"),
}
}
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-plan/src/logical_plan/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ macro_rules! push_expr {
($current_expr:expr, $c:ident, $push:ident, $push_owned:ident, $iter:ident) => {{
use Expr::*;
match $current_expr {
Nth(_) | Column(_) | Literal(_) | Wildcard | Columns(_) | DtypeColumn(_) | Len => {},
Nth(_) | Column(_) | Literal(_) | Wildcard | Columns(_) | DtypeColumn(_)
| IndexColumn(_) | Len => {},
Alias(e, _) => $push($c, e),
BinaryExpr { left, op: _, right } => {
// reverse order so that left is popped first
Expand Down
1 change: 1 addition & 0 deletions crates/polars-plan/src/logical_plan/visitor/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ impl TreeWalker for Expr {
Column(_) => self,
Columns(_) => self,
DtypeColumn(_) => self,
IndexColumn(_) => self,
Literal(_) => self,
BinaryExpr { left, op, right } => {
BinaryExpr { left: am(left, &mut f)? , op, right: am(right, f)?}
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ pub fn expr_output_name(expr: &Expr) -> PolarsResult<Arc<str>> {
ComputeError:
"cannot determine output column without a context for this expression"
),
Expr::Columns(_) | Expr::DtypeColumn(_) => polars_bail!(
Expr::Columns(_) | Expr::DtypeColumn(_) | Expr::IndexColumn(_) => polars_bail!(
ComputeError:
"this expression may produce multiple output names"
),
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-sql/tests/functions_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ fn test_array_to_string() {
context.register("df", df.clone().lazy());

let sql = r#"
SELECT b, ARRAY_TO_STRING(a,', ') AS a2s,
SELECT b, ARRAY_TO_STRING("a",', ') AS a2s,
FROM (
SELECT b, ARRAY_AGG(a)
SELECT b, ARRAY_AGG(a) AS "a"
FROM df
GROUP BY b
) tbl
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2716,7 +2716,7 @@ def sort_by(
)

def gather(
self, indices: int | list[int] | Expr | Series | np.ndarray[Any, Any]
self, indices: int | Sequence[int] | Expr | Series | np.ndarray[Any, Any]
) -> Self:
"""
Take values by index.
Expand Down
6 changes: 6 additions & 0 deletions py-polars/polars/functions/col.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@ def __new__( # type: ignore[misc]
Additional names or datatypes of columns to represent,
specified as positional arguments.

See Also
--------
first
last
nth

Examples
--------
Pass a single column name to represent that column.
Expand Down
43 changes: 34 additions & 9 deletions py-polars/polars/functions/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,9 +646,9 @@ def last(*columns: str) -> Expr:
return F.col(*columns).last()


def nth(n: int, *columns: str) -> Expr:
def nth(n: int | Sequence[int], *columns: str) -> Expr:
Copy link
Member

@stinodego stinodego May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pity we have this *columns API, because passing nth(1,2) would feel really good. I think we should deprecate this API (in another PR) and make this a keyword argument, so we can pass nth(1,2, from=['x', 'y', 'z'])

Copy link
Collaborator Author

@alexander-beedie alexander-beedie May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking the same thing 👌 Can follow-up with a separate improvement / deprecation; I think there are a few other functions with a similar set of params, so we should check and do all at once, if so.

"""
Get the nth column or value.
Get the nth column(s) or value(s).

This function has different behavior depending on the presence of `columns`
values. If none given (the default), returns an expression that takes the nth
Expand All @@ -657,11 +657,11 @@ def nth(n: int, *columns: str) -> Expr:
Parameters
----------
n
Index of the column (or value) to get.
One or more indices representing the columns/values to retrieve.
*columns
One or more column names. If omitted (the default), returns an
expression that takes the nth column of the context. Otherwise,
returns takes the nth value of the given column(s).
expression that takes the nth column of the context; otherwise,
takes the nth value of the given column(s).

Examples
--------
Expand All @@ -673,7 +673,7 @@ def nth(n: int, *columns: str) -> Expr:
... }
... )

Return the "nth" column:
Return the "nth" column(s):

>>> df.select(pl.nth(1))
shape: (3, 1)
Expand All @@ -687,7 +687,19 @@ def nth(n: int, *columns: str) -> Expr:
│ 2 │
└─────┘

Return the "nth" value for the given columns:
>>> df.select(pl.nth([2, 0]))
shape: (3, 2)
┌─────┬─────┐
│ c ┆ a │
│ --- ┆ --- │
│ str ┆ i64 │
╞═════╪═════╡
│ foo ┆ 1 │
│ bar ┆ 8 │
│ baz ┆ 3 │
└─────┴─────┘

Return the "nth" value(s) for the given columns:

>>> df.select(pl.nth(-2, "b", "c"))
shape: (1, 2)
Expand All @@ -698,11 +710,24 @@ def nth(n: int, *columns: str) -> Expr:
╞═════╪═════╡
│ 5 ┆ bar │
└─────┴─────┘

>>> df.select(pl.nth([0, 2], "c", "a"))
shape: (2, 2)
┌─────┬─────┐
│ c ┆ a │
│ --- ┆ --- │
│ str ┆ i64 │
╞═════╪═════╡
│ foo ┆ 1 │
│ baz ┆ 3 │
└─────┴─────┘
"""
indices = [n] if isinstance(n, int) else n
if not columns:
return wrap_expr(plr.nth(n))
return wrap_expr(plr.index_cols(indices))

return F.col(*columns).get(n)
cols = F.col(*columns)
return cols.get(indices[0]) if len(indices) == 1 else cols.gather(indices)


def head(column: str, n: int = 10) -> Expr:
Expand Down
Loading