Skip to content

Commit

Permalink
fix: count of null column shouldn't panic in agg context (#15710)
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa authored Apr 17, 2024
1 parent a505068 commit 2db0ba6
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 36 deletions.
78 changes: 42 additions & 36 deletions crates/polars-lazy/src/physical_plan/expressions/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,42 +170,48 @@ impl PhysicalExpr for AggregationExpr {
AggState::NotAggregated(s) => {
let s = s.clone();
let groups = ac.groups();
match groups.as_ref() {
GroupsProxy::Idx(idx) => {
let s = s.rechunk();
let array = &s.chunks()[0];
let validity = array.validity().unwrap();

let out: IdxCa = idx
.iter()
.map(|(_, g)| {
let mut count = 0 as IdxSize;
// Count valid values
g.iter().for_each(|i| {
count += validity.get_bit_unchecked(*i as usize)
as IdxSize;
});
count
})
.collect_ca_trusted_with_dtype(&keep_name, IDX_DTYPE);
AggregatedScalar(out.into_series())
},
GroupsProxy::Slice { groups, .. } => {
// Slice and use computed null count
let out: IdxCa = groups
.iter()
.map(|g| {
let start = g[0];
let len = g[1];
len - s
.slice(start as i64, len as usize)
.null_count()
as IdxSize
})
.collect_ca_trusted_with_dtype(&keep_name, IDX_DTYPE);
AggregatedScalar(out.into_series())
},
}
let out: IdxCa = if matches!(s.dtype(), &DataType::Null) {
IdxCa::full(s.name(), 0, groups.len())
} else {
match groups.as_ref() {
GroupsProxy::Idx(idx) => {
let s = s.rechunk();
let array = &s.chunks()[0];
let validity = array.validity().unwrap();
idx.iter()
.map(|(_, g)| {
let mut count = 0 as IdxSize;
// Count valid values
g.iter().for_each(|i| {
count += validity
.get_bit_unchecked(*i as usize)
as IdxSize;
});
count
})
.collect_ca_trusted_with_dtype(
&keep_name, IDX_DTYPE,
)
},
GroupsProxy::Slice { groups, .. } => {
// Slice and use computed null count
groups
.iter()
.map(|g| {
let start = g[0];
let len = g[1];
len - s
.slice(start as i64, len as usize)
.null_count()
as IdxSize
})
.collect_ca_trusted_with_dtype(
&keep_name, IDX_DTYPE,
)
},
}
};
AggregatedScalar(out.into_series())
},
}
}
Expand Down
8 changes: 8 additions & 0 deletions py-polars/tests/unit/operations/test_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,3 +530,11 @@ def test_horizontal_mean_in_groupby_15115() -> None:
}
),
)


def test_group_count_over_null_column_15705() -> None:
df = pl.DataFrame(
{"a": [1, 1, 2, 2, 3, 3], "c": [None, None, None, None, None, None]}
)
out = df.group_by("a", maintain_order=True).agg(pl.col("c").count())
assert out["c"].to_list() == [0, 0, 0]

0 comments on commit 2db0ba6

Please sign in to comment.