Skip to content

Commit

Permalink
feat: infer unknown lengths from context in from_buffers (#2732)
Browse files Browse the repository at this point in the history
* fix: don't check lengths during typetracer time

* fix: don't touch shape for ndim

* fix: infer unknown lengths from context in from_buffers

* test: update tests

* fix: ensure we touch all shapes if length is unknown

* fix: allow known-buffers in unknown-length contexts

* test: update tests
  • Loading branch information
agoose77 authored Oct 3, 2023
1 parent 5130158 commit 626ff08
Show file tree
Hide file tree
Showing 10 changed files with 393 additions and 265 deletions.
1 change: 0 additions & 1 deletion src/awkward/_nplikes/typetracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,6 @@ def nplike(self) -> TypeTracer:

@property
def ndim(self) -> int:
self.touch_shape()
return len(self._shape)

@property
Expand Down
19 changes: 11 additions & 8 deletions src/awkward/contents/bitmaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,21 +147,22 @@ def __init__(
type(self).__name__, repr(valid_when)
)
)
if length is not unknown_length:
if not (is_integer(length) and length >= 0):
raise TypeError(
"{} 'length' must be a non-negative integer, not {}".format(
type(self).__name__, length
)
if length is not unknown_length and not (is_integer(length) and length >= 0):
raise TypeError(
"{} 'length' must be a non-negative integer, not {}".format(
type(self).__name__, length
)
)
if not isinstance(lsb_order, bool):
raise TypeError(
"{} 'lsb_order' must be boolean, not {}".format(
type(self).__name__, repr(lsb_order)
)
)
if (
not (length is unknown_length or mask.length is unknown_length)
content.backend.index_nplike.known_data
and length is not unknown_length
and mask.length is not unknown_length
and length > mask.length * 8
):
raise ValueError(
Expand All @@ -170,7 +171,9 @@ def __init__(
)
)
if (
not (length is unknown_length or content.length is unknown_length)
content.backend.index_nplike.known_data
and length is not unknown_length
and mask.length is not unknown_length
and length > content.length * 8
):
raise ValueError(
Expand Down
4 changes: 3 additions & 1 deletion src/awkward/contents/bytemaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ def __init__(self, mask, content, valid_when, *, parameters=None):
)
)
if (
not (mask.length is unknown_length or content.length is unknown_length)
content.backend.index_nplike.known_data
and mask.length is not unknown_length
and content.length is not unknown_length
and mask.length > content.length
):
raise ValueError(
Expand Down
6 changes: 1 addition & 5 deletions src/awkward/contents/listarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,7 @@ def __init__(self, starts, stops, content, *, parameters=None):
type(self).__name__, repr(content)
)
)
if (
starts.nplike.known_data
and stops.nplike.known_data
and starts.length > stops.length
):
if content.backend.index_nplike.known_data and starts.length > stops.length:
raise ValueError(
"{} len(starts) ({}) must be <= len(stops) ({})".format(
type(self).__name__, starts.length, stops.length
Expand Down
6 changes: 5 additions & 1 deletion src/awkward/contents/listoffsetarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@ def __init__(self, offsets, content, *, parameters=None):
type(self).__name__, repr(content)
)
)
if offsets.length is not unknown_length and offsets.length == 0:
if (
content.backend.index_nplike.known_data
and offsets.length is not unknown_length
and offsets.length == 0
):
raise ValueError(
f"{type(self).__name__} len(offsets) ({offsets.length}) must be >= 1"
)
Expand Down
74 changes: 40 additions & 34 deletions src/awkward/contents/recordarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
parameters_intersect,
type_parameters_equal,
)
from awkward._regularize import is_integer
from awkward._slicing import NO_HEAD
from awkward._typing import TYPE_CHECKING, Callable, Final, Self, SupportsIndex, final
from awkward._util import UNSET
Expand Down Expand Up @@ -145,8 +144,15 @@ def __init__(
parameters=None,
backend=None,
):
if not (length is None or length is unknown_length):
length = int(length) # TODO: this should not happen!
if length is not None and length is not unknown_length:
try:
length = int(length) # TODO: this should not happen!
except TypeError:
raise TypeError(
"{} 'length' must be a non-negative integer or None, not {}".format(
type(self).__name__, repr(length)
)
) from None
if not isinstance(contents, Iterable):
raise TypeError(
"{} 'contents' must be iterable, not {}".format(
Expand All @@ -156,6 +162,7 @@ def __init__(
if not isinstance(contents, list):
contents = list(contents)

# Take backend from contents
for content in contents:
if not isinstance(content, Content):
raise TypeError(
Expand All @@ -179,53 +186,50 @@ def __init__(
backend = NumpyBackend.instance()

if length is None:
# Require a length if we have no contents
if len(contents) == 0:
raise TypeError(
"{} if len(contents) == 0, a 'length' must be specified".format(
type(self).__name__
)
)

if backend.nplike.known_data:
for content in contents:
assert content.length is not unknown_length
# First time we're setting length, and content.length is not unknown_length
if length is None:
length = content.length
# length is not unknown_length, content.length is not unknown_length
else:
length = min(length, content.length)
else:
for content in contents:
# First time we're setting length, and content.length is not unknown_length
if length is None:
length = content.length
# Any unknown_length means all unknown_length
if length is unknown_length:
break
# `length` is set, can't be unknown_length
elif content.length is unknown_length:
length = unknown_length
# Take length as minimum length of contents. This will touch shapes
it_contents = iter(contents)
for content in it_contents:
# First time we're setting length, and content.length is not unknown_length
if length is None:
length = content.length
# Any unknown_length means all unknown_length
if length is unknown_length:
break
# `length` is set, can't be unknown_length
else:
length = min(length, content.length)
# `length` is set, can't be unknown_length
elif content.length is unknown_length:
length = unknown_length
break
# `length` is set, can't be unknown_length
else:
length = min(length, content.length)

# Touch everything else
for content in it_contents:
content._touch_shape(False)

# Otherwise
elif length is not unknown_length:
# Ensure lengths are not smaller than given length.
for content in contents:
if content.length is not unknown_length and content.length < length:
if (
backend.index_nplike.known_data
and content.length is not unknown_length
and content.length < length
):
raise ValueError(
"{} len(content) ({}) must be >= length ({}) for all 'contents'".format(
type(self).__name__, content.length, length
)
)

if not (is_integer(length) and length >= 0):
raise TypeError(
"{} 'length' must be a non-negative integer or None, not {}".format(
type(self).__name__, repr(length)
)
)

if isinstance(fields, Iterable):
if not isinstance(fields, list):
fields = list(fields)
Expand All @@ -250,6 +254,8 @@ def __init__(

self._contents = contents
self._fields = fields
# TODO: maybe need to store original `length` arg separately to the
# computed version (for typetracer conversions)
self._length = length
self._init(parameters, backend)

Expand Down
31 changes: 12 additions & 19 deletions src/awkward/contents/regulararray.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,28 +120,21 @@ def __init__(self, content, size, zeros_length=0, *, parameters=None):
type(self).__name__
)
)
else:
if not (is_integer(size) and size >= 0):
raise TypeError(
"{} 'size' must be a non-negative integer, not {}".format(
type(self).__name__, size
)
elif not (is_integer(size) and size >= 0):
raise TypeError(
"{} 'size' must be a non-negative integer, not {}".format(
type(self).__name__, size
)
)

if zeros_length is unknown_length:
if content.backend.index_nplike.known_data:
raise TypeError(
"{} 'zeros_length' must be a non-negative integer for backends with known shapes, not None".format(
type(self).__name__
)
)
else:
if not (is_integer(zeros_length) and zeros_length >= 0):
raise TypeError(
"{} 'zeros_length' must be a non-negative integer, not {}".format(
type(self).__name__, zeros_length
)
if zeros_length is not unknown_length and not (
is_integer(zeros_length) and zeros_length >= 0
):
raise TypeError(
"{} 'zeros_length' must be a non-negative integer, not {}".format(
type(self).__name__, zeros_length
)
)

if parameters is not None and parameters.get("__array__") == "string":
if not content.is_numpy or not content.parameter("__array__") == "char":
Expand Down
22 changes: 12 additions & 10 deletions src/awkward/contents/unionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,6 @@ def __init__(self, tags, index, contents, *, parameters=None):
"try {0}.simplified instead".format(type(self).__name__)
)

if (
not (tags.length is unknown_length or index.length is unknown_length)
and tags.length > index.length
):
raise ValueError(
"{} len(tags) ({}) must be <= len(index) ({})".format(
type(self).__name__, tags.length, index.length
)
)

backend = None
for content in contents:
if backend is None:
Expand All @@ -171,6 +161,18 @@ def __init__(self, tags, index, contents, *, parameters=None):
)
)

if (
backend.index_nplike.known_data
and tags.length is not unknown_length
and index.length is not unknown_length
and tags.length > index.length
):
raise ValueError(
"{} len(tags) ({}) must be <= len(index) ({})".format(
type(self).__name__, tags.length, index.length
)
)

assert tags.nplike is backend.index_nplike
assert index.nplike is backend.index_nplike

Expand Down
14 changes: 6 additions & 8 deletions src/awkward/operations/ak_from_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,20 +149,18 @@ def _from_buffer(
# for the parent of this node. Thus, this node and its children *must* only
# contain placeholders
if count is unknown_length:
if not isinstance(buffer, PlaceholderArray):
raise AssertionError("Encountered unknown length for concrete buffer")
# We may actually have a known buffer here, but as we do not know the length,
# we cannot safely trim it. Thus, introduce a placeholder anyway
return PlaceholderArray(nplike, (unknown_length,), dtype)
# Known-length information implies that we should have known-length buffers here
# Therefore, placeholders without shape information are not permitted
# We could choose to make this an error, and have the caller re-implement some
# of #ak.from_buffers, or we can just introduce the known lengths where possible
elif isinstance(buffer, PlaceholderArray) and buffer.size is unknown_length:
return PlaceholderArray(nplike, (count,), dtype)
elif isinstance(buffer, PlaceholderArray) or nplike.is_own_array(buffer):
# Require 1D buffers
array = nplike.reshape(buffer.view(dtype), shape=(-1,), copy=False)

# Raise if the buffer we encountered isn't definitely-sized
if array.size is unknown_length:
raise AssertionError(
"Encountered unknown length for placeholder in context where length should be known"
)
if array.size < count:
raise TypeError(
f"size of array ({array.size}) is less than size of form ({count})"
Expand Down
Loading

0 comments on commit 626ff08

Please sign in to comment.