Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Allow index=None in pivot() #15855

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/polars-lazy/src/frame/pivot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl PhysicalAggExpr for PivotExpr {

pub fn pivot<I0, I1, I2, S0, S1, S2>(
df: &DataFrame,
index: I0,
index: Option<I0>,
columns: I1,
values: Option<I2>,
sort_columns: bool,
Expand Down Expand Up @@ -67,7 +67,7 @@ where

pub fn pivot_stable<I0, I1, I2, S0, S1, S2>(
df: &DataFrame,
index: I0,
index: Option<I0>,
columns: I1,
values: Option<I2>,
sort_columns: bool,
Expand Down
47 changes: 29 additions & 18 deletions crates/polars-ops/src/frame/pivot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ fn restore_logical_type(s: &Series, logical_type: &DataType) -> Series {
/// If you have a relatively large table, consider using a group_by over a pivot.
pub fn pivot<I0, I1, I2, S0, S1, S2>(
pivot_df: &DataFrame,
index: I0,
index: Option<I0>,
columns: I1,
values: Option<I2>,
sort_columns: bool,
Expand All @@ -99,10 +99,10 @@ where
S1: AsRef<str>,
S2: AsRef<str>,
{
let index = index
.into_iter()
.map(|s| s.as_ref().to_string())
.collect::<Vec<_>>();
let index = match index {
Some(i) => i.into_iter().map(|s| s.as_ref().to_string()).collect(),
None => Vec::new(),
};
let columns = columns
.into_iter()
.map(|s| s.as_ref().to_string())
Expand All @@ -127,7 +127,7 @@ where
/// If you have a relatively large table, consider using a group_by over a pivot.
pub fn pivot_stable<I0, I1, I2, S0, S1, S2>(
pivot_df: &DataFrame,
index: I0,
index: Option<I0>,
columns: I1,
values: Option<I2>,
sort_columns: bool,
Expand All @@ -142,10 +142,10 @@ where
S1: AsRef<str>,
S2: AsRef<str>,
{
let index = index
.into_iter()
.map(|s| s.as_ref().to_string())
.collect::<Vec<_>>();
let index = match index {
Some(i) => i.into_iter().map(|s| s.as_ref().to_string()).collect(),
None => Vec::new(),
};
let columns = columns
.into_iter()
.map(|s| s.as_ref().to_string())
Expand Down Expand Up @@ -205,7 +205,6 @@ fn pivot_impl(
// used as separator/delimiter in generated column names.
separator: Option<&str>,
) -> PolarsResult<DataFrame> {
polars_ensure!(!index.is_empty(), ComputeError: "index cannot be zero length");
polars_ensure!(!columns.is_empty(), ComputeError: "columns cannot be zero length");
if !stable {
println!("unstable pivot not yet supported, using stable pivot");
Expand Down Expand Up @@ -262,12 +261,24 @@ fn pivot_impl_single_column(

let groups = pivot_df.group_by_stable(group_by)?.take_groups();

let (col, row) = POOL.join(
|| positioning::compute_col_idx(pivot_df, column, &groups),
|| positioning::compute_row_idx(pivot_df, index, &groups, count),
);
let (col_locations, column_agg) = col?;
let (row_locations, n_rows, mut row_index) = row?;
let (col, row) = match index.len() {
0 => {
let col = POOL.install(
|| positioning::compute_col_idx(pivot_df, column, &groups)
)?;
let row = (vec![0; col.0.len()], 1, None);
(col, row)
},
_ => {
let (col, row) = POOL.join(
|| positioning::compute_col_idx(pivot_df, column, &groups),
|| positioning::compute_row_idx(pivot_df, index, &groups, count),
);
(col?, row?)
},
};
let (col_locations, column_agg) = col;
let (row_locations, n_rows, mut row_index) = row;

for value_col_name in values {
let value_col = pivot_df.column(value_col_name)?;
Expand Down Expand Up @@ -347,7 +358,7 @@ fn pivot_impl_single_column(
}

let cols = if count == 0 {
let mut final_cols = row_index.take().unwrap();
let mut final_cols = row_index.take().unwrap_or_else(Vec::new);
final_cols.extend(cols);
final_cols
} else {
Expand Down
24 changes: 12 additions & 12 deletions crates/polars/tests/it/core/pivot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn test_pivot_date_() -> PolarsResult<()> {
// Test with date as the `columns` input
let out = pivot(
&df,
["index"],
Some(["index"]),
["values1"],
Some(["values2"]),
true,
Expand All @@ -33,7 +33,7 @@ fn test_pivot_date_() -> PolarsResult<()> {
// Test with date as the `values` input.
let mut out = pivot_stable(
&df,
["index"],
Some(["index"]),
["values2"],
Some(["values1"]),
true,
Expand Down Expand Up @@ -63,7 +63,7 @@ fn test_pivot_old() {

let pvt = pivot(
&df,
["index"],
Some(["index"]),
["columns"],
Some(["values"]),
false,
Expand All @@ -78,7 +78,7 @@ fn test_pivot_old() {
);
let pvt = pivot(
&df,
["index"],
Some(["index"]),
["columns"],
Some(["values"]),
false,
Expand All @@ -92,7 +92,7 @@ fn test_pivot_old() {
);
let pvt = pivot(
&df,
["index"],
Some(["index"]),
["columns"],
Some(["values"]),
false,
Expand All @@ -106,7 +106,7 @@ fn test_pivot_old() {
);
let pvt = pivot(
&df,
["index"],
Some(["index"]),
["columns"],
Some(["values"]),
false,
Expand All @@ -120,7 +120,7 @@ fn test_pivot_old() {
);
let pvt = pivot(
&df,
["index"],
Some(["index"]),
["columns"],
Some(["values"]),
false,
Expand Down Expand Up @@ -148,7 +148,7 @@ fn test_pivot_categorical() -> PolarsResult<()> {

let out = pivot(
&df,
["index"],
Some(["index"]),
["columns"],
Some(["values"]),
true,
Expand All @@ -173,7 +173,7 @@ fn test_pivot_new() -> PolarsResult<()> {

let out = (pivot_stable(
&df,
["index1", "index2"],
Some(["index1", "index2"]),
["cols1"],
Some(["values1"]),
true,
Expand All @@ -190,7 +190,7 @@ fn test_pivot_new() -> PolarsResult<()> {

let out = pivot_stable(
&df,
["index1", "index2"],
Some(["index1", "index2"]),
["cols1", "cols2"],
Some(["values1"]),
true,
Expand Down Expand Up @@ -221,7 +221,7 @@ fn test_pivot_2() -> PolarsResult<()> {

let out = pivot_stable(
&df,
["index"],
Some(["index"]),
["columns"],
Some(["values"]),
false,
Expand Down Expand Up @@ -254,7 +254,7 @@ fn test_pivot_datetime() -> PolarsResult<()> {

let out = pivot(
&df,
["index"],
Some(["index"]),
["columns"],
Some(["values"]),
false,
Expand Down
4 changes: 2 additions & 2 deletions docs/src/rust/user-guide/transformations/pivot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// --8<-- [end:df]

// --8<-- [start:eager]
let out = pivot(&df, ["foo"], ["bar"], Some(["N"]), false, None, None)?;
let out = pivot(&df, Some(["foo"]), ["bar"], Some(["N"]), false, None, None)?;
println!("{}", &out);
// --8<-- [end:eager]

// --8<-- [start:lazy]
let q = df.lazy();
let q2 = pivot(
&q.collect()?,
["foo"],
Some(["foo"]),
["bar"],
Some(["N"]),
false,
Expand Down
31 changes: 28 additions & 3 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7017,7 +7017,7 @@ def pivot(
self,
values: ColumnNameOrSelector | Sequence[ColumnNameOrSelector] | None,
index: ColumnNameOrSelector | Sequence[ColumnNameOrSelector] | None,
columns: ColumnNameOrSelector | Sequence[ColumnNameOrSelector] | None,
columns: ColumnNameOrSelector | Sequence[ColumnNameOrSelector],
aggregate_function: PivotAgg | Expr | None = None,
*,
maintain_order: bool = True,
Expand All @@ -7037,7 +7037,7 @@ def pivot(
arguments contains multiple columns as well. If None, all remaining columns
will be used.
index
One or multiple keys to group by.
One or multiple keys to group by. If None, a single output row is produced.
columns
Name of the column(s) whose values will be used as the header of the output
DataFrame.
Expand Down Expand Up @@ -7128,6 +7128,30 @@ def pivot(
│ b ┆ 0.964028 ┆ 0.999954 │
└──────┴──────────┴──────────┘

Set the index to None to output a single row.

>>> df = pl.DataFrame(
... {
... "col1": ["a", "a", "a", "b", "b", "b"],
... "col2": ["x", "x", "x", "x", "y", "y"],
... "col3": [6, 7, 3, 2, 5, 7],
... }
... )
>>> df.pivot(
... index=None,
... columns="col2",
... values="col3",
... aggregate_function=pl.element().tanh().mean(),
... )
shape: (1, 2)
┌──────────┬──────────┐
│ x ┆ y │
│ --- ┆ --- │
│ f64 ┆ f64 │
╞══════════╪══════════╡
│ 0.989767 ┆ 0.999954 │
└──────────┴──────────┘

Note that `pivot` is only available in eager mode. If you know the unique
column values in advance, you can use :meth:`polars.LazyFrame.groupby` to
get the same result as above in lazy mode:
Expand All @@ -7151,7 +7175,8 @@ def pivot(
│ b ┆ 0.964028 ┆ 0.999954 │
└──────┴──────────┴──────────┘
""" # noqa: W505
index = _expand_selectors(self, index)
if index is not None:
index = _expand_selectors(self, index)
columns = _expand_selectors(self, columns)
if values is not None:
values = _expand_selectors(self, values)
Expand Down
2 changes: 1 addition & 1 deletion py-polars/src/dataframe/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ impl PyDataFrame {
#[pyo3(signature = (index, columns, values, maintain_order, sort_columns, aggregate_expr, separator))]
pub fn pivot_expr(
&self,
index: Vec<String>,
index: Option<Vec<String>>,
columns: Vec<String>,
values: Option<Vec<String>>,
maintain_order: bool,
Expand Down
25 changes: 25 additions & 0 deletions py-polars/tests/unit/operations/test_pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,31 @@ def test_pivot_no_values() -> None:
assert_frame_equal(result, expected)


def test_pivot_no_index() -> None:
df = pl.DataFrame(
{
"foo": ["A", "B", "C"],
"N": [1, 2, 3],
"M": [4, 5, 6],
}
)
result = df.pivot(
index=None, columns="foo", values=["N", "M"], aggregate_function=None
)

expected = pl.DataFrame(
{
"N_foo_A": [1],
"N_foo_B": [2],
"N_foo_C": [3],
"M_foo_A": [4],
"M_foo_B": [5],
"M_foo_C": [6],
}
)
assert_frame_equal(result, expected)


def test_pivot_list() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": [[1, 1], [2, 2], [3, 3]]})

Expand Down
Loading