Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add recursive argument to expected_from_buffers #2724

Merged
merged 5 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/awkward/forms/bitmaskedform.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,8 @@ def __setstate__(self, state):
)

def _expected_from_buffers(
self, getkey: Callable[[Form, str], str]
self, getkey: Callable[[Form, str], str], recursive: bool
) -> Iterator[tuple[str, np.dtype]]:
yield (getkey(self, "mask"), index_to_dtype[self._mask])
yield from self._content._expected_from_buffers(getkey)
if recursive:
yield from self._content._expected_from_buffers(getkey, recursive)
5 changes: 3 additions & 2 deletions src/awkward/forms/bytemaskedform.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ def __setstate__(self, state):
)

def _expected_from_buffers(
self, getkey: Callable[[Form, str], str]
self, getkey: Callable[[Form, str], str], recursive: bool
) -> Iterator[tuple[str, np.dtype]]:
yield (getkey(self, "mask"), index_to_dtype[self._mask])
yield from self._content._expected_from_buffers(getkey)
if recursive:
yield from self._content._expected_from_buffers(getkey, recursive)
2 changes: 1 addition & 1 deletion src/awkward/forms/emptyform.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,6 @@ def __setstate__(self, state):
self.__init__(form_key=form_key)

def _expected_from_buffers(
self, getkey: Callable[[Form, str], str]
self, getkey: Callable[[Form, str], str], recursive: bool
) -> Iterator[tuple[str, np.dtype]]:
yield from ()
15 changes: 11 additions & 4 deletions src/awkward/forms/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,18 +657,25 @@ def prepare(form, multiplier):
)

def _expected_from_buffers(
self, getkey: Callable[[Form, str], str]
self, getkey: Callable[[Form, str], str], recursive: bool
) -> Iterator[tuple[str, np.dtype]]:
raise NotImplementedError

def expected_from_buffers(self, buffer_key="{form_key}-{attribute}"):
def expected_from_buffers(
self, buffer_key="{form_key}-{attribute}", recursive=True
):
"""
Args:
buffer_key (str or callable): Python format string containing
`"{form_key}"` and/or `"{attribute}"` or a function that takes these
as keyword arguments and returns a string to use as a key for a buffer
in the `container`.
recursive (bool): If True, recurse into subforms; otherwise, yield
only the (buffer_key, dtype) pairs for this form object.

Yield (buffer_key, dtype) pairs describing the expected buffer keys,
and their corresponding dtypes, that a call to #ak.from_buffers would
be expected to find from the `container` object.
"""
getkey = regularize_buffer_key(buffer_key)

return dict(self._expected_from_buffers(getkey))
return dict(self._expected_from_buffers(getkey, recursive))
5 changes: 3 additions & 2 deletions src/awkward/forms/indexedform.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ def __setstate__(self, state):
self.__init__(index, content, parameters=parameters, form_key=form_key)

def _expected_from_buffers(
self, getkey: Callable[[Form, str], str]
self, getkey: Callable[[Form, str], str], recursive: bool
) -> Iterator[tuple[str, np.dtype]]:
yield (getkey(self, "index"), index_to_dtype[self._index])
yield from self._content._expected_from_buffers(getkey)
if recursive:
yield from self._content._expected_from_buffers(getkey, recursive)
5 changes: 3 additions & 2 deletions src/awkward/forms/indexedoptionform.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ def __setstate__(self, state):
self.__init__(index, content, parameters=parameters, form_key=form_key)

def _expected_from_buffers(
self, getkey: Callable[[Form, str], str]
self, getkey: Callable[[Form, str], str], recursive: bool
) -> Iterator[tuple[str, np.dtype]]:
yield (getkey(self, "index"), index_to_dtype[self._index])
yield from self._content._expected_from_buffers(getkey)
if recursive:
yield from self._content._expected_from_buffers(getkey, recursive)
5 changes: 3 additions & 2 deletions src/awkward/forms/listform.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,9 @@ def __setstate__(self, state):
)

def _expected_from_buffers(
self, getkey: Callable[[Form, str], str]
self, getkey: Callable[[Form, str], str], recursive: bool
) -> Iterator[tuple[str, np.dtype]]:
yield (getkey(self, "starts"), index_to_dtype[self._starts])
yield (getkey(self, "stops"), index_to_dtype[self._stops])
yield from self._content._expected_from_buffers(getkey)
if recursive:
yield from self._content._expected_from_buffers(getkey, recursive)
5 changes: 3 additions & 2 deletions src/awkward/forms/listoffsetform.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ def __setstate__(self, state):
self.__init__(offsets, content, parameters=parameters, form_key=form_key)

def _expected_from_buffers(
self, getkey: Callable[[Form, str], str]
self, getkey: Callable[[Form, str], str], recursive: bool
) -> Iterator[tuple[str, np.dtype]]:
yield (getkey(self, "offsets"), index_to_dtype[self._offsets])
yield from self._content._expected_from_buffers(getkey)
if recursive:
yield from self._content._expected_from_buffers(getkey, recursive)
2 changes: 1 addition & 1 deletion src/awkward/forms/numpyform.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def __setstate__(self, state):
)

def _expected_from_buffers(
self, getkey: Callable[[Form, str], str]
self, getkey: Callable[[Form, str], str], recursive: bool
) -> Iterator[tuple[str, np.dtype]]:
from awkward.types.numpytype import primitive_to_dtype

Expand Down
7 changes: 4 additions & 3 deletions src/awkward/forms/recordform.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,8 @@ def __setstate__(self, state):
)

def _expected_from_buffers(
self, getkey: Callable[[Form, str], str]
self, getkey: Callable[[Form, str], str], recursive: bool
) -> Iterator[tuple[str, np.dtype]]:
for content in self._contents:
yield from content._expected_from_buffers(getkey)
if recursive:
for content in self._contents:
yield from content._expected_from_buffers(getkey, recursive)
5 changes: 3 additions & 2 deletions src/awkward/forms/regularform.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def __setstate__(self, state):
self.__init__(content, size, parameters=parameters, form_key=form_key)

def _expected_from_buffers(
self, getkey: Callable[[Form, str], str]
self, getkey: Callable[[Form, str], str], recursive: bool
) -> Iterator[tuple[str, np.dtype]]:
yield from self._content._expected_from_buffers(getkey)
if recursive:
yield from self._content._expected_from_buffers(getkey, recursive)
7 changes: 4 additions & 3 deletions src/awkward/forms/unionform.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,10 @@ def __setstate__(self, state):
)

def _expected_from_buffers(
self, getkey: Callable[[Form, str], str]
self, getkey: Callable[[Form, str], str], recursive: bool
) -> Iterator[tuple[str, np.dtype]]:
yield (getkey(self, "tags"), index_to_dtype[self._tags])
yield (getkey(self, "index"), index_to_dtype[self._index])
for content in self._contents:
yield from content._expected_from_buffers(getkey)
if recursive:
for content in self._contents:
yield from content._expected_from_buffers(getkey, recursive)
5 changes: 3 additions & 2 deletions src/awkward/forms/unmaskedform.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def __setstate__(self, state):
self.__init__(content, parameters=parameters, form_key=form_key)

def _expected_from_buffers(
self, getkey: Callable[[Form, str], str]
self, getkey: Callable[[Form, str], str], recursive: bool
) -> Iterator[tuple[str, np.dtype]]:
yield from self._content._expected_from_buffers(getkey)
if recursive:
yield from self._content._expected_from_buffers(getkey, recursive)
29 changes: 29 additions & 0 deletions tests/test_2724_expected_from_buffers_recursive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

import numpy as np
import pytest # noqa: F401

import awkward as ak

# 6.6 and 7.7 are inaccessible
layout = ak.contents.listoffsetarray.ListOffsetArray(
ak.index.Index(np.array([1, 4, 4, 6], dtype=np.int64)),
ak.contents.numpyarray.NumpyArray(
np.array([6.6, 1.1, 2.2, 3.3, 4.4, 5.5, 7.7], dtype=np.float64)
),
)


def test_recursive():
form, length, container = ak.to_buffers(layout)
assert form.expected_from_buffers(recursive=True) == {
"node0-offsets": np.dtype("int64"),
"node1-data": np.dtype("float64"),
}


def test_non_recursive():
form, length, container = ak.to_buffers(layout)
assert form.expected_from_buffers(recursive=False) == {
"node0-offsets": np.dtype("int64")
}