Skip to content

Commit

Permalink
fix: prevent exponential memory growth in UnionArray (#3119)
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski authored May 28, 2024
1 parent 6cff8e9 commit 28a89da
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/awkward/contents/unionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def simplified(
]

if len(contents) == 1:
next = contents[0]._carry(index, True)
next = contents[0]._carry(index, False)
return next.copy(parameters=parameters_union(next._parameters, parameters))

else:
Expand Down Expand Up @@ -702,7 +702,7 @@ def project(self, index):
nextcarry = ak.index.Index64(
tmpcarry.data[: lenout[0]], nplike=self._backend.index_nplike
)
return self._contents[index]._carry(nextcarry, True)
return self._contents[index]._carry(nextcarry, False)

@staticmethod
def regular_index(
Expand Down
7 changes: 4 additions & 3 deletions tests/test_2713_from_buffers_allow_noncanonical.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,10 @@ def test_union_simplification():
projected = ak.from_buffers(
projected_form, length, container, allow_noncanonical_form=True
)

assert projected.layout.form.to_dict(verbose=False) == {
"class": "IndexedArray",
"index": "i64",
"content": {"class": "RecordArray", "fields": ["x"], "contents": ["int64"]},
"class": "RecordArray",
"fields": ["x"],
"contents": ["int64"],
}
assert ak.almost_equal(array[["x"]], projected)
34 changes: 34 additions & 0 deletions tests/test_3118_prevent_exponential_memory_growth_in_unionarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import awkward as ak


def test():
one_a = ak.Array([{"x": 1, "y": 2}], with_name="T")
one_b = ak.Array([{"x": 1, "y": 2}], with_name="T")
two_a = ak.Array([{"x": 1, "z": 3}], with_name="T")
two_b = ak.Array([{"x": 1, "z": 3}], with_name="T")
three = ak.Array([{"x": 4}, {"x": 4}], with_name="T")

first = ak.zip({"a": one_a, "b": one_b})
second = ak.zip({"a": two_a, "b": two_b})

cat = ak.concatenate([first, second], axis=0)

cat["another"] = three

def check(layout):
if hasattr(layout, "contents"):
for x in layout.contents:
check(x)
elif hasattr(layout, "content"):
check(layout.content)
else:
assert layout.length <= 2

for _ in range(5):
check(cat.layout)

cat["another", "w"] = three.x

0 comments on commit 28a89da

Please sign in to comment.