diff --git a/src/awkward/_attrs.py b/src/awkward/_attrs.py index cf7a84f9fc..1a904ba9a9 100644 --- a/src/awkward/_attrs.py +++ b/src/awkward/_attrs.py @@ -43,15 +43,15 @@ def attrs_of(*arrays, attrs: Mapping | None = None) -> Mapping: def without_transient_attrs(attrs: dict[str, Any]) -> JSONMapping: - return { - k: v for k, v in attrs.items() if not (isinstance(k, str) and k.startswith("@")) - } + return {k: v for k, v in attrs.items() if not k.startswith("@")} class Attrs(Mapping): def __init__(self, ref, data: Mapping[str, Any]): self._ref = weakref.ref(ref) - self._data = _freeze_attrs(data) + self._data = _freeze_attrs( + {_enforce_str_key(k): v for k, v in _unfreeze_attrs(data).items()} + ) def __getitem__(self, key: str): return self._data[key] @@ -61,7 +61,7 @@ def __setitem__(self, key: str, value: Any): if ref is None: msg = "The reference array has been deleted. If you still need to set attributes, convert this 'Attrs' instance to a dict with '.to_dict()'." raise ValueError(msg) - ref._attrs = _unfreeze_attrs(self._data) | {key: value} + ref._attrs = _unfreeze_attrs(self._data) | {_enforce_str_key(key): value} def __iter__(self): return iter(self._data) @@ -76,6 +76,12 @@ def to_dict(self): return _unfreeze_attrs(self._data) +def _enforce_str_key(key: Any) -> str: + if not isinstance(key, str): + raise TypeError(f"'attrs' keys must be strings, got: {key!r}") + return key + + def _freeze_attrs(attrs: Mapping[str, Any]) -> Mapping[str, Any]: return MappingProxyType(attrs) diff --git a/tests/test_3350_enforce_attrs_string_keys.py b/tests/test_3350_enforce_attrs_string_keys.py new file mode 100644 index 0000000000..f07ac6b3a5 --- /dev/null +++ b/tests/test_3350_enforce_attrs_string_keys.py @@ -0,0 +1,18 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE +# ruff: noqa: E402 + +from __future__ import annotations + +import pytest + +import awkward as ak + + +def test(): + arr = ak.Array([1]) + + with pytest.raises( + TypeError, + match="'attrs' keys must be strings, got: 1", + ): + arr.attrs[1] = "foo"