Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: update __class__ for both layout and behavior consistently #2759

Merged
merged 7 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 39 additions & 16 deletions src/awkward/highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,13 +313,18 @@ def __init__(
if backend is not None and backend != ak.operations.backend(layout):
layout = ak.operations.to_backend(layout, backend, highlevel=False)

self.layout = layout
self.behavior = behavior
if behavior is not None and not isinstance(behavior, Mapping):
raise TypeError("behavior must be None or mapping")

self._layout = layout
self._behavior = behavior

docstr = layout.purelist_parameter("__doc__")
if isinstance(docstr, str):
self.__doc__ = docstr

self._update_class()

if check_valid:
ak.operations.validity_error(self, exception=True)

Expand All @@ -330,6 +335,10 @@ def __init_subclass__(cls, **kwargs):

_histogram_module_ = awkward._connect.hist

def _update_class(self):
self._numbaview = None
self.__class__ = get_array_class(self._layout, self._behavior)

@property
def layout(self):
"""
Expand Down Expand Up @@ -377,7 +386,7 @@ def layout(self):
def layout(self, layout):
if isinstance(layout, ak.contents.Content):
self._layout = layout
self._numbaview = None
self._update_class()
else:
raise TypeError("layout must be a subclass of ak.contents.Content")

Expand All @@ -403,8 +412,8 @@ def behavior(self):
@behavior.setter
def behavior(self, behavior):
if behavior is None or isinstance(behavior, Mapping):
self.__class__ = get_array_class(self._layout, behavior)
self._behavior = behavior
self._update_class()
else:
raise TypeError("behavior must be None or a dict")

Expand Down Expand Up @@ -1516,8 +1525,9 @@ def __setstate__(self, state):
buffer_key="{form_key}-{attribute}",
byteorder="<",
)
self.layout = layout
self.behavior = behavior
self._layout = layout
self._behavior = behavior
self._update_class()

def __copy__(self):
return Array(self._layout, behavior=self._behavior)
Expand Down Expand Up @@ -1556,9 +1566,9 @@ def cpp_type(self):

if self._cpp_type is None:
self._generator = ak._connect.cling.togenerator(
self.layout.form, flatlist_as_rvec=False
self._layout.form, flatlist_as_rvec=False
)
self._lookup = ak._lookup.Lookup(self.layout)
self._lookup = ak._lookup.Lookup(self._layout)
self._generator.generate(cppyy.cppdef)
self._cpp_type = f"awkward::{self._generator.class_type()}"

Expand Down Expand Up @@ -1659,13 +1669,18 @@ def __init__(
if library is not None and library != ak.operations.library(layout):
layout = ak.operations.to_library(layout, library, highlevel=False)

self.layout = layout
self.behavior = behavior
if behavior is not None and not isinstance(behavior, Mapping):
raise TypeError("behavior must be None or mapping")

self._layout = layout
self._behavior = behavior

docstr = layout.purelist_parameter("__doc__")
if isinstance(docstr, str):
self.__doc__ = docstr

self._update_class()

if check_valid:
ak.operations.validity_error(self, exception=True)

Expand All @@ -1674,6 +1689,10 @@ def __init_subclass__(cls, **kwargs):

ak.jax.register_behavior_class(cls)

def _update_class(self):
self._numbaview = None
self.__class__ = get_record_class(self._layout, self._behavior)

@property
def layout(self):
"""
Expand Down Expand Up @@ -1715,7 +1734,7 @@ def layout(self):
def layout(self, layout):
if isinstance(layout, ak.record.Record):
self._layout = layout
self._numbaview = None
self._update_class()
else:
raise TypeError("layout must be a subclass of ak.record.Record")

Expand All @@ -1741,8 +1760,8 @@ def behavior(self):
@behavior.setter
def behavior(self, behavior):
if behavior is None or isinstance(behavior, Mapping):
self.__class__ = get_record_class(self._layout, behavior)
self._behavior = behavior
self._update_class()
else:
raise TypeError("behavior must be None or a dict")

Expand Down Expand Up @@ -2177,8 +2196,9 @@ def __setstate__(self, state):
byteorder="<",
)
layout = ak.record.Record(layout, at)
self.layout = layout
self.behavior = behavior
self._layout = layout
self._behavior = behavior
self._update_class()

def __copy__(self):
return Record(self._layout, behavior=self._behavior)
Expand Down Expand Up @@ -2329,8 +2349,11 @@ class ArrayBuilder(Sized):
"""

def __init__(self, *, behavior=None, initial=1024, resize=8):
if behavior is not None and not isinstance(behavior, Mapping):
raise TypeError("behavior must be None or mapping")

self._layout = _ext.ArrayBuilder(initial=initial, resize=resize)
self.behavior = behavior
self._behavior = behavior

@classmethod
def _wrap(cls, layout, behavior=None):
Expand All @@ -2350,7 +2373,7 @@ def _wrap(cls, layout, behavior=None):
assert isinstance(layout, _ext.ArrayBuilder)
out = cls.__new__(cls)
out._layout = layout
out.behavior = behavior
out._behavior = behavior
return out

@property
Expand Down
53 changes: 53 additions & 0 deletions tests/test_2759_update_class_consistently.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE


import awkward as ak


class ArrayBehavior(ak.Array):
def impl(self):
return True


class RecordBehavior(ak.Record):
def impl(self):
return True


BEHAVIOR = {("*", "impl"): ArrayBehavior, "impl": RecordBehavior}


def test_array_layout():
array = ak.Array([{"x": 1}, {"y": 3}], behavior=BEHAVIOR)
assert not isinstance(array, ArrayBehavior)

array.layout = ak.with_name([{"x": 1}, {"y": 3}], "impl", highlevel=False)
assert isinstance(array, ArrayBehavior)
assert array.impl()


def test_array_behavior():
array = ak.Array([{"x": 1}, {"y": 3}], with_name="impl")
assert not isinstance(array, ArrayBehavior)

array.behavior = BEHAVIOR
assert isinstance(array, ArrayBehavior)
assert array.impl()


def test_record_layout():
record = ak.Record({"x": 1}, behavior=BEHAVIOR)
assert not isinstance(record, RecordBehavior)

record.layout = ak.with_name({"x": 1}, "impl", highlevel=False)
assert isinstance(record, RecordBehavior)
assert record.impl()


def test_record_behavior():
record = ak.Record({"x": 1}, with_name="impl")
assert not isinstance(record, RecordBehavior)

record.behavior = BEHAVIOR
assert isinstance(record, RecordBehavior)
assert record.impl()