From 2db0ba608b223a014bba5e10d7b82505898798ed Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Wed, 17 Apr 2024 22:25:18 +0800 Subject: [PATCH] fix: count of null column shouldn't panic in agg context (#15710) --- .../physical_plan/expressions/aggregation.rs | 78 ++++++++++--------- .../unit/operations/test_aggregations.py | 8 ++ 2 files changed, 50 insertions(+), 36 deletions(-) diff --git a/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs b/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs index 5393d2a6a4a0..dd2937dc3e57 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs @@ -170,42 +170,48 @@ impl PhysicalExpr for AggregationExpr { AggState::NotAggregated(s) => { let s = s.clone(); let groups = ac.groups(); - match groups.as_ref() { - GroupsProxy::Idx(idx) => { - let s = s.rechunk(); - let array = &s.chunks()[0]; - let validity = array.validity().unwrap(); - - let out: IdxCa = idx - .iter() - .map(|(_, g)| { - let mut count = 0 as IdxSize; - // Count valid values - g.iter().for_each(|i| { - count += validity.get_bit_unchecked(*i as usize) - as IdxSize; - }); - count - }) - .collect_ca_trusted_with_dtype(&keep_name, IDX_DTYPE); - AggregatedScalar(out.into_series()) - }, - GroupsProxy::Slice { groups, .. } => { - // Slice and use computed null count - let out: IdxCa = groups - .iter() - .map(|g| { - let start = g[0]; - let len = g[1]; - len - s - .slice(start as i64, len as usize) - .null_count() - as IdxSize - }) - .collect_ca_trusted_with_dtype(&keep_name, IDX_DTYPE); - AggregatedScalar(out.into_series()) - }, - } + let out: IdxCa = if matches!(s.dtype(), &DataType::Null) { + IdxCa::full(s.name(), 0, groups.len()) + } else { + match groups.as_ref() { + GroupsProxy::Idx(idx) => { + let s = s.rechunk(); + let array = &s.chunks()[0]; + let validity = array.validity().unwrap(); + idx.iter() + .map(|(_, g)| { + let mut count = 0 as IdxSize; + // Count valid values + g.iter().for_each(|i| { + count += validity + .get_bit_unchecked(*i as usize) + as IdxSize; + }); + count + }) + .collect_ca_trusted_with_dtype( + &keep_name, IDX_DTYPE, + ) + }, + GroupsProxy::Slice { groups, .. } => { + // Slice and use computed null count + groups + .iter() + .map(|g| { + let start = g[0]; + let len = g[1]; + len - s + .slice(start as i64, len as usize) + .null_count() + as IdxSize + }) + .collect_ca_trusted_with_dtype( + &keep_name, IDX_DTYPE, + ) + }, + } + }; + AggregatedScalar(out.into_series()) }, } } diff --git a/py-polars/tests/unit/operations/test_aggregations.py b/py-polars/tests/unit/operations/test_aggregations.py index 64ea9551dffe..ca3153cdfc1b 100644 --- a/py-polars/tests/unit/operations/test_aggregations.py +++ b/py-polars/tests/unit/operations/test_aggregations.py @@ -530,3 +530,11 @@ def test_horizontal_mean_in_groupby_15115() -> None: } ), ) + + +def test_group_count_over_null_column_15705() -> None: + df = pl.DataFrame( + {"a": [1, 1, 2, 2, 3, 3], "c": [None, None, None, None, None, None]} + ) + out = df.group_by("a", maintain_order=True).agg(pl.col("c").count()) + assert out["c"].to_list() == [0, 0, 0]