From 5145bcee480ba748e22fc198ec5971db7c5800be Mon Sep 17 00:00:00 2001 From: Igor Vaiman Date: Wed, 6 Nov 2024 17:42:27 +0100 Subject: [PATCH] fix: correct handling of keepdims and mask_identity for weighted mean (#3291) * fix + tests * style: pre-commit fixes --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/awkward/operations/ak_mean.py | 4 +- tests/test_3285_ak_mean_weighted_row_wise.py | 65 ++++++++++++++++++++ 2 files changed, 67 insertions(+), 2 deletions(-) create mode 100644 tests/test_3285_ak_mean_weighted_row_wise.py diff --git a/src/awkward/operations/ak_mean.py b/src/awkward/operations/ak_mean.py index a9b38ce1f0..3b6552c521 100644 --- a/src/awkward/operations/ak_mean.py +++ b/src/awkward/operations/ak_mean.py @@ -225,8 +225,8 @@ def _impl(x, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs): sumw = ak.operations.ak_sum._impl( x * 0 + weight, axis, - keepdims, - mask_identity, + keepdims=True, + mask_identity=True, highlevel=True, behavior=ctx.behavior, attrs=ctx.attrs, diff --git a/tests/test_3285_ak_mean_weighted_row_wise.py b/tests/test_3285_ak_mean_weighted_row_wise.py new file mode 100644 index 0000000000..2c1bc1db9b --- /dev/null +++ b/tests/test_3285_ak_mean_weighted_row_wise.py @@ -0,0 +1,65 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +import math + +import pytest + +import awkward as ak + + +@pytest.mark.parametrize( + "keepdims, expected_result", + [ + pytest.param(False, ak.Array([2.25, 6.5])), + pytest.param(True, ak.Array([[2.25], [6.5]])), + ], +) +def test_keepdims(keepdims: bool, expected_result: ak.Array): + data = ak.Array( + [ + [1, 2, 3], + [4, 7], + ] + ) + weight = ak.Array( + [ + [1, 1, 2], + [1, 5], + ] + ) + assert ak.all( + ak.mean(data, weight=weight, axis=1, keepdims=keepdims) == expected_result + ) + + +@pytest.mark.parametrize( + "mask_identity, expected_result", + [ + pytest.param(False, ak.Array([1.5, math.nan, 8])), + pytest.param(True, ak.Array([1.5, None, 8])), + ], +) +def test_mask_identity(mask_identity: bool, expected_result: ak.Array): + data = ak.Array( + [ + [1, 2], + [], + [6, 9], + ] + ) + weight = ak.Array( + [ + [1, 1], + [], + [1, 2], + ] + ) + result = ak.mean(data, weight=weight, axis=1, mask_identity=mask_identity) + assert result[0] == expected_result[0] + if mask_identity: + assert result[1] is None + else: + assert math.isnan(result[1]) # NaN is not equal to itself per IEEE! + assert result[2] == expected_result[2]