Skip to content

Commit

Permalink
Merge branch 'main' into dependabot/github_actions/actions-0359ee9315
Browse files Browse the repository at this point in the history
  • Loading branch information
ianna authored Nov 18, 2024
2 parents 101ac38 + f0b4f40 commit 32a4f02
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 30 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ repos:
additional_dependencies: [pyyaml]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.2
rev: v0.7.3
hooks:
- id: ruff
args: ["--fix", "--show-fixes"]
Expand Down Expand Up @@ -76,6 +76,6 @@ repos:
- numpy>=1.24

- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.22
rev: v0.23
hooks:
- id: validate-pyproject
6 changes: 3 additions & 3 deletions src/awkward/_broadcasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def apply_step(
named_axis = _add_named_axis(named_axis, depth, ndim)
depth_context[NAMED_AXIS_KEY][i] = (
_unify_named_axis(named_axis, seen_named_axis),
ndim + 1,
ndim + 1 if ndim is not None else ndim,
)
if o.is_leaf:
_export_named_axis_from_depth_to_lateral(
Expand Down Expand Up @@ -645,7 +645,7 @@ def broadcast_any_list():
# rightbroadcasting adds a new first(!) dimension as depth
depth_context[NAMED_AXIS_KEY][i] = (
_add_named_axis(named_axis, depth, ndim),
ndim + 1,
ndim + 1 if ndim is not None else ndim,
)
if x.is_leaf:
_export_named_axis_from_depth_to_lateral(
Expand Down Expand Up @@ -734,7 +734,7 @@ def broadcast_any_list():
# leftbroadcasting adds a new last dimension at depth + 1
depth_context[NAMED_AXIS_KEY][i] = (
_add_named_axis(named_axis, depth + 1, ndim),
ndim + 1,
ndim + 1 if ndim is not None else ndim,
)
if x.is_leaf:
_export_named_axis_from_depth_to_lateral(
Expand Down
13 changes: 9 additions & 4 deletions src/awkward/_namedaxis.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _prettify(ax: AxisName) -> str:
return delimiter.join(items)


def _get_named_axis(ctx: tp.Any) -> AxisMapping:
def _get_named_axis(ctx: tp.Any, allow_any: bool = False) -> AxisMapping:
"""
Retrieves the named axis from the provided context.
Expand All @@ -103,9 +103,14 @@ def _get_named_axis(ctx: tp.Any) -> AxisMapping:
>>> _get_named_axis({"other_key": "other_value"})
{}
"""
if hasattr(ctx, "attrs"):
return _get_named_axis(ctx.attrs)
elif isinstance(ctx, tp.Mapping) and NAMED_AXIS_KEY in ctx:
from awkward._layout import HighLevelContext
from awkward.highlevel import Array, Record

if hasattr(ctx, "attrs") and (
isinstance(ctx, (HighLevelContext, Array, Record)) or allow_any
):
return _get_named_axis(ctx.attrs, allow_any=True)
elif allow_any and isinstance(ctx, tp.Mapping) and NAMED_AXIS_KEY in ctx:
return dict(ctx[NAMED_AXIS_KEY])
else:
return {}
Expand Down
29 changes: 13 additions & 16 deletions src/awkward/operations/ak_argcombinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,20 +108,17 @@ def _impl(

axis = regularize_axis(axis, none_allowed=False)

if axis < 0:
raise ValueError("the 'axis' for argcombinations must be non-negative")
else:
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ak._do.local_index(
ctx.unwrap(array, allow_record=False, primitive_policy="error"),
axis,
)
out = ak._do.combinations(
layout,
n,
replacement=replacement,
axis=axis,
fields=fields,
parameters=parameters,
with HighLevelContext(behavior=behavior, attrs=attrs) as ctx:
layout = ak._do.local_index(
ctx.unwrap(array, allow_record=False, primitive_policy="error"),
axis,
)
return ctx.wrap(out, highlevel=highlevel)
out = ak._do.combinations(
layout,
n,
replacement=replacement,
axis=axis,
fields=fields,
parameters=parameters,
)
return ctx.wrap(out, highlevel=highlevel)
2 changes: 1 addition & 1 deletion src/awkward/operations/ak_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def _impl(x, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs):
)
return ak.operations.ak_with_named_axis._impl(
wrapped,
named_axis=_get_named_axis(attrs_of_obj(out) or {}),
named_axis=_get_named_axis(attrs_of_obj(out), allow_any=True),
highlevel=highlevel,
behavior=None,
attrs=None,
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/operations/ak_moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _impl(
# propagate named axis to output
return ak.operations.ak_with_named_axis._impl(
wrapped,
named_axis=_get_named_axis(attrs_of_obj(out)),
named_axis=_get_named_axis(attrs_of_obj(out), allow_any=True),
highlevel=highlevel,
behavior=None,
attrs=None,
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/operations/ak_ptp.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs):
# propagate named axis to output
return ak.operations.ak_with_named_axis._impl(
wrapped,
named_axis=_get_named_axis(attrs_of_obj(out)),
named_axis=_get_named_axis(attrs_of_obj(out), allow_any=True),
highlevel=highlevel,
behavior=None,
attrs=None,
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/operations/ak_std.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, a
# propagate named axis to output
return ak.operations.ak_with_named_axis._impl(
wrapped,
named_axis=_get_named_axis(attrs_of_obj(out)),
named_axis=_get_named_axis(attrs_of_obj(out), allow_any=True),
highlevel=highlevel,
behavior=None,
attrs=None,
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/operations/ak_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, a
# propagate named axis to output
return ak.operations.ak_with_named_axis._impl(
wrapped,
named_axis=_get_named_axis(attrs_of_obj(out)),
named_axis=_get_named_axis(attrs_of_obj(out), allow_any=True),
highlevel=highlevel,
behavior=None,
attrs=None,
Expand Down
22 changes: 22 additions & 0 deletions tests/test_3300_allow_negative_axis_in_argcombinations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from __future__ import annotations

import awkward as ak


def test():
array = ak.Array([[0.0, 1.1, 2.2], [], [3.3, 4.4], [5.5], [6.6, 7.7, 8.8, 9.9]])

assert ak.combinations(array, 2, axis=-1).tolist() == [
[(0.0, 1.1), (0.0, 2.2), (1.1, 2.2)],
[],
[(3.3, 4.4)],
[],
[(6.6, 7.7), (6.6, 8.8), (6.6, 9.9), (7.7, 8.8), (7.7, 9.9), (8.8, 9.9)],
]
assert ak.argcombinations(array, 2, axis=-1).tolist() == [
[(0, 1), (0, 2), (1, 2)],
[],
[(0, 1)],
[],
[(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)],
]

0 comments on commit 32a4f02

Please sign in to comment.