Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add .attrs to highlevel objects #2757

Merged
merged 13 commits into from
Nov 8, 2023
44 changes: 44 additions & 0 deletions src/awkward/_attrs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
from __future__ import annotations

from collections.abc import Mapping

from awkward._typing import Any, JSONMapping


def attrs_of_obj(obj, attrs: Mapping | None = None) -> Mapping | None:
from awkward.highlevel import Array, ArrayBuilder, Record

if attrs is not None:
return attrs
elif isinstance(obj, (Array, Record, ArrayBuilder)):
return obj._attrs
else:
return None


def attrs_of(*arrays, attrs: Mapping | None = None) -> Mapping:
# An explicit 'attrs' always wins.
if attrs is not None:
return attrs

copied = False
for x in reversed(arrays):
x_attrs = attrs_of_obj(x)
if x_attrs is None:
continue
if attrs is None:
attrs = x_attrs
elif attrs is x_attrs:
pass
elif not copied:
attrs = dict(attrs)
attrs.update(x_attrs)
copied = True
else:
attrs.update(x_attrs)
return attrs


def without_transient_attrs(attrs: dict[str, Any]) -> JSONMapping:
return {k: v for k, v in attrs.items() if not k.startswith("@")}
8 changes: 4 additions & 4 deletions src/awkward/_backends/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def backend_of_obj(obj, default: D | Sentinel = UNSET) -> Backend | D:


def backend_of(
*objects, default: D | Sentinel = UNSET, coerce_to_common: bool = False
*objects, default: D | Sentinel = UNSET, coerce_to_common: bool = True
) -> Backend | D:
"""
Args:
Expand All @@ -116,9 +116,9 @@ def backend_of(
return common_backend(unique_backends)
else:
raise ValueError(
"could not find singular backend for",
objects,
"and coercion is not permitted",
f"could not find singular backend for "
f"{', '.join(type(t).__name__ for t in objects)} "
f"and coercion is not permitted",
)


Expand Down
69 changes: 48 additions & 21 deletions src/awkward/_connect/numba/arrayview.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from numba.core.errors import NumbaTypeError

import awkward as ak
from awkward._behavior import behavior_of, overlay_behavior
from awkward._layout import wrap_layout
from awkward._behavior import overlay_behavior
from awkward._layout import HighLevelContext, wrap_layout
from awkward._lookup import Lookup
from awkward._nplikes.numpy_like import NumpyMetadata

np = NumpyMetadata.instance()
Expand Down Expand Up @@ -152,7 +153,17 @@ def to_numbatype(form):
########## Lookup


@numba.extending.typeof_impl.register(ak._lookup.Lookup)
class NumbaLookup(Lookup):
def __init__(self, layout, attrs, generator=None):
super().__init__(layout, generator=generator)
self._attrs = attrs

@property
def attrs(self):
return self._attrs
jpivarski marked this conversation as resolved.
Show resolved Hide resolved


@numba.extending.typeof_impl.register(NumbaLookup)
def typeof_Lookup(obj, c):
return LookupType()

Expand Down Expand Up @@ -192,15 +203,21 @@ def unbox_Lookup(lookuptype, lookupobj, c):
class ArrayView:
@classmethod
def fromarray(cls, array):
behavior = behavior_of(array)
layout = ak.operations.to_layout(
array, allow_record=False, allow_unknown=False, primitive_policy="error"
)
with HighLevelContext() as ctx:
layout = ctx.unwrap(
array,
allow_record=False,
allow_unknown=False,
use_from_iter=False,
primitive_policy="error",
string_policy="error",
none_policy="error",
)

return ArrayView(
to_numbatype(layout.form),
behavior,
ak._lookup.Lookup(layout),
ctx.behavior,
NumbaLookup(layout, ctx.attrs),
0,
0,
len(layout),
Expand All @@ -219,7 +236,7 @@ def __init__(self, type, behavior, lookup, pos, start, stop, fields):
def toarray(self):
layout = self.type.tolayout(self.lookup, self.pos, self.fields)
sliced = layout._getitem_range(self.start, self.stop)
return wrap_layout(sliced, self.behavior)
return wrap_layout(sliced, behavior=self.behavior, attrs=self.lookup.attrs)


@numba.extending.typeof_impl.register(ArrayView)
Expand Down Expand Up @@ -579,20 +596,28 @@ def lower_iternext(context, builder, sig, args, result):
class RecordView:
@classmethod
def fromrecord(cls, record):
behavior = behavior_of(record)
layout = ak.operations.to_layout(
record, allow_record=True, allow_unknown=False, primitive_policy="error"
)
with HighLevelContext() as ctx:
layout = ctx.unwrap(
record,
allow_record=True,
allow_unknown=False,
use_from_iter=False,
primitive_policy="error",
string_policy="error",
none_policy="error",
)
array_layout = layout.array

assert isinstance(layout, ak.record.Record)
arraylayout = layout.array

return RecordView(
ArrayView(
to_numbatype(arraylayout.form),
behavior,
ak._lookup.Lookup(arraylayout),
to_numbatype(array_layout.form),
ctx.behavior,
NumbaLookup(array_layout, ctx.attrs),
0,
0,
len(arraylayout),
len(array_layout),
(),
),
layout.at,
Expand All @@ -603,9 +628,11 @@ def __init__(self, arrayview, at):
self.at = at

def torecord(self):
arraylayout = self.arrayview.toarray().layout
array = self.arrayview.toarray()
return wrap_layout(
ak.record.Record(arraylayout, self.at), self.arrayview.behavior
ak.record.Record(array.layout, self.at),
behavior=self.arrayview.behavior,
attrs=array.attrs,
)


Expand Down
18 changes: 16 additions & 2 deletions src/awkward/_connect/numba/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,18 @@ def __init__(self, behavior):
@numba.extending.register_model(ArrayBuilderType)
class ArrayBuilderModel(numba.core.datamodel.models.StructModel):
def __init__(self, dmm, fe_type):
members = [("rawptr", numba.types.voidptr), ("pyptr", numba.types.pyobject)]
members = [
("rawptr", numba.types.voidptr),
("pyptr", numba.types.pyobject),
("pyattrs", numba.types.pyobject),
]
super().__init__(dmm, fe_type, members)


@numba.core.imputils.lower_constant(ArrayBuilderType)
def lower_const_ArrayBuilder(context, builder, arraybuildertype, arraybuilder):
layout = arraybuilder._layout
attrs = arraybuilder._attrs
rawptr = context.get_constant(numba.intp, arraybuilder._layout._ptr)
proxyout = context.make_helper(builder, arraybuildertype)
proxyout.rawptr = builder.inttoptr(
Expand All @@ -52,20 +57,26 @@ def lower_const_ArrayBuilder(context, builder, arraybuildertype, arraybuilder):
proxyout.pyptr = context.add_dynamic_addr(
builder, id(layout), info=str(type(layout))
)
proxyout.pyattrs = context.add_dynamic_addr(
builder, id(attrs), info=str(type(attrs))
)
return proxyout._getvalue()


@numba.extending.unbox(ArrayBuilderType)
def unbox_ArrayBuilder(arraybuildertype, arraybuilderobj, c):
attrs_obj = c.pyapi.object_getattr_string(arraybuilderobj, "_attrs")
inner_obj = c.pyapi.object_getattr_string(arraybuilderobj, "_layout")
rawptr_obj = c.pyapi.object_getattr_string(inner_obj, "_ptr")

proxyout = c.context.make_helper(c.builder, arraybuildertype)
proxyout.rawptr = c.pyapi.long_as_voidptr(rawptr_obj)
proxyout.pyptr = inner_obj
proxyout.pyattrs = attrs_obj

c.pyapi.decref(inner_obj)
c.pyapi.decref(rawptr_obj)
c.pyapi.decref(attrs_obj)

is_error = numba.core.cgutils.is_not_null(c.builder, c.pyapi.err_occurred())
return numba.extending.NativeValue(proxyout._getvalue(), is_error)
Expand All @@ -90,8 +101,11 @@ def box_ArrayBuilder(arraybuildertype, arraybuilderval, c):

proxyin = c.context.make_helper(c.builder, arraybuildertype, arraybuilderval)
c.pyapi.incref(proxyin.pyptr)
attrs_obj = proxyin.pyattrs

out = c.pyapi.call_method(ArrayBuilder_obj, "_wrap", (proxyin.pyptr, behavior_obj))
out = c.pyapi.call_method(
ArrayBuilder_obj, "_wrap", (proxyin.pyptr, behavior_obj, attrs_obj)
)

c.pyapi.decref(ArrayBuilder_obj)
c.pyapi.decref(behavior_obj)
Expand Down
10 changes: 2 additions & 8 deletions src/awkward/_connect/numexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,7 @@ def evaluate(
names, ex_uses_vml = numexpr.necompiler._names_cache[expr_key]
arguments = getArguments(names, local_dict, global_dict)

arrays = [
ak.operations.to_layout(x, allow_record=True, allow_unknown=True)
for x in arguments
]
arrays = [ak.operations.to_layout(x, allow_unknown=True) for x in arguments]

def action(inputs, **ignore):
if all(
Expand Down Expand Up @@ -131,10 +128,7 @@ def re_evaluate(local_dict=None):
names = numexpr.necompiler._numexpr_last["argnames"]
arguments = getArguments(names, local_dict)

arrays = [
ak.operations.to_layout(x, allow_record=True, allow_unknown=True)
for x in arguments
]
arrays = [ak.operations.to_layout(x, allow_unknown=True) for x in arguments]

def action(inputs, **ignore):
if all(
Expand Down
14 changes: 10 additions & 4 deletions src/awkward/_connect/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,14 @@ def _to_rectilinear(arg, backend: Backend):
return arg


def array_function(func, types, args, kwargs: dict[str, Any], behavior: Mapping | None):
def array_function(
func,
types,
args,
kwargs: dict[str, Any],
behavior: Mapping | None,
attrs: Mapping[str, Any] | None = None,
):
function = implemented.get(func)
if function is not None:
return function(*args, **kwargs)
Expand All @@ -106,13 +113,13 @@ def array_function(func, types, args, kwargs: dict[str, Any], behavior: Mapping
result,
allow_record=True,
allow_unknown=True,
allow_none=True,
none_policy="pass-through",
regulararray=True,
use_from_iter=True,
primitive_policy="pass-through",
string_policy="pass-through",
)
return wrap_layout(out, behavior=behavior, allow_other=True)
return wrap_layout(out, behavior=behavior, allow_other=True, attrs=attrs)


def implements(numpy_function):
Expand Down Expand Up @@ -152,7 +159,6 @@ def _array_ufunc_custom_cast(inputs, behavior: Mapping | None, backend):
cast_fcn = find_custom_cast(x, behavior)
maybe_layout = ak.operations.to_layout(
x if cast_fcn is None else cast_fcn(x),
allow_record=True,
allow_unknown=True,
primitive_policy="pass-through",
string_policy="pass-through",
Expand Down
Loading