Skip to content

Commit

Permalink
fix typetracer lengths for optiontypes in 'ak.where' broadcastings (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey authored Dec 2, 2024
1 parent 7b150aa commit 07ce379
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/awkward/_broadcasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ def broadcast_any_option_akwhere():
if not isinstance(xyc, Content):
unmasked.append(xyc)
masks.append(
NumpyArray(backend.nplike.zeros(len(inputs[2]), dtype=np.int8))
NumpyArray(backend.nplike.zeros(inputs[2].length, dtype=np.int8))
)
elif not xyc.is_option:
unmasked.append(xyc)
Expand Down
139 changes: 139 additions & 0 deletions tests/test_3321_akwhere_typetracer_lengths_optiontypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from __future__ import annotations

import awkward as ak

fromdict = {
"class": "RecordArray",
"fields": ["muon", "jet"],
"contents": [
{
"class": "ListOffsetArray",
"offsets": "i64",
"content": {
"class": "RecordArray",
"fields": ["pt", "eta", "phi", "crossref"],
"contents": [
{
"class": "NumpyArray",
"primitive": "int64",
"inner_shape": [],
"parameters": {},
"form_key": "muon_pt!",
},
{
"class": "NumpyArray",
"primitive": "int64",
"inner_shape": [],
"parameters": {},
"form_key": "muon_eta!",
},
{
"class": "NumpyArray",
"primitive": "int64",
"inner_shape": [],
"parameters": {},
"form_key": "muon_phi!",
},
{
"class": "ListOffsetArray",
"offsets": "i64",
"content": {
"class": "NumpyArray",
"primitive": "int64",
"inner_shape": [],
"parameters": {},
"form_key": "muon_crossref_content!",
},
"parameters": {},
"form_key": "muon_crossref_index!",
},
],
"parameters": {},
"form_key": "muon_record!",
},
"parameters": {},
"form_key": "muon_list!",
},
{
"class": "ListOffsetArray",
"offsets": "i64",
"content": {
"class": "RecordArray",
"fields": [
"pt",
"eta",
"phi",
"crossref",
"thing1",
],
"contents": [
{
"class": "NumpyArray",
"primitive": "int64",
"inner_shape": [],
"parameters": {},
"form_key": "jet_pt!",
},
{
"class": "NumpyArray",
"primitive": "int64",
"inner_shape": [],
"parameters": {},
"form_key": "jet_eta!",
},
{
"class": "NumpyArray",
"primitive": "int64",
"inner_shape": [],
"parameters": {},
"form_key": "jet_phi!",
},
{
"class": "ListOffsetArray",
"offsets": "i64",
"content": {
"class": "NumpyArray",
"primitive": "int64",
"inner_shape": [],
"parameters": {},
"form_key": "jet_crossref_content!",
},
"parameters": {},
"form_key": "jet_crossref_index!",
},
{
"class": "NumpyArray",
"primitive": "int64",
"inner_shape": [],
"parameters": {},
"form_key": "jet_thing1!",
},
],
"parameters": {},
"form_key": "jet_record!",
},
"parameters": {},
"form_key": "jet_list!",
},
],
"parameters": {},
"form_key": "outer!",
}

form = ak.forms.from_dict(fromdict)
ttlayout, report = ak.typetracer.typetracer_with_report(form)
ttarray = ak.Array(ttlayout)


def test_where():
ak.where(abs(ttarray.jet.eta) < 1.0, 0.000511, ttarray.jet.thing1)


def test_maybe_where():
maybe = ak.firsts(ttarray)
ak.where(abs(maybe.jet.eta) < 1.0, 0.000511, maybe.jet.thing1)


def test_varmaybe_where():
varmaybe = ak.pad_none(ttarray, 3)
ak.where(abs(varmaybe.jet.eta) < 1.0, 0.000511, varmaybe.jet.thing1)

0 comments on commit 07ce379

Please sign in to comment.