Skip to content

Commit

Permalink
fix: control attrs better as described in issue #3277 (#3344)
Browse files Browse the repository at this point in the history
* control attrs better as described in issue #3277

* break cyclic ref with weakref

* ensure transients are strings

* style: pre-commit fixes

* fix doc string

Co-authored-by: Angus Hollands <[email protected]>

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Angus Hollands <[email protected]>
  • Loading branch information
3 people authored Dec 18, 2024
1 parent 564126d commit 27faa82
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 57 deletions.
42 changes: 41 additions & 1 deletion src/awkward/_attrs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
from __future__ import annotations

import weakref
from collections.abc import Mapping
from types import MappingProxyType

from awkward._typing import Any, JSONMapping

Expand Down Expand Up @@ -41,4 +43,42 @@ def attrs_of(*arrays, attrs: Mapping | None = None) -> Mapping:


def without_transient_attrs(attrs: dict[str, Any]) -> JSONMapping:
return {k: v for k, v in attrs.items() if not k.startswith("@")}
return {
k: v for k, v in attrs.items() if not (isinstance(k, str) and k.startswith("@"))
}


class Attrs(Mapping):
def __init__(self, ref, data: Mapping[str, Any]):
self._ref = weakref.ref(ref)
self._data = _freeze_attrs(data)

def __getitem__(self, key: str):
return self._data[key]

def __setitem__(self, key: str, value: Any):
ref = self._ref()
if ref is None:
msg = "The reference array has been deleted. If you still need to set attributes, convert this 'Attrs' instance to a dict with '.to_dict()'."
raise ValueError(msg)
ref._attrs = _unfreeze_attrs(self._data) | {key: value}

def __iter__(self):
return iter(self._data)

def __len__(self):
return len(self._data)

def __repr__(self):
return f"Attrs({_unfreeze_attrs(self._data)!r})"

def to_dict(self):
return _unfreeze_attrs(self._data)


def _freeze_attrs(attrs: Mapping[str, Any]) -> Mapping[str, Any]:
return MappingProxyType(attrs)


def _unfreeze_attrs(attrs: Mapping[str, Any]) -> dict[str, Any]:
return dict(attrs)
30 changes: 15 additions & 15 deletions src/awkward/highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import awkward as ak
import awkward._connect.hist
from awkward._attrs import attrs_of, without_transient_attrs
from awkward._attrs import Attrs, attrs_of, without_transient_attrs
from awkward._backends.dispatch import register_backend_lookup_factory
from awkward._backends.numpy import NumpyBackend
from awkward._behavior import behavior_of, get_array_class, get_record_class
Expand All @@ -42,7 +42,7 @@
unpickle_record_schema_1,
)
from awkward._regularize import is_non_string_like_iterable
from awkward._typing import Any, MutableMapping, TypeVar
from awkward._typing import Any, TypeVar
from awkward._util import STDOUT
from awkward.prettyprint import Formatter
from awkward.prettyprint import valuestr as prettyprint_valuestr
Expand Down Expand Up @@ -337,7 +337,7 @@ def __init__(
if behavior is not None and not isinstance(behavior, Mapping):
raise TypeError("behavior must be None or a mapping")

if attrs is not None and not isinstance(attrs, MutableMapping):
if attrs is not None and not isinstance(attrs, Mapping):
raise TypeError("attrs must be None or a mapping")

if named_axis:
Expand Down Expand Up @@ -379,9 +379,9 @@ def _update_class(self):
self.__class__ = get_array_class(self._layout, self._behavior)

@property
def attrs(self) -> Mapping:
def attrs(self) -> Attrs:
"""
The mutable mapping containing top-level metadata, which is serialised
The mapping containing top-level metadata, which is serialised
with the array during pickling.
Keys prefixed with `@` are identified as "transient" attributes
Expand All @@ -390,14 +390,14 @@ def attrs(self) -> Mapping:
"""
if self._attrs is None:
self._attrs = {}
return self._attrs
return Attrs(self, self._attrs)

@attrs.setter
def attrs(self, value: Mapping[str, Any]):
if isinstance(value, Mapping):
self._attrs = value
self._attrs = dict(value)
else:
raise TypeError("attrs must be a mapping")
raise TypeError("attrs must be a 'Attrs' mapping")

@property
def layout(self):
Expand Down Expand Up @@ -1846,7 +1846,7 @@ def __init__(
if behavior is not None and not isinstance(behavior, Mapping):
raise TypeError("behavior must be None or mapping")

if attrs is not None and not isinstance(attrs, MutableMapping):
if attrs is not None and not isinstance(attrs, Mapping):
raise TypeError("attrs must be None or a mapping")

if named_axis:
Expand Down Expand Up @@ -1883,7 +1883,7 @@ def _update_class(self):
self.__class__ = get_record_class(self._layout, self._behavior)

@property
def attrs(self) -> Mapping[str, Any]:
def attrs(self) -> Attrs:
"""
The mapping containing top-level metadata, which is serialised
with the record during pickling.
Expand All @@ -1894,12 +1894,12 @@ def attrs(self) -> Mapping[str, Any]:
"""
if self._attrs is None:
self._attrs = {}
return self._attrs
return Attrs(self, self._attrs)

@attrs.setter
def attrs(self, value: Mapping[str, Any]):
if isinstance(value, Mapping):
self._attrs = value
self._attrs = dict(value)
else:
raise TypeError("attrs must be a mapping")

Expand Down Expand Up @@ -2672,7 +2672,7 @@ def _wrap(cls, layout, behavior=None, attrs=None):
return out

@property
def attrs(self) -> Mapping[str, Any]:
def attrs(self) -> Attrs:
"""
The mapping containing top-level metadata, which is serialised
with the array during pickling.
Expand All @@ -2683,12 +2683,12 @@ def attrs(self) -> Mapping[str, Any]:
"""
if self._attrs is None:
self._attrs = {}
return self._attrs
return Attrs(self, self._attrs)

@attrs.setter
def attrs(self, value: Mapping[str, Any]):
if isinstance(value, Mapping):
self._attrs = value
self._attrs = dict(value)
else:
raise TypeError("attrs must be a mapping")

Expand Down
52 changes: 26 additions & 26 deletions tests/test_2757_attrs_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_set_attrs():
assert array.attrs == {}

array.attrs = OTHER_ATTRS
assert array.attrs is OTHER_ATTRS
assert array.attrs == OTHER_ATTRS

with pytest.raises(TypeError):
array.attrs = "Hello world!"
Expand All @@ -52,7 +52,7 @@ def test_transient_metadata_persists():
attrs = {**SOME_ATTRS, "@transient_key": lambda: None}
array = ak.Array([[1, 2, 3]], attrs=attrs)
num = ak.num(array)
assert num.attrs is attrs
assert num.attrs == attrs


@pytest.mark.parametrize(
Expand All @@ -79,13 +79,13 @@ def test_single_arg_ops(func):
# Carry from argument
assert (
func([[1, 2, 3, 4], [5], [10]], axis=-1, highlevel=True, attrs=SOME_ATTRS).attrs
is SOME_ATTRS
== SOME_ATTRS
)
# Carry from outer array
array = ak.Array([[1, 2, 3, 4], [5], [10]], attrs=SOME_ATTRS)
assert func(array, axis=-1, highlevel=True).attrs is SOME_ATTRS
assert func(array, axis=-1, highlevel=True).attrs == SOME_ATTRS
# Carry from argument exclusively
assert func(array, axis=-1, highlevel=True, attrs=OTHER_ATTRS).attrs is OTHER_ATTRS
assert func(array, axis=-1, highlevel=True, attrs=OTHER_ATTRS).attrs == OTHER_ATTRS


@pytest.mark.parametrize(
Expand Down Expand Up @@ -134,15 +134,15 @@ def test_string_operations_unary(func):
highlevel=True,
attrs=SOME_ATTRS,
).attrs
is SOME_ATTRS
== SOME_ATTRS
)
# Carry from outer array
array = ak.Array(
[["hello", "world!"], [], ["it's a beautiful day!"]], attrs=SOME_ATTRS
)
assert func(array, highlevel=True).attrs is SOME_ATTRS
assert func(array, highlevel=True).attrs == SOME_ATTRS
# Carry from argument exclusively
assert func(array, highlevel=True, attrs=OTHER_ATTRS).attrs is OTHER_ATTRS
assert func(array, highlevel=True, attrs=OTHER_ATTRS).attrs == OTHER_ATTRS


@pytest.mark.parametrize(
Expand Down Expand Up @@ -188,15 +188,15 @@ def test_string_operations_unary_with_arg(func, arg):
highlevel=True,
attrs=SOME_ATTRS,
).attrs
is SOME_ATTRS
== SOME_ATTRS
)
# Carry from outer array
array = ak.Array(
[["hello", "world!"], [], ["it's a beautiful day!"]], attrs=SOME_ATTRS
)
assert func(array, arg, highlevel=True).attrs is SOME_ATTRS
assert func(array, arg, highlevel=True).attrs == SOME_ATTRS
# Carry from argument exclusively
assert func(array, arg, highlevel=True, attrs=OTHER_ATTRS).attrs is OTHER_ATTRS
assert func(array, arg, highlevel=True, attrs=OTHER_ATTRS).attrs == OTHER_ATTRS


def test_string_operations_unary_with_arg_slice():
Expand All @@ -220,16 +220,16 @@ def test_string_operations_unary_with_arg_slice():
highlevel=True,
attrs=SOME_ATTRS,
).attrs
is SOME_ATTRS
== SOME_ATTRS
)
# Carry from outer array
array = ak.Array(
[["hello", "world!"], [], ["it's a beautiful day!"]], attrs=SOME_ATTRS
)
assert ak.str.slice(array, 1, highlevel=True).attrs is SOME_ATTRS
assert ak.str.slice(array, 1, highlevel=True).attrs == SOME_ATTRS
# Carry from argument exclusively
assert (
ak.str.slice(array, 1, highlevel=True, attrs=OTHER_ATTRS).attrs is OTHER_ATTRS
ak.str.slice(array, 1, highlevel=True, attrs=OTHER_ATTRS).attrs == OTHER_ATTRS
)


Expand Down Expand Up @@ -262,13 +262,13 @@ def test_string_operations_binary(func):
highlevel=True,
attrs=SOME_ATTRS,
).attrs
is SOME_ATTRS
== SOME_ATTRS
)
# Carry from first array
array = ak.Array(
[["hello", "world!"], [], ["it's a beautiful day!"]], attrs=SOME_ATTRS
)
assert func(array, ["hello"], highlevel=True).attrs is SOME_ATTRS
assert func(array, ["hello"], highlevel=True).attrs == SOME_ATTRS

# Carry from second array
value_array = ak.Array(["hello"], attrs=OTHER_ATTRS)
Expand All @@ -278,7 +278,7 @@ def test_string_operations_binary(func):
value_array,
highlevel=True,
).attrs
is OTHER_ATTRS
== OTHER_ATTRS
)
# Carry from both arrays
assert func(
Expand All @@ -289,7 +289,7 @@ def test_string_operations_binary(func):

# Carry from argument
assert (
func(array, value_array, highlevel=True, attrs=OTHER_ATTRS).attrs is OTHER_ATTRS
func(array, value_array, highlevel=True, attrs=OTHER_ATTRS).attrs == OTHER_ATTRS
)


Expand All @@ -298,38 +298,38 @@ def test_broadcasting_arrays():
right = ak.Array([1], attrs=OTHER_ATTRS)

left_result, right_result = ak.broadcast_arrays(left, right)
assert left_result.attrs is SOME_ATTRS
assert right_result.attrs is OTHER_ATTRS
assert left_result.attrs == SOME_ATTRS
assert right_result.attrs == OTHER_ATTRS


def test_broadcasting_fields():
left = ak.Array([{"x": 1}, {"x": 2}], attrs=SOME_ATTRS)
right = ak.Array([{"y": 1}, {"y": 2}], attrs=OTHER_ATTRS)

left_result, right_result = ak.broadcast_fields(left, right)
assert left_result.attrs is SOME_ATTRS
assert right_result.attrs is OTHER_ATTRS
assert left_result.attrs == SOME_ATTRS
assert right_result.attrs == OTHER_ATTRS


def test_numba_arraybuilder():
numba = pytest.importorskip("numba")
builder = ak.ArrayBuilder(attrs=SOME_ATTRS)
assert builder.attrs is SOME_ATTRS
assert builder.attrs == SOME_ATTRS

@numba.njit
def func(array):
return array

assert func(builder).attrs is SOME_ATTRS
assert func(builder).attrs == SOME_ATTRS


def test_numba_array():
numba = pytest.importorskip("numba")
array = ak.Array([1, 2, 3], attrs=SOME_ATTRS)
assert array.attrs is SOME_ATTRS
assert array.attrs == SOME_ATTRS

@numba.njit
def func(array):
return array

assert func(array).attrs is SOME_ATTRS
assert func(array).attrs == SOME_ATTRS
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_ArrayBuilder_behavior():
SOME_ATTRS = {"FOO": "BAR"}
builder = ak.ArrayBuilder(behavior=SOME_ATTRS)

assert builder.behavior is SOME_ATTRS
assert builder.behavior == SOME_ATTRS
assert func(builder).behavior == SOME_ATTRS


Expand Down
4 changes: 2 additions & 2 deletions tests/test_2806_attrs_typetracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_typetracer_with_report():
form = layout.form_with_key("node{id}")

meta, report = typetracer_with_report(form, highlevel=True, attrs=SOME_ATTRS)
assert meta.attrs is SOME_ATTRS
assert meta.attrs == SOME_ATTRS

meta, report = typetracer_with_report(form, highlevel=True, attrs=None)
assert meta._attrs is None
Expand All @@ -44,5 +44,5 @@ def test_function(function):
"z": [[0.1, 0.1, 0.2], [3, 1, 2], [2, 1, 2]],
}
)
assert function(array, attrs=SOME_ATTRS).attrs is SOME_ATTRS
assert function(array, attrs=SOME_ATTRS).attrs == SOME_ATTRS
assert function(array)._attrs is None
6 changes: 3 additions & 3 deletions tests/test_2837_ufunc_attrs_behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ def test():
def test_unary():
x = ak.Array([1, 2, 3], behavior={"foo": "BAR"}, attrs={"hello": "world"})
y = -x
assert y.attrs is x.attrs
assert y.attrs == x.attrs
assert x.behavior is y.behavior


def test_two_return():
x = ak.Array([1, 2, 3], behavior={"foo": "BAR"}, attrs={"hello": "world"})
y, y_ret = divmod(x, 2)
assert y.attrs is y_ret.attrs
assert y.attrs is x.attrs
assert y.attrs == y_ret.attrs
assert y.attrs == x.attrs

assert y.behavior is y_ret.behavior
assert y.behavior is x.behavior
Loading

0 comments on commit 27faa82

Please sign in to comment.