From 13ec1d62cc81e6771808fb6b94551589d3b37afd Mon Sep 17 00:00:00 2001 From: Peter Fackeldey Date: Wed, 13 Nov 2024 17:13:14 +0100 Subject: [PATCH] fix: restrict `named_axis` inferring to `ak.Arrays/ak.Records/ak.HighLevelContexts` by default (#3304) * fix: use None when ndims can't be inferred from a layout-like obj * restrict _get_named_axis by default to awkward-arrays only * fix boolean condition --- src/awkward/_broadcasting.py | 6 +++--- src/awkward/_namedaxis.py | 13 +++++++++---- src/awkward/operations/ak_mean.py | 2 +- src/awkward/operations/ak_moment.py | 2 +- src/awkward/operations/ak_ptp.py | 2 +- src/awkward/operations/ak_std.py | 2 +- src/awkward/operations/ak_var.py | 2 +- 7 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/awkward/_broadcasting.py b/src/awkward/_broadcasting.py index 8dab0af30e..ae8f1f8157 100644 --- a/src/awkward/_broadcasting.py +++ b/src/awkward/_broadcasting.py @@ -463,7 +463,7 @@ def apply_step( named_axis = _add_named_axis(named_axis, depth, ndim) depth_context[NAMED_AXIS_KEY][i] = ( _unify_named_axis(named_axis, seen_named_axis), - ndim + 1, + ndim + 1 if ndim is not None else ndim, ) if o.is_leaf: _export_named_axis_from_depth_to_lateral( @@ -645,7 +645,7 @@ def broadcast_any_list(): # rightbroadcasting adds a new first(!) dimension as depth depth_context[NAMED_AXIS_KEY][i] = ( _add_named_axis(named_axis, depth, ndim), - ndim + 1, + ndim + 1 if ndim is not None else ndim, ) if x.is_leaf: _export_named_axis_from_depth_to_lateral( @@ -734,7 +734,7 @@ def broadcast_any_list(): # leftbroadcasting adds a new last dimension at depth + 1 depth_context[NAMED_AXIS_KEY][i] = ( _add_named_axis(named_axis, depth + 1, ndim), - ndim + 1, + ndim + 1 if ndim is not None else ndim, ) if x.is_leaf: _export_named_axis_from_depth_to_lateral( diff --git a/src/awkward/_namedaxis.py b/src/awkward/_namedaxis.py index 8217c30a8d..8a23ae0573 100644 --- a/src/awkward/_namedaxis.py +++ b/src/awkward/_namedaxis.py @@ -82,7 +82,7 @@ def _prettify(ax: AxisName) -> str: return delimiter.join(items) -def _get_named_axis(ctx: tp.Any) -> AxisMapping: +def _get_named_axis(ctx: tp.Any, allow_any: bool = False) -> AxisMapping: """ Retrieves the named axis from the provided context. @@ -103,9 +103,14 @@ def _get_named_axis(ctx: tp.Any) -> AxisMapping: >>> _get_named_axis({"other_key": "other_value"}) {} """ - if hasattr(ctx, "attrs"): - return _get_named_axis(ctx.attrs) - elif isinstance(ctx, tp.Mapping) and NAMED_AXIS_KEY in ctx: + from awkward._layout import HighLevelContext + from awkward.highlevel import Array, Record + + if hasattr(ctx, "attrs") and ( + isinstance(ctx, (HighLevelContext, Array, Record)) or allow_any + ): + return _get_named_axis(ctx.attrs, allow_any=True) + elif allow_any and isinstance(ctx, tp.Mapping) and NAMED_AXIS_KEY in ctx: return dict(ctx[NAMED_AXIS_KEY]) else: return {} diff --git a/src/awkward/operations/ak_mean.py b/src/awkward/operations/ak_mean.py index 3b6552c521..a6b1fcb360 100644 --- a/src/awkward/operations/ak_mean.py +++ b/src/awkward/operations/ak_mean.py @@ -270,7 +270,7 @@ def _impl(x, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs): ) return ak.operations.ak_with_named_axis._impl( wrapped, - named_axis=_get_named_axis(attrs_of_obj(out) or {}), + named_axis=_get_named_axis(attrs_of_obj(out), allow_any=True), highlevel=highlevel, behavior=None, attrs=None, diff --git a/src/awkward/operations/ak_moment.py b/src/awkward/operations/ak_moment.py index 2c8e29adb1..882da992e2 100644 --- a/src/awkward/operations/ak_moment.py +++ b/src/awkward/operations/ak_moment.py @@ -168,7 +168,7 @@ def _impl( # propagate named axis to output return ak.operations.ak_with_named_axis._impl( wrapped, - named_axis=_get_named_axis(attrs_of_obj(out)), + named_axis=_get_named_axis(attrs_of_obj(out), allow_any=True), highlevel=highlevel, behavior=None, attrs=None, diff --git a/src/awkward/operations/ak_ptp.py b/src/awkward/operations/ak_ptp.py index 6d4beafbd5..460285382a 100644 --- a/src/awkward/operations/ak_ptp.py +++ b/src/awkward/operations/ak_ptp.py @@ -145,7 +145,7 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): # propagate named axis to output return ak.operations.ak_with_named_axis._impl( wrapped, - named_axis=_get_named_axis(attrs_of_obj(out)), + named_axis=_get_named_axis(attrs_of_obj(out), allow_any=True), highlevel=highlevel, behavior=None, attrs=None, diff --git a/src/awkward/operations/ak_std.py b/src/awkward/operations/ak_std.py index 7926b341fe..9343e0f1ce 100644 --- a/src/awkward/operations/ak_std.py +++ b/src/awkward/operations/ak_std.py @@ -233,7 +233,7 @@ def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, a # propagate named axis to output return ak.operations.ak_with_named_axis._impl( wrapped, - named_axis=_get_named_axis(attrs_of_obj(out)), + named_axis=_get_named_axis(attrs_of_obj(out), allow_any=True), highlevel=highlevel, behavior=None, attrs=None, diff --git a/src/awkward/operations/ak_var.py b/src/awkward/operations/ak_var.py index 759f5edf1c..61139181cb 100644 --- a/src/awkward/operations/ak_var.py +++ b/src/awkward/operations/ak_var.py @@ -285,7 +285,7 @@ def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, a # propagate named axis to output return ak.operations.ak_with_named_axis._impl( wrapped, - named_axis=_get_named_axis(attrs_of_obj(out)), + named_axis=_get_named_axis(attrs_of_obj(out), allow_any=True), highlevel=highlevel, behavior=None, attrs=None,