Skip to content

Commit

Permalink
fix: correct handling of keepdims and mask_identity for weighted mean (
Browse files Browse the repository at this point in the history
…#3291)

* fix + tests

* style: pre-commit fixes

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
nj-vs-vh and pre-commit-ci[bot] authored Nov 6, 2024
1 parent 8a0bcb5 commit 5145bce
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/awkward/operations/ak_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,8 @@ def _impl(x, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs):
sumw = ak.operations.ak_sum._impl(
x * 0 + weight,
axis,
keepdims,
mask_identity,
keepdims=True,
mask_identity=True,
highlevel=True,
behavior=ctx.behavior,
attrs=ctx.attrs,
Expand Down
65 changes: 65 additions & 0 deletions tests/test_3285_ak_mean_weighted_row_wise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import math

import pytest

import awkward as ak


@pytest.mark.parametrize(
"keepdims, expected_result",
[
pytest.param(False, ak.Array([2.25, 6.5])),
pytest.param(True, ak.Array([[2.25], [6.5]])),
],
)
def test_keepdims(keepdims: bool, expected_result: ak.Array):
data = ak.Array(
[
[1, 2, 3],
[4, 7],
]
)
weight = ak.Array(
[
[1, 1, 2],
[1, 5],
]
)
assert ak.all(
ak.mean(data, weight=weight, axis=1, keepdims=keepdims) == expected_result
)


@pytest.mark.parametrize(
"mask_identity, expected_result",
[
pytest.param(False, ak.Array([1.5, math.nan, 8])),
pytest.param(True, ak.Array([1.5, None, 8])),
],
)
def test_mask_identity(mask_identity: bool, expected_result: ak.Array):
data = ak.Array(
[
[1, 2],
[],
[6, 9],
]
)
weight = ak.Array(
[
[1, 1],
[],
[1, 2],
]
)
result = ak.mean(data, weight=weight, axis=1, mask_identity=mask_identity)
assert result[0] == expected_result[0]
if mask_identity:
assert result[1] is None
else:
assert math.isnan(result[1]) # NaN is not equal to itself per IEEE!
assert result[2] == expected_result[2]

0 comments on commit 5145bce

Please sign in to comment.