diff --git a/.gitignore b/.gitignore
index d28e4e3ea8..fd5b1bf8cf 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,5 @@
studies/**/sample-*
+studies/named_axis.*
docs/demos/countries.geojson
docs/demos/test-program
docs/demos/test-program.cpp
diff --git a/docs/_toc.yml b/docs/_toc.yml
index 2f4ff6663f..12c406ab15 100644
--- a/docs/_toc.yml
+++ b/docs/_toc.yml
@@ -4,12 +4,11 @@ title: "Awkward Array"
defaults:
titlesonly: True
-
subtrees:
- entries:
- file: getting-started/index
subtrees:
- - entries:
+ - entries:
- file: getting-started/what-is-an-awkward-array
- file: getting-started/10-minutes-to-awkward-array
- file: getting-started/uproot-awkward-columnar-hats
@@ -18,7 +17,7 @@ subtrees:
- file: getting-started/papers-and-talks
- file: user-guide/index
subtrees:
- - entries:
+ - entries:
- file: user-guide/how-to-convert
title: "Converting arrays"
subtrees:
@@ -74,6 +73,13 @@ subtrees:
- file: user-guide/how-to-examine-checking-validity
title: "Checking validity"
+ - file: user-guide/how-to-array-properties
+ title: "Array properties"
+ subtrees:
+ - entries:
+ - file: user-guide/how-to-array-properties-named-axis
+ title: "Named axes"
+
- file: user-guide/how-to-math
title: "Numerical math"
subtrees:
diff --git a/docs/conf.py b/docs/conf.py
index 37faac9e39..f6ea6f5e64 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -140,7 +140,7 @@
html_js_files = ["js/awkward.js"]
# MyST settings
-myst_enable_extensions = ["colon_fence"]
+myst_enable_extensions = ["colon_fence", "deflist"]
nb_execution_mode = "cache"
nb_execution_raise_on_error = True
diff --git a/docs/user-guide/how-to-array-properties-named-axis.md b/docs/user-guide/how-to-array-properties-named-axis.md
new file mode 100644
index 0000000000..9d5321b67f
--- /dev/null
+++ b/docs/user-guide/how-to-array-properties-named-axis.md
@@ -0,0 +1,304 @@
+---
+jupytext:
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.14.1
+kernelspec:
+ display_name: Python 3 (ipykernel)
+ language: python
+ name: python3
+---
+
+Named axes
+==========
+
+Named axes are a feature in Awkward Array that allows you to give names to the axes of an array.
+This can be useful for documentation, debugging, and for writing code that is more robust to changes in the structure of the data.
+As argumented at [PyHEP.dev 2023](https://indico.cern.ch/event/1234156/) and by the Harvard NLP group in their ["Tensor Considered Harmful"](https://nlp.seas.harvard.edu/NamedTensor.html) write-up, named axes can be a powerful tool to make code more readable and less error-prone.
+
+Awkward array ensures that named axes are properly propagated to the result.
+All highlevel, indexing, and broadcasting operations in awkward array support named axes.
+
+Other libraries that support named axes include:
+- [hist](https://hist.readthedocs.io/en/latest/)
+- [haliax](https://github.com/stanford-crfm/haliax)
+- [Tensor Considered Harmful](https://nlp.seas.harvard.edu/NamedTensor.html)
+- [PyTorch Named Tensors](https://pytorch.org/docs/stable/name_inference.html#name-inference-reference-doc)
+- [Penzai Named Axis](https://penzai.readthedocs.io/en/stable/notebooks/named_axes.html)
+- [xarray Named Axis](https://docs.xarray.dev/en/stable/user-guide/indexing.html#)
+
+Named axes in Awkward Array are inspired primarily by `hist` and `PyTorch Named Tensors`.
+
++++
+
+How to (de-)attach named axes?
+-------------------------
+
+Named axes can be attached to an array using the high-level {func}`ak.with_named_axis` function.
+Awkward Array allows strings as named axes and integers as positional axes.
+
+The `named_axis` argument of {func}`ak.with_named_axis` accepts either a `tuple` or `dict`:
+- `tuple`:
+ - `named axis`: item
+ - `positional axis`: index of the item
+ - _additional_: `None` represents a wildcard for not specifying a name, e.g.: `("x", None)` means that the first axis is named "x" and the second is not named.
+- `dict`:
+ - `named axis`: key
+ - `positional axis`: value
+ - _additional_: not specifying a name is not allowed, e.g.: `{"x": 0}` means that the first axis is named "x", all other existing dimensions are unnamed. The `dict` option also allows for renaming negative axes, e.g.: `{"x": -1}` means that the last axis is named "x".
+
+
+```{code-cell}
+import awkward as ak
+import numpy as np
+```
+
+The axis names of an array can be attached through the constructor:
+```{code-cell}
+named_array = ak.Array([[1, 2], [3], [], [4, 5, 6]], named_axis=("x", "y"))
+# or
+named_array = ak.Array([[1, 2], [3], [], [4, 5, 6]], named_axis={"x": 0, "y": 1})
+```
+
+... or through `ak.with_named_axis`:
+```{code-cell}
+array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+named_array = ak.with_named_axis(array, named_axis=("x", "y"))
+# or
+named_array = ak.with_named_axis(array, named_axis={"x": 0, "y": 1})
+```
+
+After attaching named axes, you can see the named axes comma-separated in the arrays representation and in `.show(named_axis=True)`:
+
+```{code-cell}
+ak.Array([[1, 2], [3], [], [4, 5, 6]], named_axis=("x", "y"))
+```
+
+```{code-cell}
+ak.Array([[1, 2], [3], [], [4, 5, 6]], named_axis=("x", "y")).show(named_axis=True)
+```
+
+Accessing the named axis mapping to positional axis can be done using the `named_axis` and `positional_axis` properties:
+
+```{code-cell}
+named_array.named_axis
+```
+
+```{code-cell}
+named_array.positional_axis
+```
+
+If you want to remove the named axes from an array, you can use the {func}`ak.without_named_axis` function:
+
+```{code-cell}
+array = ak.without_named_axis(named_array)
+array.named_axis
+```
+
+
+Indexing with Named Axes
+------------------------
+
+Named axes can be used for indexing operations.
+This is enabled throuhg a special syntax that allows you to index with a dictionary where keys refer to named (or positional) axes and the values to the slice or index.
+
+Simple examples:
+
+```{code-cell}
+array = ak.Array([[[1, 2]], [[3]], [[4]], [[5, 6], [7]]])
+named_array = ak.with_named_axis(array, named_axis=("x", "y", "z"))
+
+# named axes
+named_array[{"x": 0}] # array[0, :, :]
+named_array[{"z": 0}] # array[:, :, 0]
+
+named_array[{"x": 0, "y": 0}] # array[0, 0, :]
+named_array[{"x": slice(0, 1), "y": 0}] # array[0:1, 0, :]
+
+named_array[named_array > 3] # array[array > 3]
+
+
+# positional axes
+named_array[{0: 0}] # array[0, :, :]
+named_array[{2: 0}] # array[:, :, 0]
+
+named_array[{-3: 0}] # array[0, :, :]
+named_array[{-1: 0}] # array[:, :, 0]
+None
+```
+
+If multiple keys that point to the same positional axis are used, the last key will be used and all others will be ignored:
+
+```{code-cell}
+array = ak.Array([[[1, 2]], [[3]], [[4]], [[5, 6], [7]]])
+named_array = ak.with_named_axis(array, named_axis=("x", "y", "z"))
+
+assert ak.all(named_array[{0: 0, "x": slice(0, 2)}] == named_array[0:2])
+assert ak.all(named_array[{"x": slice(0, 2), 0: 0}] == named_array[0])
+```
+
+
+More detailed example:
+
+```{code-cell}
+# create a Record Array that represents four events with a variable number of jets
+events = ak.zip({
+ "event_no": np.arange(4),
+ "jetpt": ak.Array([[50, 60], [45], [], [80, 30, 50]]),
+})
+named_events = ak.with_named_axis(events, ("events", "jets"))
+
+print("classic indexing:", named_events[0, 0:1])
+print("named indexing :", named_events[{"events": 0, "jets": slice(0, 1)}])
+```
+
+For syntatic suger, use `np.s_` to define slices more easily:
+
+```{code-cell}
+array = ak.Array([[[1, 2]], [[3]], [[4]], [[5, 6], [7]]])
+named_array = ak.with_named_axis(array, named_axis=("x", "y", "z"))
+
+assert ak.all(named_array[{"x": np.s_[0:2]}] == named_array[{"x": slice(0, 2)}])
+```
+
+Highlevel Operations with Named Axes
+------------------------------------
+
+Named axes can be used for specifying the axis of a highlevel operation given that the operation is performed on an array that supports this named axis.
+
+For example, the `ak.sum` operation can be performed on an array with named axes:
+
+```{code-cell}
+array = ak.Array([[[1, 2]], [[3]], [[4]], [[5, 6], [7]]])
+named_array = ak.with_named_axis(array, named_axis=("x", "y", "z"))
+
+print("Sum over axis 'x':", ak.sum(named_array, axis="x")) # ak.sum(array, axis=0)
+print("Sum over axis 'y':", ak.sum(named_array, axis="y")) # ak.sum(array, axis=1)
+print("Sum over axis 'z':", ak.sum(named_array, axis="z")) # ak.sum(array, axis=2)
+```
+
+
+Named Axes Propagation Strategies
+---------------------------------
+
+
+Named axes are propagated through all operations in Awkward Array.
+For this, specific strategies are defined for each operation to ensure that the named axes are properly propagated to the result.
+
+The possible strategies are:
+- `keep all`: keep all named axes
+- `keep one`: keep one named axis
+- `keep up to`: keep all named axes up to a certain positional axis
+- `remove all`: remove all named axis
+- `remove one`: remove one named axis
+- `add one`: add a new axis
+- `unify`: unify named axes of two arrays. The named axes are unifiable if the have the same name (or `None`) and point to the same positional axis.
+
+Indexing operations
+: The following table shows the strategy for indexing operations:
+
+| Operation | Strategy |
+|----------------------|--------------|
+| `array[:]` | `keep all` |
+| `array[...]` | `keep all` |
+| `array[()]` | `keep all` |
+| `array[0:1]` | `keep all` |
+| `array[[0, 1]]` | `keep all` |
+| `array[array % 2]` | `keep all` |
+| `array[0]` | `remove one` |
+| `array[np.array(0)]` | `remove one` |
+| `array[None]` | `add one` |
+| `array[np.newaxis]` | `add one` |
+
+Universal functions (`ufuncs`)
+: `ufuncs` with single argument signatures (i.e. unary operations, such as `__abs__`, `__neg__`, `__invert__`, ...) do not modify named axes (strategy: `keep all`).
+: `ufuncs` with two argument signatures (i.e. binary operations, such as `__add__`, `__sub__`, `__mul__`, ...) try to merge named axis of the given arrays (strategy: `unify`).
+ This means that the named axes of the two arrays are merged if they have the same name (or either is `None`) and point to the same positional axis.
+ If there's a mismatch of named axes, e.g., the same named axis has different names or point to different positional axes, an exception is raised.
+
+```{code-cell}
+array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+named_array = ak.with_named_axis(array, named_axis=("x", "y"))
+
+# unary operations with named axes
+assert (-named_array).named_axis == {"x": 0, "y": 1}
+assert (+named_array).named_axis == {"x": 0, "y": 1}
+assert (~named_array).named_axis == {"x": 0, "y": 1}
+assert abs(named_array).named_axis == {"x": 0, "y": 1}
+
+# binary operations with named axes
+named_array1 = ak.with_named_axis(array, named_axis=(None, "y"))
+named_array2 = ak.with_named_axis(array, named_axis=("x", None))
+named_array3 = ak.with_named_axis(array, named_axis=("x", "y"))
+
+assert (array + array).named_axis == {}
+assert (named_array1 + array).named_axis == {"y": 1}
+assert (named_array2 + array).named_axis == {"x": 0}
+assert (named_array3 + array).named_axis == {"x": 0, "y": 1}
+
+assert (named_array1 + named_array2).named_axis == {"x": 0, "y": 1}
+assert (named_array3 + named_array3).named_axis == {"x": 0, "y": 1}
+```
+
+Reducers (`ak.sum`, `ak.any`, ...)
+: If `axis=int` and `keepdims=False` (typical use-case) removes the named axis that is reduced (strategy: `remove one`).
+: If `keepdims=True` is set, the named axis is kept (strategy: `keep all`).
+: If `axis=None` is set, all named axes are removed (strategy: `remove all`).
+
+```{code-cell}
+array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+named_array = ak.with_named_axis(array, ("x", "y"))
+
+assert ak.sum(named_array, axis="x", keepdims=False).named_axis == {"y": 0}
+assert ak.sum(named_array, axis="x", keepdims=True).named_axis == {"x": 0, "y": 1}
+```
+
+---
+A full list of operations and their strategies can be found in the following table.
+If an operation is not listed, the strategy is either `keep all` or automatically inferred from the below listed operations.
+
+
+| Operation | Strategy |
+|-----------------------------------------------------|--------------------|
+| `ak.all(..., axis=None)` | `remove all` |
+| `ak.all(..., axis=int, keepdims=False)` | `remove one` |
+| `ak.all(..., axis=int, keepdims=True)` | `keep all` |
+| `ak.any(..., axis=None)` | `remove all` |
+| `ak.any(..., axis=int, keepdims=False)` | `remove one` |
+| `ak.any(..., axis=int, keepdims=True)` | `keep all` |
+| `ak.[arg]cartesian` | `unify` |
+| `ak.[arg]combinations` | `keep all` |
+| `ak.[arg]max(..., axis=None)` | `remove all` |
+| `ak.[arg]max(..., axis=int, keepdims=False)` | `remove one` |
+| `ak.[arg]max(..., axis=int, keepdims=True)` | `keep all` |
+| `ak.[arg]min(..., axis=None)` | `remove all` |
+| `ak.[arg]min(..., axis=int, keepdims=False)` | `remove one` |
+| `ak.[arg]min(..., axis=int, keepdims=True)` | `keep all` |
+| `ak.[arg]sort` | `keep all` |
+| `ak.broadcast_arrays` | `unify`, `add one` |
+| `ak.broadcast_fields` | `unify`, `add one` |
+| `ak.categories` | `remove all` |
+| `ak.concatenate` | `unify` |
+| `ak.count[_nonzero](..., axis=None)` | `remove all` |
+| `ak.count[_nonzero](..., axis=int, keepdims=False)` | `remove one` |
+| `ak.count[_nonzero](..., axis=int, keepdims=True)` | `keep all` |
+| `ak.firsts` | `remove one` |
+| `ak.flatten(..., axis=None)` | `remove all` |
+| `ak.flatten(..., axis=0)` | `keep all` |
+| `ak.flatten(..., axis=(!=0), keepdims=True)` | `remove one` |
+| `ak.local_index` | `keep up to` |
+| `ak.num` | `keep one` |
+| `ak.prod(..., axis=None)` | `remove all` |
+| `ak.prod(..., axis=int, keepdims=False)` | `remove one` |
+| `ak.prod(..., axis=int, keepdims=True)` | `keep all` |
+| `ak.ravel` | `remove all` |
+| `ak.singletons` | `add one` |
+| `ak.sum(..., axis=None)` | `remove all` |
+| `ak.sum(..., axis=int, keepdims=False)` | `remove one` |
+| `ak.sum(..., axis=int, keepdims=True)` | `keep all` |
+| `ak.unflatten` | `remove all` |
+| `ak.where` | `unify`, `add one` |
+| `ak.with_field` | `unify`, `add one` |
+| `ak.zip` | `unify`, `add one` |
diff --git a/docs/user-guide/how-to-array-properties.md b/docs/user-guide/how-to-array-properties.md
new file mode 100644
index 0000000000..be811e888e
--- /dev/null
+++ b/docs/user-guide/how-to-array-properties.md
@@ -0,0 +1,23 @@
+---
+jupytext:
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.10.3
+kernelspec:
+ display_name: Python 3
+ language: python
+ name: python3
+---
+
+Array properties
+================
+
+The user guide is a collection of "how to..." guides for common tasks. See the left side-bar (or bring it into view by clicking on the upper-left `≡`) to access the guides, grouped by topic.
+
+If you're looking for documentation on a specific function, see the API reference instead.
+
+You can test any examples in a new window/tab by clicking on [![Try It! ⭷](https://img.shields.io/badge/-Try%20It%21%20%E2%86%97-orange?style=for-the-badge)](https://awkward-array.org/doc/main/_static/try-it.html).
+
+
diff --git a/pyproject.toml b/pyproject.toml
index f5d1424359..f10e745671 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -232,6 +232,7 @@ ignore_missing_imports = true
[[tool.mypy.overrides]]
module = [
'awkward._nplikes.*',
+ 'awkward._namedaxis',
'awkward._behavior.*',
'awkward._backends.*',
'awkward._meta.*',
diff --git a/src/awkward/__init__.py b/src/awkward/__init__.py
index c82e83777f..c84b655ccb 100644
--- a/src/awkward/__init__.py
+++ b/src/awkward/__init__.py
@@ -24,6 +24,7 @@
import awkward._errors
import awkward._lookup
import awkward._ext # strictly for unpickling from Awkward 1
+import awkward._namedaxis
# third-party connectors
import awkward._connect.numpy
diff --git a/src/awkward/_broadcasting.py b/src/awkward/_broadcasting.py
index 7eb2300372..8dab0af30e 100644
--- a/src/awkward/_broadcasting.py
+++ b/src/awkward/_broadcasting.py
@@ -11,6 +11,11 @@
import awkward as ak
from awkward._backends.backend import Backend
from awkward._backends.dispatch import backend_of
+from awkward._namedaxis import (
+ NAMED_AXIS_KEY,
+ _add_named_axis,
+ _unify_named_axis,
+)
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._nplikes.shape import ShapeItem, unknown_length
@@ -319,10 +324,18 @@ def is_string_like(obj) -> bool:
}
-def left_broadcast_to(content: Content, depth: int) -> Content:
- for _ in range(content.purelist_depth, depth):
- content = RegularArray(content, 1, content.length)
- return content
+def _export_named_axis_from_depth_to_lateral(
+ idx: int,
+ depth_context: dict[str, Any],
+ lateral_context: dict[str, Any],
+) -> None:
+ # set adjusted named axes to lateral (inplace)
+ named_axis, ndim = depth_context[NAMED_AXIS_KEY][idx]
+ seen_named_axis, _ = lateral_context[NAMED_AXIS_KEY][idx]
+ lateral_context[NAMED_AXIS_KEY][idx] = (
+ _unify_named_axis(named_axis, seen_named_axis),
+ ndim,
+ )
def broadcast_regular_dim_size(contents: Sequence[ak.contents.Content]) -> ShapeItem:
@@ -433,10 +446,32 @@ def apply_step(
max_depth = max(x.purelist_depth for x in contents)
if max_depth > 0 and all(x.purelist_isregular for x in contents):
- nextinputs = [
- left_broadcast_to(o, max_depth) if isinstance(o, Content) else o
- for o in inputs
- ]
+ nextinputs = []
+
+ named_axes_with_ndims = depth_context[NAMED_AXIS_KEY]
+ seen_named_axes = lateral_context[NAMED_AXIS_KEY]
+ for i, ((named_axis, ndim), o) in enumerate(
+ zip(named_axes_with_ndims, inputs)
+ ):
+ if isinstance(o, Content):
+ # rightbroadcast
+ for _ in range(o.purelist_depth, max_depth):
+ o = RegularArray(o, 1, o.length)
+ # track new dimensions for named axis
+ # rightbroadcasting adds a new first(!) dimension at depth
+ seen_named_axis, seen_ndim = seen_named_axes[i]
+ named_axis = _add_named_axis(named_axis, depth, ndim)
+ depth_context[NAMED_AXIS_KEY][i] = (
+ _unify_named_axis(named_axis, seen_named_axis),
+ ndim + 1,
+ )
+ if o.is_leaf:
+ _export_named_axis_from_depth_to_lateral(
+ i, depth_context, lateral_context
+ )
+ nextinputs.append(o)
+ else:
+ nextinputs.append(o)
# Did a broadcast take place?
if any(x is not y for x, y in zip(inputs, nextinputs)):
return apply_step(
@@ -538,6 +573,7 @@ def broadcast_any_list():
# Under the category of "is_list", we have both strings and non-strings
# The strings should behave like non-lists within these routines.
+ named_axes_with_ndims = depth_context[NAMED_AXIS_KEY]
# Are the non-string list types exclusively regular?
if all(x.is_regular or (is_string_like(x) or not x.is_list) for x in contents):
# Compute the expected dim size
@@ -586,7 +622,9 @@ def broadcast_any_list():
# we don't left-broadcast
nextinputs = []
nextparameters = []
- for x, x_is_string in zip(inputs, inputs_are_strings):
+ for i, ((named_axis, ndim), x, x_is_string) in enumerate(
+ zip(named_axes_with_ndims, inputs, inputs_are_strings)
+ ):
if isinstance(x, RegularArray) and not x_is_string:
content_size_maybe_one = (
x.size is not unknown_length and x.size == 1
@@ -603,6 +641,16 @@ def broadcast_any_list():
)
)
nextparameters.append(x._parameters)
+ # track new dimensions for named axis
+ # rightbroadcasting adds a new first(!) dimension as depth
+ depth_context[NAMED_AXIS_KEY][i] = (
+ _add_named_axis(named_axis, depth, ndim),
+ ndim + 1,
+ )
+ if x.is_leaf:
+ _export_named_axis_from_depth_to_lateral(
+ i, depth_context, lateral_context
+ )
# Any unknown values or sizes are assumed to be correct as-is
elif (
dim_size is unknown_length
@@ -667,7 +715,9 @@ def broadcast_any_list():
nextinputs = []
nextparameters = []
- for x, x_is_string in zip(inputs, input_is_string):
+ for i, ((named_axis, ndim), x, x_is_string) in enumerate(
+ zip(named_axes_with_ndims, inputs, input_is_string)
+ ):
if isinstance(x, listtypes) and not x_is_string:
next_content = broadcast_to_offsets_avoiding_carry(x, offsets)
nextinputs.append(next_content)
@@ -680,6 +730,16 @@ def broadcast_any_list():
.content
)
nextparameters.append(NO_PARAMETERS)
+ # track new dimensions for named axis
+ # leftbroadcasting adds a new last dimension at depth + 1
+ depth_context[NAMED_AXIS_KEY][i] = (
+ _add_named_axis(named_axis, depth + 1, ndim),
+ ndim + 1,
+ )
+ if x.is_leaf:
+ _export_named_axis_from_depth_to_lateral(
+ i, depth_context, lateral_context
+ )
else:
nextinputs.append(x)
nextparameters.append(NO_PARAMETERS)
@@ -889,7 +949,7 @@ def action_logical_or(inputs, backend, **kwargs):
(xy_mask, cond_mask),
action_logical_or,
0,
- None,
+ depth_context,
lateral_context,
simple_options,
)[0]
@@ -917,7 +977,7 @@ def apply_mask_action(inputs, backend, **kwargs):
(xy_unmasked, mask),
apply_mask_action,
0,
- None,
+ depth_context,
lateral_context,
simple_options,
)
diff --git a/src/awkward/_connect/numexpr.py b/src/awkward/_connect/numexpr.py
index 85ab566c8c..50c3e0485f 100644
--- a/src/awkward/_connect/numexpr.py
+++ b/src/awkward/_connect/numexpr.py
@@ -4,12 +4,15 @@
import sys
import warnings
+from functools import reduce
from packaging.version import parse as parse_version
import awkward as ak
-from awkward._behavior import behavior_of
+from awkward._attrs import attrs_of_obj
+from awkward._behavior import behavior_of, behavior_of_obj
from awkward._layout import wrap_layout
+from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis
_has_checked_version = False
@@ -110,9 +113,26 @@ def action(inputs, **ignore):
return None
behavior = behavior_of(*arrays)
- out = ak._broadcasting.broadcast_and_apply(arrays, action, allow_records=False)
+ depth_context, lateral_context = NamedAxesWithDims.prepare_contexts(arguments)
+ out = ak._broadcasting.broadcast_and_apply(
+ arrays,
+ action,
+ depth_context=depth_context,
+ lateral_context=lateral_context,
+ allow_records=False,
+ )
assert isinstance(out, tuple) and len(out) == 1
- return wrap_layout(out[0], behavior)
+ wrapped = wrap_layout(out[0], behavior)
+ out_named_axis = reduce(
+ _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis
+ )
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped,
+ named_axis=out_named_axis,
+ highlevel=True,
+ behavior=behavior_of_obj(wrapped),
+ attrs=attrs_of_obj(wrapped),
+ )
evaluate.evaluate = evaluate
@@ -148,6 +168,24 @@ def action(inputs, **ignore):
return None
behavior = behavior_of(*arrays)
- out = ak._broadcasting.broadcast_and_apply(arrays, action, allow_records=False)
+
+ depth_context, lateral_context = NamedAxesWithDims.prepare_contexts(arguments)
+ out = ak._broadcasting.broadcast_and_apply(
+ arrays,
+ action,
+ depth_context=depth_context,
+ lateral_context=lateral_context,
+ allow_records=False,
+ )
assert isinstance(out, tuple) and len(out) == 1
- return wrap_layout(out[0], behavior)
+ wrapped = wrap_layout(out[0], behavior)
+ out_named_axis = reduce(
+ _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis
+ )
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped,
+ named_axis=out_named_axis,
+ highlevel=True,
+ behavior=behavior_of_obj(wrapped),
+ attrs=attrs_of_obj(wrapped),
+ )
diff --git a/src/awkward/_connect/numpy.py b/src/awkward/_connect/numpy.py
index f17ee98b36..7f5a7cdb08 100644
--- a/src/awkward/_connect/numpy.py
+++ b/src/awkward/_connect/numpy.py
@@ -22,6 +22,7 @@
)
from awkward._categorical import as_hashable
from awkward._layout import wrap_layout
+from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis
from awkward._nplikes import to_nplike
from awkward._parameters import parameters_intersect
from awkward._regularize import is_non_string_like_iterable
@@ -363,6 +364,8 @@ def array_ufunc(ufunc, method: str, inputs, kwargs: dict[str, Any]):
attrs = attrs_of(*inputs)
backend = backend_of(*inputs, coerce_to_common=True)
+ depth_context, lateral_context = NamedAxesWithDims.prepare_contexts(inputs)
+
inputs = _array_ufunc_custom_cast(inputs, behavior, backend)
def action(inputs, **ignore):
@@ -464,13 +467,40 @@ def action(inputs, **ignore):
return None
out = ak._broadcasting.broadcast_and_apply(
- inputs, action, allow_records=False, function_name=ufunc.__name__
+ inputs,
+ action,
+ depth_context=depth_context,
+ lateral_context=lateral_context,
+ allow_records=False,
+ function_name=ufunc.__name__,
)
+ out_named_axis = functools.reduce(
+ _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis
+ )
if len(out) == 1:
- return wrap_layout(out[0], behavior=behavior, attrs=attrs)
+ wrapped = wrap_layout(out[0], behavior=behavior, attrs=attrs)
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped,
+ named_axis=out_named_axis,
+ highlevel=True,
+ behavior=None,
+ attrs=None,
+ )
else:
- return tuple(wrap_layout(o, behavior=behavior, attrs=attrs) for o in out)
+ wrapped_out = []
+ for o in out:
+ wrapped = wrap_layout(o, behavior=behavior, attrs=attrs)
+ wrapped_out.append(
+ ak.operations.ak_with_named_axis._impl(
+ wrapped,
+ named_axis=out_named_axis,
+ highlevel=True,
+ behavior=None,
+ attrs=None,
+ )
+ )
+ return tuple(wrapped_out)
def action_for_matmul(inputs):
diff --git a/src/awkward/_layout.py b/src/awkward/_layout.py
index 11cb4bcbe5..f642dd6ed4 100644
--- a/src/awkward/_layout.py
+++ b/src/awkward/_layout.py
@@ -56,9 +56,7 @@ def merge_mappings(
class HighLevelContext:
- def __init__(
- self, behavior: Mapping | None = None, attrs: Mapping[str, Any] | None = None
- ):
+ def __init__(self, behavior: Mapping | None = None, attrs: Mapping | None = None):
self._behavior = behavior
self._attrs = attrs
self._is_finalized = False
@@ -66,6 +64,22 @@ def __init__(
self._attrs_from_objects = []
self._behavior_from_objects = []
+ def with_attr(self, key, value) -> Self:
+ self._ensure_finalized()
+ return type(self)(
+ behavior=self.behavior,
+ attrs={**self.attrs, key: value},
+ ).finalize()
+
+ def without_attr(self, key) -> Self:
+ self._ensure_finalized()
+ attrs = dict(self.attrs)
+ attrs.pop(key, None)
+ return type(self)(
+ behavior=self.behavior,
+ attrs=attrs,
+ ).finalize()
+
def __enter__(self):
return self
@@ -81,8 +95,10 @@ def _ensure_not_finalized(self):
raise RuntimeError("HighLevelContext has already been finalized")
@property
- def attrs(self) -> Mapping[str, Any] | None:
+ def attrs(self) -> Mapping:
self._ensure_finalized()
+ if self._attrs is None:
+ self._attrs = {}
return self._attrs
@property
@@ -154,7 +170,11 @@ def unwrap(
)
def wrap(
- self, obj: Any, *, highlevel: bool = True, allow_other: bool = False
+ self,
+ obj: Any,
+ *,
+ highlevel: bool = True,
+ allow_other: bool = False,
) -> Any:
self._ensure_finalized()
@@ -230,7 +250,7 @@ def maybe_highlevel_to_lowlevel(obj):
Args:
obj: an object
- Calls #ak.to_layout and returns the result iff. the object is a high-level
+ Calls #ak.to_layout and returns the result if the object is a high-level
Awkward object, otherwise the object is returned as-is.
This function should be removed once scalars are properly handled by `to_layout`.
@@ -372,6 +392,34 @@ def attach(x):
return layout
+def _neg2pos_axis(
+ axis: int,
+ total: int,
+) -> int:
+ """
+ Converts a negative axis index to a positive one.
+
+ This function takes a negative axis index and the total number of axes and returns the corresponding positive axis index.
+ If the input axis index is already positive, it is returned as is.
+
+ Args:
+ axis (int): The axis index to convert. Can be negative.
+ total (int): The total number of axes.
+
+ Returns:
+ int: The positive axis index corresponding to the input axis index.
+
+ Examples:
+ >>> _neg2pos_axis(-1, 3)
+ 2
+ >>> _neg2pos_axis(1, 3)
+ 1
+ """
+ if axis < 0:
+ return total + axis
+ return axis
+
+
def maybe_posaxis(layout: Content, axis: int, depth: int) -> int | None:
from awkward.record import Record
@@ -386,6 +434,6 @@ def maybe_posaxis(layout: Content, axis: int, depth: int) -> int | None:
else:
is_branching, additional_depth = layout.branch_depth
if not is_branching:
- return axis + depth + additional_depth - 1
+ return _neg2pos_axis(axis, additional_depth) + depth - 1
else:
return None
diff --git a/src/awkward/_namedaxis.py b/src/awkward/_namedaxis.py
new file mode 100644
index 0000000000..9f0b8d36cc
--- /dev/null
+++ b/src/awkward/_namedaxis.py
@@ -0,0 +1,764 @@
+from __future__ import annotations
+
+import json
+import re
+from dataclasses import dataclass
+
+import awkward._typing as tp
+from awkward._layout import _neg2pos_axis
+from awkward._regularize import is_integer
+
+# axis names are hashables, mostly strings,
+# except for integers, which are reserved for positional axis.
+AxisName: tp.TypeAlias = tp.Hashable
+
+# e.g.: {"x": 0, "y": 1, "z": 2}
+AxisMapping: tp.TypeAlias = tp.Mapping[AxisName, int]
+
+# e.g.: ("x", "y", None) where None is a wildcard
+AxisTuple: tp.TypeAlias = tp.Tuple[AxisName, ...]
+
+
+NAMED_AXIS_KEY: tp.Literal["__named_axis__"] = (
+ "__named_axis__" # reserved for named axis
+)
+
+
+# just a class for inplace mutation
+class NamedAxis:
+ mapping: AxisMapping
+
+
+NamedAxis.mapping = {}
+
+
+def _prettify_named_axes(
+ named_axis: AxisMapping,
+ delimiter: str = ", ",
+ maxlen: None | int = None,
+) -> str:
+ """
+ This function takes a named axis mapping and returns a string representation of the mapping.
+ The axis names are sorted in ascending order of their corresponding integer values.
+ If the axis name is a valid Python identifier, it is represented as is.
+ Otherwise, it is represented as a JSON string.
+
+ Args:
+ named_axis (AxisMapping): The named axis mapping to prettify.
+ delimiter (str, optional): The delimiter to use between items in the output string. Defaults to ", ".
+ maxlen (None | int, optional): The maximum length of the output string. If the string exceeds this length, it is truncated and ends with "...". Defaults to None.
+
+ Returns:
+ str: The prettified string representation of the named axis mapping.
+
+ Examples:
+ >>> _prettify_named_axes({"x": 0, "y": 1, "z": 2})
+ 'x:0, y:1, z:2'
+ >>> _prettify_named_axes({"x": 0, "y": 1, "$": 2})
+ 'x:0, y:1, "$":2'
+ >>> _prettify_named_axes({"x": 0, "y": 1, "z": 2}, delimiter="; ")
+ 'x:0; y:1; z:2'
+ >>> _prettify_named_axes({"foo": 0, "bar": 1, "baz": 2}, maxlen=17)
+ 'foo:0, bar:1, ...'
+ """
+
+ def _prettify(ax: AxisName) -> str:
+ repr_ax = str(ax)
+ if re.match("[A-Za-z_][A-Za-z_0-9]*", repr_ax):
+ return repr_ax
+ return json.dumps(repr_ax)
+
+ sorted_named_axis = sorted(named_axis.items(), key=lambda x: x[1])
+ items = [
+ f"{_prettify(named_ax)}:{pos_ax}" for named_ax, pos_ax in sorted_named_axis
+ ]
+ if maxlen is not None:
+ if len(delimiter.join(items)) > maxlen:
+ while (
+ len(delimiter.join(items)) > maxlen - len(delimiter + "...")
+ ) and items:
+ items.pop(-1)
+ items.append("...")
+ return delimiter.join(items)
+
+
+def _get_named_axis(ctx: tp.Any) -> AxisMapping:
+ """
+ Retrieves the named axis from the provided context.
+
+ Args:
+ ctx (Any): The context from which the named axis is to be retrieved.
+
+ Returns:
+ AxisMapping: The named axis retrieved from the context. If the context does not include a named axis,
+ an empty dictionary is returned.
+
+ Examples:
+ >>> _get_named_axis(ak.Array([1, 2, 3], named_axis={"x": 0}))
+ {"x": 0}
+ >>> _get_named_axis(np.array([1, 2, 3]))
+ {}
+ >>> _get_named_axis({NAMED_AXIS_KEY: {"x": 0, "y": 1, "z": 2}})
+ {"x": 0, "y": 1, "z": 2}
+ >>> _get_named_axis({"other_key": "other_value"})
+ {}
+ """
+ if hasattr(ctx, "attrs"):
+ return _get_named_axis(ctx.attrs)
+ elif isinstance(ctx, tp.Mapping) and NAMED_AXIS_KEY in ctx:
+ return dict(ctx[NAMED_AXIS_KEY])
+ else:
+ return {}
+
+
+def _make_positional_axis_tuple(n: int) -> tuple[int, ...]:
+ """
+ Generates a positional axis tuple of length n.
+
+ Args:
+ n (int): The length of the positional axis tuple to generate.
+
+ Returns:
+ tuple[int, ...]: The generated positional axis tuple.
+
+ Examples:
+ >>> _make_positional_axis_tuple(3)
+ (0, 1, 2)
+ """
+ return tuple(range(n))
+
+
+def _is_valid_named_axis(axis: AxisName) -> bool:
+ """
+ Checks if the given axis is a valid named axis. A valid named axis is a hashable object that is not an integer or None. Currently it is restricted to strings.
+
+ Args:
+ axis (AxisName): The axis to check.
+
+ Returns:
+ bool: True if the axis is a valid named axis, False otherwise.
+
+ Examples:
+ >>> _is_valid_named_axis("x")
+ True
+ >>> _is_valid_named_axis(1)
+ False
+ """
+ return (
+ # axis must be hashable
+ isinstance(axis, AxisName)
+ # ... but not an integer, otherwise we would confuse it with positional axis
+ and not is_integer(axis)
+ # we also prohibit None, which is reserved for wildcard
+ and axis is not None
+ # Let's only allow strings for now, in the future we can open up to more types
+ # by removing the isinstance(axis, str) check.
+ and isinstance(axis, str)
+ )
+
+
+def _check_valid_axis(axis: AxisName) -> AxisName:
+ """
+ Checks if the given axis is a valid named axis. If not, raises a ValueError.
+
+ Args:
+ axis (AxisName): The axis to check.
+
+ Returns:
+ AxisName: The axis if it is a valid named axis.
+
+ Raises:
+ ValueError: If the axis is not a valid named axis.
+
+ Examples:
+ >>> _check_valid_axis("x")
+ "x"
+ >>> _check_valid_axis(1)
+ Traceback (most recent call last):
+ ...
+ ValueError: Axis names must be hashable and not int, got 1 [type(axis)=]
+ """
+ if not _is_valid_named_axis(axis):
+ raise ValueError(
+ f"Axis names must be hashable and not int, got {axis!r} [{type(axis)=}]"
+ )
+ return axis
+
+
+def _check_valid_named_axis_mapping(named_axis: AxisMapping) -> AxisMapping:
+ """
+ Checks if the given named axis mapping is valid. A valid named axis mapping is a dictionary where the keys are valid named axes
+ (hashable objects that are not integers) and the values are integers.
+
+ Args:
+ named_axis (AxisMapping): The named axis mapping to check.
+
+ Raises:
+ ValueError: If any of the keys in the named axis mapping is not a valid named axis or if any of the values is not an integer.
+
+ Examples:
+ >>> _check_valid_named_axis_mapping({"x": 0, "y": 1, "z": 2}) # No exception is raised
+ >>> _check_valid_named_axis_mapping({"x": 0, "y": 1, "z": "2"})
+ Traceback (most recent call last):
+ ...
+ ValueError: Named axis mapping values must be integers, got '2' [type(axis)=]
+ >>> _check_valid_named_axis_mapping({"x": 0, 1: 1, "z": 2})
+ Traceback (most recent call last):
+ ...
+ ValueError: Axis names must be hashable and not int, got 1 [type(axis)=]
+ """
+ for name, axis in named_axis.items():
+ _check_valid_axis(name)
+ if not is_integer(axis):
+ raise ValueError(
+ f"Named axis mapping values must be integers, got {axis!r} [{type(axis)=}]"
+ )
+ return named_axis
+
+
+def _axis_tuple_to_mapping(axis_tuple: AxisTuple) -> AxisMapping:
+ """
+ Converts a tuple of axis names to a dictionary mapping axis names to their positions.
+
+ Args:
+ axis_tuple (AxisTuple): A tuple of axis names. Can include None as a wildcard.
+
+ Returns:
+ AxisMapping: A dictionary mapping axis names to their positions.
+
+ Examples:
+ >>> _axis_tuple_to_mapping(("x", None, "y"))
+ {"x": 0, "y": 2}
+ """
+ return {axis: i for i, axis in enumerate(axis_tuple) if axis is not None}
+
+
+def _prepare_named_axis_for_attrs(
+ named_axis: AxisMapping | AxisTuple,
+ ndim: int,
+) -> AxisMapping:
+ """
+ Prepares the named axis for attribute assignment.
+
+ This function takes a named axis, which can either be a mapping or a tuple, and returns a dictionary mapping axis names to their positions.
+ The function checks if the named axis is valid and if the positional axes match the number of dimensions. If not, an error is raised.
+
+ Args:
+ named_axis (AxisMapping | AxisTuple): The named axis to prepare. Can either be a mapping or a tuple.
+ ndim (int): The number of dimensions.
+
+ Returns:
+ AxisMapping: The prepared named axis.
+
+ Raises:
+ TypeError: If the named axis is not a mapping or a tuple.
+ ValueError: If the named axes do not point to positional axes matching the number of dimensions.
+
+ Examples:
+ >>> _prepare_named_axis_for_attrs({"x": 0, "y": 1, "z": 2}, 3)
+ {"x": 0, "y": 1, "z": 2}
+ >>> _prepare_named_axis_for_attrs(("x", "y", "z"), 3)
+ {"x": 0, "y": 1, "z": 2}
+ >>> _prepare_named_axis_for_attrs({"x": 0, "y": 1, "z": 2}, 2)
+ Traceback (most recent call last):
+ ...
+ ValueError: Named axes must point to positional axes matching 2 dimensions, got named_axis={"x": 0, "y": 1, "z": 2}, ndim=2
+ """
+ if isinstance(named_axis, tuple):
+ _named_axis = _axis_tuple_to_mapping(named_axis)
+ elif isinstance(named_axis, dict):
+ _named_axis = named_axis
+ else:
+ raise TypeError(
+ f"named_axis must be a mapping or a tuple, got {named_axis=} [{type(named_axis)=}]"
+ )
+ _check_valid_named_axis_mapping(_named_axis)
+ pos_axes = set(_named_axis.values())
+ if max(pos_axes, default=0) >= ndim or min(pos_axes, default=0) < -ndim:
+ raise ValueError(
+ f"Named axes must point to positional axes matching {ndim} dimensions, got {named_axis=}, {ndim=}"
+ )
+ return _named_axis
+
+
+def _make_named_int_class(name: tp.Any) -> type[int]:
+ class NamedInt(int):
+ def __repr__(self):
+ value_repr = super().__repr__()
+ return f"{name!r} (named axis) -> {value_repr} (pos. axis)"
+
+ __str__ = __repr__
+
+ return NamedInt
+
+
+def _named_axis_to_positional_axis(
+ named_axis: AxisMapping,
+ axis: AxisName,
+) -> int | None:
+ """
+ Converts a single named axis to a positional axis.
+
+ Args:
+ axis (AxisName): The named axis to convert.
+ named_axis (AxisMapping): The mapping from named axes to positional axes.
+
+ Returns:
+ int | None: The positional axis corresponding to the given named axis. If the named axis is not found in the mapping, returns None.
+
+ Raises:
+ ValueError: If the named axis is not found in the named axis mapping.
+
+ Examples:
+ >>> _named_axis_to_positional_axis({"x": 0, "y": 1, "z": 2}, "x")
+ 0
+ """
+ if _is_valid_named_axis(axis):
+ if axis not in named_axis:
+ raise ValueError(f"{axis=} not found in {named_axis=} mapping.")
+
+ # we wrap it to preserve the original name in its __repr__ and __str__
+ # in order to properly display it in error messages. This is useful for cases
+ # where the positional axis is pointing to a non-existing axis. The error message
+ # will then show the original (named) axis together with the positional axis.
+ cls = _make_named_int_class(axis)
+ return cls(named_axis[axis])
+
+ if is_integer(axis):
+ # TODO: is_integer is an external helper function that doesn't specify types
+ return int(tp.cast(tp.Any, axis))
+ elif axis is None:
+ return None
+ else:
+ raise ValueError(f"Invalid {axis=} [{type(axis)=}]")
+
+
+# These are the strategies to handle named axis for the
+# output array when performing operations along an axis.
+# See studies/named_axis.md#named-axis-in-high-level-functions and
+# https://pytorch.org/docs/stable/name_inference.html.
+#
+# The possible strategies are:
+# - "keep all" (_keep_named_axis(..., None)): Keep all named axes in the output array, e.g.: `ak.drop_none`
+# - "keep one" (_keep_named_axis(..., int)): Keep one named axes in the output array, e.g.: `ak.firsts`
+# - "keep up to" (_keep_named_axis_up_to(..., int)): Keep all named axes up to a certain positional axis in the output array, e.g.: `ak.local_index`
+# - "remove all" (_remove_all_named_axis): Removes all named axis, e.g.: `ak.categories`
+# - "remove one" (_remove_named_axis): Remove the named axis from the output array, e.g.: `ak.sum`
+# - "add one" (_add_named_axis): Add a new named axis to the output array, e.g.: `ak.concatenate`
+# - "unify" (_unify_named_axis): Unify the named axis in the output array given two input arrays, e.g.: `ak.broadcast_arrays`
+
+
+def _keep_named_axis(
+ named_axis: AxisMapping,
+ axis: int | None = None,
+) -> AxisMapping:
+ """
+ Determines the new named axis after keeping the specified axis. This function is useful when an operation
+ is applied that retains only one axis.
+
+ Args:
+ named_axis (AxisMapping): The current named axis.
+ axis (int | None, optional): The index of the axis to keep. If None, all axes are kept. Default is None.
+
+ Returns:
+ AxisMapping: The new named axis after keeping the specified axis.
+
+ Examples:
+ >>> _keep_named_axis({"x": 0, "y": 1, "z": 2}, 1)
+ {"y": 0}
+ >>> _keep_named_axis({"x": 0, "y": 1, "z": 2}, None)
+ {"x": 0, "y": 1, "z": 2}
+ """
+ if axis is None:
+ return named_axis
+ return {k: 0 for k, v in named_axis.items() if v == axis}
+
+
+def _keep_named_axis_up_to(
+ named_axis: AxisMapping,
+ axis: int,
+ total: int,
+) -> AxisMapping:
+ """
+ Determines the new named axis after keeping all axes up to the specified axis. This function is useful when an operation
+ is applied that retains all axes up to a certain axis.
+
+ Args:
+ named_axis (AxisMapping): The current named axis.
+ axis (int): The index of the axis up to which to keep.
+ total (int): The total number of axes.
+
+ Returns:
+ AxisMapping: The new named axis after keeping all axes up to the specified axis.
+
+ Examples:
+ >>> _keep_named_axis_up_to({"x": 0, "y": 1, "z": 2}, 1, 3)
+ {"x": 0, "y": 1}
+ >>> _keep_named_axis_up_to({"x": 0, "y": 1, "z": 2}, -1, 3)
+ {"x": 0, "y": 1, "z": 2}
+ >>> _keep_named_axis_up_to({"x": 0, "y": 1, "z": 2}, 0, 3)
+ {"x": 0}
+ """
+ axis = _neg2pos_axis(axis, total)
+ out = {}
+ for k, v in named_axis.items():
+ if v >= 0 and v <= axis:
+ out[k] = v
+ elif v < 0 and v >= -axis - 1:
+ out[k] = v
+ return out
+
+
+def _remove_all_named_axis(
+ named_axis: AxisMapping,
+) -> AxisMapping:
+ """
+ Returns an empty named axis mapping after removing all axes from the given named axis mapping.
+ This function is typically used when an operation that eliminates all axes is applied.
+
+ Args:
+ named_axis (AxisMapping): The current named axis mapping.
+
+ Returns:
+ AxisMapping: An empty named axis mapping.
+
+ Examples:
+ >>> _remove_all_named_axis({"x": 0, "y": 1, "z": 2})
+ {}
+ """
+ return _remove_named_axis(named_axis=named_axis, axis=None)
+
+
+def _remove_named_axis(
+ named_axis: AxisMapping,
+ axis: int | None = None,
+ total: int | None = None,
+) -> AxisMapping:
+ """
+ Determines the new named axis after removing the specified axis. This is useful, for example,
+ when applying an operation that removes one axis.
+
+ Args:
+ named_axis (AxisMapping): The current named axis.
+ axis (int | None, optional): The index of the axis to remove. If None, no axes are removed. Default is None.
+ total (int | None, optional): The total number of axes. If None, it is calculated as the length of the named axis. Default is None.
+
+ Returns:
+ AxisMapping: The new named axis after removing the specified axis.
+
+ Examples:
+ >>> _remove_named_axis({"x": 0, "y": 1}, None)
+ {}
+ >>> _remove_named_axis({"x": 0, "y": 1}, 0)
+ {"y": 0}
+ >>> _remove_named_axis({"x": 0, "y": 1, "z": 2}, 1)
+ {"x": 0, "z": 1}
+ >>> _remove_named_axis({"x": 0, "y": 1, "z": -1}, 1)
+ {"x": 0, "z": -1}
+ >>> _remove_named_axis({"x": 0, "y": 1, "z": -3}, 1)
+ {"x": 0, "z": -2}
+ """
+ if axis is None:
+ return {}
+
+ if total is None:
+ total = len(named_axis)
+
+ # remove the specified axis
+ out = {
+ ax: pos
+ for ax, pos in named_axis.items()
+ if _neg2pos_axis(pos, total) != _neg2pos_axis(axis, total)
+ }
+
+ return _adjust_pos_axis(out, axis, total, direction=-1)
+
+
+def _adjust_pos_axis(
+ named_axis: AxisMapping,
+ axis: int,
+ total: int,
+ direction: int,
+) -> AxisMapping:
+ """
+ Adjusts the positions of the axes in the named axis mapping after an axis has been removed or added.
+
+ Args:
+ named_axis (AxisMapping): The current named axis mapping.
+ axis (int): The position of the removed/added axis.
+ total (int): The total number of axes.
+ direction (int): The direction of the adjustment. -1 means axis is removed; +1 means axis is added. Default is +1.
+
+ Returns:
+ AxisMapping: The adjusted named axis mapping.
+
+ Examples:
+ # axis=1 removed
+ >>> _adjust_pos_axis({"x": 0, "z": 2}, 1, 3, -1)
+ {"x": 0, "z": 1}
+ # axis=1 added
+ >>> _adjust_pos_axis({"x": 0, "z": 2}, 1, 3, +1)
+ {"x": 0, "z": 3}
+ # axis=1 removed
+ >>> _adjust_pos_axis({"x": 0, "z": -1}, 1, 3, -1)
+ {"x": 0, "z": -1}
+ # axis=1 added
+ >>> _adjust_pos_axis({"x": 0, "z": -1}, 1, 3, +1)
+ {"x": 0, "z": -1}
+ """
+ assert direction in (-1, +1), f"Invalid direction: {direction}"
+
+ def _adjust(pos: int, axis: int, direction: int) -> int:
+ # positive axis
+ if axis >= 0:
+ # positive axis and position greater than or equal to the removed/added (positive) axis
+ # -> change position by direction
+ if pos >= axis:
+ return pos + direction
+ # positive axis and negative position
+ # -> change position by direction
+ elif pos < 0 and pos + total < axis:
+ return pos - direction
+ # positive axis and position smaller than the removed/added (positive) axis, but greater than 0
+ # -> keep position
+ else:
+ return pos
+ # negative axis
+ else:
+ # negative axis and position smaller than the removed/added (negative) axis
+ # -> change position by inverse direction
+ if pos <= axis:
+ return pos - direction
+ # negative axis and positive position
+ # -> change position by inverse direction
+ elif pos > 0 and pos > axis + total:
+ return pos + direction
+ # negative axis and position greater than the removed/added (negative) axis, but smaller than 0
+ # -> keep position
+ else:
+ return pos
+
+ return {k: _adjust(v, axis, direction) for k, v in named_axis.items()}
+
+
+def _add_named_axis(
+ named_axis: AxisMapping,
+ axis: int,
+ total: int | None = None,
+) -> AxisMapping:
+ """
+ Adds a new axis to the named_axis at the specified position.
+
+ Args:
+ named_axis (AxisMapping): The current named axis mapping.
+ axis (int): The position at which to add the new axis.
+ total (int | None): The total number of axes.
+
+ Returns:
+ AxisMapping: The updated named axis mapping after adding the new axis.
+
+ Examples:
+ >>> _add_named_axis({"x": 0, "y": 1, "z": 2}, 0)
+ {"x": 1, "y": 2, "z": 3}
+ >>> _add_named_axis({"x": 0, "y": 1, "z": 2}, 1)
+ {"x": 0, "y": 2, "z": 3}
+ """
+ if total is None:
+ total = len(named_axis)
+
+ return _adjust_pos_axis(named_axis, axis, total, direction=+1)
+
+
+def _unify_named_axis(
+ named_axis1: AxisMapping,
+ named_axis2: AxisMapping,
+) -> AxisMapping:
+ """
+ Unifies two named axes into a single named axis. The function iterates over all positional axes present in either of the input named axes.
+ For each positional axis, it checks the corresponding axis names in both input named axes. If the axis names are the same or if one of them is None,
+ the unified axis will be the non-None axis. If the axis names are different and neither of them is None, a ValueError is raised.
+
+ Args:
+ named_axis1 (AxisMapping): The first named axis to unify.
+ named_axis2 (AxisMapping): The second named axis to unify.
+
+ Returns:
+ AxisMapping: The unified named axis.
+
+ Raises:
+ ValueError: If the axes are different and neither of them is None.
+
+ Examples:
+ >>> _unify_named_axis({"x": 0, "y": 1, "z": 2}, {"x": 0, "y": 1, "z": 2})
+ {"x": 0, "y": 1, "z": 2}
+
+ >>> _unify_named_axis({"x": 0, "y": 1}, {"x": 0, "y": 1, "z": 2})
+ {"x": 0, "y": 1, "z": 2}
+
+ >>> _unify_named_axis({"x": 0, "y": 1, "z": 2}, {"a": 0, "b": 1, "c": 2})
+ Traceback (most recent call last):
+ ...
+ ValueError: The named axes are different. Got: 'x' and 'a' for positional axis 0
+
+ >>> _unify_named_axis({"x": 0, "y": 1, "z": 2}, {"x": 0, "y": 1, "z": 3})
+ {"x": 0, "y": 1, "z": 2}
+
+ >>> _unify_named_axis({"x": 0, "y": 1, "z": 2}, {})
+ {"x": 0, "y": 1, "z": 2}
+
+ >>> _unify_named_axis({}, {"x": 0, "y": 1, "z": 2})
+ {"x": 0, "y": 1, "z": 2}
+
+ >>> _unify_named_axis({}, {})
+ {}
+ """
+
+ def _get_axis_name(
+ axis_mapping: AxisMapping, positional_axis: int
+ ) -> AxisName | None:
+ for name, position in axis_mapping.items():
+ if position == positional_axis:
+ return name
+ return None
+
+ unified_named_axis = {}
+ all_positional_axes = set(named_axis1.values()) | set(named_axis2.values())
+ for position in all_positional_axes:
+ axis_name1 = _get_axis_name(named_axis1, position)
+ axis_name2 = _get_axis_name(named_axis2, position)
+ if axis_name1 is not None and axis_name2 is not None:
+ if axis_name1 != axis_name2:
+ raise ValueError(
+ f"The named axes are incompatible. Got: {axis_name1} and {axis_name2} for positional axis {position}"
+ )
+ unified_named_axis[axis_name1] = position
+ elif axis_name1 is not None: # axis_name2 is None
+ unified_named_axis[axis_name1] = position
+ elif axis_name2 is not None: # axis_name1 is None
+ unified_named_axis[axis_name2] = position
+ return unified_named_axis
+
+
+@dataclass
+class NamedAxesWithDims:
+ """
+ A dataclass that stores the named axis and their corresponding dimensions.
+ This is a helper class to store the named axis mapping and the number of
+ dimensions of each named axis, which is useful for broadcasting.
+
+ Attributes:
+ named_axis (AxisMapping): The named axis mapping.
+ ndims (Tuple[int]): The number of dimensions of the named axis.
+ """
+
+ named_axis: list[AxisMapping]
+ ndims: list[int]
+
+ def __post_init__(self):
+ if len(self.named_axis) != len(self.ndims):
+ raise ValueError(
+ "The number of dimensions must match the number of named axis mappings."
+ )
+
+ def __iter__(self) -> tp.Iterator[tuple[AxisMapping, int]]:
+ yield from zip(self.named_axis, self.ndims)
+
+ @classmethod
+ def prepare_contexts(
+ cls, arrays: tp.Sequence, unwrap_kwargs: dict | None = None
+ ) -> tuple[dict, dict]:
+ from awkward._layout import HighLevelContext
+ from awkward._typetracer import MaybeNone
+
+ # unwrap options
+ arrays = [x.content if isinstance(x, MaybeNone) else x for x in arrays]
+
+ _unwrap_kwargs = {"allow_unknown": True}
+ if unwrap_kwargs is not None:
+ _unwrap_kwargs.update(unwrap_kwargs)
+
+ _named_axes = []
+ _ndims = []
+ for array in arrays:
+ with HighLevelContext() as ctx:
+ layout = ctx.unwrap(array, **_unwrap_kwargs)
+ _named_axes.append(_get_named_axis(array))
+ _ndims.append(layout.minmax_depth[1])
+
+ depth_context = {NAMED_AXIS_KEY: cls(_named_axes, _ndims)}
+ lateral_context = {NAMED_AXIS_KEY: cls(_named_axes, _ndims)}
+ return depth_context, lateral_context
+
+ def __setitem__(self, index: int, named_axis_with_ndim: tuple[AxisMapping, int]):
+ named_axis, ndim = named_axis_with_ndim
+ self.named_axis[index] = named_axis
+ self.ndims[index] = ndim
+
+ def __getitem__(self, index: int) -> tuple[AxisMapping, int]:
+ return self.named_axis[index], self.ndims[index]
+
+ def __len__(self) -> int:
+ return len(self.named_axis)
+
+
+# Define a type alias for a slice or int (can be a single axis or a sequence of axes)
+AxisSlice: tp.TypeAlias = tp.Union[tuple, slice, int, tp.EllipsisType, None]
+NamedAxisSlice: tp.TypeAlias = tp.Dict[AxisName, AxisSlice]
+
+
+def _normalize_named_slice(
+ named_axis: AxisMapping,
+ where: AxisSlice | NamedAxisSlice,
+ total: int,
+) -> AxisSlice:
+ """
+ Normalizes a named slice into a positional slice.
+
+ This function takes a named slice (a dictionary mapping axis names to slices) and converts it into a positional slice
+ (a tuple of slices). The positional slice can then be used to index an array.
+
+ Args:
+ named_axis (AxisMapping): The current named axis mapping.
+ where (AxisSlice | NamedAxisSlice): The slice to normalize. Can be a single slice, a tuple of slices, or a dictionary mapping axis names to slices.
+ total (int): The total number of axes.
+
+ Returns:
+ AxisSlice: The normalized slice.
+
+ Raises:
+ ValueError: If an invalid axis name is provided in the slice.
+
+ Examples:
+ >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {0: 0}, 3)
+ (0, slice(None), slice(None))
+ >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {-1: 0}, 3)
+ (slice(None), slice(None), 0)
+ >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {"x": 0}, 3)
+ (0, slice(None), slice(None))
+ >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {"x": 0, "y": 1}, 3)
+ (0, 1, slice(None))
+ >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {"x": 0, "y": 1, "z": ...}, 3)
+ (0, 1, ...)
+ >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {"x": 0, "y": 1, "z": slice(0, 1)}, 3)
+ (0, 1, slice(0, 1))
+ >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {"x": (0, 1)}, 3)
+ ((0, 1), slice(None), slice(None))
+ >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {"x": [0, 1]}, 3)
+ ([0, 1], slice(None), slice(None))
+ """
+ if isinstance(where, dict):
+ out_where: list[AxisSlice] = [slice(None)] * total
+ for ax_name, ax_where in where.items():
+ slice_ = ax_where if ax_where is not ... else slice(None)
+ if is_integer(ax_name):
+ # it's an integer, pyright doesn't get this
+ idx = tp.cast(int, ax_name)
+ out_where[idx] = slice_
+ elif _is_valid_named_axis(ax_name):
+ # it's an integer, pyright doesn't get this
+ idx = tp.cast(int, _named_axis_to_positional_axis(named_axis, ax_name))
+ out_where[idx] = slice_
+ else:
+ raise ValueError(f"Invalid axis name: {ax_name} in slice {where}")
+ where = tuple(out_where)
+ return where
diff --git a/src/awkward/_nplikes/array_like.py b/src/awkward/_nplikes/array_like.py
index d82611ae5d..d75fe6cfcb 100644
--- a/src/awkward/_nplikes/array_like.py
+++ b/src/awkward/_nplikes/array_like.py
@@ -8,6 +8,7 @@
from awkward._typing import (
TYPE_CHECKING,
DType,
+ EllipsisType,
Protocol,
Self,
SupportsIndex,
@@ -15,8 +16,6 @@
)
if TYPE_CHECKING:
- from types import EllipsisType
-
from numpy.typing import DTypeLike
diff --git a/src/awkward/_nplikes/typetracer.py b/src/awkward/_nplikes/typetracer.py
index 64b99abf01..548071c0ee 100644
--- a/src/awkward/_nplikes/typetracer.py
+++ b/src/awkward/_nplikes/typetracer.py
@@ -26,6 +26,7 @@
TYPE_CHECKING,
Any,
DType,
+ EllipsisType,
Final,
Literal,
Self,
@@ -36,8 +37,6 @@
)
if TYPE_CHECKING:
- from types import EllipsisType
-
from numpy.typing import DTypeLike
from awkward.contents.content import Content
diff --git a/src/awkward/_operators.py b/src/awkward/_operators.py
index 2c58330492..f5a2cf90da 100644
--- a/src/awkward/_operators.py
+++ b/src/awkward/_operators.py
@@ -50,6 +50,7 @@ def _binary_method(ufunc, name):
def func(self, other):
if _disables_array_ufunc(other):
return NotImplemented
+
return ufunc(self, other)
func.__name__ = f"__{name}__"
diff --git a/src/awkward/_regularize.py b/src/awkward/_regularize.py
index 663d9eb01a..6f78a18409 100644
--- a/src/awkward/_regularize.py
+++ b/src/awkward/_regularize.py
@@ -7,7 +7,7 @@
from collections.abc import Iterable, Sequence, Sized
from awkward._nplikes.numpy_like import NumpyMetadata
-from awkward._typing import AxisMaybeNone, SupportsInt
+from awkward._typing import Any
np = NumpyMetadata.instance()
@@ -51,8 +51,19 @@ def is_non_string_like_sequence(obj) -> bool:
return not isinstance(obj, (str, bytes)) and isinstance(obj, Sequence)
-def regularize_axis(axis: SupportsInt | None) -> AxisMaybeNone:
- if axis is None:
- return None
+def regularize_axis(axis: Any, none_allowed: bool = True) -> int | None:
+ """
+ This function's main purpose is to convert [np,cp,...].array(0) to 0.
+ """
+ if is_integer_like(axis):
+ regularized_axis = int(axis)
else:
- return int(axis)
+ regularized_axis = axis
+ cond = is_integer(regularized_axis)
+ msg = f"'axis' must be an integer, not {axis!r}"
+ if none_allowed:
+ cond = cond or regularized_axis is None
+ msg = f"'axis' must be an integer or None, not {axis!r}"
+ if not cond:
+ raise TypeError(msg)
+ return regularized_axis
diff --git a/src/awkward/_typing.py b/src/awkward/_typing.py
index 0e987b4399..be474a37d9 100644
--- a/src/awkward/_typing.py
+++ b/src/awkward/_typing.py
@@ -26,6 +26,7 @@
"Literal",
"SupportsIndex",
"ParamSpec",
+ "EllipsisType",
*typing.__all__,
}
)
@@ -46,7 +47,10 @@
TypeGuard,
Unpack,
)
+
+ EllipsisType = type(...)
else:
+ from types import EllipsisType
from typing import (
ClassVar,
Final,
diff --git a/src/awkward/contents/content.py b/src/awkward/contents/content.py
index d0169ee2eb..f324d9dac5 100644
--- a/src/awkward/contents/content.py
+++ b/src/awkward/contents/content.py
@@ -16,8 +16,14 @@
)
from awkward._behavior import get_array_class, get_record_class
from awkward._kernels import KernelError
-from awkward._layout import wrap_layout
+from awkward._layout import maybe_posaxis, wrap_layout
from awkward._meta.meta import Meta
+from awkward._namedaxis import (
+ NamedAxis,
+ _add_named_axis,
+ _keep_named_axis,
+ _remove_named_axis,
+)
from awkward._nplikes import to_nplike
from awkward._nplikes.dispatch import nplike_of_obj
from awkward._nplikes.numpy import Numpy
@@ -27,7 +33,12 @@
parameters_are_equal,
type_parameters_equal,
)
-from awkward._regularize import is_integer_like, is_sized_iterable
+from awkward._regularize import (
+ is_array_like,
+ is_integer,
+ is_integer_like,
+ is_sized_iterable,
+)
from awkward._slicing import normalize_slice
from awkward._typing import (
TYPE_CHECKING,
@@ -38,6 +49,7 @@
Protocol,
Self,
SupportsIndex,
+ Type,
TypeAlias,
TypedDict,
)
@@ -509,10 +521,14 @@ def _getitem_next_missing(
)
def __getitem__(self, where):
- return self._getitem(where)
+ return self._getitem(where, NamedAxis)
- def _getitem(self, where):
+ def _getitem(self, where, named_axis: Type[NamedAxis] = NamedAxis):
if is_integer_like(where):
+ # propagate named_axis to output
+ named_axis.mapping = _remove_named_axis(
+ named_axis.mapping, where, self.purelist_depth
+ )
return self._getitem_at(ak._slicing.normalize_integer_like(where))
elif isinstance(where, slice) and where.step is None:
@@ -523,21 +539,35 @@ def _getitem(self, where):
return self._getitem_range(start, stop)
elif isinstance(where, slice):
- return self._getitem((where,))
+ return self._getitem((where,), named_axis)
elif isinstance(where, str):
return self._getitem_field(where)
elif where is np.newaxis:
- return self._getitem((where,))
+ return self._getitem((where,), named_axis)
elif where is Ellipsis:
- return self._getitem((where,))
+ return self._getitem((where,), named_axis)
elif isinstance(where, tuple):
if len(where) == 0:
return self
+ # count number of ellipsis
+ # Need to use a little trick here:
+ # where.count(Ellipsis) does not work, because it will do a == comparison against Ellipsis,
+ # and this will fail in the case of typetracers where == is dispatched to np.equal ufunc.
+ # In this dispatch we encounter an assertion that the type of the Ellipsis is not allowed.
+ # ...but luckily we can use the fact that Ellipsis is a singleton and use the 'is' operator
+ n_ellipsis = 0
+ for w in where:
+ if w is ...:
+ n_ellipsis += 1
+
+ if n_ellipsis > 1:
+ raise IndexError("an index can only have a single ellipsis ('...')")
+
# Backend may change if index contains typetracers
backend = backend_of(self, *where, coerce_to_common=True)
this = self.to_backend(backend)
@@ -547,6 +577,62 @@ def _getitem(self, where):
# Prepare items for advanced indexing (e.g. via broadcasting)
nextwhere = ak._slicing.prepare_advanced_indexing(items, backend)
+ # Handle named axis
+ # first expand the ellipsis to colons in nextwhere,
+ # copy nextwhere to not pollute the original
+ _nextwhere = tuple(nextwhere)
+ if n_ellipsis == 1:
+ # collect the ellipsis index
+ # same little trick as above for `nextwhere.index(...)`
+ (ellipsis_at,) = tuple(i for i, x in enumerate(nextwhere) if x is ...)
+ # calculate how many slice(None) we need to add
+ # same little trick as above for `nextwhere.count(None)`
+ n_newaxis = 0
+ for x in nextwhere:
+ if x is np.newaxis or x is None:
+ n_newaxis += 1
+ n_total = self.minmax_depth[1]
+ n_slice_none = n_total - (len(_nextwhere) - n_newaxis - 1)
+ # expand `[...]` to `[:]*n_slice_none`
+ _nextwhere = (
+ _nextwhere[:ellipsis_at]
+ + (slice(None),) * n_slice_none
+ + _nextwhere[ellipsis_at + 1 :]
+ )
+
+ # now propagate named axis
+ _named_axis = _keep_named_axis(named_axis.mapping, None)
+ _adjust_dim = 0
+ # this loop does the following:
+ # - remove a named axis for integer indices, e.g. `a[1, 2]`
+ # - add a named axis for None (or np.newaxis) indices, e.g. `a[..., None]`
+ # - keep named axis for any other index, e.g. `a[:]`, `a[0:1]`, or `a[a>0]`
+ # (these may only remove elements, but not dimensions)
+ for dim, nw in enumerate(_nextwhere):
+ dim_adjusted = dim + _adjust_dim
+ total_adjusted = self.minmax_depth[1] + _adjust_dim
+ for _, pos in _named_axis.items():
+ if maybe_posaxis(self, pos, 0) == dim_adjusted:
+ break
+
+ if is_integer(nw) or (is_array_like(nw) and nw.ndim == 0):
+ _named_axis = _remove_named_axis(
+ named_axis=_named_axis,
+ axis=dim_adjusted,
+ total=total_adjusted,
+ )
+ _adjust_dim -= 1
+ elif nw is None:
+ _named_axis = _add_named_axis(
+ named_axis=_named_axis,
+ axis=dim_adjusted,
+ total=total_adjusted,
+ )
+ _adjust_dim += 1
+
+ # set propagated named axis
+ named_axis.mapping = _named_axis
+
next = ak.contents.RegularArray(
this,
this.length,
@@ -562,7 +648,7 @@ def _getitem(self, where):
return out._getitem_at(0)
elif isinstance(where, ak.highlevel.Array):
- return self._getitem(where.layout)
+ return self._getitem(where.layout, named_axis)
# Convert between nplikes of different backends
elif (
@@ -570,7 +656,9 @@ def _getitem(self, where):
and where.backend is not self._backend
):
backend = backend_of(self, where, coerce_to_common=True)
- return self.to_backend(backend)._getitem(where.to_backend(backend))
+ return self.to_backend(backend)._getitem(
+ where.to_backend(backend), named_axis
+ )
elif isinstance(where, ak.contents.NumpyArray):
data_as_index = to_nplike(
@@ -602,7 +690,7 @@ def _getitem(self, where):
allow_lazy = "copied" # True, but also can be modified in-place
else:
wheres = self._backend.index_nplike.nonzero(data_as_index)
- return self._getitem(wheres)
+ return self._getitem(wheres, named_axis)
else:
raise TypeError(
"array slice must be an array of integers or booleans, not\n\n {}".format(
@@ -621,9 +709,9 @@ def _getitem(self, where):
elif isinstance(where, ak.contents.RegularArray):
maybe_numpy = where.maybe_to_NumpyArray()
if maybe_numpy is None:
- return self._getitem((where,))
+ return self._getitem((where,), named_axis)
else:
- return self._getitem(maybe_numpy)
+ return self._getitem(maybe_numpy, named_axis)
# Awkward Array of strings
elif (
@@ -637,7 +725,7 @@ def _getitem(self, where):
return where.to_NumpyArray(np.int64)
elif isinstance(where, Content):
- return self._getitem((where,))
+ return self._getitem((where,), named_axis)
elif is_sized_iterable(where):
# Do we have an array
@@ -654,7 +742,7 @@ def _getitem(self, where):
primitive_policy="error",
string_policy="as-characters",
)
- return self._getitem(layout)
+ return self._getitem(layout, named_axis)
elif len(where) == 0:
return self._carry(
@@ -682,7 +770,7 @@ def _getitem(self, where):
),
self._backend,
)
- return self._getitem(layout)
+ return self._getitem(layout, named_axis)
else:
raise TypeError(
diff --git a/src/awkward/contents/numpyarray.py b/src/awkward/contents/numpyarray.py
index 315d9383b7..5c90ca0141 100644
--- a/src/awkward/contents/numpyarray.py
+++ b/src/awkward/contents/numpyarray.py
@@ -175,7 +175,11 @@ def shape(self) -> tuple[ShapeItem, ...]:
@property
def inner_shape(self) -> tuple[ShapeItem, ...]:
- return self._data.shape[1:]
+ if hasattr(self._data, "inner_shape"):
+ inner_shape = self._data.inner_shape
+ else:
+ inner_shape = self._data.shape[1:]
+ return inner_shape
@property
def strides(self) -> tuple[ShapeItem, ...]:
@@ -189,14 +193,9 @@ def _raw(self, nplike=None):
return to_nplike(self.data, nplike, from_nplike=self._backend.nplike)
def _form_with_key(self, getkey: Callable[[Content], str | None]) -> NumpyForm:
- if hasattr(self._data, "inner_shape"):
- inner_shape = self._data.inner_shape
- else:
- inner_shape = self._data.shape[1:]
-
return self.form_cls(
ak.types.numpytype.dtype_to_primitive(self._data.dtype),
- inner_shape,
+ self.inner_shape,
parameters=self._parameters,
form_key=getkey(self),
)
diff --git a/src/awkward/highlevel.py b/src/awkward/highlevel.py
index f315945511..6d1d6649aa 100644
--- a/src/awkward/highlevel.py
+++ b/src/awkward/highlevel.py
@@ -23,6 +23,16 @@
from awkward._backends.numpy import NumpyBackend
from awkward._behavior import behavior_of, get_array_class, get_record_class
from awkward._layout import wrap_layout
+from awkward._namedaxis import (
+ NAMED_AXIS_KEY,
+ AxisMapping,
+ NamedAxis,
+ _get_named_axis,
+ _make_positional_axis_tuple,
+ _normalize_named_slice,
+ _prepare_named_axis_for_attrs,
+ _prettify_named_axes,
+)
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._operators import NDArrayOperatorsMixin
@@ -32,7 +42,7 @@
unpickle_record_schema_1,
)
from awkward._regularize import is_non_string_like_iterable
-from awkward._typing import Any, TypeVar
+from awkward._typing import Any, MutableMapping, TypeVar
from awkward._util import STDOUT
from awkward.prettyprint import Formatter
from awkward.prettyprint import valuestr as prettyprint_valuestr
@@ -278,6 +288,7 @@ def __init__(
check_valid=False,
backend=None,
attrs=None,
+ named_axis=None,
):
self._cpp_type = None
if isinstance(data, ak.contents.Content):
@@ -326,9 +337,20 @@ def __init__(
if behavior is not None and not isinstance(behavior, Mapping):
raise TypeError("behavior must be None or a mapping")
- if attrs is not None and not isinstance(attrs, Mapping):
+ if attrs is not None and not isinstance(attrs, MutableMapping):
raise TypeError("attrs must be None or a mapping")
+ if named_axis:
+ _named_axis = _prepare_named_axis_for_attrs(
+ named_axis=named_axis,
+ ndim=layout.minmax_depth[1],
+ )
+ # now we're good, set the named axis
+ if attrs is None:
+ attrs = {}
+ # if NAMED_AXIS_KEY is already in attrs, it will be overwritten
+ attrs[NAMED_AXIS_KEY] = _named_axis
+
self._layout = layout
self._behavior = behavior
self._attrs = attrs
@@ -357,7 +379,7 @@ def _update_class(self):
self.__class__ = get_array_class(self._layout, self._behavior)
@property
- def attrs(self) -> Mapping[str, Any]:
+ def attrs(self) -> Mapping:
"""
The mutable mapping containing top-level metadata, which is serialised
with the array during pickling.
@@ -455,6 +477,15 @@ def behavior(self, behavior):
else:
raise TypeError("behavior must be None or a dict")
+ @property
+ def positional_axis(self) -> tuple[int, ...]:
+ (_, ndim) = self._layout.minmax_depth
+ return _make_positional_axis_tuple(ndim)
+
+ @property
+ def named_axis(self) -> AxisMapping:
+ return _get_named_axis(self)
+
class Mask:
def __init__(self, array):
self._array = array
@@ -1062,12 +1093,30 @@ def __getitem__(self, where):
have the same dimension as the array being indexed.
"""
with ak._errors.SlicingErrorContext(self, where):
- return wrap_layout(
- prepare_layout(self._layout[where]),
- self._behavior,
- allow_other=True,
- attrs=self._attrs,
- )
+ # Handle named axis
+ (_, ndim) = self._layout.minmax_depth
+ named_axis = _get_named_axis(self)
+ where = _normalize_named_slice(named_axis, where, ndim)
+
+ NamedAxis.mapping = named_axis
+
+ indexed_layout = prepare_layout(self._layout._getitem(where, NamedAxis))
+
+ if NamedAxis.mapping:
+ return ak.operations.ak_with_named_axis._impl(
+ indexed_layout,
+ named_axis=NamedAxis.mapping,
+ highlevel=True,
+ behavior=self._behavior,
+ attrs=self._attrs,
+ )
+ else:
+ return wrap_layout(
+ indexed_layout,
+ self._behavior,
+ allow_other=True,
+ attrs=self._attrs,
+ )
def __bytes__(self) -> bytes:
if isinstance(self._layout, ak.contents.NumpyArray) and self._layout.parameter(
@@ -1309,6 +1358,15 @@ def _repr(self, limit_cols):
else:
valuestr = "-typetracer"
+ # prepare named_axis str for repr
+ axisstr = ""
+ if self.named_axis:
+ # we reserve at maximum 20 characters for the named axis string
+ axisstr = _prettify_named_axes(self.named_axis, delimiter=",", maxlen=20)
+ axisstr = f" {axisstr}"
+ # subtract the reserved space from the limit_cols
+ limit_cols -= len(axisstr)
+
if len(typestr) + len(pytype) + len(" type=''") + 3 < limit_cols // 2:
strwidth = limit_cols - (len(typestr) + len(pytype) + len(" type=''") + 3)
else:
@@ -1327,13 +1385,14 @@ def _repr(self, limit_cols):
else:
typestr = "'" + typestr + "'"
- return f"<{pytype}{valuestr} type={typestr}>"
+ return f"<{pytype}{valuestr}{axisstr} type={typestr}>"
def show(
self,
limit_rows=20,
limit_cols=80,
type=False,
+ named_axis=False,
stream=STDOUT,
*,
formatter=None,
@@ -1365,25 +1424,41 @@ def show(
valuestr = prettyprint_valuestr(
self, limit_rows, limit_cols, formatter=formatter_impl
)
+
+ out_io = io.StringIO()
if type:
- tmp = io.StringIO()
- self.type.show(stream=tmp)
- out = "type: " + tmp.getvalue() + valuestr
- else:
- out = valuestr
+ out_io.write("type: ")
+ self.type.show(stream=out_io)
+ if named_axis and self.named_axis:
+ out_io.write("axes: ")
+ out_io.write(
+ _prettify_named_axes(self.named_axis, delimiter=", ", maxlen=None)
+ )
+ out_io.write("\n")
+ out_io.write(valuestr)
if stream is None:
- return out
+ return out_io
else:
if stream is STDOUT:
stream = STDOUT.stream
- stream.write(out + "\n")
+ stream.write(out_io.getvalue() + "\n")
def _repr_mimebundle_(self, include=None, exclude=None):
+ # order: 1. array, 2. named_axis, 3. type
value_buff = io.StringIO()
- self.show(type=False, stream=value_buff)
+ self.show(type=False, named_axis=False, stream=value_buff)
header_lines = value_buff.getvalue().splitlines()
+ named_axis_line = ""
+ if self.named_axis:
+ named_axis_buff = io.StringIO()
+ named_axis_buff.write("axes: ")
+ named_axis_buff.write(
+ _prettify_named_axes(self.named_axis, delimiter=", ", maxlen=None)
+ )
+ named_axis_line = named_axis_buff.getvalue()
+
type_buff = io.StringIO()
self.type.show(stream=type_buff)
footer_lines = type_buff.getvalue().splitlines()
@@ -1393,8 +1468,16 @@ def _repr_mimebundle_(self, include=None, exclude=None):
if header_lines[-1] == "":
del header_lines[-1]
- n_cols = max(len(line) for line in itertools.chain(header_lines, footer_lines))
- body = "\n".join([*header_lines, "-" * n_cols, *footer_lines])
+ n_cols = max(
+ len(line)
+ for line in itertools.chain(header_lines, [named_axis_line], footer_lines)
+ )
+ body_lines = header_lines
+ body_lines.append("-" * n_cols)
+ if named_axis_line:
+ body_lines.append(named_axis_line)
+ body_lines.extend(footer_lines)
+ body = "\n".join(body_lines)
return {
"text/html": f"{html.escape(body)}
",
@@ -1719,6 +1802,7 @@ def __init__(
check_valid=False,
backend=None,
attrs=None,
+ named_axis=None,
):
if isinstance(data, ak.record.Record):
layout = data
@@ -1762,6 +1846,20 @@ def __init__(
if behavior is not None and not isinstance(behavior, Mapping):
raise TypeError("behavior must be None or mapping")
+ if attrs is not None and not isinstance(attrs, MutableMapping):
+ raise TypeError("attrs must be None or a mapping")
+
+ if named_axis:
+ _named_axis = _prepare_named_axis_for_attrs(
+ named_axis=named_axis,
+ ndim=layout.minmax_depth[1],
+ )
+ # now we're good, set the named axis
+ if attrs is None:
+ attrs = {}
+ # if NAMED_AXIS_KEY is already in attrs, it will be overwritten
+ attrs[NAMED_AXIS_KEY] = _named_axis
+
self._layout = layout
self._behavior = behavior
self._attrs = attrs
@@ -1877,6 +1975,15 @@ def behavior(self, behavior):
else:
raise TypeError("behavior must be None or a dict")
+ @property
+ def positional_axis(self) -> tuple[int, ...]:
+ (_, ndim) = self._layout.minmax_depth
+ return _make_positional_axis_tuple(ndim)
+
+ @property
+ def named_axis(self) -> AxisMapping:
+ return _get_named_axis(self)
+
def tolist(self):
"""
Converts this Record into Python objects; same as #ak.to_list
@@ -2170,6 +2277,15 @@ def _repr(self, limit_cols):
else:
valuestr = "-typetracer"
+ # prepare named_axis str for repr
+ axisstr = ""
+ if self.named_axis:
+ # we reserve at maximum 20 characters for the named axis string
+ axisstr = _prettify_named_axes(self.named_axis, delimiter=",", maxlen=20)
+ axisstr = f" {axisstr}"
+ # subtract the reserved space from the limit_cols
+ limit_cols -= len(axisstr)
+
if len(typestr) + len(pytype) + len(" type=''") + 3 < limit_cols // 2:
strwidth = limit_cols - (len(typestr) + len(pytype) + len(" type=''") + 3)
else:
@@ -2188,13 +2304,14 @@ def _repr(self, limit_cols):
else:
typestr = "'" + typestr + "'"
- return f"<{pytype}{valuestr} type={typestr}>"
+ return f"<{pytype}{valuestr}{axisstr} type={typestr}>"
def show(
self,
limit_rows=20,
limit_cols=80,
type=False,
+ named_axis=False,
stream=STDOUT,
*,
formatter=None,
@@ -2224,25 +2341,41 @@ def show(
valuestr = prettyprint_valuestr(
self, limit_rows, limit_cols, formatter=formatter_impl
)
+
+ out_io = io.StringIO()
if type:
- tmp = io.StringIO()
- self.type.show(stream=tmp)
- out = "type: " + tmp.getvalue() + valuestr
- else:
- out = valuestr
+ out_io.write("type: ")
+ self.type.show(stream=out_io)
+ if named_axis and self.named_axis:
+ out_io.write("axes: ")
+ out_io.write(
+ _prettify_named_axes(self.named_axis, delimiter=", ", maxlen=None)
+ )
+ out_io.write("\n")
+ out_io.write(valuestr)
if stream is None:
- return out
+ return out_io.getvalue()
else:
if stream is STDOUT:
stream = STDOUT.stream
- stream.write(out + "\n")
+ stream.write(out_io.getvalue() + "\n")
def _repr_mimebundle_(self, include=None, exclude=None):
+ # order: 1. array, 2. named_axis, 3. type
value_buff = io.StringIO()
- self.show(type=False, stream=value_buff)
+ self.show(type=False, named_axis=False, stream=value_buff)
header_lines = value_buff.getvalue().splitlines()
+ named_axis_line = ""
+ if self.named_axis:
+ named_axis_buff = io.StringIO()
+ named_axis_buff.write("axes: ")
+ named_axis_buff.write(
+ _prettify_named_axes(self.named_axis, delimiter=", ", maxlen=None)
+ )
+ named_axis_line = named_axis_buff.getvalue()
+
type_buff = io.StringIO()
self.type.show(stream=type_buff)
footer_lines = type_buff.getvalue().splitlines()
@@ -2252,8 +2385,16 @@ def _repr_mimebundle_(self, include=None, exclude=None):
if header_lines[-1] == "":
del header_lines[-1]
- n_cols = max(len(line) for line in itertools.chain(header_lines, footer_lines))
- body = "\n".join([*header_lines, "-" * n_cols, *footer_lines])
+ n_cols = max(
+ len(line)
+ for line in itertools.chain(header_lines, [named_axis_line], footer_lines)
+ )
+ body_lines = header_lines
+ body_lines.append("-" * n_cols)
+ if named_axis_line:
+ body_lines.append(named_axis_line)
+ body_lines.extend(footer_lines)
+ body = "\n".join(body_lines)
return {
"text/html": f"{html.escape(body)}
",
diff --git a/src/awkward/operations/__init__.py b/src/awkward/operations/__init__.py
index d76d8e2688..94dbd9ffac 100644
--- a/src/awkward/operations/__init__.py
+++ b/src/awkward/operations/__init__.py
@@ -114,8 +114,10 @@
from awkward.operations.ak_where import *
from awkward.operations.ak_with_field import *
from awkward.operations.ak_with_name import *
+from awkward.operations.ak_with_named_axis import *
from awkward.operations.ak_with_parameter import *
from awkward.operations.ak_without_field import *
+from awkward.operations.ak_without_named_axis import *
from awkward.operations.ak_without_parameters import *
from awkward.operations.ak_zeros_like import *
from awkward.operations.ak_zip import *
diff --git a/src/awkward/operations/ak_all.py b/src/awkward/operations/ak_all.py
index 859bfd98cb..98a22520ba 100644
--- a/src/awkward/operations/ak_all.py
+++ b/src/awkward/operations/ak_all.py
@@ -6,6 +6,12 @@
from awkward._connect.numpy import UNSUPPORTED
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
+from awkward._namedaxis import (
+ _get_named_axis,
+ _keep_named_axis,
+ _named_axis_to_positional_axis,
+ _remove_named_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
@@ -67,9 +73,26 @@ def all(
def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
+
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+ # Step 2: propagate named axis from input to output,
+ # keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
+ # keepdims=False: use strategy "remove one" (see: awkward._namedaxis)
+ out_named_axis = _keep_named_axis(named_axis, None)
+ if not keepdims:
+ out_named_axis = _remove_named_axis(
+ named_axis=out_named_axis,
+ axis=axis,
+ total=layout.minmax_depth[1],
+ )
+
+ axis = regularize_axis(axis, none_allowed=True)
+
reducer = ak._reducers.All()
out = ak._do.reduce(
@@ -80,7 +103,21 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
keepdims=keepdims,
behavior=ctx.behavior,
)
- return ctx.wrap(out, highlevel=highlevel, allow_other=True)
+
+ wrapped_out = ctx.wrap(
+ out,
+ highlevel=highlevel,
+ allow_other=True,
+ )
+
+ # propagate named axis to output
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
@ak._connect.numpy.implements("all")
diff --git a/src/awkward/operations/ak_almost_equal.py b/src/awkward/operations/ak_almost_equal.py
index 66f67e4d8a..949f955a45 100644
--- a/src/awkward/operations/ak_almost_equal.py
+++ b/src/awkward/operations/ak_almost_equal.py
@@ -7,6 +7,7 @@
from awkward._behavior import behavior_of, get_array_class, get_record_class
from awkward._dispatch import high_level_function
from awkward._layout import ensure_same_backend
+from awkward._namedaxis import _get_named_axis
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._parameters import parameters_are_equal
from awkward.operations.ak_to_layout import to_layout
@@ -27,6 +28,7 @@ def almost_equal(
dtype_exact: bool = True,
check_parameters: bool = True,
check_regular: bool = True,
+ check_named_axis: bool = True,
):
"""
Args:
@@ -39,6 +41,7 @@ def almost_equal(
check_parameters: whether to compare parameters.
check_regular: whether to consider ragged and regular dimensions as
unequal.
+ check_named_axis: bool (default=True) whether to consider named axes as unequal.
Return True if the two array-like arguments are considered equal for the
given options. Otherwise, return False.
@@ -61,6 +64,7 @@ def almost_equal(
dtype_exact=dtype_exact,
check_parameters=check_parameters,
check_regular=check_regular,
+ check_named_axis=check_named_axis,
exact_eq=False,
same_content_types=False,
equal_nan=False,
@@ -75,6 +79,7 @@ def _impl(
dtype_exact: bool,
check_parameters: bool,
check_regular: bool,
+ check_named_axis: bool,
exact_eq: bool,
same_content_types: bool,
equal_nan: bool,
@@ -91,6 +96,10 @@ def _impl(
right_layout = layouts[1].to_packed()
backend = backend_of(left_layout)
+ if check_named_axis and _get_named_axis(left) and _get_named_axis(right):
+ if left.named_axis != right.named_axis:
+ return False
+
if not backend.nplike.known_data:
raise NotImplementedError(
"Awkward Arrays with typetracer backends cannot yet be compared with `ak.almost_equal`."
diff --git a/src/awkward/operations/ak_any.py b/src/awkward/operations/ak_any.py
index 79c9cc6b83..e99065d97c 100644
--- a/src/awkward/operations/ak_any.py
+++ b/src/awkward/operations/ak_any.py
@@ -6,6 +6,12 @@
from awkward._connect.numpy import UNSUPPORTED
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
+from awkward._namedaxis import (
+ _get_named_axis,
+ _keep_named_axis,
+ _named_axis_to_positional_axis,
+ _remove_named_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
@@ -67,9 +73,26 @@ def any(
def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
+
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+ # Step 2: propagate named axis from input to output,
+ # keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
+ # keepdims=False: use strategy "remove one" (see: awkward._namedaxis)
+ out_named_axis = _keep_named_axis(named_axis, None)
+ if not keepdims:
+ out_named_axis = _remove_named_axis(
+ named_axis=out_named_axis,
+ axis=axis,
+ total=layout.minmax_depth[1],
+ )
+
+ axis = regularize_axis(axis, none_allowed=True)
+
reducer = ak._reducers.Any()
out = ak._do.reduce(
@@ -80,7 +103,21 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
keepdims=keepdims,
behavior=ctx.behavior,
)
- return ctx.wrap(out, highlevel=highlevel, allow_other=True)
+
+ wrapped_out = ctx.wrap(
+ out,
+ highlevel=highlevel,
+ allow_other=True,
+ )
+
+ # propagate named axis to output
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
@ak._connect.numpy.implements("any")
diff --git a/src/awkward/operations/ak_argcartesian.py b/src/awkward/operations/ak_argcartesian.py
index 12deed5749..f012290cbe 100644
--- a/src/awkward/operations/ak_argcartesian.py
+++ b/src/awkward/operations/ak_argcartesian.py
@@ -7,7 +7,6 @@
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._nplikes.numpy_like import NumpyMetadata
-from awkward._regularize import regularize_axis
__all__ = ("argcartesian",)
@@ -107,8 +106,6 @@ def argcartesian(
def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
-
if isinstance(arrays, Mapping):
index_arrays = {n: ak.local_index(x, axis) for n, x in arrays.items()}
else:
diff --git a/src/awkward/operations/ak_argcombinations.py b/src/awkward/operations/ak_argcombinations.py
index 98a2643855..337f77cec1 100644
--- a/src/awkward/operations/ak_argcombinations.py
+++ b/src/awkward/operations/ak_argcombinations.py
@@ -5,6 +5,7 @@
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
+from awkward._namedaxis import _get_named_axis, _named_axis_to_positional_axis
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
@@ -93,7 +94,6 @@ def _impl(
behavior,
attrs,
):
- axis = regularize_axis(axis)
if parameters is None:
parameters = {}
else:
@@ -101,6 +101,13 @@ def _impl(
if with_name is not None:
parameters["__record__"] = with_name
+ # Handle named axis
+ named_axis = _get_named_axis(array)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+
+ axis = regularize_axis(axis, none_allowed=False)
+
if axis < 0:
raise ValueError("the 'axis' for argcombinations must be non-negative")
else:
diff --git a/src/awkward/operations/ak_argmax.py b/src/awkward/operations/ak_argmax.py
index a4dbe947bd..ef9b37e57c 100644
--- a/src/awkward/operations/ak_argmax.py
+++ b/src/awkward/operations/ak_argmax.py
@@ -6,6 +6,12 @@
from awkward._connect.numpy import UNSUPPORTED
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
+from awkward._namedaxis import (
+ _get_named_axis,
+ _keep_named_axis,
+ _named_axis_to_positional_axis,
+ _remove_named_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
@@ -132,9 +138,26 @@ def nanargmax(
def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
+
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+ # Step 2: propagate named axis from input to output,
+ # keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
+ # keepdims=False: use strategy "remove one" (see: awkward._namedaxis)
+ out_named_axis = _keep_named_axis(named_axis, None)
+ if not keepdims:
+ out_named_axis = _remove_named_axis(
+ named_axis=out_named_axis,
+ axis=axis,
+ total=layout.minmax_depth[1],
+ )
+
+ axis = regularize_axis(axis, none_allowed=True)
+
reducer = ak._reducers.ArgMax()
out = ak._do.reduce(
@@ -145,7 +168,21 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
keepdims=keepdims,
behavior=ctx.behavior,
)
- return ctx.wrap(out, highlevel=highlevel, allow_other=True)
+
+ wrapped_out = ctx.wrap(
+ out,
+ highlevel=highlevel,
+ allow_other=True,
+ )
+
+ # propagate named axis to output
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
@ak._connect.numpy.implements("argmax")
diff --git a/src/awkward/operations/ak_argmin.py b/src/awkward/operations/ak_argmin.py
index 7f21fb3aa8..6982a4d407 100644
--- a/src/awkward/operations/ak_argmin.py
+++ b/src/awkward/operations/ak_argmin.py
@@ -6,6 +6,12 @@
from awkward._connect.numpy import UNSUPPORTED
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
+from awkward._namedaxis import (
+ _get_named_axis,
+ _keep_named_axis,
+ _named_axis_to_positional_axis,
+ _remove_named_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
@@ -129,10 +135,26 @@ def nanargmin(
def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+ # Step 2: propagate named axis from input to output,
+ # keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
+ # keepdims=False: use strategy "remove one" (see: awkward._namedaxis)
+ out_named_axis = _keep_named_axis(named_axis, None)
+ if not keepdims:
+ out_named_axis = _remove_named_axis(
+ named_axis=out_named_axis,
+ axis=axis,
+ total=layout.minmax_depth[1],
+ )
+
+ axis = regularize_axis(axis, none_allowed=True)
+
reducer = ak._reducers.ArgMin()
out = ak._do.reduce(
@@ -143,7 +165,21 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
keepdims=keepdims,
behavior=ctx.behavior,
)
- return ctx.wrap(out, highlevel=highlevel, allow_other=True)
+
+ wrapped_out = ctx.wrap(
+ out,
+ highlevel=highlevel,
+ allow_other=True,
+ )
+
+ # propagate named axis to output
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
@ak._connect.numpy.implements("argmin")
diff --git a/src/awkward/operations/ak_argsort.py b/src/awkward/operations/ak_argsort.py
index bade378b20..7c92d6a645 100644
--- a/src/awkward/operations/ak_argsort.py
+++ b/src/awkward/operations/ak_argsort.py
@@ -6,6 +6,10 @@
from awkward._connect.numpy import UNSUPPORTED
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
+from awkward._namedaxis import (
+ _get_named_axis,
+ _named_axis_to_positional_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
@@ -70,11 +74,22 @@ def argsort(
def _impl(array, axis, ascending, stable, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
+
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+
+ axis = regularize_axis(axis, none_allowed=False)
+
out = ak._do.argsort(layout, axis, ascending, stable)
- return ctx.wrap(out, highlevel=highlevel)
+
+ return ctx.wrap(
+ out,
+ highlevel=highlevel,
+ )
@ak._connect.numpy.implements("argsort")
diff --git a/src/awkward/operations/ak_array_equal.py b/src/awkward/operations/ak_array_equal.py
index 398db6b2a6..1dabd60f31 100644
--- a/src/awkward/operations/ak_array_equal.py
+++ b/src/awkward/operations/ak_array_equal.py
@@ -18,6 +18,7 @@ def array_equal(
same_content_types: bool = True,
check_parameters: bool = True,
check_regular: bool = True,
+ check_named_axis: bool = True,
):
"""
True if two arrays have the same shape and elements, False otherwise.
@@ -34,6 +35,7 @@ def array_equal(
check_parameters: bool (default=True) whether to compare parameters.
check_regular: bool (default=True) whether to consider ragged and regular dimensions as
unequal.
+ check_named_axis: bool (default=True) whether to consider named axes as unequal.
TypeTracer arrays are not supported, as there is very little information to
be compared.
@@ -49,6 +51,7 @@ def array_equal(
dtype_exact=dtype_exact,
check_parameters=check_parameters,
check_regular=check_regular,
+ check_named_axis=check_named_axis,
exact_eq=True,
same_content_types=same_content_types and check_regular,
equal_nan=equal_nan,
diff --git a/src/awkward/operations/ak_broadcast_arrays.py b/src/awkward/operations/ak_broadcast_arrays.py
index 877c69f9c0..feef9b5138 100644
--- a/src/awkward/operations/ak_broadcast_arrays.py
+++ b/src/awkward/operations/ak_broadcast_arrays.py
@@ -2,6 +2,8 @@
from __future__ import annotations
+from functools import reduce
+
import awkward as ak
from awkward._attrs import attrs_of_obj
from awkward._backends.dispatch import backend_of
@@ -10,6 +12,11 @@
from awkward._connect.numpy import UNSUPPORTED
from awkward._dispatch import high_level_function
from awkward._layout import wrap_layout
+from awkward._namedaxis import (
+ NAMED_AXIS_KEY,
+ NamedAxesWithDims,
+ _unify_named_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
__all__ = ("broadcast_arrays",)
@@ -243,24 +250,43 @@ def action(inputs, depth, **kwargs):
else:
return None
+ depth_context, lateral_context = NamedAxesWithDims.prepare_contexts(arrays)
out = ak._broadcasting.broadcast_and_apply(
inputs,
action,
+ depth_context=depth_context,
+ lateral_context=lateral_context,
left_broadcast=left_broadcast,
right_broadcast=right_broadcast,
broadcast_parameters_rule=broadcast_parameters_rule,
numpy_to_regular=True,
)
assert isinstance(out, tuple)
- return [
- wrap_layout(
+
+ # unify named axes
+ out_named_axis = reduce(
+ _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis
+ )
+ wrapped_out = []
+ for layout_out, array_in in zip(out, arrays):
+ _behavior = behavior_of_obj(array_in, behavior=behavior)
+ _attrs = attrs_of_obj(array_in, attrs=attrs)
+ wrapped = wrap_layout(
layout_out,
- behavior=behavior_of_obj(array_in, behavior=behavior),
+ behavior=_behavior,
highlevel=highlevel,
- attrs=attrs_of_obj(array_in, attrs=attrs),
+ attrs=_attrs,
)
- for layout_out, array_in in zip(out, arrays)
- ]
+ wrapped_out.append(
+ ak.operations.ak_with_named_axis._impl(
+ wrapped,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=_behavior,
+ attrs=_attrs,
+ )
+ )
+ return wrapped_out
@ak._connect.numpy.implements("broadcast_arrays")
diff --git a/src/awkward/operations/ak_cartesian.py b/src/awkward/operations/ak_cartesian.py
index 91767d27d8..0f46f449c9 100644
--- a/src/awkward/operations/ak_cartesian.py
+++ b/src/awkward/operations/ak_cartesian.py
@@ -3,11 +3,20 @@
from __future__ import annotations
from collections.abc import Mapping
+from functools import reduce
import awkward as ak
from awkward._backends.numpy import NumpyBackend
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext, ensure_same_backend, maybe_posaxis
+from awkward._namedaxis import (
+ NAMED_AXIS_KEY,
+ NamedAxesWithDims,
+ _add_named_axis,
+ _get_named_axis,
+ _named_axis_to_positional_axis,
+ _unify_named_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
from awkward.errors import AxisError
@@ -214,7 +223,6 @@ def cartesian(
def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
if isinstance(arrays, Mapping):
layouts = ensure_same_backend(
@@ -226,6 +234,11 @@ def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attr
fields = list(arrays.keys())
array_layouts = dict(zip(fields, layouts))
+ # propagate named axis from input to output,
+ # use strategy "unify" (see: awkward._namedaxis)
+ out_named_axis = reduce(
+ _unify_named_axis, map(_get_named_axis, arrays.values())
+ )
else:
layouts = array_layouts = ensure_same_backend(
*(
@@ -234,6 +247,15 @@ def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attr
)
)
fields = None
+ # propagate named axis from input to output,
+ # use strategy "unify" (see: awkward._namedaxis)
+ out_named_axis = reduce(_unify_named_axis, map(_get_named_axis, arrays))
+
+ # Handle named axis
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(out_named_axis, axis)
+ axis = regularize_axis(axis, none_allowed=False)
+ max_ndim = max(layout.minmax_depth[1] for layout in layouts)
if with_name is not None:
if parameters is None:
@@ -262,6 +284,7 @@ def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attr
if nested is None or nested is False:
nested = []
elif nested is True:
+ out_named_axis = _add_named_axis(out_named_axis, 0, max_ndim)
if fields is not None:
nested = list(fields)[:-1]
else:
@@ -287,6 +310,8 @@ def _impl(arrays, axis, nested, parameters, with_name, highlevel, behavior, attr
"the 'nested' parameter of cartesian must be integers in "
"[0, len(arrays) - 1) for an iterable of arrays"
)
+ for n in nested:
+ out_named_axis = _add_named_axis(out_named_axis, n, max_ndim)
backend = next((layout.backend for layout in layouts), cpu)
if posaxis == 0:
@@ -398,16 +423,48 @@ def apply_build_record(inputs, depth, **kwargs):
else:
return None
+ depth_context, lateral_context = NamedAxesWithDims.prepare_contexts(
+ list(arrays.values()) if isinstance(arrays, Mapping) else list(arrays)
+ )
out = ak._broadcasting.broadcast_and_apply(
- new_layouts, apply_build_record, right_broadcast=False
+ new_layouts,
+ apply_build_record,
+ depth_context=depth_context,
+ lateral_context=lateral_context,
+ right_broadcast=False,
)
assert isinstance(out, tuple) and len(out) == 1
result = out[0]
+ # Unify named axes propagated through the broadcast
+ out_named_axis = reduce(
+ _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis
+ )
+ wrapped_out = ctx.wrap(result, highlevel=highlevel)
+ # propagate named axis to output
+ result = ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
+
# Remove surplus dimensions, iterating from smallest to greatest
for axis_to_flatten in axes_to_flatten:
result = ak.operations.flatten(
- result, axis=axis_to_flatten, highlevel=False, behavior=behavior
+ result, axis=axis_to_flatten, highlevel=highlevel, behavior=behavior
)
- return ctx.wrap(result, highlevel=highlevel)
+ return result
+
+ wrapped_out = ctx.wrap(result, highlevel=highlevel)
+
+ # propagate named axis to output
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
diff --git a/src/awkward/operations/ak_categories.py b/src/awkward/operations/ak_categories.py
index cd7f6ccf4c..e723d098da 100644
--- a/src/awkward/operations/ak_categories.py
+++ b/src/awkward/operations/ak_categories.py
@@ -49,6 +49,16 @@ def action(layout, **kwargs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
+
ak._do.recursively_apply(layout, action)
- return ctx.wrap(output, highlevel=highlevel)
+ wrapped_out = ctx.wrap(output, highlevel=highlevel)
+
+ # propagate named axis from input to output,
+ # use strategy "drop all" (see: awkward._namedaxis)
+ return ak.operations.ak_without_named_axis._impl(
+ wrapped_out,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
diff --git a/src/awkward/operations/ak_combinations.py b/src/awkward/operations/ak_combinations.py
index d22708cb4a..284023f2cd 100644
--- a/src/awkward/operations/ak_combinations.py
+++ b/src/awkward/operations/ak_combinations.py
@@ -5,6 +5,10 @@
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
+from awkward._namedaxis import (
+ _get_named_axis,
+ _named_axis_to_positional_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
@@ -214,7 +218,15 @@ def _impl(
behavior,
attrs,
):
- axis = regularize_axis(axis)
+ with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
+ layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
+
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+
+ axis = regularize_axis(axis, none_allowed=False)
if with_name is None:
pass
@@ -223,8 +235,6 @@ def _impl(
else:
parameters = {**parameters, "__record__": with_name}
- with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
- layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
out = ak._do.combinations(
layout,
n,
diff --git a/src/awkward/operations/ak_concatenate.py b/src/awkward/operations/ak_concatenate.py
index fb8fcf94ae..3e086f7e8c 100644
--- a/src/awkward/operations/ak_concatenate.py
+++ b/src/awkward/operations/ak_concatenate.py
@@ -2,6 +2,7 @@
from __future__ import annotations
+from functools import reduce
from itertools import permutations
import awkward as ak
@@ -9,6 +10,13 @@
from awkward._dispatch import high_level_function
from awkward._do import mergeable
from awkward._layout import HighLevelContext, ensure_same_backend, maybe_posaxis
+from awkward._namedaxis import (
+ NAMED_AXIS_KEY,
+ NamedAxesWithDims,
+ _get_named_axis,
+ _named_axis_to_positional_axis,
+ _unify_named_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._nplikes.shape import unknown_length
from awkward._parameters import type_parameters_equal
@@ -92,7 +100,6 @@ def _merge_as_union(
def _impl(arrays, axis, mergebool, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
# Simple single-array, axis=0 fast-path
if (
# Is an array with a known backend
@@ -121,6 +128,15 @@ def _impl(arrays, axis, mergebool, highlevel, behavior, attrs):
)
)
+ # Handle named axis
+ merged_named_axis = reduce(_unify_named_axis, map(_get_named_axis, arrays))
+ # Step 1: normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(merged_named_axis, axis)
+ axis = regularize_axis(axis, none_allowed=False)
+ # Step 2: propagate named axis from input to output,
+ # use strategy "unify" (see: awkward._namedaxis)
+ out_named_axis = merged_named_axis
+
contents = [x for x in content_or_others if isinstance(x, ak.contents.Content)]
if len(contents) == 0:
raise ValueError("need at least one array to concatenate")
@@ -342,11 +358,35 @@ def action(inputs, depth, backend, **kwargs):
else:
return None
+ depth_context, lateral_context = NamedAxesWithDims.prepare_contexts(
+ list(arrays)
+ )
out = ak._broadcasting.broadcast_and_apply(
- content_or_others, action, allow_records=True, right_broadcast=False
+ content_or_others,
+ action,
+ depth_context=depth_context,
+ lateral_context=lateral_context,
+ allow_records=True,
+ right_broadcast=False,
)[0]
+ # Unify named axes
+ out_named_axis = reduce(
+ _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis
+ )
- return ctx.wrap(out, highlevel=highlevel)
+ wrapped_out = ctx.wrap(
+ out,
+ highlevel=highlevel,
+ )
+
+ # propagate named axis to output
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
def _form_has_type(form, type_):
diff --git a/src/awkward/operations/ak_corr.py b/src/awkward/operations/ak_corr.py
index 74d148831d..e646a43b0f 100644
--- a/src/awkward/operations/ak_corr.py
+++ b/src/awkward/operations/ak_corr.py
@@ -3,12 +3,14 @@
from __future__ import annotations
import awkward as ak
+from awkward._attrs import attrs_of_obj
from awkward._dispatch import high_level_function
from awkward._layout import (
HighLevelContext,
ensure_same_backend,
maybe_highlevel_to_lowlevel,
)
+from awkward._namedaxis import _get_named_axis, _is_valid_named_axis
from awkward._nplikes import ufuncs
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
@@ -86,7 +88,10 @@ def corr(
def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
+ if _is_valid_named_axis(axis):
+ raise NotImplementedError("named axis not yet supported for ak.corr")
+
+ axis = regularize_axis(axis, none_allowed=True)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
x_layout, y_layout, weight_layout = ensure_same_backend(
@@ -110,7 +115,7 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr
x,
weight,
axis,
- False,
+ True,
mask_identity,
highlevel=True,
behavior=ctx.behavior,
@@ -120,7 +125,7 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr
y,
weight,
axis,
- False,
+ True,
mask_identity,
highlevel=True,
behavior=ctx.behavior,
@@ -184,8 +189,19 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr
behavior=ctx.behavior,
attrs=ctx.attrs,
)
- return ctx.wrap(
- maybe_highlevel_to_lowlevel(sumwxy / ufuncs.sqrt(sumwxx * sumwyy)),
+
+ out = sumwxy / ufuncs.sqrt(sumwxx * sumwyy)
+
+ wrapped = ctx.wrap(
+ maybe_highlevel_to_lowlevel(out),
highlevel=highlevel,
allow_other=True,
)
+ # propagate named axis to output
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped,
+ named_axis=_get_named_axis(attrs_of_obj(out)),
+ highlevel=highlevel,
+ behavior=None,
+ attrs=None,
+ )
diff --git a/src/awkward/operations/ak_count.py b/src/awkward/operations/ak_count.py
index 85f43a27ee..f9b8c48481 100644
--- a/src/awkward/operations/ak_count.py
+++ b/src/awkward/operations/ak_count.py
@@ -5,6 +5,12 @@
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
+from awkward._namedaxis import (
+ _get_named_axis,
+ _keep_named_axis,
+ _named_axis_to_positional_axis,
+ _remove_named_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
@@ -109,9 +115,26 @@ def count(
def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
+
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+ # Step 2: propagate named axis from input to output,
+ # keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
+ # keepdims=False: use strategy "remove one" (see: awkward._namedaxis)
+ out_named_axis = _keep_named_axis(named_axis, None)
+ if not keepdims:
+ out_named_axis = _remove_named_axis(
+ named_axis=out_named_axis,
+ axis=axis,
+ total=layout.minmax_depth[1],
+ )
+
+ axis = regularize_axis(axis, none_allowed=True)
+
reducer = ak._reducers.Count()
out = ak._do.reduce(
@@ -122,4 +145,18 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
keepdims=keepdims,
behavior=ctx.behavior,
)
- return ctx.wrap(out, highlevel=highlevel, allow_other=True)
+
+ wrapped_out = ctx.wrap(
+ out,
+ highlevel=highlevel,
+ allow_other=True,
+ )
+
+ # propagate named axis to output
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
diff --git a/src/awkward/operations/ak_count_nonzero.py b/src/awkward/operations/ak_count_nonzero.py
index 919a6abf22..74a8b23033 100644
--- a/src/awkward/operations/ak_count_nonzero.py
+++ b/src/awkward/operations/ak_count_nonzero.py
@@ -5,6 +5,12 @@
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
+from awkward._namedaxis import (
+ _get_named_axis,
+ _keep_named_axis,
+ _named_axis_to_positional_axis,
+ _remove_named_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
@@ -68,7 +74,26 @@ def count_nonzero(
def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
+ with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
+ layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
+
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+ # Step 2: propagate named axis from input to output,
+ # keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
+ # keepdims=False: use strategy "remove one" (see: awkward._namedaxis)
+ out_named_axis = _keep_named_axis(named_axis, None)
+ if not keepdims:
+ out_named_axis = _remove_named_axis(
+ named_axis=out_named_axis,
+ axis=axis,
+ total=layout.minmax_depth[1],
+ )
+
+ axis = regularize_axis(axis, none_allowed=True)
+
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
reducer = ak._reducers.CountNonzero()
@@ -81,7 +106,21 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
keepdims=keepdims,
behavior=ctx.behavior,
)
- return ctx.wrap(out, highlevel=highlevel, allow_other=True)
+
+ wrapped_out = ctx.wrap(
+ out,
+ highlevel=highlevel,
+ allow_other=True,
+ )
+
+ # propagate named axis to output
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
@ak._connect.numpy.implements("count_nonzero")
diff --git a/src/awkward/operations/ak_covar.py b/src/awkward/operations/ak_covar.py
index a070ac6895..7c8fe930fe 100644
--- a/src/awkward/operations/ak_covar.py
+++ b/src/awkward/operations/ak_covar.py
@@ -3,12 +3,14 @@
from __future__ import annotations
import awkward as ak
+from awkward._attrs import attrs_of_obj
from awkward._dispatch import high_level_function
from awkward._layout import (
HighLevelContext,
ensure_same_backend,
maybe_highlevel_to_lowlevel,
)
+from awkward._namedaxis import _get_named_axis, _is_valid_named_axis
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
@@ -83,7 +85,9 @@ def covar(
def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
+ if _is_valid_named_axis(axis):
+ raise NotImplementedError("named axis not yet supported for ak.covar")
+ axis = regularize_axis(axis, none_allowed=True)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
x_layout, y_layout, weight_layout = ensure_same_backend(
@@ -107,7 +111,7 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr
x,
weight,
axis,
- False,
+ True,
mask_identity,
highlevel=True,
behavior=None,
@@ -117,7 +121,7 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr
y,
weight,
axis,
- False,
+ True,
mask_identity,
highlevel=True,
behavior=None,
@@ -161,8 +165,18 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr
behavior=None,
attrs=None,
)
- return ctx.wrap(
- maybe_highlevel_to_lowlevel(sumwxy / sumw),
+
+ out = sumwxy / sumw
+
+ wrapped = ctx.wrap(
+ maybe_highlevel_to_lowlevel(out),
highlevel=highlevel,
allow_other=True,
)
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped,
+ named_axis=_get_named_axis(attrs_of_obj(out)),
+ highlevel=highlevel,
+ behavior=None,
+ attrs=None,
+ )
diff --git a/src/awkward/operations/ak_drop_none.py b/src/awkward/operations/ak_drop_none.py
index c6c06014db..d81770f78f 100644
--- a/src/awkward/operations/ak_drop_none.py
+++ b/src/awkward/operations/ak_drop_none.py
@@ -5,6 +5,10 @@
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext, maybe_posaxis
+from awkward._namedaxis import (
+ _get_named_axis,
+ _named_axis_to_positional_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
from awkward.errors import AxisError
@@ -65,10 +69,16 @@ def _drop_none_if_list(layout):
def _impl(array, axis, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+
+ axis = regularize_axis(axis, none_allowed=True)
+
if axis is None:
# if the outer layout is_option, drop_nones without affecting offsets
if layout.is_option:
@@ -120,4 +130,7 @@ def action(layout, depth, **kwargs):
if len(options["none_indexes"]) > 0:
out = ak._do.recursively_apply(out, recompute_offsets, depth_context=options)
- return ctx.wrap(out, highlevel=highlevel)
+ return ctx.wrap(
+ out,
+ highlevel=highlevel,
+ )
diff --git a/src/awkward/operations/ak_fill_none.py b/src/awkward/operations/ak_fill_none.py
index 89834689cd..fb3dbfd019 100644
--- a/src/awkward/operations/ak_fill_none.py
+++ b/src/awkward/operations/ak_fill_none.py
@@ -5,6 +5,10 @@
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext, ensure_same_backend, maybe_posaxis
+from awkward._namedaxis import (
+ _get_named_axis,
+ _named_axis_to_positional_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
from awkward.errors import AxisError
@@ -69,8 +73,6 @@ def fill_none(array, value, axis=-1, *, highlevel=True, behavior=None, attrs=Non
def _impl(array, value, axis, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
-
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
array_layout, value_layout = ensure_same_backend(
ctx.unwrap(array, allow_record=True, allow_unknown=False),
@@ -84,6 +86,13 @@ def _impl(array, value, axis, highlevel, behavior, attrs):
),
)
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+
+ axis = regularize_axis(axis, none_allowed=True)
+
if isinstance(value_layout, ak.record.Record):
value_layout = value_layout.array[value_layout.at : value_layout.at + 1]
elif isinstance(value_layout, ak.contents.Content):
diff --git a/src/awkward/operations/ak_firsts.py b/src/awkward/operations/ak_firsts.py
index f67da6dde1..79fba6eb51 100644
--- a/src/awkward/operations/ak_firsts.py
+++ b/src/awkward/operations/ak_firsts.py
@@ -5,8 +5,13 @@
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext, maybe_posaxis
+from awkward._namedaxis import (
+ _get_named_axis,
+ _named_axis_to_positional_axis,
+ _remove_named_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
-from awkward._regularize import is_integer, regularize_axis
+from awkward._regularize import regularize_axis
from awkward.errors import AxisError
__all__ = ("firsts",)
@@ -58,10 +63,20 @@ def firsts(array, axis=1, *, highlevel=True, behavior=None, attrs=None):
def _impl(array, axis, highlevel, behavior, attrs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False)
- axis = regularize_axis(axis)
- if not is_integer(axis):
- raise TypeError(f"'axis' must be an integer, not {axis!r}")
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+ # Step 2: propagate named axis from input to output,
+ # use strategy "remove one" (see: awkward._namedaxis)
+ out_named_axis = _remove_named_axis(
+ named_axis=named_axis,
+ axis=axis,
+ total=layout.minmax_depth[1],
+ )
+
+ axis = regularize_axis(axis, none_allowed=False)
if maybe_posaxis(layout, axis, 1) == 0:
# specialized logic; it's tested in test_0582-propagate-context-in-broadcast_and_apply.py
@@ -103,4 +118,17 @@ def action(layout, depth, backend, **kwargs):
out = ak._do.recursively_apply(layout, action, numpy_to_regular=True)
- return ctx.wrap(out, highlevel=highlevel, allow_other=True)
+ wrapped_out = ctx.wrap(
+ out,
+ highlevel=highlevel,
+ allow_other=True,
+ )
+
+ # propagate named axis to output
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
diff --git a/src/awkward/operations/ak_flatten.py b/src/awkward/operations/ak_flatten.py
index b246870463..3805d28e71 100644
--- a/src/awkward/operations/ak_flatten.py
+++ b/src/awkward/operations/ak_flatten.py
@@ -5,6 +5,12 @@
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext, maybe_posaxis
+from awkward._namedaxis import (
+ _get_named_axis,
+ _keep_named_axis,
+ _named_axis_to_positional_axis,
+ _remove_named_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
@@ -173,10 +179,25 @@ def flatten(array, axis=1, *, highlevel=True, behavior=None, attrs=None):
def _impl(array, axis, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+ axis = regularize_axis(axis, none_allowed=True)
+ # Step 2: propagate named axis from input to output,
+ # if axis == None: use strategy "remove all" (see: awkward._namedaxis)
+ # if axis == 0: use strategy "keep all" (see: awkward._namedaxis)
+ # if axis != 0: use strategy "remove one" (see: awkward._namedaxis)
+ if axis is None:
+ pass
+ elif axis == 0 or maybe_posaxis(layout, axis, 1) == 0:
+ out_named_axis = _keep_named_axis(named_axis, None)
+ else:
+ out_named_axis = _remove_named_axis(named_axis, axis, layout.minmax_depth[1])
+
if axis is None:
out = ak._do.remove_structure(layout, function_name="ak.flatten")
assert isinstance(out, tuple) and all(
@@ -234,4 +255,27 @@ def apply(layout):
out = apply(layout)
else:
out = ak._do.flatten(layout, axis)
- return ctx.wrap(out, highlevel=highlevel)
+
+ wrapped_out = ctx.wrap(
+ out,
+ highlevel=highlevel,
+ )
+
+ # propagate named axis to output
+ # if axis == None: use strategy "remove all" (see: awkward._namedaxis)
+ if axis is None:
+ return ak.operations.ak_without_named_axis._impl(
+ wrapped_out,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
+ # if axis == 0: use strategy "keep all" (see: awkward._namedaxis)
+ # if axis != 0: use strategy "remove one" (see: awkward._namedaxis)
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
diff --git a/src/awkward/operations/ak_from_regular.py b/src/awkward/operations/ak_from_regular.py
index b3f840ef31..9fe2800a2b 100644
--- a/src/awkward/operations/ak_from_regular.py
+++ b/src/awkward/operations/ak_from_regular.py
@@ -55,7 +55,8 @@ def from_regular(array, axis=1, *, highlevel=True, behavior=None, attrs=None):
def _impl(array, axis, highlevel, behavior, attrs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False)
- axis = regularize_axis(axis)
+
+ axis = regularize_axis(axis, none_allowed=True)
if axis is None:
diff --git a/src/awkward/operations/ak_is_none.py b/src/awkward/operations/ak_is_none.py
index 078c86bde6..d9a58a5478 100644
--- a/src/awkward/operations/ak_is_none.py
+++ b/src/awkward/operations/ak_is_none.py
@@ -5,8 +5,13 @@
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext, maybe_posaxis
+from awkward._namedaxis import (
+ _get_named_axis,
+ _keep_named_axis_up_to,
+ _named_axis_to_positional_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
-from awkward._regularize import is_integer, regularize_axis
+from awkward._regularize import regularize_axis
from awkward.errors import AxisError
__all__ = ("is_none",)
@@ -41,12 +46,19 @@ def is_none(array, axis=0, *, highlevel=True, behavior=None, attrs=None):
def _impl(array, axis, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
- if not is_integer(axis):
- raise TypeError(f"'axis' must be an integer, not {axis!r}")
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+
+ axis = regularize_axis(axis, none_allowed=False)
+
+ # Step 2: propagate named axis from input to output,
+ # use strategy "keep up to" (see: awkward._namedaxis)
+ out_named_axis = _keep_named_axis_up_to(named_axis, axis, layout.minmax_depth[1])
def action(layout, depth, backend, lateral_context, **kwargs):
posaxis = maybe_posaxis(layout, axis, depth)
@@ -68,4 +80,16 @@ def action(layout, depth, backend, lateral_context, **kwargs):
out = ak._do.recursively_apply(layout, action, numpy_to_regular=True)
- return ctx.wrap(out, highlevel=highlevel)
+ wrapped_out = ctx.wrap(
+ out,
+ highlevel=highlevel,
+ )
+
+ # propagate named axis to output
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
diff --git a/src/awkward/operations/ak_isclose.py b/src/awkward/operations/ak_isclose.py
index 8797c36752..d5ff825c61 100644
--- a/src/awkward/operations/ak_isclose.py
+++ b/src/awkward/operations/ak_isclose.py
@@ -2,9 +2,12 @@
from __future__ import annotations
+from functools import reduce
+
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext, ensure_same_backend
+from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis
from awkward._nplikes.numpy_like import NumpyMetadata
__all__ = ("isclose",)
@@ -70,10 +73,26 @@ def action(inputs, backend, **kwargs):
),
)
- out = ak._broadcasting.broadcast_and_apply(layouts, action)
+ depth_context, lateral_context = NamedAxesWithDims.prepare_contexts([a, b])
+ out = ak._broadcasting.broadcast_and_apply(
+ layouts,
+ action,
+ depth_context=depth_context,
+ lateral_context=lateral_context,
+ )
assert isinstance(out, tuple) and len(out) == 1
- return ctx.wrap(out[0], highlevel=highlevel)
+ out_named_axis = reduce(
+ _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis
+ )
+ wrapped_out = ctx.wrap(out[0], highlevel=highlevel)
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
@ak._connect.numpy.implements("isclose")
diff --git a/src/awkward/operations/ak_linear_fit.py b/src/awkward/operations/ak_linear_fit.py
index 971fea64fe..01ac0f3297 100644
--- a/src/awkward/operations/ak_linear_fit.py
+++ b/src/awkward/operations/ak_linear_fit.py
@@ -7,7 +7,6 @@
from awkward._layout import HighLevelContext, ensure_same_backend
from awkward._nplikes import ufuncs
from awkward._nplikes.numpy_like import NumpyMetadata
-from awkward._regularize import regularize_axis
__all__ = ("linear_fit",)
@@ -95,8 +94,6 @@ def linear_fit(
def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
-
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
x_layout, y_layout, weight_layout = ensure_same_backend(
ctx.unwrap(x, allow_record=False, primitive_policy="error"),
@@ -231,4 +228,13 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr
if is_scalar:
out = out[0]
- return ctx.wrap(out, highlevel=highlevel, allow_other=is_scalar)
+ wrapped_out = ctx.wrap(out, highlevel=highlevel, allow_other=is_scalar)
+
+ # propagate named axis
+ # use strategy "remove all" (see: awkward._namedaxis)
+ return ak.operations.ak_without_named_axis._impl(
+ wrapped_out,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
diff --git a/src/awkward/operations/ak_local_index.py b/src/awkward/operations/ak_local_index.py
index 2231ac229f..d5e7089dbc 100644
--- a/src/awkward/operations/ak_local_index.py
+++ b/src/awkward/operations/ak_local_index.py
@@ -5,6 +5,11 @@
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
+from awkward._namedaxis import (
+ _get_named_axis,
+ _keep_named_axis_up_to,
+ _named_axis_to_positional_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
@@ -88,8 +93,32 @@ def local_index(array, axis=-1, *, highlevel=True, behavior=None, attrs=None):
def _impl(array, axis, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
+
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+
+ axis = regularize_axis(axis, none_allowed=False)
+
+ # Step 2: propagate named axis from input to output,
+ # use strategy "keep up to" (see: awkward._namedaxis)
+ out_named_axis = _keep_named_axis_up_to(named_axis, axis, layout.minmax_depth[1])
+
out = ak._do.local_index(layout, axis)
- return ctx.wrap(out, highlevel=highlevel)
+
+ wrapped_out = ctx.wrap(
+ out,
+ highlevel=highlevel,
+ )
+
+ # propagate named axis to output
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
diff --git a/src/awkward/operations/ak_mask.py b/src/awkward/operations/ak_mask.py
index 54d9a5e04b..b18273047a 100644
--- a/src/awkward/operations/ak_mask.py
+++ b/src/awkward/operations/ak_mask.py
@@ -2,9 +2,12 @@
from __future__ import annotations
+from functools import reduce
+
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext, ensure_same_backend
+from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis
from awkward._nplikes.numpy_like import NumpyMetadata
__all__ = ("mask",)
@@ -124,8 +127,26 @@ def action(inputs, backend, **kwargs):
ctx.unwrap(mask, allow_record=False, primitive_policy="error"),
)
+ depth_context, lateral_context = NamedAxesWithDims.prepare_contexts([array, mask])
out = ak._broadcasting.broadcast_and_apply(
- layouts, action, numpy_to_regular=True, right_broadcast=False
+ layouts,
+ action,
+ depth_context=depth_context,
+ lateral_context=lateral_context,
+ numpy_to_regular=True,
+ right_broadcast=False,
)
assert isinstance(out, tuple) and len(out) == 1
- return ctx.wrap(out[0], highlevel=highlevel)
+
+ # Unify named axes propagated through the broadcast
+ out_named_axis = reduce(
+ _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis
+ )
+ wrapped_out = ctx.wrap(out[0], highlevel=highlevel)
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
diff --git a/src/awkward/operations/ak_max.py b/src/awkward/operations/ak_max.py
index a01a0d64c5..319b2c7bed 100644
--- a/src/awkward/operations/ak_max.py
+++ b/src/awkward/operations/ak_max.py
@@ -6,6 +6,12 @@
from awkward._connect.numpy import UNSUPPORTED
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
+from awkward._namedaxis import (
+ _get_named_axis,
+ _keep_named_axis,
+ _named_axis_to_positional_axis,
+ _remove_named_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
@@ -142,9 +148,26 @@ def nanmax(
def _impl(array, axis, keepdims, initial, mask_identity, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
+
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+ # Step 2: propagate named axis from input to output,
+ # keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
+ # keepdims=False: use strategy "remove one" (see: awkward._namedaxis)
+ out_named_axis = _keep_named_axis(named_axis, None)
+ if not keepdims:
+ out_named_axis = _remove_named_axis(
+ named_axis=out_named_axis,
+ axis=axis,
+ total=layout.minmax_depth[1],
+ )
+
+ axis = regularize_axis(axis, none_allowed=True)
+
reducer = ak._reducers.Max(initial)
out = ak._do.reduce(
@@ -155,7 +178,21 @@ def _impl(array, axis, keepdims, initial, mask_identity, highlevel, behavior, at
keepdims=keepdims,
behavior=ctx.behavior,
)
- return ctx.wrap(out, highlevel=highlevel, allow_other=True)
+
+ wrapped_out = ctx.wrap(
+ out,
+ highlevel=highlevel,
+ allow_other=True,
+ )
+
+ # propagate named axis to output
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
@ak._connect.numpy.implements("amax")
diff --git a/src/awkward/operations/ak_mean.py b/src/awkward/operations/ak_mean.py
index fa74a89b61..a9b38ce1f0 100644
--- a/src/awkward/operations/ak_mean.py
+++ b/src/awkward/operations/ak_mean.py
@@ -3,6 +3,7 @@
from __future__ import annotations
import awkward as ak
+from awkward._attrs import attrs_of_obj
from awkward._connect.numpy import UNSUPPORTED
from awkward._dispatch import high_level_function
from awkward._layout import (
@@ -11,6 +12,10 @@
maybe_highlevel_to_lowlevel,
maybe_posaxis,
)
+from awkward._namedaxis import (
+ _get_named_axis,
+ _named_axis_to_positional_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
@@ -174,8 +179,6 @@ def nanmean(
def _impl(x, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
-
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
x_layout, weight_layout = ensure_same_backend(
ctx.unwrap(x, allow_record=False, primitive_policy="error"),
@@ -191,6 +194,13 @@ def _impl(x, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs):
x = ctx.wrap(x_layout)
weight = ctx.wrap(weight_layout, allow_other=True)
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+
+ axis = regularize_axis(axis, none_allowed=True)
+
with np.errstate(invalid="ignore", divide="ignore"):
if weight is None:
sumw = ak.operations.ak_count._impl(
@@ -245,14 +255,25 @@ def _impl(x, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs):
if axis is None:
if not keepdims:
+ # remove all dimensions
out = out[(0,) * out.ndim]
else:
if not keepdims:
+ # remove reduced dimension
posaxis = maybe_posaxis(out.layout, axis, 1)
out = out[(slice(None, None),) * posaxis + (0,)]
- return ctx.wrap(
- maybe_highlevel_to_lowlevel(out), highlevel=highlevel, allow_other=True
+ wrapped = ctx.wrap(
+ maybe_highlevel_to_lowlevel(out),
+ highlevel=highlevel,
+ allow_other=True,
+ )
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped,
+ named_axis=_get_named_axis(attrs_of_obj(out) or {}),
+ highlevel=highlevel,
+ behavior=None,
+ attrs=None,
)
diff --git a/src/awkward/operations/ak_merge_option_of_records.py b/src/awkward/operations/ak_merge_option_of_records.py
index c3e1095ba4..17402e77a6 100644
--- a/src/awkward/operations/ak_merge_option_of_records.py
+++ b/src/awkward/operations/ak_merge_option_of_records.py
@@ -5,6 +5,10 @@
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext, maybe_posaxis
+from awkward._namedaxis import (
+ _get_named_axis,
+ _named_axis_to_positional_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
from awkward.errors import AxisError
@@ -49,10 +53,15 @@ def merge_option_of_records(
def _impl(array, axis, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+
+ axis = regularize_axis(axis, none_allowed=False)
+
# First, normalise type-invsible "index-of-records" to "record-of-index"
def apply_displace_index(layout, backend, **kwargs):
if (layout.is_indexed and not layout.is_option) and layout.content.is_record:
diff --git a/src/awkward/operations/ak_merge_union_of_records.py b/src/awkward/operations/ak_merge_union_of_records.py
index d523c0b5f8..0094203947 100644
--- a/src/awkward/operations/ak_merge_union_of_records.py
+++ b/src/awkward/operations/ak_merge_union_of_records.py
@@ -5,6 +5,10 @@
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext, maybe_posaxis
+from awkward._namedaxis import (
+ _get_named_axis,
+ _named_axis_to_positional_axis,
+)
from awkward._nplikes.numpy_like import ArrayLike, NumpyMetadata
from awkward._regularize import regularize_axis
from awkward.errors import AxisError
@@ -59,10 +63,15 @@ def merge_union_of_records(
def _impl(array, axis, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+
+ axis = regularize_axis(axis, none_allowed=False)
+
def invert_record_union(
tags: ArrayLike, index: ArrayLike, contents
) -> ak.contents.RecordArray:
diff --git a/src/awkward/operations/ak_min.py b/src/awkward/operations/ak_min.py
index 05e583d430..1b9189f740 100644
--- a/src/awkward/operations/ak_min.py
+++ b/src/awkward/operations/ak_min.py
@@ -6,6 +6,12 @@
from awkward._connect.numpy import UNSUPPORTED
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
+from awkward._namedaxis import (
+ _get_named_axis,
+ _keep_named_axis,
+ _named_axis_to_positional_axis,
+ _remove_named_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
@@ -142,9 +148,26 @@ def nanmin(
def _impl(array, axis, keepdims, initial, mask_identity, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
+
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+ # Step 2: propagate named axis from input to output,
+ # keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
+ # keepdims=False: use strategy "remove one" (see: awkward._namedaxis)
+ out_named_axis = _keep_named_axis(named_axis, None)
+ if not keepdims:
+ out_named_axis = _remove_named_axis(
+ named_axis=out_named_axis,
+ axis=axis,
+ total=layout.minmax_depth[1],
+ )
+
+ axis = regularize_axis(axis, none_allowed=True)
+
reducer = ak._reducers.Min(initial)
out = ak._do.reduce(
@@ -155,7 +178,21 @@ def _impl(array, axis, keepdims, initial, mask_identity, highlevel, behavior, at
keepdims=keepdims,
behavior=ctx.behavior,
)
- return ctx.wrap(out, highlevel=highlevel, allow_other=True)
+
+ wrapped_out = ctx.wrap(
+ out,
+ highlevel=highlevel,
+ allow_other=True,
+ )
+
+ # propagate named axis to output
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
@ak._connect.numpy.implements("amin")
diff --git a/src/awkward/operations/ak_moment.py b/src/awkward/operations/ak_moment.py
index 7cac2498ee..2c8e29adb1 100644
--- a/src/awkward/operations/ak_moment.py
+++ b/src/awkward/operations/ak_moment.py
@@ -3,14 +3,19 @@
from __future__ import annotations
import awkward as ak
+from awkward._attrs import attrs_of_obj
from awkward._dispatch import high_level_function
from awkward._layout import (
HighLevelContext,
ensure_same_backend,
maybe_highlevel_to_lowlevel,
)
+from awkward._namedaxis import (
+ AxisName,
+ _get_named_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
-from awkward._regularize import regularize_axis
+from awkward._typing import Mapping
__all__ = ("moment",)
@@ -22,13 +27,13 @@ def moment(
x,
n,
weight=None,
- axis=None,
+ axis: AxisName = None,
*,
- keepdims=False,
- mask_identity=False,
- highlevel=True,
- behavior=None,
- attrs=None,
+ keepdims: bool = False,
+ mask_identity: bool = False,
+ highlevel: bool = True,
+ behavior: Mapping | None = None,
+ attrs: Mapping | None = None,
):
"""
Args:
@@ -86,9 +91,17 @@ def moment(
)
-def _impl(x, n, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
-
+def _impl(
+ x,
+ n,
+ weight,
+ axis: AxisName,
+ keepdims: bool,
+ mask_identity: bool,
+ highlevel: bool,
+ behavior: Mapping | None,
+ attrs: Mapping | None,
+):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
x_layout, weight_layout = ensure_same_backend(
ctx.unwrap(x, allow_record=False, primitive_policy="error"),
@@ -143,8 +156,20 @@ def _impl(x, n, weight, axis, keepdims, mask_identity, highlevel, behavior, attr
behavior=ctx.behavior,
attrs=ctx.attrs,
)
- return ctx.wrap(
- maybe_highlevel_to_lowlevel(sumwxn / sumw),
+
+ out = sumwxn / sumw
+
+ # propagate named axis to output
+ wrapped = ctx.wrap(
+ maybe_highlevel_to_lowlevel(out),
highlevel=highlevel,
allow_other=True,
)
+ # propagate named axis to output
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped,
+ named_axis=_get_named_axis(attrs_of_obj(out)),
+ highlevel=highlevel,
+ behavior=None,
+ attrs=None,
+ )
diff --git a/src/awkward/operations/ak_nan_to_none.py b/src/awkward/operations/ak_nan_to_none.py
index 7dabbfe828..23ef938dbe 100644
--- a/src/awkward/operations/ak_nan_to_none.py
+++ b/src/awkward/operations/ak_nan_to_none.py
@@ -6,6 +6,7 @@
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
from awkward._nplikes.numpy_like import NumpyMetadata
+from awkward._typing import Mapping
__all__ = ("nan_to_none",)
@@ -13,7 +14,13 @@
@high_level_function()
-def nan_to_none(array, *, highlevel=True, behavior=None, attrs=None):
+def nan_to_none(
+ array,
+ *,
+ highlevel: bool = True,
+ behavior: Mapping | None = None,
+ attrs: Mapping | None = None,
+):
"""
Args:
array: Array-like data (anything #ak.to_layout recognizes).
@@ -35,7 +42,7 @@ def nan_to_none(array, *, highlevel=True, behavior=None, attrs=None):
return _impl(array, highlevel, behavior, attrs)
-def _impl(array, highlevel, behavior, attrs):
+def _impl(array, highlevel: bool, behavior: Mapping | None, attrs: Mapping | None):
def action(layout, continuation, backend, **kwargs):
if layout.is_numpy and np.issubdtype(layout.dtype, np.floating):
mask = backend.nplike.isnan(layout.data)
@@ -55,5 +62,6 @@ def action(layout, continuation, backend, **kwargs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
+
out = ak._do.recursively_apply(layout, action)
return ctx.wrap(out, highlevel=highlevel)
diff --git a/src/awkward/operations/ak_nan_to_num.py b/src/awkward/operations/ak_nan_to_num.py
index 4c7472a06f..69e2617c00 100644
--- a/src/awkward/operations/ak_nan_to_num.py
+++ b/src/awkward/operations/ak_nan_to_num.py
@@ -2,10 +2,14 @@
from __future__ import annotations
+from functools import reduce
+
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext, ensure_same_backend
+from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis
from awkward._nplikes.numpy_like import NumpyMetadata
+from awkward._typing import Mapping
__all__ = ("nan_to_num",)
@@ -15,14 +19,14 @@
@high_level_function()
def nan_to_num(
array,
- copy=True,
+ copy: bool = True,
nan=0.0,
posinf=None,
neginf=None,
*,
- highlevel=True,
- behavior=None,
- attrs=None,
+ highlevel: bool = True,
+ behavior: Mapping | None = None,
+ attrs: Mapping | None = None,
):
"""
Args:
@@ -52,7 +56,16 @@ def nan_to_num(
return _impl(array, copy, nan, posinf, neginf, highlevel, behavior, attrs)
-def _impl(array, copy, nan, posinf, neginf, highlevel, behavior, attrs):
+def _impl(
+ array,
+ copy: bool,
+ nan,
+ posinf,
+ neginf,
+ highlevel: bool,
+ behavior: Mapping | None,
+ attrs: Mapping | None,
+):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout, nan_layout, posinf_layout, neginf_layout = ensure_same_backend(
ctx.unwrap(array),
@@ -81,15 +94,19 @@ def _impl(array, copy, nan, posinf, neginf, highlevel, behavior, attrs):
broadcasting_ids = {}
broadcasting = [layout]
+ arrays_to_broadcast = [array]
if isinstance(nan_layout, ak.contents.Content):
broadcasting_ids[id(nan)] = len(broadcasting)
broadcasting.append(nan_layout)
+ arrays_to_broadcast.append(nan)
if isinstance(posinf_layout, ak.contents.Content):
broadcasting_ids[id(posinf)] = len(broadcasting)
broadcasting.append(posinf_layout)
+ arrays_to_broadcast.append(posinf)
if isinstance(neginf_layout, ak.contents.Content):
broadcasting_ids[id(neginf)] = len(broadcasting)
broadcasting.append(neginf_layout)
+ arrays_to_broadcast.append(neginf)
if len(broadcasting) == 1:
@@ -138,9 +155,29 @@ def action(inputs, backend, **kwargs):
else:
return None
- out = ak._broadcasting.broadcast_and_apply(broadcasting, action)
+ depth_context, lateral_context = NamedAxesWithDims.prepare_contexts(
+ arrays_to_broadcast
+ )
+ out = ak._broadcasting.broadcast_and_apply(
+ broadcasting,
+ action,
+ depth_context=depth_context,
+ lateral_context=lateral_context,
+ )
assert isinstance(out, tuple) and len(out) == 1
- out = out[0]
+
+ # Unify named axes propagated through the broadcast
+ out_named_axis = reduce(
+ _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis
+ )
+ wrapped_out = ctx.wrap(out[0], highlevel=highlevel)
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
return ctx.wrap(out, highlevel=highlevel)
diff --git a/src/awkward/operations/ak_num.py b/src/awkward/operations/ak_num.py
index ad9b4e746c..705a1e1c63 100644
--- a/src/awkward/operations/ak_num.py
+++ b/src/awkward/operations/ak_num.py
@@ -5,8 +5,14 @@
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext, maybe_posaxis
+from awkward._namedaxis import (
+ _get_named_axis,
+ _keep_named_axis,
+ _named_axis_to_positional_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
-from awkward._regularize import is_integer, regularize_axis
+from awkward._regularize import regularize_axis
+from awkward._typing import Mapping
from awkward.errors import AxisError
__all__ = ("num",)
@@ -15,7 +21,14 @@
@high_level_function()
-def num(array, axis=1, *, highlevel=True, behavior=None, attrs=None):
+def num(
+ array,
+ axis=1,
+ *,
+ highlevel: bool = True,
+ behavior: Mapping | None = None,
+ attrs: Mapping | None = None,
+):
"""
Args:
array: Array-like data (anything #ak.to_layout recognizes).
@@ -83,13 +96,25 @@ def num(array, axis=1, *, highlevel=True, behavior=None, attrs=None):
return _impl(array, axis, highlevel, behavior, attrs)
-def _impl(array, axis, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
+def _impl(
+ array,
+ axis,
+ highlevel: bool,
+ behavior: Mapping | None,
+ attrs: Mapping | None,
+):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
- if not is_integer(axis):
- raise TypeError(f"'axis' must be an integer, not {axis!r}")
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+ # Step 2: propagate named axis from input to output,
+ # use strategy "keep one" (see: awkward._namedaxis)
+ out_named_axis = _keep_named_axis(named_axis, axis)
+
+ axis = regularize_axis(axis, none_allowed=False)
if maybe_posaxis(layout, axis, 1) == 0:
index_nplike = layout.backend.index_nplike
@@ -109,4 +134,16 @@ def action(layout, depth, **kwargs):
out = ak._do.recursively_apply(layout, action, numpy_to_regular=True)
- return ctx.wrap(out, highlevel=highlevel)
+ wrapped_out = ctx.wrap(
+ out,
+ highlevel=highlevel,
+ )
+
+ # propagate named axis to output
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
diff --git a/src/awkward/operations/ak_pad_none.py b/src/awkward/operations/ak_pad_none.py
index 34355a8546..17bb3035ac 100644
--- a/src/awkward/operations/ak_pad_none.py
+++ b/src/awkward/operations/ak_pad_none.py
@@ -5,6 +5,10 @@
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
+from awkward._namedaxis import (
+ _get_named_axis,
+ _named_axis_to_positional_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
@@ -113,9 +117,15 @@ def pad_none(
def _impl(array, target, axis, clip, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
+
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+
+ axis = regularize_axis(axis, none_allowed=False)
+
out = ak._do.pad_none(layout, target, axis, clip=clip)
return ctx.wrap(out, highlevel=highlevel)
diff --git a/src/awkward/operations/ak_prod.py b/src/awkward/operations/ak_prod.py
index cde898f174..d3d1a050c3 100644
--- a/src/awkward/operations/ak_prod.py
+++ b/src/awkward/operations/ak_prod.py
@@ -6,6 +6,12 @@
from awkward._connect.numpy import UNSUPPORTED
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
+from awkward._namedaxis import (
+ _get_named_axis,
+ _keep_named_axis,
+ _named_axis_to_positional_axis,
+ _remove_named_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
@@ -119,9 +125,26 @@ def nanprod(
def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
+
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+ # Step 2: propagate named axis from input to output,
+ # keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
+ # keepdims=False: use strategy "remove one" (see: awkward._namedaxis)
+ out_named_axis = _keep_named_axis(named_axis, None)
+ if not keepdims:
+ out_named_axis = _remove_named_axis(
+ named_axis=out_named_axis,
+ axis=axis,
+ total=layout.minmax_depth[1],
+ )
+
+ axis = regularize_axis(axis, none_allowed=True)
+
reducer = ak._reducers.Prod()
out = ak._do.reduce(
@@ -132,7 +155,21 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
keepdims=keepdims,
behavior=ctx.behavior,
)
- return ctx.wrap(out, highlevel=highlevel, allow_other=True)
+
+ wrapped_out = ctx.wrap(
+ out,
+ highlevel=highlevel,
+ allow_other=True,
+ )
+
+ # propagate named axis to output
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
@ak._connect.numpy.implements("prod")
diff --git a/src/awkward/operations/ak_ptp.py b/src/awkward/operations/ak_ptp.py
index 56daaa6980..6d4beafbd5 100644
--- a/src/awkward/operations/ak_ptp.py
+++ b/src/awkward/operations/ak_ptp.py
@@ -3,6 +3,7 @@
from __future__ import annotations
import awkward as ak
+from awkward._attrs import attrs_of_obj
from awkward._connect.numpy import UNSUPPORTED
from awkward._dispatch import high_level_function
from awkward._layout import (
@@ -10,6 +11,10 @@
maybe_highlevel_to_lowlevel,
maybe_posaxis,
)
+from awkward._namedaxis import (
+ _get_named_axis,
+ _named_axis_to_positional_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
@@ -83,10 +88,16 @@ def ptp(
def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+
+ axis = regularize_axis(axis, none_allowed=True)
+
with np.errstate(invalid="ignore", divide="ignore"):
maxi = ak.operations.ak_max._impl(
layout,
@@ -126,8 +137,18 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
posaxis = maybe_posaxis(out.layout, axis, 1)
out = out[(slice(None, None),) * posaxis + (0,)]
- return ctx.wrap(
- maybe_highlevel_to_lowlevel(out), highlevel=highlevel, allow_other=True
+ wrapped = ctx.wrap(
+ maybe_highlevel_to_lowlevel(out),
+ highlevel=highlevel,
+ allow_other=True,
+ )
+ # propagate named axis to output
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped,
+ named_axis=_get_named_axis(attrs_of_obj(out)),
+ highlevel=highlevel,
+ behavior=None,
+ attrs=None,
)
diff --git a/src/awkward/operations/ak_ravel.py b/src/awkward/operations/ak_ravel.py
index 66a3e3a55d..062601eff4 100644
--- a/src/awkward/operations/ak_ravel.py
+++ b/src/awkward/operations/ak_ravel.py
@@ -75,7 +75,16 @@ def _impl(array, highlevel, behavior, attrs):
result = ak._do.mergemany(out)
- return ctx.wrap(result, highlevel=highlevel)
+ wrapped_out = ctx.wrap(result, highlevel=highlevel)
+
+ # propagate named axis to output
+ # use strategy "remove all" (see: awkward._namedaxis)
+ return ak.operations.ak_without_named_axis._impl(
+ wrapped_out,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
@ak._connect.numpy.implements("ravel")
diff --git a/src/awkward/operations/ak_real.py b/src/awkward/operations/ak_real.py
index 6d52971dab..655e4e8007 100644
--- a/src/awkward/operations/ak_real.py
+++ b/src/awkward/operations/ak_real.py
@@ -14,10 +14,10 @@
@ak._connect.numpy.implements("real")
@high_level_function()
-def real(val, highlevel=True, behavior=None, attrs=None):
+def real(array, highlevel=True, behavior=None, attrs=None):
"""
Args:
- val : array_like
+ array : array_like
Input array.
highlevel (bool, default is True): If True, return an #ak.Array;
otherwise, return a low-level #ak.contents.Content subclass.
@@ -30,15 +30,15 @@ def real(val, highlevel=True, behavior=None, attrs=None):
If the arrays have complex elements, the returned arrays are floats.
"""
# Dispatch
- yield (val,)
+ yield (array,)
# Implementation
- return _impl_real(val, highlevel, behavior, attrs)
+ return _impl(array, highlevel, behavior, attrs)
-def _impl_real(val, highlevel, behavior, attrs):
+def _impl(array, highlevel, behavior, attrs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
- layout = ctx.unwrap(val, allow_record=False, primitive_policy="error")
+ layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
out = ak._do.recursively_apply(layout, _action_real)
return ctx.wrap(out, highlevel=highlevel)
diff --git a/src/awkward/operations/ak_singletons.py b/src/awkward/operations/ak_singletons.py
index 35f60d5c97..4de6a59151 100644
--- a/src/awkward/operations/ak_singletons.py
+++ b/src/awkward/operations/ak_singletons.py
@@ -5,8 +5,13 @@
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext, maybe_posaxis
+from awkward._namedaxis import (
+ _add_named_axis,
+ _get_named_axis,
+ _named_axis_to_positional_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
-from awkward._regularize import is_integer, regularize_axis
+from awkward._regularize import regularize_axis
from awkward.errors import AxisError
__all__ = ("singletons",)
@@ -56,12 +61,21 @@ def singletons(array, axis=0, *, highlevel=True, behavior=None, attrs=None):
def _impl(array, axis, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
- if not is_integer(axis):
- raise TypeError(f"'axis' must be an integer, not {axis!r}")
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+
+ axis = regularize_axis(axis, none_allowed=False)
+
+ # Step 2: propagate named axis from input to output,
+ # use strategy "add one" (see: awkward._namedaxis)
+ out_named_axis = _add_named_axis(
+ named_axis, (axis + 1) if axis >= 0 else axis, layout.minmax_depth[1]
+ )
def action(layout, depth, backend, **kwargs):
posaxis = maybe_posaxis(layout, axis, depth)
@@ -90,4 +104,16 @@ def action(layout, depth, backend, **kwargs):
out = ak._do.recursively_apply(layout, action, numpy_to_regular=True)
- return ctx.wrap(out, highlevel=highlevel)
+ wrapped_out = ctx.wrap(
+ out,
+ highlevel=highlevel,
+ )
+
+ # propagate named axis to output
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
diff --git a/src/awkward/operations/ak_softmax.py b/src/awkward/operations/ak_softmax.py
index e86cbe9cf0..b2cb11bff0 100644
--- a/src/awkward/operations/ak_softmax.py
+++ b/src/awkward/operations/ak_softmax.py
@@ -3,12 +3,17 @@
from __future__ import annotations
import awkward as ak
+from awkward._attrs import attrs_of_obj
from awkward._dispatch import high_level_function
from awkward._layout import (
HighLevelContext,
maybe_highlevel_to_lowlevel,
maybe_posaxis,
)
+from awkward._namedaxis import (
+ _get_named_axis,
+ _named_axis_to_positional_axis,
+)
from awkward._nplikes import ufuncs
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
@@ -75,10 +80,16 @@ def softmax(
def _impl(x, axis, keepdims, mask_identity, highlevel, behavior, attrs):
original_axis = axis
- axis = regularize_axis(axis)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
x_layout = ctx.unwrap(x, allow_record=False, primitive_policy="error")
+
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+ axis = regularize_axis(axis, none_allowed=True)
+
x = ctx.wrap(x_layout)
if maybe_posaxis(x_layout, axis, 1) != maybe_posaxis(x_layout, -1, 1):
@@ -97,8 +108,19 @@ def _impl(x, axis, keepdims, mask_identity, highlevel, behavior, attrs):
behavior=ctx.behavior,
attrs=ctx.attrs,
)
- return ctx.wrap(
- maybe_highlevel_to_lowlevel(expx / denom),
+
+ out = expx / denom
+
+ wrapped = ctx.wrap(
+ maybe_highlevel_to_lowlevel(out),
highlevel=highlevel,
allow_other=True,
)
+ # propagate named axis to output
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped,
+ named_axis=_get_named_axis(attrs_of_obj(out)),
+ highlevel=highlevel,
+ behavior=None,
+ attrs=None,
+ )
diff --git a/src/awkward/operations/ak_sort.py b/src/awkward/operations/ak_sort.py
index 5e82e91604..0864fc5d98 100644
--- a/src/awkward/operations/ak_sort.py
+++ b/src/awkward/operations/ak_sort.py
@@ -6,6 +6,10 @@
from awkward._connect.numpy import UNSUPPORTED
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
+from awkward._namedaxis import (
+ _get_named_axis,
+ _named_axis_to_positional_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
@@ -59,11 +63,22 @@ def sort(
def _impl(array, axis, ascending, stable, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
+
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+
+ axis = regularize_axis(axis, none_allowed=False)
+
out = ak._do.sort(layout, axis, ascending, stable)
- return ctx.wrap(out, highlevel=highlevel)
+
+ return ctx.wrap(
+ out,
+ highlevel=highlevel,
+ )
@ak._connect.numpy.implements("sort")
diff --git a/src/awkward/operations/ak_std.py b/src/awkward/operations/ak_std.py
index 0385032440..7926b341fe 100644
--- a/src/awkward/operations/ak_std.py
+++ b/src/awkward/operations/ak_std.py
@@ -3,6 +3,7 @@
from __future__ import annotations
import awkward as ak
+from awkward._attrs import attrs_of_obj
from awkward._connect.numpy import UNSUPPORTED
from awkward._dispatch import high_level_function
from awkward._layout import (
@@ -11,6 +12,10 @@
maybe_highlevel_to_lowlevel,
maybe_posaxis,
)
+from awkward._namedaxis import (
+ _get_named_axis,
+ _named_axis_to_positional_axis,
+)
from awkward._nplikes import ufuncs
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
@@ -165,8 +170,6 @@ def nanstd(
def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
-
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
x_layout, weight_layout = ensure_same_backend(
ctx.unwrap(x, allow_record=False, primitive_policy="error"),
@@ -182,6 +185,13 @@ def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, a
x = ctx.wrap(x_layout)
weight = ctx.wrap(weight_layout, allow_other=True)
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+
+ axis = regularize_axis(axis, none_allowed=True)
+
with np.errstate(invalid="ignore", divide="ignore"):
out = ufuncs.sqrt(
ak.operations.ak_var._impl(
@@ -215,8 +225,18 @@ def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, a
posaxis = maybe_posaxis(out.layout, axis, 1)
out = out[(slice(None, None),) * posaxis + (0,)]
- return ctx.wrap(
- maybe_highlevel_to_lowlevel(out), highlevel=highlevel, allow_other=True
+ wrapped = ctx.wrap(
+ maybe_highlevel_to_lowlevel(out),
+ highlevel=highlevel,
+ allow_other=True,
+ )
+ # propagate named axis to output
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped,
+ named_axis=_get_named_axis(attrs_of_obj(out)),
+ highlevel=highlevel,
+ behavior=None,
+ attrs=None,
)
diff --git a/src/awkward/operations/ak_strings_astype.py b/src/awkward/operations/ak_strings_astype.py
index b0834db3a6..479232cf01 100644
--- a/src/awkward/operations/ak_strings_astype.py
+++ b/src/awkward/operations/ak_strings_astype.py
@@ -82,5 +82,6 @@ def action(layout, **kwargs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
+
out = ak._do.recursively_apply(layout, action)
return ctx.wrap(out, highlevel=highlevel)
diff --git a/src/awkward/operations/ak_sum.py b/src/awkward/operations/ak_sum.py
index f00434083e..ae6a40aef8 100644
--- a/src/awkward/operations/ak_sum.py
+++ b/src/awkward/operations/ak_sum.py
@@ -6,6 +6,12 @@
from awkward._connect.numpy import UNSUPPORTED
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext
+from awkward._namedaxis import (
+ _get_named_axis,
+ _keep_named_axis,
+ _named_axis_to_positional_axis,
+ _remove_named_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
@@ -269,9 +275,26 @@ def nansum(
def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
+
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+ # Step 2: propagate named axis from input to output,
+ # keepdims=True: use strategy "keep all" (see: awkward._namedaxis)
+ # keepdims=False: use strategy "remove one" (see: awkward._namedaxis)
+ out_named_axis = _keep_named_axis(named_axis, None)
+ if not keepdims:
+ out_named_axis = _remove_named_axis(
+ named_axis=out_named_axis,
+ axis=axis,
+ total=layout.minmax_depth[1],
+ )
+
+ axis = regularize_axis(axis, none_allowed=True)
+
reducer = ak._reducers.Sum()
out = ak._do.reduce(
@@ -282,7 +305,21 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
keepdims=keepdims,
behavior=ctx.behavior,
)
- return ctx.wrap(out, highlevel=highlevel, allow_other=True)
+
+ wrapped_out = ctx.wrap(
+ out,
+ highlevel=highlevel,
+ allow_other=True,
+ )
+
+ # propagate named axis to output
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
@ak._connect.numpy.implements("sum")
diff --git a/src/awkward/operations/ak_to_backend.py b/src/awkward/operations/ak_to_backend.py
index f65a2c0a81..8d93e2de94 100644
--- a/src/awkward/operations/ak_to_backend.py
+++ b/src/awkward/operations/ak_to_backend.py
@@ -17,7 +17,7 @@ def to_backend(array, backend, *, highlevel=True, behavior=None, attrs=None):
"""
Args:
array: Array-like data (anything #ak.to_layout recognizes).
- backend (`"cpu"`, `"cuda"`, or `"jax"`): If `"cpu"`, the array structure is
+ backend (`"cpu"`, `"cuda"`, `"jax"`, or `"typetracer"`): If `"cpu"`, the array structure is
recursively copied (if need be) to main memory for use with
the default Numpy backend; if `"cuda"`, the structure is copied
to the GPU(s) for use with CuPy. If `"jax"`, the structure is
diff --git a/src/awkward/operations/ak_to_regular.py b/src/awkward/operations/ak_to_regular.py
index b72e48d7c5..ae9f9cc3da 100644
--- a/src/awkward/operations/ak_to_regular.py
+++ b/src/awkward/operations/ak_to_regular.py
@@ -66,7 +66,7 @@ def to_regular(array, axis=1, *, highlevel=True, behavior=None, attrs=None):
def _impl(array, axis, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
+ axis = regularize_axis(axis, none_allowed=True)
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
diff --git a/src/awkward/operations/ak_transform.py b/src/awkward/operations/ak_transform.py
index 23b4dbfd4e..93a4911914 100644
--- a/src/awkward/operations/ak_transform.py
+++ b/src/awkward/operations/ak_transform.py
@@ -3,6 +3,7 @@
from __future__ import annotations
import copy
+from functools import reduce
import awkward as ak
from awkward._backends.numpy import NumpyBackend
@@ -15,6 +16,7 @@
)
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext, ensure_same_backend
+from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis
__all__ = ("transform",)
@@ -580,6 +582,17 @@ def action(inputs, **kwargs):
f"transformation must return a Content, tuple of Contents, or None, not {type(out)}\n\n{out!r}"
)
+ if depth_context is None:
+ depth_context = {}
+ if lateral_context is None:
+ lateral_context = {}
+ assert NAMED_AXIS_KEY not in depth_context
+ assert NAMED_AXIS_KEY not in lateral_context
+ _depth_context, _lateral_context = NamedAxesWithDims.prepare_contexts(
+ [array, *more_arrays]
+ )
+ depth_context.update(_depth_context)
+ lateral_context.update(_lateral_context)
backend = next((layout.backend for layout in layouts), cpu)
isscalar = []
out = apply_broadcasting_step(
@@ -594,6 +607,11 @@ def action(inputs, **kwargs):
assert isinstance(out, tuple)
out = [broadcast_unpack(x, isscalar) for x in out]
+ # Unify named axes propagated through the broadcast
+ out_named_axis = reduce(
+ _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis
+ )
+
if return_value == "none":
return
elif expect_return_value and not transformer_did_terminate:
@@ -602,6 +620,25 @@ def action(inputs, **kwargs):
"or tuple of Contents, but instead only returned None."
)
elif len(out) == 1:
- return ctx.wrap(out[0], highlevel=highlevel)
+ wrapped_out = ctx.wrap(out[0], highlevel=highlevel)
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
else:
- return tuple(ctx.wrap(x, highlevel=highlevel) for x in out)
+ wrapped_out = []
+ for x in out:
+ wrapped = ctx.wrap(x, highlevel=highlevel)
+ wrapped_out.append(
+ ak.operations.ak_with_named_axis._impl(
+ wrapped,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
+ )
+ return tuple(wrapped_out)
diff --git a/src/awkward/operations/ak_unflatten.py b/src/awkward/operations/ak_unflatten.py
index 78c2631e31..83a3b8f2b4 100644
--- a/src/awkward/operations/ak_unflatten.py
+++ b/src/awkward/operations/ak_unflatten.py
@@ -6,6 +6,10 @@
from awkward._backends.numpy import NumpyBackend
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext, ensure_same_backend, maybe_posaxis
+from awkward._namedaxis import (
+ _get_named_axis,
+ _named_axis_to_positional_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._nplikes.shape import unknown_length
from awkward._nplikes.typetracer import is_unknown_scalar
@@ -91,8 +95,6 @@ def unflatten(array, counts, axis=0, *, highlevel=True, behavior=None, attrs=Non
def _impl(array, counts, axis, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
-
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout, maybe_counts_layout = ensure_same_backend(
ctx.unwrap(array, allow_record=False, primitive_policy="error"),
@@ -105,6 +107,13 @@ def _impl(array, counts, axis, highlevel, behavior, attrs):
),
)
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+
+ axis = regularize_axis(axis, none_allowed=False)
+
if is_integer_like(maybe_counts_layout):
# Regularize unknown values to unknown lengths
if (
@@ -292,4 +301,16 @@ def apply(layout, depth, backend, **kwargs):
f"at axis={axis}"
)
- return ctx.wrap(out, highlevel=highlevel)
+ wrapped_out = ctx.wrap(
+ out,
+ highlevel=highlevel,
+ )
+
+ # Step 2: propagate named axis from input to output,
+ # use strategy "remove all" (see: awkward._namedaxis)
+ return ak.operations.ak_without_named_axis._impl(
+ wrapped_out,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
diff --git a/src/awkward/operations/ak_unzip.py b/src/awkward/operations/ak_unzip.py
index 8d0bfc229a..8c19380133 100644
--- a/src/awkward/operations/ak_unzip.py
+++ b/src/awkward/operations/ak_unzip.py
@@ -51,6 +51,7 @@ def unzip(array, *, highlevel=True, behavior=None, attrs=None):
def _impl(array, highlevel, behavior, attrs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=True, primitive_policy="error")
+
fields = ak.operations.fields(layout)
def check_for_union(layout, **kwargs):
@@ -70,5 +71,10 @@ def check_for_union(layout, **kwargs):
return (ctx.wrap(layout, highlevel=highlevel, allow_other=True),)
else:
return tuple(
- ctx.wrap(layout[n], highlevel=highlevel, allow_other=True) for n in fields
+ ctx.wrap(
+ layout[n],
+ highlevel=highlevel,
+ allow_other=True,
+ )
+ for n in fields
)
diff --git a/src/awkward/operations/ak_values_astype.py b/src/awkward/operations/ak_values_astype.py
index 714a4320d9..fa25ca5a35 100644
--- a/src/awkward/operations/ak_values_astype.py
+++ b/src/awkward/operations/ak_values_astype.py
@@ -72,6 +72,7 @@ def values_astype(
def _impl(array, to, including_unknown, highlevel, behavior, attrs):
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ctx.unwrap(array, allow_record=False, primitive_policy="error")
+
to_str = ak.types.numpytype.dtype_to_primitive(np.dtype(to))
out = ak._do.numbers_to_type(layout, to_str, including_unknown)
return ctx.wrap(out, highlevel=highlevel)
diff --git a/src/awkward/operations/ak_var.py b/src/awkward/operations/ak_var.py
index d1139d8b4c..759f5edf1c 100644
--- a/src/awkward/operations/ak_var.py
+++ b/src/awkward/operations/ak_var.py
@@ -3,6 +3,7 @@
from __future__ import annotations
import awkward as ak
+from awkward._attrs import attrs_of_obj
from awkward._connect.numpy import UNSUPPORTED
from awkward._dispatch import high_level_function
from awkward._layout import (
@@ -11,6 +12,10 @@
maybe_highlevel_to_lowlevel,
maybe_posaxis,
)
+from awkward._namedaxis import (
+ _get_named_axis,
+ _named_axis_to_positional_axis,
+)
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import regularize_axis
@@ -170,8 +175,6 @@ def nanvar(
def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, attrs):
- axis = regularize_axis(axis)
-
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
x_layout, weight_layout = ensure_same_backend(
ctx.unwrap(x, allow_record=False, primitive_policy="error"),
@@ -187,6 +190,12 @@ def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, a
x = ctx.wrap(x_layout)
weight = ctx.wrap(weight_layout, allow_other=True)
+ # Handle named axis
+ named_axis = _get_named_axis(ctx)
+ # Step 1: Normalize named axis to positional axis
+ axis = _named_axis_to_positional_axis(named_axis, axis)
+ axis = regularize_axis(axis, none_allowed=True)
+
with np.errstate(invalid="ignore", divide="ignore"):
if weight is None:
sumw = ak.operations.ak_count._impl(
@@ -267,8 +276,19 @@ def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, a
posaxis = maybe_posaxis(out.layout, axis, 1)
out = out[(slice(None, None),) * posaxis + (0,)]
- return ctx.wrap(
- maybe_highlevel_to_lowlevel(out), highlevel=highlevel, allow_other=True
+ wrapped = ctx.wrap(
+ maybe_highlevel_to_lowlevel(out),
+ highlevel=highlevel,
+ allow_other=True,
+ )
+
+ # propagate named axis to output
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped,
+ named_axis=_get_named_axis(attrs_of_obj(out)),
+ highlevel=highlevel,
+ behavior=None,
+ attrs=None,
)
diff --git a/src/awkward/operations/ak_where.py b/src/awkward/operations/ak_where.py
index dda7d99f42..07f5f2f7bd 100644
--- a/src/awkward/operations/ak_where.py
+++ b/src/awkward/operations/ak_where.py
@@ -2,9 +2,12 @@
from __future__ import annotations
+from functools import reduce
+
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext, ensure_same_backend
+from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis
from awkward._nplikes.numpy_like import NumpyMetadata
__all__ = ("where",)
@@ -121,8 +124,26 @@ def action(inputs, backend, **kwargs):
else:
return None
+ depth_context, lateral_context = NamedAxesWithDims.prepare_contexts(
+ [x, y, condition]
+ )
out = ak._broadcasting.broadcast_and_apply(
- layouts, action, numpy_to_regular=True, function_name="ak.where"
+ layouts,
+ action,
+ depth_context=depth_context,
+ lateral_context=lateral_context,
+ numpy_to_regular=True,
+ function_name="ak.where",
+ )
+ # Unify named axes propagated through the broadcast
+ out_named_axis = reduce(
+ _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis
+ )
+ wrapped_out = ctx.wrap(out[0], highlevel=highlevel)
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
)
-
- return ctx.wrap(out[0], highlevel=highlevel)
diff --git a/src/awkward/operations/ak_with_field.py b/src/awkward/operations/ak_with_field.py
index 671a061978..3adb5c33a1 100644
--- a/src/awkward/operations/ak_with_field.py
+++ b/src/awkward/operations/ak_with_field.py
@@ -3,10 +3,12 @@
from __future__ import annotations
import copy
+from functools import reduce
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext, ensure_same_backend
+from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._regularize import is_non_string_like_sequence
@@ -76,6 +78,11 @@ def _impl(base, what, where, highlevel, behavior, attrs):
if is_non_string_like_sequence(where):
where = where[0]
+ depth_context, lateral_context = NamedAxesWithDims.prepare_contexts(
+ [base, what],
+ unwrap_kwargs={"none_policy": "promote"},
+ )
+
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
base, what = ensure_same_backend(
ctx.unwrap(base, allow_record=True, primitive_policy="error"),
@@ -156,9 +163,24 @@ def action(inputs, **kwargs):
return None
out = ak._broadcasting.broadcast_and_apply(
- [base, what], action, right_broadcast=False
+ [base, what],
+ action,
+ depth_context=depth_context,
+ lateral_context=lateral_context,
+ right_broadcast=False,
)
assert isinstance(out, tuple) and len(out) == 1
- return ctx.wrap(out[0], highlevel=highlevel)
+ # Unify named axes propagated through the broadcast
+ out_named_axis = reduce(
+ _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis
+ )
+ wrapped_out = ctx.wrap(out[0], highlevel=highlevel)
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
diff --git a/src/awkward/operations/ak_with_named_axis.py b/src/awkward/operations/ak_with_named_axis.py
new file mode 100644
index 0000000000..507acc485c
--- /dev/null
+++ b/src/awkward/operations/ak_with_named_axis.py
@@ -0,0 +1,72 @@
+# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
+
+from __future__ import annotations
+
+from awkward._dispatch import high_level_function
+from awkward._layout import HighLevelContext
+from awkward._namedaxis import (
+ NAMED_AXIS_KEY,
+ AxisMapping,
+ AxisTuple,
+ _prepare_named_axis_for_attrs,
+)
+from awkward._nplikes.numpy_like import NumpyMetadata
+
+__all__ = ("with_named_axis",)
+
+np = NumpyMetadata.instance()
+
+
+@high_level_function()
+def with_named_axis(
+ array,
+ named_axis: AxisTuple | AxisMapping,
+ *,
+ highlevel=True,
+ behavior=None,
+ attrs=None,
+):
+ """
+ Args:
+ array: Array-like data (anything #ak.to_layout recognizes).
+ named_axis: AxisTuple | AxisMapping: Names to give to the array axis; this assigns
+ the `"__named_axis__"` attr. If None, any existing name is unset.
+ highlevel (bool): If True, return an #ak.Array; otherwise, return
+ a low-level #ak.contents.Content subclass.
+ behavior (None or dict): Custom #ak.behavior for the output array, if
+ high-level.
+ attrs (None or dict): Custom attributes for the output array, if
+ high-level.
+
+ Returns an #ak.Array or #ak.Record (or low-level equivalent, if
+ `highlevel=False`) with a new name. This function does not change the
+ array in-place. If the new name is None, then the array is returned as it is.
+ """
+ # Dispatch
+ yield (array,)
+
+ # Implementation
+ return _impl(array, named_axis, highlevel, behavior, attrs)
+
+
+def _impl(array, named_axis, highlevel, behavior, attrs):
+ # Named axis handling
+ if not named_axis: # no-op, e.g. named_axis is None, (), {}
+ return array
+
+ with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
+ layout = ctx.unwrap(array, allow_record=True)
+
+ _named_axis = _prepare_named_axis_for_attrs(
+ named_axis=named_axis,
+ ndim=layout.minmax_depth[1],
+ )
+ # now we're good, set the named axis
+ return ctx.with_attr(
+ key=NAMED_AXIS_KEY,
+ value=_named_axis,
+ ).wrap(
+ layout,
+ highlevel=highlevel,
+ allow_other=True,
+ )
diff --git a/src/awkward/operations/ak_without_named_axis.py b/src/awkward/operations/ak_without_named_axis.py
new file mode 100644
index 0000000000..3697344a4b
--- /dev/null
+++ b/src/awkward/operations/ak_without_named_axis.py
@@ -0,0 +1,54 @@
+# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
+
+from __future__ import annotations
+
+from awkward._dispatch import high_level_function
+from awkward._layout import HighLevelContext
+from awkward._namedaxis import (
+ NAMED_AXIS_KEY,
+)
+from awkward._nplikes.numpy_like import NumpyMetadata
+
+__all__ = ("without_named_axis",)
+
+np = NumpyMetadata.instance()
+
+
+@high_level_function()
+def without_named_axis(
+ array,
+ *,
+ highlevel=True,
+ behavior=None,
+ attrs=None,
+):
+ """
+ Args:
+ array: Array-like data (anything #ak.to_layout recognizes).
+ highlevel (bool): If True, return an #ak.Array; otherwise, return
+ a low-level #ak.contents.Content subclass.
+ behavior (None or dict): Custom #ak.behavior for the output array, if
+ high-level.
+ attrs (None or dict): Custom attributes for the output array, if
+ high-level.
+
+ Returns an #ak.Array or #ak.Record (or low-level equivalent, if
+ `highlevel=False`) without named axes. This function does not change the
+ array in-place.
+ """
+ # Dispatch
+ yield (array,)
+
+ # Implementation
+ return _impl(array, highlevel, behavior, attrs)
+
+
+def _impl(array, highlevel, behavior, attrs):
+ with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
+ layout = ctx.unwrap(array, allow_record=True)
+
+ return ctx.without_attr(key=NAMED_AXIS_KEY).wrap(
+ layout,
+ highlevel=highlevel,
+ allow_other=True,
+ )
diff --git a/src/awkward/operations/ak_zip.py b/src/awkward/operations/ak_zip.py
index bed5c233e5..5ce58f8b1a 100644
--- a/src/awkward/operations/ak_zip.py
+++ b/src/awkward/operations/ak_zip.py
@@ -3,10 +3,12 @@
from __future__ import annotations
from collections.abc import Mapping
+from functools import reduce
import awkward as ak
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext, ensure_same_backend
+from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis
from awkward._nplikes.numpy_like import NumpyMetadata
__all__ = ("zip",)
@@ -174,6 +176,7 @@ def _impl(
):
if depth_limit is not None and depth_limit <= 0:
raise ValueError("depth_limit must be None or at least 1")
+
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
if isinstance(arrays, Mapping):
layouts = ensure_same_backend(
@@ -238,8 +241,15 @@ def action(inputs, depth, backend, **ignore):
else:
return None
+ depth_context, lateral_context = NamedAxesWithDims.prepare_contexts(
+ list(arrays.values()) if isinstance(arrays, Mapping) else list(arrays)
+ )
out = ak._broadcasting.broadcast_and_apply(
- layouts, action, right_broadcast=right_broadcast
+ layouts,
+ action,
+ depth_context=depth_context,
+ lateral_context=lateral_context,
+ right_broadcast=right_broadcast,
)
assert isinstance(out, tuple) and len(out) == 1
out = out[0]
@@ -248,4 +258,15 @@ def action(inputs, depth, backend, **ignore):
out = out[0]
assert isinstance(out, ak.record.Record)
- return ctx.wrap(out, highlevel=highlevel)
+ # Unify named axes propagated through the broadcast
+ out_named_axis = reduce(
+ _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis
+ )
+ wrapped_out = ctx.wrap(out, highlevel=highlevel)
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped_out,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=ctx.behavior,
+ attrs=ctx.attrs,
+ )
diff --git a/src/awkward/operations/str/akstr_join.py b/src/awkward/operations/str/akstr_join.py
index a5ab638ba5..d18dc0174e 100644
--- a/src/awkward/operations/str/akstr_join.py
+++ b/src/awkward/operations/str/akstr_join.py
@@ -2,10 +2,15 @@
from __future__ import annotations
+from functools import reduce
+
import awkward as ak
+from awkward._attrs import attrs_of_obj
from awkward._backends.typetracer import TypeTracerBackend
+from awkward._behavior import behavior_of_obj
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext, ensure_same_backend
+from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis
__all__ = ("join",)
@@ -95,6 +100,7 @@ def apply_unary(layout, **kwargs):
)
out = ak._do.recursively_apply(layout, apply_unary)
+ return ctx.wrap(out, highlevel=highlevel)
else:
def apply_binary(layouts, **kwargs):
@@ -123,8 +129,24 @@ def apply_binary(layouts, **kwargs):
),
)
+ depth_context, lateral_context = NamedAxesWithDims.prepare_contexts(
+ [array, separator]
+ )
(out,) = ak._broadcasting.broadcast_and_apply(
- (layout, maybe_separator_layout), apply_binary
+ (layout, maybe_separator_layout),
+ apply_binary,
+ depth_context=depth_context,
+ lateral_context=lateral_context,
)
- return ctx.wrap(out, highlevel=highlevel)
+ out_named_axis = reduce(
+ _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis
+ )
+ wrapped = ctx.wrap(out, highlevel=highlevel)
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=behavior_of_obj(wrapped),
+ attrs=attrs_of_obj(wrapped),
+ )
diff --git a/src/awkward/operations/str/akstr_join_element_wise.py b/src/awkward/operations/str/akstr_join_element_wise.py
index 98f4e42f91..cd2ed0184f 100644
--- a/src/awkward/operations/str/akstr_join_element_wise.py
+++ b/src/awkward/operations/str/akstr_join_element_wise.py
@@ -2,10 +2,15 @@
from __future__ import annotations
+from functools import reduce
+
import awkward as ak
+from awkward._attrs import attrs_of_obj
from awkward._backends.typetracer import TypeTracerBackend
+from awkward._behavior import behavior_of_obj
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext, ensure_same_backend
+from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis
__all__ = ("join_element_wise",)
@@ -66,6 +71,22 @@ def action(layouts, **kwargs):
):
return (_apply_through_arrow(pc.binary_join_element_wise, *layouts),)
- (out,) = ak._broadcasting.broadcast_and_apply(layouts, action)
-
- return ctx.wrap(out, highlevel=highlevel)
+ depth_context, lateral_context = NamedAxesWithDims.prepare_contexts(arrays)
+ (out,) = ak._broadcasting.broadcast_and_apply(
+ layouts,
+ action,
+ depth_context=depth_context,
+ lateral_context=lateral_context,
+ )
+
+ out_named_axis = reduce(
+ _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis
+ )
+ wrapped = ctx.wrap(out, highlevel=highlevel)
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=behavior_of_obj(wrapped),
+ attrs=attrs_of_obj(wrapped),
+ )
diff --git a/src/awkward/operations/str/akstr_repeat.py b/src/awkward/operations/str/akstr_repeat.py
index de929c57b7..49bec96569 100644
--- a/src/awkward/operations/str/akstr_repeat.py
+++ b/src/awkward/operations/str/akstr_repeat.py
@@ -3,11 +3,15 @@
from __future__ import annotations
import numbers
+from functools import reduce
import awkward as ak
+from awkward._attrs import attrs_of_obj
from awkward._backends.typetracer import TypeTracerBackend
+from awkward._behavior import behavior_of_obj
from awkward._dispatch import high_level_function
from awkward._layout import HighLevelContext, ensure_same_backend
+from awkward._namedaxis import NAMED_AXIS_KEY, NamedAxesWithDims, _unify_named_axis
from awkward._nplikes.numpy_like import NumpyMetadata
__all__ = ("repeat",)
@@ -79,8 +83,26 @@ def action(inputs, **kwargs):
return (_apply_through_arrow(pc.binary_repeat, *inputs),)
+ depth_context, lateral_context = NamedAxesWithDims.prepare_contexts(
+ [array, num_repeats]
+ )
(out,) = ak._broadcasting.broadcast_and_apply(
- (layout, num_repeats_layout), action
+ (layout, num_repeats_layout),
+ action,
+ depth_context=depth_context,
+ lateral_context=lateral_context,
+ )
+
+ out_named_axis = reduce(
+ _unify_named_axis, lateral_context[NAMED_AXIS_KEY].named_axis
+ )
+ wrapped = ctx.wrap(out, highlevel=highlevel)
+ return ak.operations.ak_with_named_axis._impl(
+ wrapped,
+ named_axis=out_named_axis,
+ highlevel=highlevel,
+ behavior=behavior_of_obj(wrapped),
+ attrs=attrs_of_obj(wrapped),
)
else:
@@ -98,4 +120,4 @@ def action(layout, **kwargs):
out = ak._do.recursively_apply(layout, action)
- return ctx.wrap(out, highlevel=highlevel)
+ return ctx.wrap(out, highlevel=highlevel)
diff --git a/tests/test_2596_named_axis.py b/tests/test_2596_named_axis.py
new file mode 100644
index 0000000000..acfa1e9e34
--- /dev/null
+++ b/tests/test_2596_named_axis.py
@@ -0,0 +1,2243 @@
+# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE
+
+from __future__ import annotations
+
+import sys
+
+import numpy as np
+import pytest
+
+import awkward as ak
+from awkward._namedaxis import _get_named_axis
+
+
+def test_constructor():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]], named_axis=("x", "y"))
+ assert _get_named_axis(array)
+ assert array.named_axis == {"x": 0, "y": 1}
+ assert array.positional_axis == (0, 1)
+
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]], named_axis={"x": 0, "y": 1})
+ assert _get_named_axis(array)
+ assert array.named_axis == {"x": 0, "y": 1}
+ assert array.positional_axis == (0, 1)
+
+
+def test_with_named_axis():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+ assert not _get_named_axis(array)
+ assert array.named_axis == {}
+ assert array.positional_axis == (0, 1)
+
+ array = ak.with_named_axis(array, named_axis=("x", "y"))
+ assert _get_named_axis(array)
+ assert array.named_axis == {"x": 0, "y": 1}
+ assert array.positional_axis == (0, 1)
+
+ array = ak.with_named_axis(array, named_axis=("x", None))
+ assert _get_named_axis(array)
+ assert array.named_axis == {"x": 0}
+ assert array.positional_axis == (0, 1)
+
+ array = ak.with_named_axis(array, named_axis=(None, "x"))
+ assert _get_named_axis(array)
+ assert array.named_axis == {"x": 1}
+ assert array.positional_axis == (0, 1)
+
+ array = ak.with_named_axis(array, named_axis={"x": 0, "y": 1})
+ assert _get_named_axis(array)
+ assert array.named_axis == {"x": 0, "y": 1}
+ assert array.positional_axis == (0, 1)
+
+ array = ak.with_named_axis(array, named_axis={"x": 1})
+ assert _get_named_axis(array)
+ assert array.named_axis == {"x": 1}
+ assert array.positional_axis == (0, 1)
+
+ array = ak.with_named_axis(array, named_axis={"y": -1})
+ assert _get_named_axis(array)
+ assert array.named_axis == {"y": -1}
+ assert array.positional_axis == (0, 1)
+
+ # This is possible in a future version of named axis, but currently only strings are supported
+ # from dataclasses import dataclass
+
+ # @dataclass(frozen=True)
+ # class exotic_axis:
+ # attr: str
+
+ # ax1 = exotic_axis(attr="I'm not the type of axis that you're used to")
+ # ax2 = exotic_axis(attr="...me neither!")
+
+ # array = ak.with_named_axis(array, named_axis=(ax1, ax2))
+ # assert array.named_axis == (ax1, ax2)
+ # assert array.positional_axis == (0, 1)
+
+
+def test_named_axis_indexing():
+ array = ak.Array([[[1, 2]], [[3]], [[4]], [[5, 6], [7]]])
+
+ named_array = ak.with_named_axis(array, named_axis=("x", "y", "z"))
+
+ # test indexing
+ assert ak.all(array[...] == named_array[...])
+ assert ak.all(array[()] == named_array[()])
+
+ assert ak.all(array[None, :, :, :] == named_array[None, :, :, :])
+ assert ak.all(array[:, None, :, :] == named_array[:, None, :, :])
+ assert ak.all(array[:, :, None, :] == named_array[:, :, None, :])
+ assert ak.all(array[:, :, :, None] == named_array[:, :, :, None])
+
+ assert ak.all(array[0, :, :] == named_array[{"x": 0}])
+ assert ak.all(array[:, 0, :] == named_array[{"y": 0}])
+ assert ak.all(array[:, :, 0] == named_array[{"z": 0}])
+
+ assert ak.all(array[0, :, :] == named_array[{0: 0}])
+ assert ak.all(array[:, 0, :] == named_array[{1: 0}])
+ assert ak.all(array[:, :, 0] == named_array[{2: 0}])
+
+ assert ak.all(array[0, :, :] == named_array[{-3: 0}])
+ assert ak.all(array[:, 0, :] == named_array[{-2: 0}])
+ assert ak.all(array[:, :, 0] == named_array[{-1: 0}])
+
+ assert ak.all(array[0, 0, :] == named_array[{"x": 0, "y": 0}])
+ assert ak.all(array[0, :, 0] == named_array[{"x": 0, "z": 0}])
+ assert ak.all(array[:, 0, 0] == named_array[{"y": 0, "z": 0}])
+ assert array[0, 0, 0] == named_array[{"x": 0, "y": 0, "z": 0}]
+
+ assert ak.all(array[slice(0, 1), :, :] == named_array[{"x": slice(0, 1)}])
+ assert ak.all(array[:, slice(0, 1), :] == named_array[{"y": slice(0, 1)}])
+ assert ak.all(array[:, :, slice(0, 1)] == named_array[{"z": slice(0, 1)}])
+
+ assert ak.all(array[0, :, slice(0, 1)] == named_array[{"x": 0, "z": slice(0, 1)}])
+ assert ak.all(array[:, 0, slice(0, 1)] == named_array[{"y": 0, "z": slice(0, 1)}])
+ assert ak.all(array[slice(0, 1), 0, :] == named_array[{"x": slice(0, 1), "y": 0}])
+
+ assert ak.all(array[array > 3] == named_array[named_array > 3])
+
+ # test naming propagation
+ assert (
+ named_array[...].named_axis
+ == named_array.named_axis
+ == {"x": 0, "y": 1, "z": 2}
+ )
+ assert (
+ named_array[()].named_axis == named_array.named_axis == {"x": 0, "y": 1, "z": 2}
+ )
+
+ assert named_array[None, :, :, :].named_axis == {"x": 1, "y": 2, "z": 3}
+ assert named_array[:, None, :, :].named_axis == {"x": 0, "y": 2, "z": 3}
+ assert named_array[:, :, None, :].named_axis == {"x": 0, "y": 1, "z": 3}
+ assert named_array[:, :, :, None].named_axis == {"x": 0, "y": 1, "z": 2}
+
+ assert named_array[None, ...].named_axis == {"x": 1, "y": 2, "z": 3}
+ assert named_array[:, None, ...].named_axis == {"x": 0, "y": 2, "z": 3}
+ assert named_array[..., None, :].named_axis == {"x": 0, "y": 1, "z": 3}
+ assert named_array[..., None].named_axis == {"x": 0, "y": 1, "z": 2}
+
+ assert (
+ named_array[0, :, :].named_axis
+ == named_array[{"x": 0}].named_axis
+ == {"y": 0, "z": 1}
+ )
+ assert (
+ named_array[:, 0, :].named_axis
+ == named_array[{"y": 0}].named_axis
+ == {"x": 0, "z": 1}
+ )
+ assert (
+ named_array[:, :, 0].named_axis
+ == named_array[{"z": 0}].named_axis
+ == {"x": 0, "y": 1}
+ )
+
+ assert (
+ named_array[0, ...].named_axis
+ == named_array[{"x": 0}].named_axis
+ == {"y": 0, "z": 1}
+ )
+ assert (
+ named_array[:, 0, :].named_axis
+ == named_array[{"y": 0}].named_axis
+ == {"x": 0, "z": 1}
+ )
+ assert (
+ named_array[..., 0].named_axis
+ == named_array[{"z": 0}].named_axis
+ == {"x": 0, "y": 1}
+ )
+
+ assert named_array[{0: 0}].named_axis == {"y": 0, "z": 1}
+ assert named_array[{1: 0}].named_axis == {"x": 0, "z": 1}
+ assert named_array[{2: 0}].named_axis == {"x": 0, "y": 1}
+
+ assert named_array[{-3: 0}].named_axis == {"y": 0, "z": 1}
+ assert named_array[{-2: 0}].named_axis == {"x": 0, "z": 1}
+ assert named_array[{-1: 0}].named_axis == {"x": 0, "y": 1}
+
+ assert (
+ named_array[0, 0, :].named_axis
+ == named_array[{"x": 0, "y": 0}].named_axis
+ == {"z": 0}
+ )
+ assert (
+ named_array[0, :, 0].named_axis
+ == named_array[{"x": 0, "z": 0}].named_axis
+ == {"y": 0}
+ )
+ assert (
+ named_array[:, 0, 0].named_axis
+ == named_array[{"y": 0, "z": 0}].named_axis
+ == {"x": 0}
+ )
+ assert not _get_named_axis(named_array[0, 0, 0])
+ assert not _get_named_axis(named_array[{"x": 0, "y": 0, "z": 0}])
+
+ assert (
+ named_array[slice(0, 1), :, :].named_axis
+ == named_array[{"x": slice(0, 1)}].named_axis
+ == {"x": 0, "y": 1, "z": 2}
+ )
+ assert (
+ named_array[:, slice(0, 1), :].named_axis
+ == named_array[{"y": slice(0, 1)}].named_axis
+ == {"x": 0, "y": 1, "z": 2}
+ )
+ assert (
+ named_array[:, :, slice(0, 1)].named_axis
+ == named_array[{"z": slice(0, 1)}].named_axis
+ == {"x": 0, "y": 1, "z": 2}
+ )
+
+ assert (
+ named_array[0, :, slice(0, 1)].named_axis
+ == named_array[{"x": 0, "z": slice(0, 1)}].named_axis
+ == {"y": 0, "z": 1}
+ )
+ assert (
+ named_array[:, 0, slice(0, 1)].named_axis
+ == named_array[{"y": 0, "z": slice(0, 1)}].named_axis
+ == {"x": 0, "z": 1}
+ )
+ assert (
+ named_array[slice(0, 1), 0, :].named_axis
+ == named_array[{"x": slice(0, 1), "y": 0}].named_axis
+ == {"x": 0, "z": 1}
+ )
+
+
+def test_negative_named_axis_indexing():
+ array = ak.Array([[[1, 2]], [[3]], [[4]], [[5, 6], [7]]])
+
+ named_array = ak.with_named_axis(array, named_axis={"x": -3, "y": -2, "z": -1})
+
+ # test indexing
+ assert ak.all(array[...] == named_array[...])
+ assert ak.all(array[()] == named_array[()])
+
+ assert ak.all(array[None, :, :, :] == named_array[None, :, :, :])
+ assert ak.all(array[:, None, :, :] == named_array[:, None, :, :])
+ assert ak.all(array[:, :, None, :] == named_array[:, :, None, :])
+ assert ak.all(array[:, :, :, None] == named_array[:, :, :, None])
+
+ assert ak.all(array[0, :, :] == named_array[{"x": 0}])
+ assert ak.all(array[:, 0, :] == named_array[{"y": 0}])
+ assert ak.all(array[:, :, 0] == named_array[{"z": 0}])
+
+ assert ak.all(array[0, :, :] == named_array[{0: 0}])
+ assert ak.all(array[:, 0, :] == named_array[{1: 0}])
+ assert ak.all(array[:, :, 0] == named_array[{2: 0}])
+
+ assert ak.all(array[0, :, :] == named_array[{-3: 0}])
+ assert ak.all(array[:, 0, :] == named_array[{-2: 0}])
+ assert ak.all(array[:, :, 0] == named_array[{-1: 0}])
+
+ assert ak.all(array[0, 0, :] == named_array[{"x": 0, "y": 0}])
+ assert ak.all(array[0, :, 0] == named_array[{"x": 0, "z": 0}])
+ assert ak.all(array[:, 0, 0] == named_array[{"y": 0, "z": 0}])
+ assert array[0, 0, 0] == named_array[{"x": 0, "y": 0, "z": 0}]
+
+ assert ak.all(array[slice(0, 1), :, :] == named_array[{"x": slice(0, 1)}])
+ assert ak.all(array[:, slice(0, 1), :] == named_array[{"y": slice(0, 1)}])
+ assert ak.all(array[:, :, slice(0, 1)] == named_array[{"z": slice(0, 1)}])
+
+ assert ak.all(array[0, :, slice(0, 1)] == named_array[{"x": 0, "z": slice(0, 1)}])
+ assert ak.all(array[:, 0, slice(0, 1)] == named_array[{"y": 0, "z": slice(0, 1)}])
+ assert ak.all(array[slice(0, 1), 0, :] == named_array[{"x": slice(0, 1), "y": 0}])
+
+ assert ak.all(array[array > 3] == named_array[named_array > 3])
+
+ # test naming propagation
+ assert (
+ named_array[...].named_axis
+ == named_array.named_axis
+ == {"x": -3, "y": -2, "z": -1}
+ )
+ assert (
+ named_array[()].named_axis
+ == named_array.named_axis
+ == {"x": -3, "y": -2, "z": -1}
+ )
+
+ assert named_array[None, :, :, :].named_axis == {"x": -3, "y": -2, "z": -1}
+ assert named_array[:, None, :, :].named_axis == {"x": -4, "y": -2, "z": -1}
+ assert named_array[:, :, None, :].named_axis == {"x": -4, "y": -3, "z": -1}
+ assert named_array[:, :, :, None].named_axis == {"x": -4, "y": -3, "z": -2}
+
+ assert named_array[None, ...].named_axis == {"x": -3, "y": -2, "z": -1}
+ assert named_array[:, None, ...].named_axis == {"x": -4, "y": -2, "z": -1}
+ assert named_array[..., None, :].named_axis == {"x": -4, "y": -3, "z": -1}
+ assert named_array[..., None].named_axis == {"x": -4, "y": -3, "z": -2}
+
+ assert (
+ named_array[0, :, :].named_axis
+ == named_array[{"x": 0}].named_axis
+ == {"y": -2, "z": -1}
+ )
+ assert (
+ named_array[:, 0, :].named_axis
+ == named_array[{"y": 0}].named_axis
+ == {"x": -2, "z": -1}
+ )
+ assert (
+ named_array[:, :, 0].named_axis
+ == named_array[{"z": 0}].named_axis
+ == {"x": -2, "y": -1}
+ )
+
+ assert (
+ named_array[0, ...].named_axis
+ == named_array[{"x": 0}].named_axis
+ == {"y": -2, "z": -1}
+ )
+ assert (
+ named_array[..., 0].named_axis
+ == named_array[{"z": 0}].named_axis
+ == {"x": -2, "y": -1}
+ )
+
+ assert named_array[{0: 0}].named_axis == {"y": -2, "z": -1}
+ assert named_array[{1: 0}].named_axis == {"x": -2, "z": -1}
+ assert named_array[{2: 0}].named_axis == {"x": -2, "y": -1}
+
+ assert named_array[{-3: 0}].named_axis == {"y": -2, "z": -1}
+ assert named_array[{-2: 0}].named_axis == {"x": -2, "z": -1}
+ assert named_array[{-1: 0}].named_axis == {"x": -2, "y": -1}
+
+ assert (
+ named_array[0, 0, :].named_axis
+ == named_array[{"x": 0, "y": 0}].named_axis
+ == {"z": -1}
+ )
+ assert (
+ named_array[0, :, 0].named_axis
+ == named_array[{"x": 0, "z": 0}].named_axis
+ == {"y": -1}
+ )
+ assert (
+ named_array[:, 0, 0].named_axis
+ == named_array[{"y": 0, "z": 0}].named_axis
+ == {"x": -1}
+ )
+ assert not _get_named_axis(named_array[0, 0, 0])
+ assert not _get_named_axis(named_array[{"x": 0, "y": 0, "z": 0}])
+
+ assert (
+ named_array[slice(0, 1), :, :].named_axis
+ == named_array[{"x": slice(0, 1)}].named_axis
+ == {"x": -3, "y": -2, "z": -1}
+ )
+ assert (
+ named_array[:, slice(0, 1), :].named_axis
+ == named_array[{"y": slice(0, 1)}].named_axis
+ == {"x": -3, "y": -2, "z": -1}
+ )
+ assert (
+ named_array[:, :, slice(0, 1)].named_axis
+ == named_array[{"z": slice(0, 1)}].named_axis
+ == {"x": -3, "y": -2, "z": -1}
+ )
+
+ assert (
+ named_array[0, :, slice(0, 1)].named_axis
+ == named_array[{"x": 0, "z": slice(0, 1)}].named_axis
+ == {"y": -2, "z": -1}
+ )
+ assert (
+ named_array[:, 0, slice(0, 1)].named_axis
+ == named_array[{"y": 0, "z": slice(0, 1)}].named_axis
+ == {"x": -2, "z": -1}
+ )
+ assert (
+ named_array[slice(0, 1), 0, :].named_axis
+ == named_array[{"x": slice(0, 1), "y": 0}].named_axis
+ == {"x": -2, "z": -1}
+ )
+
+
+@pytest.mark.xfail(
+ sys.platform == "win32",
+ reason="right-broadcasting (NumPy-style) behaves differently for 32-bit windows",
+ strict=False,
+)
+def test_named_axis_right_broadcasting():
+ # [NumPy-style] rightbroadcasting: (n, m) -> (1, n, m)
+ a = ak.Array([1]) # (1,)
+ b = ak.Array([[10, 20], [30, 40], [50, 60]]) # (3, 2)
+
+ na = ak.with_named_axis(a, named_axis={"y": 0})
+ nb = ak.with_named_axis(b, named_axis={"x": 0, "y": 1})
+
+ naa, nbb = ak.broadcast_arrays(na, nb)
+
+ assert naa.named_axis == nbb.named_axis == {"x": 0, "y": 1}
+
+ na = ak.with_named_axis(a, named_axis={"y": -1})
+ nb = ak.with_named_axis(b, named_axis={"y": -2, "x": -1})
+
+ naa, nbb = ak.broadcast_arrays(na, nb)
+
+ assert naa.named_axis == nbb.named_axis == {"y": -2, "x": -1}
+
+
+def test_named_axis_left_broadcasting():
+ # [Awkward-style] leftbroadcasting: (n, m) -> (n, m, 1)
+ a = ak.Array([[[0, 1, 2], [], [3, 4]], [], [[5], [6, 7, 8, 9]]]) # (3, var, var)
+ b = ak.Array([[10, 20, 30], [], [40, 50]]) # (3, var)
+
+ na = ak.with_named_axis(a, named_axis=("x", "y", "z"))
+ nb = ak.with_named_axis(b, named_axis=("x", "y"))
+
+ naa, nbb = ak.broadcast_arrays(na, nb)
+
+ assert naa.named_axis == nbb.named_axis == {"x": 0, "y": 1, "z": 2}
+
+ na = ak.with_named_axis(a, named_axis={"x": -3, "y": -2, "z": -1})
+ nb = ak.with_named_axis(b, named_axis={"x": -2, "y": -1})
+
+ naa, nbb = ak.broadcast_arrays(na, nb)
+
+ assert naa.named_axis == nbb.named_axis == {"x": -3, "y": -2, "z": -1}
+
+ # this is not allowed!
+ a = ak.with_named_axis(ak.Array([[1, 2], [3, 4]]), ("x", "y")) # {"x": 0, "y": 1}
+ asum = ak.sum(a, axis="x") # {"y": 0}
+
+ with pytest.raises(ValueError):
+ _ = a + asum
+
+ # this is allowed!
+ a = ak.with_named_axis(ak.Array([[1, 2], [3, 4]]), ("x", "y")) # {"x": 0, "y": 1}
+ asum = ak.sum(a, axis="y") # {"x": 0}
+
+ assert (a + asum).named_axis == {"x": 0, "y": 1}
+
+
+def test_named_axis_unary_ufuncs():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis=("x", "y"))
+
+ assert (-named_array).named_axis == named_array.named_axis
+ assert (+named_array).named_axis == named_array.named_axis
+ assert (~named_array).named_axis == named_array.named_axis
+ assert abs(named_array).named_axis == named_array.named_axis
+
+
+def test_named_axis_binary_ufuncs():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array1 = ak.with_named_axis(array, named_axis=(None, "y"))
+ named_array2 = ak.with_named_axis(array, named_axis=("x", None))
+ named_array3 = ak.with_named_axis(array, named_axis=("x", "y"))
+
+ # just for addition, the rest is the same
+ # __add__
+ assert (array + array).named_axis == {}
+ assert (named_array1 + array).named_axis == {"y": 1}
+ assert (named_array2 + array).named_axis == {"x": 0}
+ assert (named_array3 + array).named_axis == {"x": 0, "y": 1}
+
+ assert (named_array1 + named_array2).named_axis == {"x": 0, "y": 1}
+ assert (named_array3 + named_array3).named_axis == {"x": 0, "y": 1}
+
+ # __radd__
+ assert (array + named_array1).named_axis == {"y": 1}
+ assert (array + named_array2).named_axis == {"x": 0}
+ assert (array + named_array3).named_axis == {"x": 0, "y": 1}
+
+ a = ak.with_named_axis(array, named_axis=("x", None))
+ b = ak.with_named_axis(array, named_axis=("y", None))
+ with pytest.raises(
+ ValueError,
+ match="The named axes are incompatible. Got: x and y for positional axis 0",
+ ):
+ _ = a + b
+
+ a = ak.with_named_axis(array, named_axis=(None, "x"))
+ b = ak.with_named_axis(array, named_axis=(None, "y"))
+ with pytest.raises(
+ ValueError,
+ match="The named axes are incompatible. Got: x and y for positional axis 1",
+ ):
+ _ = a + b
+
+
+def test_named_axis_ak_all():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis=("x", "y"))
+
+ # first check that they work the same
+ assert ak.all(ak.all(array < 4, axis=0) == ak.all(named_array < 4, axis="x"))
+ assert ak.all(ak.all(array < 4, axis=1) == ak.all(named_array < 4, axis="y"))
+
+ # check that result axis names are correctly propagated
+ assert (
+ ak.all(named_array < 4, axis=0).named_axis
+ == ak.all(named_array < 4, axis="x").named_axis
+ == {"y": 0}
+ )
+ assert (
+ ak.all(named_array < 4, axis=1).named_axis
+ == ak.all(named_array < 4, axis="y").named_axis
+ == {"x": 0}
+ )
+ assert (
+ ak.all(named_array < 4, axis=0, keepdims=True).named_axis
+ == ak.all(named_array < 4, axis="x", keepdims=True).named_axis
+ == {"x": 0, "y": 1}
+ )
+ assert (
+ ak.all(named_array < 4, axis=1, keepdims=True).named_axis
+ == ak.all(named_array < 4, axis="y", keepdims=True).named_axis
+ == {"x": 0, "y": 1}
+ )
+ assert not _get_named_axis(ak.all(named_array < 4, axis=None))
+
+
+def test_negative_named_axis_ak_all():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1})
+
+ # first check that they work the same
+ assert ak.all(ak.all(array < 4, axis=-2) == ak.all(named_array < 4, axis="x"))
+ assert ak.all(ak.all(array < 4, axis=-1) == ak.all(named_array < 4, axis="y"))
+
+ # check that result axis names are correctly propagated
+ assert (
+ ak.all(named_array < 4, axis=-2).named_axis
+ == ak.all(named_array < 4, axis="x").named_axis
+ == {"y": -1}
+ )
+ assert (
+ ak.all(named_array < 4, axis=-1).named_axis
+ == ak.all(named_array < 4, axis="y").named_axis
+ == {"x": -1}
+ )
+ assert (
+ ak.all(named_array < 4, axis=-2, keepdims=True).named_axis
+ == ak.all(named_array < 4, axis="x", keepdims=True).named_axis
+ == {"x": -2, "y": -1}
+ )
+ assert (
+ ak.all(named_array < 4, axis=-1, keepdims=True).named_axis
+ == ak.all(named_array < 4, axis="y", keepdims=True).named_axis
+ == {"x": -2, "y": -1}
+ )
+ assert not _get_named_axis(ak.all(named_array < 4, axis=None))
+
+
+def test_named_axis_ak_almost_equal():
+ array1 = array2 = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array1 = named_array2 = ak.with_named_axis(array1, named_axis=("x", "y"))
+
+ assert ak.almost_equal(array1, array2, check_named_axis=False) == ak.almost_equal(
+ named_array1, named_array2, check_named_axis=False
+ )
+ assert ak.almost_equal(array1, array2, check_named_axis=True) == ak.almost_equal(
+ named_array1, named_array2, check_named_axis=True
+ )
+
+ assert ak.almost_equal(named_array1, array1, check_named_axis=False)
+ assert ak.almost_equal(named_array1, array1, check_named_axis=True)
+
+ named_array3 = ak.with_named_axis(array1, named_axis=("x", "muons"))
+ assert ak.almost_equal(named_array1, named_array3, check_named_axis=False)
+ assert not ak.almost_equal(named_array1, named_array3, check_named_axis=True)
+
+
+def test_negative_named_axis_ak_almost_equal():
+ array1 = array2 = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array1 = named_array2 = ak.with_named_axis(
+ array1, named_axis={"x": -2, "y": -1}
+ )
+
+ assert ak.almost_equal(array1, array2, check_named_axis=False) == ak.almost_equal(
+ named_array1, named_array2, check_named_axis=False
+ )
+ assert ak.almost_equal(array1, array2, check_named_axis=True) == ak.almost_equal(
+ named_array1, named_array2, check_named_axis=True
+ )
+
+ assert ak.almost_equal(named_array1, array1, check_named_axis=False)
+ assert ak.almost_equal(named_array1, array1, check_named_axis=True)
+
+ named_array3 = ak.with_named_axis(array1, named_axis={"x": -2, "z": -1})
+ assert ak.almost_equal(named_array1, named_array3, check_named_axis=False)
+ assert not ak.almost_equal(named_array1, named_array3, check_named_axis=True)
+
+
+def test_named_axis_ak_angle():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis=("x", "y"))
+
+ # first check that they work the same
+ assert ak.all(ak.angle(array) == ak.angle(named_array))
+
+ # check that result axis names are correctly propagated
+ assert ak.angle(named_array).named_axis == {"x": 0, "y": 1}
+
+
+def test_negative_named_axis_ak_angle():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1})
+
+ # first check that they work the same
+ assert ak.all(ak.angle(array) == ak.angle(named_array))
+
+ # check that result axis names are correctly propagated
+ assert ak.angle(named_array).named_axis == {"x": -2, "y": -1}
+
+
+def test_named_axis_ak_any():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis=("x", "y"))
+
+ # first check that they work the same
+ assert ak.all(ak.any(array < 4, axis=0) == ak.any(named_array < 4, axis="x"))
+ assert ak.all(ak.any(array < 4, axis=1) == ak.any(named_array < 4, axis="y"))
+
+ # check that result axis names are correctly propagated
+ assert (
+ ak.any(named_array < 4, axis=0).named_axis
+ == ak.any(named_array < 4, axis="x").named_axis
+ == {"y": 0}
+ )
+ assert (
+ ak.any(named_array < 4, axis=1).named_axis
+ == ak.any(named_array < 4, axis="y").named_axis
+ == {"x": 0}
+ )
+ assert (
+ ak.any(named_array < 4, axis=0, keepdims=True).named_axis
+ == ak.any(named_array < 4, axis="x", keepdims=True).named_axis
+ == {"x": 0, "y": 1}
+ )
+ assert (
+ ak.any(named_array < 4, axis=1, keepdims=True).named_axis
+ == ak.any(named_array < 4, axis="y", keepdims=True).named_axis
+ == {"x": 0, "y": 1}
+ )
+ assert not _get_named_axis(ak.all(named_array < 4, axis=None))
+
+
+def test_negative_named_axis_ak_any():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1})
+
+ # first check that they work the same
+ assert ak.all(ak.any(array < 4, axis=-2) == ak.any(named_array < 4, axis="x"))
+ assert ak.all(ak.any(array < 4, axis=-1) == ak.any(named_array < 4, axis="y"))
+
+ # check that result axis names are correctly propagated
+ assert (
+ ak.any(named_array < 4, axis=-2).named_axis
+ == ak.any(named_array < 4, axis="x").named_axis
+ == {"y": -1}
+ )
+ assert (
+ ak.any(named_array < 4, axis=-1).named_axis
+ == ak.any(named_array < 4, axis="y").named_axis
+ == {"x": -1}
+ )
+ assert (
+ ak.any(named_array < 4, axis=-2, keepdims=True).named_axis
+ == ak.any(named_array < 4, axis="x", keepdims=True).named_axis
+ == {"x": -2, "y": -1}
+ )
+ assert (
+ ak.any(named_array < 4, axis=-1, keepdims=True).named_axis
+ == ak.any(named_array < 4, axis="y", keepdims=True).named_axis
+ == {"x": -2, "y": -1}
+ )
+ assert not _get_named_axis(ak.all(named_array < 4, axis=None))
+
+
+def test_named_axis_ak_argcartesian():
+ one = ak.Array([[1], [2], [3]])
+ two = ak.Array([[4, 5]])
+ three = ak.Array([[6, 7]])
+
+ named_one = ak.with_named_axis(one, named_axis=("x", "y"))
+ named_two = ak.with_named_axis(two, named_axis=("x", "y"))
+ named_three = ak.with_named_axis(three, named_axis=("x", "y"))
+
+ assert ak.argcartesian(
+ [named_one, named_two, named_three], axis="x", nested=False
+ ).named_axis == {"x": 0, "y": 1}
+ assert ak.argcartesian(
+ [named_one, named_two, named_three], axis="x", nested=True
+ ).named_axis == {"x": 1, "y": 2}
+ assert ak.argcartesian(
+ [named_one, named_two, named_three], axis="x", nested=[0]
+ ).named_axis == {"x": 1, "y": 2}
+ assert ak.argcartesian(
+ [named_one, named_two, named_three], axis="x", nested=[1]
+ ).named_axis == {"x": 0, "y": 2}
+ assert ak.argcartesian(
+ [named_one, named_two, named_three], axis="x", nested=[0, 1]
+ ).named_axis == {"x": 2, "y": 3}
+
+
+def test_negative_named_axis_ak_argcartesian():
+ one = ak.Array([[1], [2], [3]])
+ two = ak.Array([[4, 5]])
+ three = ak.Array([[6, 7]])
+
+ named_one = ak.with_named_axis(one, named_axis={"x": -2, "y": -1})
+ named_two = ak.with_named_axis(two, named_axis={"x": -2, "y": -1})
+ named_three = ak.with_named_axis(three, named_axis={"x": -2, "y": -1})
+
+ assert ak.argcartesian(
+ [named_one, named_two, named_three], axis="y", nested=False
+ ).named_axis == {"x": -1}
+ assert ak.argcartesian(
+ [named_one, named_two, named_three], axis="y", nested=True
+ ).named_axis == {"x": -2, "y": -1}
+ assert ak.argcartesian(
+ [named_one, named_two, named_three], axis="y", nested=[0]
+ ).named_axis == {"x": -1}
+ assert ak.argcartesian(
+ [named_one, named_two, named_three], axis="y", nested=[1]
+ ).named_axis == {"y": -1}
+ assert ak.argcartesian(
+ [named_one, named_two, named_three], axis="y", nested=[0, 1]
+ ).named_axis == {"x": -2, "y": -1}
+
+
+def test_named_axis_ak_argcombinations():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis=("x", "y"))
+
+ assert (
+ ak.argcombinations(named_array, 2, axis=0).named_axis == named_array.named_axis
+ )
+ assert (
+ ak.argcombinations(named_array, 2, axis=1).named_axis == named_array.named_axis
+ )
+
+
+def test_negative_named_axis_ak_argcombinations():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1})
+
+ assert (
+ ak.argcombinations(named_array, 2, axis=0).named_axis == named_array.named_axis
+ )
+ assert (
+ ak.argcombinations(named_array, 2, axis=1).named_axis == named_array.named_axis
+ )
+
+
+def test_named_axis_ak_argmax():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis=("x", "y"))
+
+ # first check that they work the same
+ assert ak.all(ak.argmax(array, axis=0) == ak.argmax(named_array, axis="x"))
+ assert ak.all(ak.argmax(array, axis=1) == ak.argmax(named_array, axis="y"))
+ assert ak.all(
+ ak.argmax(array, axis=0, keepdims=True)
+ == ak.argmax(named_array, axis="x", keepdims=True)
+ )
+ assert ak.all(
+ ak.argmax(array, axis=1, keepdims=True)
+ == ak.argmax(named_array, axis="y", keepdims=True)
+ )
+ assert ak.argmax(array, axis=None) == ak.argmax(named_array, axis=None)
+
+ # check that result axis names are correctly propagated
+ assert (
+ ak.argmax(named_array, axis=0).named_axis
+ == ak.argmax(named_array, axis="x").named_axis
+ == {"y": 0}
+ )
+ assert (
+ ak.argmax(named_array, axis=1).named_axis
+ == ak.argmax(named_array, axis="y").named_axis
+ == {"x": 0}
+ )
+ assert (
+ ak.argmax(named_array, axis=0, keepdims=True).named_axis
+ == ak.argmax(named_array, axis="x", keepdims=True).named_axis
+ == {"x": 0, "y": 1}
+ )
+ assert (
+ ak.argmax(named_array, axis=1, keepdims=True).named_axis
+ == ak.argmax(named_array, axis="y", keepdims=True).named_axis
+ == {"x": 0, "y": 1}
+ )
+ assert not _get_named_axis(ak.argmax(named_array, axis=None))
+
+
+def test_negative_named_axis_ak_argmax():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1})
+
+ # first check that they work the same
+ assert ak.all(ak.argmax(array, axis=-2) == ak.argmax(named_array, axis="x"))
+ assert ak.all(ak.argmax(array, axis=-1) == ak.argmax(named_array, axis="y"))
+ assert ak.all(
+ ak.argmax(array, axis=-2, keepdims=True)
+ == ak.argmax(named_array, axis="x", keepdims=True)
+ )
+ assert ak.all(
+ ak.argmax(array, axis=-1, keepdims=True)
+ == ak.argmax(named_array, axis="y", keepdims=True)
+ )
+ assert ak.argmax(array, axis=None) == ak.argmax(named_array, axis=None)
+
+ # check that result axis names are correctly propagated
+ assert (
+ ak.argmax(named_array, axis=-2).named_axis
+ == ak.argmax(named_array, axis="x").named_axis
+ == {"y": -1}
+ )
+ assert (
+ ak.argmax(named_array, axis=-1).named_axis
+ == ak.argmax(named_array, axis="y").named_axis
+ == {"x": -1}
+ )
+ assert (
+ ak.argmax(named_array, axis=-2, keepdims=True).named_axis
+ == ak.argmax(named_array, axis="x", keepdims=True).named_axis
+ == {"x": -2, "y": -1}
+ )
+ assert (
+ ak.argmax(named_array, axis=-1, keepdims=True).named_axis
+ == ak.argmax(named_array, axis="y", keepdims=True).named_axis
+ == {"x": -2, "y": -1}
+ )
+ assert not _get_named_axis(ak.argmax(named_array, axis=None))
+
+
+def test_named_axis_ak_argmin():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis=("x", "y"))
+
+ # first check that they work the same
+ assert ak.all(ak.argmin(array, axis=0) == ak.argmin(named_array, axis="x"))
+ assert ak.all(ak.argmin(array, axis=1) == ak.argmin(named_array, axis="y"))
+ assert ak.all(
+ ak.argmin(array, axis=0, keepdims=True)
+ == ak.argmin(named_array, axis="x", keepdims=True)
+ )
+ assert ak.all(
+ ak.argmin(array, axis=1, keepdims=True)
+ == ak.argmin(named_array, axis="y", keepdims=True)
+ )
+ assert ak.argmin(array, axis=None) == ak.argmin(named_array, axis=None)
+
+ # check that result axis names are correctly propagated
+ assert (
+ ak.argmin(named_array, axis=0).named_axis
+ == ak.argmin(named_array, axis="x").named_axis
+ == {"y": 0}
+ )
+ assert (
+ ak.argmin(named_array, axis=1).named_axis
+ == ak.argmin(named_array, axis="y").named_axis
+ == {"x": 0}
+ )
+ assert (
+ ak.argmin(named_array, axis=0, keepdims=True).named_axis
+ == ak.argmin(named_array, axis="x", keepdims=True).named_axis
+ == {"x": 0, "y": 1}
+ )
+ assert (
+ ak.argmin(named_array, axis=1, keepdims=True).named_axis
+ == ak.argmin(named_array, axis="y", keepdims=True).named_axis
+ == {"x": 0, "y": 1}
+ )
+ assert not _get_named_axis(ak.argmin(named_array, axis=None))
+
+
+def test_negative_named_axis_ak_argmin():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1})
+
+ # first check that they work the same
+ assert ak.all(ak.argmin(array, axis=-2) == ak.argmin(named_array, axis="x"))
+ assert ak.all(ak.argmin(array, axis=-1) == ak.argmin(named_array, axis="y"))
+ assert ak.all(
+ ak.argmin(array, axis=-2, keepdims=True)
+ == ak.argmin(named_array, axis="x", keepdims=True)
+ )
+ assert ak.all(
+ ak.argmin(array, axis=-1, keepdims=True)
+ == ak.argmin(named_array, axis="y", keepdims=True)
+ )
+ assert ak.argmin(array, axis=None) == ak.argmin(named_array, axis=None)
+
+ # check that result axis names are correctly propagated
+ assert (
+ ak.argmin(named_array, axis=-2).named_axis
+ == ak.argmin(named_array, axis="x").named_axis
+ == {"y": -1}
+ )
+ assert (
+ ak.argmin(named_array, axis=-1).named_axis
+ == ak.argmin(named_array, axis="y").named_axis
+ == {"x": -1}
+ )
+ assert (
+ ak.argmin(named_array, axis=-2, keepdims=True).named_axis
+ == ak.argmin(named_array, axis="x", keepdims=True).named_axis
+ == {"x": -2, "y": -1}
+ )
+ assert (
+ ak.argmin(named_array, axis=-1, keepdims=True).named_axis
+ == ak.argmin(named_array, axis="y", keepdims=True).named_axis
+ == {"x": -2, "y": -1}
+ )
+ assert not _get_named_axis(ak.argmin(named_array, axis=None))
+
+
+def test_named_axis_ak_argsort():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis=("x", "y"))
+
+ # first check that they work the same
+ assert ak.all(ak.argsort(array, axis=0) == ak.argsort(named_array, axis="x"))
+ assert ak.all(ak.argsort(array, axis=1) == ak.argsort(named_array, axis="y"))
+
+ # check that result axis names are correctly propagated
+ assert (
+ ak.argsort(named_array, axis=0).named_axis
+ == ak.argsort(named_array, axis="x").named_axis
+ == {"x": 0, "y": 1}
+ )
+ assert (
+ ak.argsort(named_array, axis=1).named_axis
+ == ak.argsort(named_array, axis="y").named_axis
+ == {"x": 0, "y": 1}
+ )
+
+
+def test_negative_named_axis_ak_argsort():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1})
+
+ # first check that they work the same
+ assert ak.all(ak.argsort(array, axis=-2) == ak.argsort(named_array, axis="x"))
+ assert ak.all(ak.argsort(array, axis=-1) == ak.argsort(named_array, axis="y"))
+
+ # check that result axis names are correctly propagated
+ assert (
+ ak.argsort(named_array, axis=-2).named_axis
+ == ak.argsort(named_array, axis="x").named_axis
+ == {"x": -2, "y": -1}
+ )
+ assert (
+ ak.argsort(named_array, axis=-1).named_axis
+ == ak.argsort(named_array, axis="y").named_axis
+ == {"x": -2, "y": -1}
+ )
+
+
+def test_named_axis_ak_array_equal():
+ array1 = array2 = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array1 = named_array2 = ak.with_named_axis(array1, named_axis=("x", "y"))
+
+ assert ak.array_equal(array1, array2, check_named_axis=False) == ak.array_equal(
+ named_array1, named_array2, check_named_axis=False
+ )
+ assert ak.array_equal(array1, array2, check_named_axis=True) == ak.array_equal(
+ named_array1, named_array2, check_named_axis=True
+ )
+
+ assert ak.array_equal(named_array1, array1, check_named_axis=False)
+ assert ak.array_equal(named_array1, array1, check_named_axis=True)
+
+ named_array3 = ak.with_named_axis(array1, named_axis=("x", "z"))
+ assert ak.array_equal(named_array1, named_array3, check_named_axis=False)
+ assert not ak.array_equal(named_array1, named_array3, check_named_axis=True)
+
+
+def test_negative_named_axis_ak_array_equal():
+ array1 = array2 = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array1 = named_array2 = ak.with_named_axis(
+ array1, named_axis={"x": -2, "y": -1}
+ )
+
+ assert ak.array_equal(array1, array2, check_named_axis=False) == ak.array_equal(
+ named_array1, named_array2, check_named_axis=False
+ )
+ assert ak.array_equal(array1, array2, check_named_axis=True) == ak.array_equal(
+ named_array1, named_array2, check_named_axis=True
+ )
+
+ assert ak.array_equal(named_array1, array1, check_named_axis=False)
+ assert ak.array_equal(named_array1, array1, check_named_axis=True)
+
+ named_array3 = ak.with_named_axis(array1, named_axis={"x": -2, "z": -1})
+ assert ak.array_equal(named_array1, named_array3, check_named_axis=False)
+ assert not ak.array_equal(named_array1, named_array3, check_named_axis=True)
+
+
+def test_named_axis_ak_backend():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis=("x", "y"))
+
+ assert ak.backend(array) == ak.backend(named_array)
+
+
+def test_named_axis_ak_broadcast_fields():
+ x = ak.Array([{"x": {"y": 1, "z": 2, "w": [1]}}])
+ y = ak.Array([{"x": [{"y": 1}]}])
+
+ nx = ak.with_named_axis(x, named_axis=("x", "y"))
+ ny = ak.with_named_axis(y, named_axis=("a", "b"))
+
+ na, nb = ak.broadcast_fields(nx, ny)
+ assert na.named_axis == {"x": 0, "y": 1}
+ assert nb.named_axis == {"a": 0, "b": 1}
+
+
+def test_named_axis_ak_cartesian():
+ one = ak.Array([[1], [2], [3]])
+ two = ak.Array([[4, 5]])
+ three = ak.Array([[6, 7]])
+
+ named_one = ak.with_named_axis(one, named_axis=("x", "y"))
+ named_two = ak.with_named_axis(two, named_axis=("x", "y"))
+ named_three = ak.with_named_axis(three, named_axis=("x", "y"))
+
+ assert ak.cartesian(
+ [named_one, named_two, named_three], axis="x", nested=False
+ ).named_axis == {"x": 0, "y": 1}
+ assert ak.cartesian(
+ [named_one, named_two, named_three], axis="x", nested=True
+ ).named_axis == {"x": 1, "y": 2}
+ assert ak.cartesian(
+ [named_one, named_two, named_three], axis="x", nested=[0]
+ ).named_axis == {"x": 1, "y": 2}
+ assert ak.cartesian(
+ [named_one, named_two, named_three], axis="x", nested=[1]
+ ).named_axis == {"x": 0, "y": 2}
+ assert ak.cartesian(
+ [named_one, named_two, named_three], axis="x", nested=[0, 1]
+ ).named_axis == {"x": 2, "y": 3}
+
+
+def test_negative_named_axis_ak_cartesian():
+ one = ak.Array([[1], [2], [3]])
+ two = ak.Array([[4, 5]])
+ three = ak.Array([[6, 7]])
+
+ named_one = ak.with_named_axis(one, named_axis={"x": -2, "y": -1})
+ named_two = ak.with_named_axis(two, named_axis={"x": -2, "y": -1})
+ named_three = ak.with_named_axis(three, named_axis={"x": -2, "y": -1})
+
+ assert ak.cartesian(
+ [named_one, named_two, named_three], axis="y", nested=False
+ ).named_axis == {"x": -1}
+ assert ak.cartesian(
+ [named_one, named_two, named_three], axis="y", nested=True
+ ).named_axis == {"x": -2, "y": -1}
+ assert ak.cartesian(
+ [named_one, named_two, named_three], axis="y", nested=[0]
+ ).named_axis == {"x": -1}
+ assert ak.cartesian(
+ [named_one, named_two, named_three], axis="y", nested=[1]
+ ).named_axis == {"y": -1}
+ assert ak.cartesian(
+ [named_one, named_two, named_three], axis="y", nested=[0, 1]
+ ).named_axis == {"x": -2, "y": -1}
+
+
+def test_named_axis_ak_categories():
+ pyarrow = pytest.importorskip("pyarrow") # noqa: F841
+
+ array = ak.str.to_categorical([["one", "two"], ["one", "three"], ["one", "four"]])
+
+ named_array = ak.with_named_axis(array, named_axis=("a", "b"))
+
+ assert ak.all(ak.categories(array) == ak.categories(named_array)) # FIX: ufuncs
+ assert (
+ ak.categories(array).named_axis == ak.categories(named_array).named_axis == {}
+ )
+
+
+def test_named_axis_ak_combinations():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis=("x", "y"))
+
+ assert ak.combinations(named_array, 2, axis=0).named_axis == named_array.named_axis
+ assert ak.combinations(named_array, 2, axis=1).named_axis == named_array.named_axis
+
+
+def test_negative_named_axis_ak_combinations():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1})
+
+ assert ak.combinations(named_array, 2, axis=-2).named_axis == named_array.named_axis
+ assert ak.combinations(named_array, 2, axis=-1).named_axis == named_array.named_axis
+
+
+def test_named_axis_ak_concatenate():
+ array1 = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+ array2 = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+ array3 = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+ array4 = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ all_arrays = [array1, array2, array3, array4]
+
+ named_array1 = ak.with_named_axis(array1, named_axis=(None, None))
+ named_array2 = ak.with_named_axis(array1, named_axis=(None, "y"))
+ named_array3 = ak.with_named_axis(array1, named_axis=("x", None))
+ named_array4 = ak.with_named_axis(array1, named_axis=("x", "y"))
+
+ all_named_arrays = [named_array1, named_array2, named_array3, named_array4]
+
+ assert ak.all(
+ ak.concatenate(all_arrays, axis=0) == ak.concatenate(all_named_arrays, axis="x")
+ )
+ assert ak.all(
+ ak.concatenate(all_arrays, axis=1) == ak.concatenate(all_named_arrays, axis="y")
+ )
+
+ assert ak.concatenate(all_named_arrays, axis="x").named_axis == {"x": 0, "y": 1}
+ assert ak.concatenate(all_named_arrays, axis="y").named_axis == {"x": 0, "y": 1}
+
+ with pytest.raises(
+ ValueError,
+ match="The named axes are incompatible. Got: x and y for positional axis 0",
+ ):
+ ak.concatenate(
+ [
+ ak.with_named_axis(array1, named_axis=("x", None)),
+ ak.with_named_axis(array2, named_axis=("y", None)),
+ ],
+ axis=0,
+ )
+
+ with pytest.raises(
+ ValueError,
+ match="The named axes are incompatible. Got: x and y for positional axis 1",
+ ):
+ ak.concatenate(
+ [
+ ak.with_named_axis(array1, named_axis=(None, "x")),
+ ak.with_named_axis(array2, named_axis=(None, "y")),
+ ],
+ axis=1,
+ )
+
+
+def test_negative_named_axis_ak_concatenate():
+ array1 = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+ array2 = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+ array3 = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+ array4 = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ all_arrays = [array1, array2, array3, array4]
+
+ named_array1 = ak.with_named_axis(array1, named_axis={})
+ named_array2 = ak.with_named_axis(array1, named_axis={"y": -1})
+ named_array3 = ak.with_named_axis(array1, named_axis={"x": -2})
+ named_array4 = ak.with_named_axis(array1, named_axis={"x": -2, "y": -1})
+
+ all_named_arrays = [named_array1, named_array2, named_array3, named_array4]
+
+ assert ak.all(
+ ak.concatenate(all_arrays, axis=-2)
+ == ak.concatenate(all_named_arrays, axis="x")
+ )
+ assert ak.all(
+ ak.concatenate(all_arrays, axis=-1)
+ == ak.concatenate(all_named_arrays, axis="y")
+ )
+
+ assert ak.concatenate(all_named_arrays, axis="x").named_axis == {"x": -2, "y": -1}
+ assert ak.concatenate(all_named_arrays, axis="y").named_axis == {"x": -2, "y": -1}
+
+
+def test_named_axis_ak_copy():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis=("x", "y"))
+
+ assert ak.copy(named_array).named_axis == {"x": 0, "y": 1}
+
+
+# def test_named_axis_ak_corr():
+# array_x = ak.Array([[0, 1.1], [3.3, 4.4]])
+# array_y = ak.Array([[0, 1], [3, 4]])
+
+# named_array_x = ak.with_named_axis(array_x, ("x", "y"))
+# named_array_y = ak.with_named_axis(array_y, ("x", "y"))
+
+# assert ak.all(
+# ak.corr(array_x, array_y, axis=0)
+# == ak.corr(named_array_x, named_array_y, axis="x")
+# )
+# assert ak.all(
+# ak.corr(array_x, array_y, axis=1)
+# == ak.corr(named_array_x, named_array_y, axis="y")
+# )
+# assert ak.corr(array_x, array_y, axis=None) == ak.corr(
+# named_array_x, named_array_y, axis=None
+# )
+
+# assert ak.corr(named_array_x, named_array_y, axis="x").named_axis == {"y": 0}
+# assert ak.corr(named_array_x, named_array_y, axis="y").named_axis == {"x": 0}
+# assert not _get_named_axis(ak.corr(named_array_x, named_array_y, axis=None))
+
+
+def test_named_axis_ak_count():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(ak.count(array, axis=0) == ak.count(named_array, axis="x"))
+ assert ak.all(ak.count(array, axis=1) == ak.count(named_array, axis="y"))
+ assert ak.count(array, axis=None) == ak.count(named_array, axis=None)
+
+ assert ak.count(named_array, axis="x").named_axis == {"y": 0}
+ assert ak.count(named_array, axis="y").named_axis == {"x": 0}
+ assert not _get_named_axis(ak.count(named_array, axis=None))
+
+
+def test_negative_named_axis_ak_count():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, {"x": -2, "y": -1})
+
+ assert ak.all(ak.count(array, axis=-2) == ak.count(named_array, axis="x"))
+ assert ak.all(ak.count(array, axis=-1) == ak.count(named_array, axis="y"))
+ assert ak.count(array, axis=None) == ak.count(named_array, axis=None)
+
+ assert ak.count(named_array, axis="x").named_axis == {"y": -1}
+ assert ak.count(named_array, axis="y").named_axis == {"x": -1}
+ assert not _get_named_axis(ak.count(named_array, axis=None))
+
+
+def test_named_axis_ak_count_nonzero():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(
+ ak.count_nonzero(array, axis=0) == ak.count_nonzero(named_array, axis="x")
+ )
+ assert ak.all(
+ ak.count_nonzero(array, axis=1) == ak.count_nonzero(named_array, axis="y")
+ )
+ assert ak.count_nonzero(array, axis=None) == ak.count_nonzero(
+ named_array, axis=None
+ )
+
+ assert ak.count_nonzero(named_array, axis="x").named_axis == {"y": 0}
+ assert ak.count_nonzero(named_array, axis="y").named_axis == {"x": 0}
+ assert not _get_named_axis(ak.count_nonzero(named_array, axis=None))
+
+
+def test_negative_named_axis_ak_count_nonzero():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, {"x": -2, "y": -1})
+
+ assert ak.all(
+ ak.count_nonzero(array, axis=-2) == ak.count_nonzero(named_array, axis="x")
+ )
+ assert ak.all(
+ ak.count_nonzero(array, axis=-1) == ak.count_nonzero(named_array, axis="y")
+ )
+ assert ak.count_nonzero(array, axis=None) == ak.count_nonzero(
+ named_array, axis=None
+ )
+
+ assert ak.count_nonzero(named_array, axis="x").named_axis == {"y": -1}
+ assert ak.count_nonzero(named_array, axis="y").named_axis == {"x": -1}
+ assert not _get_named_axis(ak.count_nonzero(named_array, axis=None))
+
+
+# def test_named_axis_ak_covar():
+# array_x = ak.Array([[0, 1.1], [3.3, 4.4]])
+# array_y = ak.Array([[0, 1], [3, 4]])
+
+# named_array_x = ak.with_named_axis(array_x, ("x", "y"))
+# named_array_y = ak.with_named_axis(array_y, ("x", "y"))
+
+# assert ak.all(
+# ak.covar(array_x, array_y, axis=0)
+# == ak.covar(named_array_x, named_array_y, axis="x")
+# )
+# assert ak.all(
+# ak.covar(array_x, array_y, axis=1)
+# == ak.covar(named_array_x, named_array_y, axis="y")
+# )
+# assert ak.covar(array_x, array_y, axis=None) == ak.covar(
+# named_array_x, named_array_y, axis=None
+# )
+
+# assert ak.covar(named_array_x, named_array_y, axis="x").named_axis == {"y": 0}
+# assert ak.covar(named_array_x, named_array_y, axis="y").named_axis == {"x": 0}
+# assert not _get_named_axis(ak.covar(named_array_x, named_array_y, axis=None))
+
+
+def test_named_axis_ak_drop_none():
+ array = ak.Array([[1, None], [3], [None], [4, None, 6]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(ak.drop_none(array, axis=0) == ak.drop_none(named_array, axis="x"))
+ assert ak.all(ak.drop_none(array, axis=1) == ak.drop_none(named_array, axis="y"))
+ assert ak.all(
+ ak.drop_none(array, axis=None) == ak.drop_none(named_array, axis=None)
+ )
+
+ assert ak.drop_none(named_array, axis="x").named_axis == {"x": 0, "y": 1}
+ assert ak.drop_none(named_array, axis="y").named_axis == {"x": 0, "y": 1}
+ assert ak.drop_none(named_array, axis=None).named_axis == {"x": 0, "y": 1}
+
+
+def test_negative_named_axis_ak_drop_none():
+ array = ak.Array([[1, None], [3], [None], [4, None, 6]])
+
+ named_array = ak.with_named_axis(array, {"x": -2, "y": -1})
+
+ assert ak.all(ak.drop_none(array, axis=-2) == ak.drop_none(named_array, axis="x"))
+ assert ak.all(ak.drop_none(array, axis=-1) == ak.drop_none(named_array, axis="y"))
+ assert ak.all(
+ ak.drop_none(array, axis=None) == ak.drop_none(named_array, axis=None)
+ )
+
+ assert ak.drop_none(named_array, axis="x").named_axis == {"x": -2, "y": -1}
+ assert ak.drop_none(named_array, axis="y").named_axis == {"x": -2, "y": -1}
+ assert ak.drop_none(named_array, axis=None).named_axis == {"x": -2, "y": -1}
+
+
+def test_named_axis_ak_enforce_type():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.enforce_type(named_array, "var * ?int64").named_axis == {"x": 0, "y": 1}
+
+
+def test_named_axis_ak_fill_none():
+ array = ak.Array([[1.1, None, 2.2], [], [None, 3.3, 4.4]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(
+ ak.fill_none(array, 0, axis=0) == ak.fill_none(named_array, 0, axis="x")
+ )
+ assert ak.all(
+ ak.fill_none(array, 0, axis=1) == ak.fill_none(named_array, 0, axis="y")
+ )
+ assert ak.all(
+ ak.fill_none(array, 0, axis=None) == ak.fill_none(named_array, 0, axis=None)
+ )
+
+ assert ak.fill_none(named_array, 0, axis="x").named_axis == {"x": 0, "y": 1}
+ assert ak.fill_none(named_array, 0, axis="y").named_axis == {"x": 0, "y": 1}
+ assert ak.fill_none(named_array, 0, axis=None).named_axis == {"x": 0, "y": 1}
+
+
+def test_negative_named_axis_ak_fill_none():
+ array = ak.Array([[1.1, None, 2.2], [], [None, 3.3, 4.4]])
+
+ named_array = ak.with_named_axis(array, {"x": -2, "y": -1})
+
+ assert ak.all(
+ ak.fill_none(array, 0, axis=-2) == ak.fill_none(named_array, 0, axis="x")
+ )
+ assert ak.all(
+ ak.fill_none(array, 0, axis=-1) == ak.fill_none(named_array, 0, axis="y")
+ )
+ assert ak.all(
+ ak.fill_none(array, 0, axis=None) == ak.fill_none(named_array, 0, axis=None)
+ )
+
+ assert ak.fill_none(named_array, 0, axis="x").named_axis == {"x": -2, "y": -1}
+ assert ak.fill_none(named_array, 0, axis="y").named_axis == {"x": -2, "y": -1}
+ assert ak.fill_none(named_array, 0, axis=None).named_axis == {"x": -2, "y": -1}
+
+
+def test_named_axis_ak_firsts():
+ array = ak.Array([[1.1], [2.2], [], [3.3], [], [], [4.4], [5.5]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(ak.firsts(array, axis=0) == ak.firsts(named_array, axis="x"))
+ assert ak.all(ak.firsts(array, axis=1) == ak.firsts(named_array, axis="y"))
+
+ assert ak.firsts(named_array, axis="x").named_axis == {"y": 0}
+ assert ak.firsts(named_array, axis="y").named_axis == {"x": 0}
+
+
+def test_negative_named_axis_ak_firsts():
+ array = ak.Array([[1.1], [2.2], [], [3.3], [], [], [4.4], [5.5]])
+
+ named_array = ak.with_named_axis(array, {"x": -2, "y": -1})
+
+ assert ak.all(ak.firsts(array, axis=-2) == ak.firsts(named_array, axis="x"))
+ assert ak.all(ak.firsts(array, axis=-1) == ak.firsts(named_array, axis="y"))
+
+ assert ak.firsts(named_array, axis="x").named_axis == {"y": -1}
+ assert ak.firsts(named_array, axis="y").named_axis == {"x": -1}
+
+
+def test_named_axis_ak_flatten():
+ array = ak.Array([[[1.1, 2.2]], [[]], [[3.3]], [[]], [[]], [[4.4, 5.5]]])
+
+ named_array = ak.with_named_axis(array, ("x", "y", "z"))
+
+ assert ak.all(ak.flatten(array, axis=0) == ak.flatten(named_array, axis="x"))
+ assert ak.all(ak.flatten(array, axis=1) == ak.flatten(named_array, axis="y"))
+ assert ak.all(ak.flatten(array, axis=2) == ak.flatten(named_array, axis="z"))
+ assert ak.all(ak.flatten(array, axis=None) == ak.flatten(named_array, axis=None))
+
+ assert ak.flatten(named_array, axis="x").named_axis == {"x": 0, "y": 1, "z": 2}
+ assert ak.flatten(named_array, axis="y").named_axis == {"x": 0, "z": 1}
+ assert ak.flatten(named_array, axis="z").named_axis == {"x": 0, "y": 1}
+ assert not _get_named_axis(ak.flatten(named_array, axis=None))
+
+
+def test_negative_named_axis_ak_flatten():
+ array = ak.Array([[[1.1, 2.2]], [[]], [[3.3]], [[]], [[]], [[4.4, 5.5]]])
+
+ named_array = ak.with_named_axis(array, named_axis={"x": -3, "y": -2, "z": -1})
+
+ assert ak.all(ak.flatten(array, axis=-3) == ak.flatten(named_array, axis="x"))
+ assert ak.all(ak.flatten(array, axis=-2) == ak.flatten(named_array, axis="y"))
+ assert ak.all(ak.flatten(array, axis=-1) == ak.flatten(named_array, axis="z"))
+ assert ak.all(ak.flatten(array, axis=None) == ak.flatten(named_array, axis=None))
+
+ assert ak.flatten(named_array, axis="x").named_axis == {"x": -3, "y": -2, "z": -1}
+ assert ak.flatten(named_array, axis="y").named_axis == {"x": -2, "z": -1}
+ assert ak.flatten(named_array, axis="z").named_axis == {"x": -2, "y": -1}
+ assert not _get_named_axis(ak.flatten(named_array, axis=None))
+
+
+def test_named_axis_ak_imag():
+ array = ak.Array([[1 + 2j], [2 + 1j], []])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(ak.imag(array) == ak.imag(named_array))
+ assert ak.imag(named_array).named_axis == {"x": 0, "y": 1}
+
+
+def test_named_axis_ak_is_none():
+ array = ak.Array([[[1, None]], [[3]], [[None]], [[4, None, 6]]])
+
+ named_array = ak.with_named_axis(array, ("x", "y", "z"))
+
+ assert ak.all(ak.is_none(array, axis=0) == ak.is_none(named_array, axis="x"))
+ assert ak.all(ak.is_none(array, axis=1) == ak.is_none(named_array, axis="y"))
+ assert ak.all(ak.is_none(array, axis=2) == ak.is_none(named_array, axis="z"))
+
+ assert ak.is_none(named_array, axis="x").named_axis == {"x": 0}
+ assert ak.is_none(named_array, axis="y").named_axis == {"x": 0, "y": 1}
+ assert ak.is_none(named_array, axis="z").named_axis == {"x": 0, "y": 1, "z": 2}
+
+
+def test_negative_named_axis_ak_is_none():
+ array = ak.Array([[[1, None]], [[3]], [[None]], [[4, None, 6]]])
+
+ named_array = ak.with_named_axis(array, named_axis={"x": -3, "y": -2, "z": -1})
+
+ assert ak.all(ak.is_none(array, axis=-3) == ak.is_none(named_array, axis="x"))
+ assert ak.all(ak.is_none(array, axis=-2) == ak.is_none(named_array, axis="y"))
+ assert ak.all(ak.is_none(array, axis=-1) == ak.is_none(named_array, axis="z"))
+
+ assert ak.is_none(named_array, axis="x").named_axis == {"z": -1}
+ assert ak.is_none(named_array, axis="y").named_axis == {"y": -2, "z": -1}
+ assert ak.is_none(named_array, axis="z").named_axis == {"x": -3, "y": -2, "z": -1}
+
+
+def test_named_axis_ak_isclose():
+ a = b = ak.Array(
+ [[[0.0, 1.1, 2.2], []], [[3.3, 4.4]], [], [[5.5], [], [6.6, 7.7, 8.8, 9.9]]]
+ )
+
+ na = ak.with_named_axis(a, ("x", "y", "z"))
+ nb = ak.with_named_axis(b, ("x", "y", "z"))
+ assert ak.all(ak.isclose(a, b) == ak.isclose(na, nb))
+
+ na = ak.with_named_axis(a, (None, "y", "z"))
+ nb = ak.with_named_axis(b, ("x", "y", None))
+ assert ak.isclose(na, nb).named_axis == {"x": 0, "y": 1, "z": 2}
+
+
+def test_named_axis_ak_local_index():
+ array = ak.Array(
+ [[[0.0, 1.1, 2.2], []], [[3.3, 4.4]], [], [[5.5], [], [6.6, 7.7, 8.8, 9.9]]]
+ )
+
+ named_array = ak.with_named_axis(array, ("x", "y", "z"))
+
+ assert ak.all(
+ ak.local_index(array, axis=0) == ak.local_index(named_array, axis="x")
+ )
+ assert ak.all(
+ ak.local_index(array, axis=1) == ak.local_index(named_array, axis="y")
+ )
+ assert ak.all(
+ ak.local_index(array, axis=2) == ak.local_index(named_array, axis="z")
+ )
+
+ assert ak.local_index(named_array, axis="x").named_axis == {"x": 0}
+ assert ak.local_index(named_array, axis="y").named_axis == {"x": 0, "y": 1}
+ assert ak.local_index(named_array, axis="z").named_axis == {"x": 0, "y": 1, "z": 2}
+
+
+def test_negative_named_axis_ak_local_index():
+ array = ak.Array(
+ [[[0.0, 1.1, 2.2], []], [[3.3, 4.4]], [], [[5.5], [], [6.6, 7.7, 8.8, 9.9]]]
+ )
+ named_array = ak.with_named_axis(array, {"x": -3, "y": -2, "z": -1})
+
+ assert ak.local_index(named_array, axis="x").named_axis == {"z": -1}
+ assert ak.local_index(named_array, axis="y").named_axis == {"y": -2, "z": -1}
+ assert ak.local_index(named_array, axis="z").named_axis == {
+ "x": -3,
+ "y": -2,
+ "z": -1,
+ }
+
+
+def test_named_axis_ak_mask():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+ mask = array > 3
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+ named_mask = named_array > 3
+
+ assert ak.all(ak.mask(array, mask) == ak.mask(named_array, mask))
+ assert ak.all(ak.mask(array, mask) == ak.mask(named_array, named_mask))
+
+ assert ak.mask(named_array, mask).named_axis == named_array.named_axis
+ assert ak.mask(named_array, named_mask).named_axis == named_array.named_axis
+
+
+def test_named_axis_ak_max():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis=("x", "y"))
+
+ # first check that they work the same
+ assert ak.all(ak.max(array, axis=0) == ak.max(named_array, axis="x"))
+ assert ak.all(ak.max(array, axis=1) == ak.max(named_array, axis="y"))
+
+ # check that result axis names are correctly propagated
+ assert (
+ ak.max(named_array, axis=0).named_axis
+ == ak.max(named_array, axis="x").named_axis
+ == {"y": 0}
+ )
+ assert (
+ ak.max(named_array, axis=1).named_axis
+ == ak.max(named_array, axis="y").named_axis
+ == {"x": 0}
+ )
+ assert (
+ ak.max(named_array, axis=0, keepdims=True).named_axis
+ == ak.max(named_array, axis="x", keepdims=True).named_axis
+ == {"x": 0, "y": 1}
+ )
+ assert (
+ ak.max(named_array, axis=1, keepdims=True).named_axis
+ == ak.max(named_array, axis="y", keepdims=True).named_axis
+ == {"x": 0, "y": 1}
+ )
+ assert not _get_named_axis(ak.max(named_array, axis=None))
+
+
+def test_negative_named_axis_ak_max():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1})
+
+ # first check that they work the same
+ assert ak.all(ak.max(array, axis=-2) == ak.max(named_array, axis="x"))
+ assert ak.all(ak.max(array, axis=-1) == ak.max(named_array, axis="y"))
+
+ # check that result axis names are correctly propagated
+ assert (
+ ak.max(named_array, axis=-2).named_axis
+ == ak.max(named_array, axis="x").named_axis
+ == {"y": -1}
+ )
+ assert (
+ ak.max(named_array, axis=-1).named_axis
+ == ak.max(named_array, axis="y").named_axis
+ == {"x": -1}
+ )
+ assert (
+ ak.max(named_array, axis=-2, keepdims=True).named_axis
+ == ak.max(named_array, axis="x", keepdims=True).named_axis
+ == {"x": -2, "y": -1}
+ )
+ assert (
+ ak.max(named_array, axis=-1, keepdims=True).named_axis
+ == ak.max(named_array, axis="y", keepdims=True).named_axis
+ == {"x": -2, "y": -1}
+ )
+ assert not _get_named_axis(ak.max(named_array, axis=None))
+
+
+def test_named_axis_ak_mean():
+ array = ak.Array([[1, 2], [3], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(ak.mean(array, axis=0) == ak.mean(named_array, axis="x"))
+ assert ak.all(ak.mean(array, axis=1) == ak.mean(named_array, axis="y"))
+ assert ak.mean(array, axis=None) == ak.mean(named_array, axis=None)
+
+ assert ak.mean(named_array, axis="x").named_axis == {"y": 0}
+ assert ak.mean(named_array, axis="y").named_axis == {"x": 0}
+ assert ak.mean(named_array, axis="x", keepdims=True).named_axis == {"x": 0, "y": 1}
+ assert ak.mean(named_array, axis="y", keepdims=True).named_axis == {"x": 0, "y": 1}
+ assert not _get_named_axis(ak.mean(named_array, axis=None))
+
+
+def test_negative_named_axis_ak_mean():
+ array = ak.Array([[1, 2], [3], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1})
+
+ assert ak.all(ak.mean(array, axis=-2) == ak.mean(named_array, axis="x"))
+ assert ak.all(ak.mean(array, axis=-1) == ak.mean(named_array, axis="y"))
+ assert ak.mean(array, axis=None) == ak.mean(named_array, axis=None)
+
+ assert ak.mean(named_array, axis="x").named_axis == {"y": -1}
+ assert ak.mean(named_array, axis="y").named_axis == {"x": -1}
+ assert ak.mean(named_array, axis="x", keepdims=True).named_axis == {
+ "x": -2,
+ "y": -1,
+ }
+ assert ak.mean(named_array, axis="y", keepdims=True).named_axis == {
+ "x": -2,
+ "y": -1,
+ }
+ assert not _get_named_axis(ak.mean(named_array, axis=None))
+
+
+def test_named_axis_ak_merge_option_of_records():
+ array = ak.Array([None, {"a": 1}, {"a": 2}])
+
+ named_array = ak.with_named_axis(array, named_axis=("x",))
+
+ assert (
+ ak.merge_option_of_records(named_array, axis="x").named_axis
+ == named_array.named_axis
+ )
+
+
+def test_named_axis_ak_merge_union_of_records():
+ array = ak.concatenate(([{"a": 1}], [{"b": 2}]))
+
+ named_array = ak.with_named_axis(array, named_axis=("x",))
+
+ assert (
+ ak.merge_union_of_records(named_array, axis="x").named_axis
+ == named_array.named_axis
+ )
+
+
+def test_named_axis_ak_min():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis=("x", "y"))
+
+ # first check that they work the same
+ assert ak.all(ak.min(array, axis=0) == ak.min(named_array, axis="x"))
+ assert ak.all(ak.min(array, axis=1) == ak.min(named_array, axis="y"))
+
+ # check that result axis names are correctly propagated
+ assert (
+ ak.min(named_array, axis=0).named_axis
+ == ak.min(named_array, axis="x").named_axis
+ == {"y": 0}
+ )
+ assert (
+ ak.min(named_array, axis=1).named_axis
+ == ak.min(named_array, axis="y").named_axis
+ == {"x": 0}
+ )
+ assert (
+ ak.min(named_array, axis=0, keepdims=True).named_axis
+ == ak.min(named_array, axis="x", keepdims=True).named_axis
+ == {"x": 0, "y": 1}
+ )
+ assert (
+ ak.min(named_array, axis=1, keepdims=True).named_axis
+ == ak.min(named_array, axis="y", keepdims=True).named_axis
+ == {"x": 0, "y": 1}
+ )
+ assert not _get_named_axis(ak.min(named_array, axis=None))
+
+
+def test_negative_named_axis_ak_min():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis={"x": -2, "y": -1})
+
+ # first check that they work the same
+ assert ak.all(ak.min(array, axis=-2) == ak.min(named_array, axis="x"))
+ assert ak.all(ak.min(array, axis=-1) == ak.min(named_array, axis="y"))
+
+ # check that result axis names are correctly propagated
+ assert (
+ ak.min(named_array, axis=-2).named_axis
+ == ak.min(named_array, axis="x").named_axis
+ == {"y": -1}
+ )
+ assert (
+ ak.min(named_array, axis=-1).named_axis
+ == ak.min(named_array, axis="y").named_axis
+ == {"x": -1}
+ )
+ assert (
+ ak.min(named_array, axis=-2, keepdims=True).named_axis
+ == ak.min(named_array, axis="x", keepdims=True).named_axis
+ == {"x": -2, "y": -1}
+ )
+ assert (
+ ak.min(named_array, axis=-1, keepdims=True).named_axis
+ == ak.min(named_array, axis="y", keepdims=True).named_axis
+ == {"x": -2, "y": -1}
+ )
+ assert not _get_named_axis(ak.min(named_array, axis=None))
+
+
+def test_named_axis_ak_moment():
+ array = ak.Array([[0, 1.1], [3.3, 4.4]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(ak.moment(array, 0, axis=0) == ak.moment(named_array, 0, axis="x"))
+ assert ak.all(ak.moment(array, 0, axis=1) == ak.moment(named_array, 0, axis="y"))
+ assert ak.moment(array, 0, axis=None) == ak.moment(named_array, 0, axis=None)
+
+ assert ak.moment(named_array, 0, axis="x").named_axis == {"y": 0}
+ assert ak.moment(named_array, 0, axis="y").named_axis == {"x": 0}
+ assert not _get_named_axis(ak.moment(named_array, 0, axis=None))
+
+
+def test_negative_named_axis_ak_moment():
+ array = ak.Array([[0, 1.1], [3.3, 4.4]])
+
+ named_array = ak.with_named_axis(array, {"x": -2, "y": -1})
+
+ assert ak.all(ak.moment(array, 0, axis=-2) == ak.moment(named_array, 0, axis="x"))
+ assert ak.all(ak.moment(array, 0, axis=-1) == ak.moment(named_array, 0, axis="y"))
+ assert ak.moment(array, 0, axis=None) == ak.moment(named_array, 0, axis=None)
+
+ assert ak.moment(named_array, 0, axis="x").named_axis == {"y": -1}
+ assert ak.moment(named_array, 0, axis="y").named_axis == {"x": -1}
+ assert not _get_named_axis(ak.moment(named_array, 0, axis=None))
+
+
+def test_named_axis_ak_nan_to_none():
+ array = ak.Array([[0, np.nan], [np.nan], [3.3, 4.4]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(ak.nan_to_none(array) == ak.nan_to_none(named_array))
+ assert ak.nan_to_none(named_array).named_axis == named_array.named_axis
+
+
+def test_named_axis_ak_nan_to_num():
+ array = ak.Array([[0, np.nan], [np.nan], [3.3, 4.4]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(ak.nan_to_num(array, nan=0.0) == ak.nan_to_num(named_array, nan=0.0))
+ assert ak.nan_to_num(named_array, nan=0.0).named_axis == named_array.named_axis
+
+
+def test_named_axis_ak_num():
+ array = ak.Array([[1, 2], [3], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.num(array, axis=0) == ak.num(named_array, axis="x")
+ assert ak.all(ak.num(array, axis=1) == ak.num(named_array, axis="y"))
+
+ assert ak.num(named_array, axis="y").named_axis == {"y": 0}
+ assert not _get_named_axis(ak.num(named_array, axis="x"))
+
+
+def test_negative_named_axis_ak_num():
+ array = ak.Array([[1, 2], [3], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, {"x": -2, "y": -1})
+
+ assert ak.num(array, axis=-2) == ak.num(named_array, axis="x")
+ assert ak.all(ak.num(array, axis=-1) == ak.num(named_array, axis="y"))
+
+ assert ak.num(named_array, axis="y").named_axis == {"y": 0}
+ assert not _get_named_axis(ak.num(named_array, axis="x"))
+
+
+def test_named_axis_ak_ones_like():
+ array = ak.Array([[1, 2], [3], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(ak.ones_like(array) == ak.ones_like(named_array))
+
+ assert ak.ones_like(named_array).named_axis == named_array.named_axis
+
+
+def test_named_axis_ak_pad_none():
+ array = ak.Array([[1, 2], [3], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(ak.pad_none(array, 3, axis=0) == ak.pad_none(named_array, 3, axis=0))
+ assert ak.all(ak.pad_none(array, 3, axis=1) == ak.pad_none(named_array, 3, axis=1))
+
+ assert ak.pad_none(named_array, 3, axis=0).named_axis == named_array.named_axis
+ assert ak.pad_none(named_array, 3, axis=1).named_axis == named_array.named_axis
+
+
+def test_named_axis_ak_prod():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(ak.prod(array, axis=0) == ak.prod(named_array, axis="x"))
+ assert ak.all(ak.prod(array, axis=1) == ak.prod(named_array, axis="y"))
+ assert ak.prod(array, axis=None) == ak.prod(named_array, axis=None)
+
+ assert ak.prod(named_array, axis="x").named_axis == {"y": 0}
+ assert ak.prod(named_array, axis="y").named_axis == {"x": 0}
+ assert not _get_named_axis(ak.prod(named_array, axis=None))
+
+
+def test_negative_named_axis_ak_prod():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, {"x": -2, "y": -1})
+
+ assert ak.all(ak.prod(array, axis=-2) == ak.prod(named_array, axis="x"))
+ assert ak.all(ak.prod(array, axis=-1) == ak.prod(named_array, axis="y"))
+ assert ak.prod(array, axis=None) == ak.prod(named_array, axis=None)
+
+ assert ak.prod(named_array, axis="x").named_axis == {"y": -1}
+ assert ak.prod(named_array, axis="y").named_axis == {"x": -1}
+ assert not _get_named_axis(ak.prod(named_array, axis=None))
+
+
+def test_named_axis_ak_ptp():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(ak.ptp(array, axis=0) == ak.ptp(named_array, axis="x"))
+ assert ak.all(ak.ptp(array, axis=1) == ak.ptp(named_array, axis="y"))
+ assert ak.ptp(array, axis=None) == ak.ptp(named_array, axis=None)
+
+ assert ak.ptp(named_array, axis="x").named_axis == {"y": 0}
+ assert ak.ptp(named_array, axis="y").named_axis == {"x": 0}
+ assert not _get_named_axis(ak.ptp(named_array, axis=None))
+
+
+def test_negative_named_axis_ak_ptp():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, {"x": -2, "y": -1})
+
+ assert ak.all(ak.ptp(array, axis=-2) == ak.ptp(named_array, axis="x"))
+ assert ak.all(ak.ptp(array, axis=-1) == ak.ptp(named_array, axis="y"))
+ assert ak.ptp(array, axis=None) == ak.ptp(named_array, axis=None)
+
+ assert ak.ptp(named_array, axis="x").named_axis == {"y": -1}
+ assert ak.ptp(named_array, axis="y").named_axis == {"x": -1}
+ assert not _get_named_axis(ak.ptp(named_array, axis=None))
+
+
+def test_named_axis_ak_ravel():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(ak.ravel(array) == ak.ravel(named_array))
+
+ assert not _get_named_axis(ak.ravel(named_array))
+
+
+def test_named_axis_ak_real():
+ array = ak.Array([[1 + 2j], [2 + 1j], []])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(ak.real(array) == ak.real(named_array))
+ assert ak.real(named_array).named_axis == {"x": 0, "y": 1}
+
+
+def test_named_axis_ak_round():
+ array = ak.Array([[1.234], [2.345, 3.456], []])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(ak.round(array) == ak.round(named_array))
+ assert ak.round(named_array).named_axis == {"x": 0, "y": 1}
+
+
+def test_named_axis_ak_run_lengths():
+ array = ak.Array([[1.1, 1.1, 1.1, 2.2, 3.3], [3.3, 4.4], [4.4, 5.5]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(ak.run_lengths(array) == ak.run_lengths(named_array))
+
+ assert ak.run_lengths(named_array).named_axis == named_array.named_axis
+
+
+def test_named_axis_ak_singletons():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(ak.singletons(array, axis=0) == ak.singletons(named_array, axis="x"))
+ assert ak.all(ak.singletons(array, axis=1) == ak.singletons(named_array, axis="y"))
+
+ assert ak.singletons(named_array, axis=0).named_axis == {"x": 0, "y": 2}
+ assert ak.singletons(named_array, axis=1).named_axis == {"x": 0, "y": 1}
+
+
+def test_negative_named_axis_ak_singletons():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, {"x": -2, "y": -1})
+
+ assert ak.all(ak.singletons(array, axis=-2) == ak.singletons(named_array, axis="x"))
+ assert ak.all(ak.singletons(array, axis=-1) == ak.singletons(named_array, axis="y"))
+
+ assert ak.singletons(named_array, axis=-2).named_axis == {"x": -3, "y": -1}
+ assert ak.singletons(named_array, axis=-1).named_axis == {"x": -3, "y": -2}
+
+
+def test_named_axis_ak_softmax():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(ak.softmax(array, axis=-1) == ak.softmax(named_array, axis="y"))
+
+ assert ak.softmax(named_array, axis="y").named_axis == {"x": 0, "y": 1}
+
+
+def test_named_axis_ak_sort():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, named_axis=("x", "y"))
+
+ # first check that they work the same
+ assert ak.all(ak.sort(array, axis=0) == ak.sort(named_array, axis="x"))
+ assert ak.all(ak.sort(array, axis=1) == ak.sort(named_array, axis="y"))
+
+ # check that result axis names are correctly propagated
+ assert (
+ ak.sort(named_array, axis=0).named_axis
+ == ak.sort(named_array, axis="x").named_axis
+ == {"x": 0, "y": 1}
+ )
+ assert (
+ ak.sort(named_array, axis=1).named_axis
+ == ak.sort(named_array, axis="y").named_axis
+ == {"x": 0, "y": 1}
+ )
+
+
+def test_named_axis_ak_std():
+ array = ak.Array([[1, 2], [3], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(ak.std(array, axis=0) == ak.std(named_array, axis="x"))
+ assert ak.all(ak.std(array, axis=1) == ak.std(named_array, axis="y"))
+ assert ak.std(array, axis=None) == ak.std(named_array, axis=None)
+
+ assert ak.std(named_array, axis="x").named_axis == {"y": 0}
+ assert ak.std(named_array, axis="y").named_axis == {"x": 0}
+ assert not _get_named_axis(ak.std(named_array, axis=None))
+
+
+def test_negative_named_axis_ak_std():
+ array = ak.Array([[1, 2], [3], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, {"x": -2, "y": -1})
+
+ assert ak.all(ak.std(array, axis=-2) == ak.std(named_array, axis="x"))
+ assert ak.all(ak.std(array, axis=-1) == ak.std(named_array, axis="y"))
+ assert ak.std(array, axis=None) == ak.std(named_array, axis=None)
+
+ assert ak.std(named_array, axis="x").named_axis == {"y": -1}
+ assert ak.std(named_array, axis="y").named_axis == {"x": -1}
+ assert not _get_named_axis(ak.std(named_array, axis=None))
+
+
+def test_named_axis_ak_strings_astype():
+ array = ak.Array([["1", "2"], ["3"], ["4", "5", "6"]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(
+ ak.strings_astype(array, np.int32) == ak.strings_astype(named_array, np.int32)
+ )
+
+ assert ak.strings_astype(named_array, np.int32).named_axis == named_array.named_axis
+
+
+def test_named_axis_ak_sum():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(ak.sum(array, axis=0) == ak.sum(named_array, axis="x"))
+ assert ak.all(ak.sum(array, axis=1) == ak.sum(named_array, axis="y"))
+ assert ak.sum(array, axis=None) == ak.sum(named_array, axis=None)
+
+ assert ak.sum(named_array, axis="x").named_axis == {"y": 0}
+ assert ak.sum(named_array, axis="y").named_axis == {"x": 0}
+ assert not _get_named_axis(ak.sum(named_array, axis=None))
+
+
+def test_negative_named_axis_ak_sum():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, {"x": -2, "y": -1})
+
+ assert ak.all(ak.sum(array, axis=-2) == ak.sum(named_array, axis="x"))
+ assert ak.all(ak.sum(array, axis=-1) == ak.sum(named_array, axis="y"))
+ assert ak.sum(array, axis=None) == ak.sum(named_array, axis=None)
+
+ assert ak.sum(named_array, axis="x").named_axis == {"y": -1}
+ assert ak.sum(named_array, axis="y").named_axis == {"x": -1}
+ assert not _get_named_axis(ak.sum(named_array, axis=None))
+
+
+def test_named_axis_ak_to_backend():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.to_backend(named_array, "typetracer").named_axis == named_array.named_axis
+
+
+def test_named_axis_ak_to_packed():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(ak.to_packed(array) == ak.to_packed(named_array))
+
+ assert ak.to_packed(named_array).named_axis == named_array.named_axis
+
+
+def test_named_axis_ak_unflatten():
+ array = ak.Array([[1, 2, 3, 4], [], [5, 6, 7], [8, 9]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ counts = ak.Array([2, 2, 1, 2, 1, 1])
+
+ assert ak.all(
+ ak.unflatten(array, counts, axis=1)
+ == ak.unflatten(named_array, counts, axis="y")
+ )
+ assert not _get_named_axis(ak.unflatten(named_array, counts, axis="y"))
+
+
+def test_named_axis_ak_unzip():
+ array = ak.Array(
+ [
+ {"x": 1.1, "y": [1]},
+ {"x": 2.2, "y": [2, 2]},
+ {"x": 3.3, "y": [3, 3, 3]},
+ ]
+ )
+ named_array = ak.with_named_axis(array, ("x", "y"))
+ x, y = ak.unzip(named_array)
+ assert x.named_axis == y.named_axis == {"x": 0, "y": 1}
+
+
+def test_named_axis_ak_values_astype():
+ array = ak.Array([[1, 2, 3, 4], [], [5, 6, 7], [8, 9]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(
+ ak.values_astype(array, np.float32) == ak.values_astype(named_array, np.float32)
+ )
+
+ assert (
+ ak.values_astype(named_array, np.float32).named_axis == named_array.named_axis
+ )
+
+
+def test_named_axis_ak_var():
+ array = ak.Array([[1, 2], [3], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(ak.var(array, axis=0) == ak.var(named_array, axis="x"))
+ assert ak.all(ak.var(array, axis=1) == ak.var(named_array, axis="y"))
+ assert ak.var(array, axis=None) == ak.var(named_array, axis=None)
+
+ assert ak.var(named_array, axis="x").named_axis == {"y": 0}
+ assert ak.var(named_array, axis="y").named_axis == {"x": 0}
+ assert not _get_named_axis(ak.var(named_array, axis=None))
+
+
+def test_negative_named_axis_ak_var():
+ array = ak.Array([[1, 2], [3], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, {"x": -2, "y": -1})
+
+ assert ak.all(ak.var(array, axis=-2) == ak.var(named_array, axis="x"))
+ assert ak.all(ak.var(array, axis=-1) == ak.var(named_array, axis="y"))
+ assert ak.var(array, axis=None) == ak.var(named_array, axis=None)
+
+ assert ak.var(named_array, axis="x").named_axis == {"y": -1}
+ assert ak.var(named_array, axis="y").named_axis == {"x": -1}
+ assert not _get_named_axis(ak.var(named_array, axis=None))
+
+
+def test_named_axis_ak_where():
+ a = ak.Array([[1, 2], [3, 4]])
+ na = ak.with_named_axis(a, ("x", "y"))
+
+ assert ak.all(ak.where(a > 2, 0, 1) == ak.where(na > 2, 0, 1))
+ assert ak.where(na > 2, 0, 1).named_axis == {"x": 0, "y": 1}
+ assert ak.where(na > 2, na, 1).named_axis == {"x": 0, "y": 1}
+
+ nb = ak.with_named_axis(a, ("a", "b"))
+ with pytest.raises(ValueError):
+ _ = ak.where(na > 2, nb, 1)
+
+
+def test_named_axis_ak_with_field():
+ array = ak.Array(
+ [
+ {"x": 1.1, "y": [1]},
+ {"x": 2.2, "y": [2, 2]},
+ {"x": 3.3, "y": [3, 3, 3]},
+ ]
+ )
+ named_array = ak.with_named_axis(array, ("x", "y"))
+ xyz = ak.with_field(named_array, ak.Array([[1], [2], [3]]), "z")
+ x, y, z = ak.unzip(xyz)
+ assert x.named_axis == y.named_axis == z.named_axis == {"x": 0, "y": 1}
+
+ named_z = ak.with_named_axis(ak.Array([[1], [2], [3]]), ("a", "b"))
+ with pytest.raises(ValueError):
+ ak.with_field(named_array, named_z, "z")
+
+
+def test_named_axis_ak_with_name():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.with_name(named_array, "new_name").named_axis == named_array.named_axis
+
+
+def test_named_axis_ak_with_named_axis():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ # tuple
+ named_array = ak.with_named_axis(array, ("x", "y"))
+ assert named_array.named_axis == {"x": 0, "y": 1}
+
+ # dict
+ named_array = ak.with_named_axis(array, {"x": 0, "y": -1})
+ assert named_array.named_axis == {"x": 0, "y": -1}
+
+
+def test_named_axis_ak_with_parameter():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert (
+ ak.with_parameter(named_array, "param", 1.0).named_axis
+ == named_array.named_axis
+ )
+
+
+def test_named_axis_ak_without_parameters():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ named_array_with_parameteter = ak.with_parameter(named_array, "param", 1.0)
+
+ assert (
+ ak.without_parameters(named_array_with_parameteter).named_axis
+ == named_array.named_axis
+ )
+
+
+def test_named_axis_ak_zeros_like():
+ array = ak.Array([[1, 2], [3], [], [4, 5, 6]])
+
+ named_array = ak.with_named_axis(array, ("x", "y"))
+
+ assert ak.all(ak.zeros_like(array) == ak.zeros_like(named_array))
+
+ assert ak.zeros_like(named_array).named_axis == named_array.named_axis
+
+
+def test_named_axis_ak_zip():
+ named_array1 = ak.with_named_axis(ak.Array([1, 2, 3]), ("x",))
+ named_array2 = ak.with_named_axis(ak.Array([[4, 5, 6], [], [7]]), ("x", "y"))
+
+ assert ak.zip({"x": named_array1, "y": named_array2}).named_axis == {"x": 0, "y": 1}
+
+ named_array1 = ak.with_named_axis(ak.Array([1, 2, 3]), ("a",))
+ named_array2 = ak.with_named_axis(ak.Array([[4, 5, 6], [], [7]]), ("x", "y"))
+
+ with pytest.raises(ValueError):
+ _ = ak.zip({"x": named_array1, "y": named_array2})