Skip to content

Commit

Permalink
feat!: move ufunc preparation to nplike (#2767)
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 authored Oct 24, 2023
1 parent 8927adb commit 97cf35f
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 65 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ repos:
rev: v2.2.6
hooks:
- id: codespell
args: ["-L", "ue,subjet,parms,fo,numer,thre"]
args: ["-L", "ue,subjet,parms,fo,numer,thre,nin,nout"]

- repo: local
hooks:
Expand Down
12 changes: 2 additions & 10 deletions src/awkward/_backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import awkward as ak
from awkward._kernels import KernelError
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpylike import ArrayLike, NumpyLike, NumpyMetadata
from awkward._nplikes.numpylike import NumpyLike, NumpyMetadata
from awkward._singleton import PublicSingleton
from awkward._typing import Callable, Protocol, Tuple, TypeAlias, TypeVar, Unpack
from awkward._typing import Callable, Tuple, TypeAlias, TypeVar, Unpack

np = NumpyMetadata.instance()
numpy = Numpy.instance()
Expand All @@ -19,11 +19,6 @@
KernelType: TypeAlias = "Callable[..., KernelError | None]"


class UfuncLike(Protocol):
def __call__(self, *args: ArrayLike, **kwargs) -> ArrayLike:
...


class Backend(PublicSingleton, ABC):
name: str

Expand All @@ -43,9 +38,6 @@ def __getitem__(self, key: KernelKeyType) -> KernelType:
def prepare_reducer(self, reducer: ak._reducers.Reducer) -> ak._reducers.Reducer:
return reducer

def prepare_ufunc(self, ufunc: UfuncLike) -> UfuncLike:
return ufunc

def format_kernel_error(
self,
error: KernelError,
Expand Down
7 changes: 1 addition & 6 deletions src/awkward/_backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import awkward_cpp

import awkward as ak
from awkward._backends.backend import Backend, KernelKeyType, UfuncLike
from awkward._backends.backend import Backend, KernelKeyType
from awkward._backends.dispatch import register_backend
from awkward._kernels import JaxKernel
from awkward._nplikes.jax import Jax
Expand Down Expand Up @@ -43,8 +43,3 @@ def prepare_reducer(self, reducer: ak._reducers.Reducer) -> ak._reducers.Reducer
from awkward._connect.jax import get_jax_reducer

return get_jax_reducer(reducer)

def prepare_ufunc(self, ufunc: UfuncLike) -> UfuncLike:
from awkward._connect.jax import get_jax_ufunc

return get_jax_ufunc(ufunc)
9 changes: 2 additions & 7 deletions src/awkward/_connect/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,13 +385,8 @@ def action(inputs, **ignore):
parameters_intersect, (c._parameters for c in contents)
)

args = [x.data if isinstance(x, NumpyArray) else x for x in inputs]

# Give backend a chance to change the ufunc implementation
impl = backend.prepare_ufunc(ufunc)

# Invoke ufunc
result = impl(*args, **kwargs)
input_args = [x.data if isinstance(x, NumpyArray) else x for x in inputs]
result = backend.nplike.apply_ufunc(ufunc, method, input_args, kwargs)

if isinstance(result, tuple):
return tuple(
Expand Down
27 changes: 27 additions & 0 deletions src/awkward/_nplikes/array_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
IndexType,
NumpyLike,
NumpyMetadata,
UfuncLike,
UniqueAllResult,
)
from awkward._nplikes.placeholder import PlaceholderArray
Expand All @@ -22,6 +23,9 @@
class ArrayModuleNumpyLike(NumpyLike):
known_data: Final = True

def prepare_ufunc(self, ufunc: UfuncLike) -> UfuncLike:
return ufunc

############################ array creation

def asarray(
Expand Down Expand Up @@ -146,6 +150,29 @@ def searchsorted(

############################ manipulation

def apply_ufunc(
self,
ufunc: UfuncLike,
method: str,
args: list[Any],
kwargs: dict[str, Any] | None = None,
) -> ArrayLike | tuple[ArrayLike]:
# Determine input argument dtypes
input_arg_dtypes = [getattr(obj, "dtype", type(obj)) for obj in args]
# Resolve these for the given ufunc
arg_dtypes = tuple(input_arg_dtypes + [None] * ufunc.nout)
resolved_dtypes = ufunc.resolve_dtypes(arg_dtypes)
# Interpret the arguments under these dtypes
resolved_args = [
self.asarray(arg, dtype=dtype) for arg, dtype in zip(args, resolved_dtypes)
]
# Broadcast these resolved arguments
broadcasted_args = self.broadcast_arrays(*resolved_args)
# Allow other nplikes to replace implementation
impl = self.prepare_ufunc(ufunc)
# Compute the result
return impl(*broadcasted_args, **kwargs)

def broadcast_arrays(self, *arrays: ArrayLike) -> list[ArrayLike]:
assert not any(isinstance(x, PlaceholderArray) for x in arrays)
return self._module.broadcast_arrays(*arrays)
Expand Down
7 changes: 6 additions & 1 deletion src/awkward/_nplikes/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import awkward as ak
from awkward._nplikes.array_module import ArrayModuleNumpyLike
from awkward._nplikes.dispatch import register_nplike
from awkward._nplikes.numpylike import ArrayLike
from awkward._nplikes.numpylike import ArrayLike, UfuncLike
from awkward._nplikes.shape import ShapeItem
from awkward._typing import Final

Expand All @@ -18,6 +18,11 @@ def __init__(self):
jax = ak.jax.import_jax()
self._module = jax.numpy

def prepare_ufunc(self, ufunc: UfuncLike) -> UfuncLike:
from awkward._connect.jax import get_jax_ufunc

return get_jax_ufunc(ufunc)

@property
def ma(self):
raise ValueError(
Expand Down
24 changes: 24 additions & 0 deletions src/awkward/_nplikes/numpylike.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,33 @@ def issubdtype(self):
NumpyMetadata.timedelta64 = numpy.timedelta64


class UfuncLike(Protocol):
nargs: int
nin: int
nout: int

def resolve_dtypes(
self, dtypes: tuple[numpy.dtype | type, ...]
) -> tuple[numpy.dtype, ...]:
...

def __call__(self, *args: ArrayLike, **kwargs) -> ArrayLike:
...


class NumpyLike(PublicSingleton, Protocol):
############################ Awkward features

@abstractmethod
def apply_ufunc(
self,
ufunc: UfuncLike,
method: str,
args: list[Any],
kwargs: dict[str, Any] | None = None,
) -> ArrayLike | tuple[ArrayLike]:
...

@property
@abstractmethod
def supports_structured_dtypes(self) -> bool:
Expand Down
79 changes: 40 additions & 39 deletions src/awkward/_nplikes/typetracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
IndexType,
NumpyLike,
NumpyMetadata,
UfuncLike,
UniqueAllResult,
)
from awkward._nplikes.placeholder import PlaceholderArray
Expand Down Expand Up @@ -387,8 +388,7 @@ def __getitem__(
try_touch_data(item)
try_touch_data(self)

if is_unknown_scalar(item):
item = self.nplike.promote_scalar(item)
item = self.nplike.asarray(item)

# If this is the first advanced index, insert the location
if not advanced_shapes:
Expand All @@ -415,7 +415,7 @@ def __getitem__(
try_touch_data(item)
try_touch_data(self)

item = self.nplike.promote_scalar(item)
item = self.nplike.asarray(item)

if is_unknown_length(dimension_length) or is_unknown_integer(item):
continue
Expand Down Expand Up @@ -469,7 +469,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):

if len(kwargs) > 0:
raise ValueError("TypeTracerArray does not support kwargs for ufuncs")
return self.nplike._apply_ufunc(ufunc, *inputs)
return self.nplike.apply_ufunc(ufunc, method, inputs, kwargs)

def __bool__(self) -> bool:
raise RuntimeError("cannot realise an unknown value")
Expand Down Expand Up @@ -504,26 +504,39 @@ class TypeTracer(NumpyLike):
is_eager: Final = True
supports_structured_dtypes: Final = True

def _apply_ufunc(self, ufunc, *inputs):
for x in inputs:
assert not isinstance(x, PlaceholderArray)
def apply_ufunc(
self,
ufunc: UfuncLike,
method: str,
args: list[Any],
kwargs: dict[str, Any] | None = None,
) -> TypeTracerArray | tuple[TypeTracerArray]:
for x in args:
try_touch_data(x)

inputs = [x.content if isinstance(x, MaybeNone) else x for x in inputs]

broadcasted = self.broadcast_arrays(*inputs)
placeholders = [numpy.empty(0, x.dtype) for x in broadcasted]

result = ufunc(*placeholders)
if isinstance(result, numpy.ndarray):
return TypeTracerArray._new(result.dtype, shape=broadcasted[0].shape)
elif isinstance(result, tuple):
return (
TypeTracerArray._new(x.dtype, shape=b.shape)
for x, b in zip(result, broadcasted)
# Unwrap options, assume they don't occur
args = [x.content if isinstance(x, MaybeNone) else x for x in args]
# Determine input argument dtypes
input_arg_dtypes = [getattr(obj, "dtype", type(obj)) for obj in args]
# Resolve these for the given ufunc
arg_dtypes = tuple(input_arg_dtypes + [None] * ufunc.nout)
resolved_dtypes = ufunc.resolve_dtypes(arg_dtypes)
# Interpret the arguments under these dtypes
resolved_args = [
self.asarray(arg, dtype=dtype) for arg, dtype in zip(args, resolved_dtypes)
]
# Broadcast these resolved arguments
broadcasted_args = self.broadcast_arrays(*resolved_args)
result_dtypes = resolved_dtypes[ufunc.nin :]
if len(result_dtypes) == 1:
return TypeTracerArray._new(
result_dtypes[0], shape=broadcasted_args[0].shape
)
else:
raise TypeError
return (
TypeTracerArray._new(dtype, shape=b.shape)
for dtype, b in zip(result_dtypes, broadcasted_args)
)

def _axis_is_valid(self, axis: int, ndim: int) -> bool:
if axis < 0:
Expand Down Expand Up @@ -801,18 +814,6 @@ def searchsorted(
return TypeTracerArray._new(x.dtype, (values.size,))

############################ manipulation

def promote_scalar(self, obj) -> TypeTracerArray:
assert not isinstance(obj, PlaceholderArray)
if is_unknown_scalar(obj):
return obj
elif isinstance(obj, (Number, bool)):
# TODO: statically define these types for all nplikes
as_array = numpy.asarray(obj)
return TypeTracerArray._new(as_array.dtype, ())
else:
raise TypeError(f"expected scalar type, received {obj}")

def shape_item_as_index(self, x1: ShapeItem) -> IndexType:
if x1 is unknown_length:
return TypeTracerArray._new(np.int64, shape=())
Expand Down Expand Up @@ -1206,7 +1207,7 @@ def add(
maybe_out: ArrayLike | None = None,
) -> TypeTracerArray:
assert not isinstance(x1, PlaceholderArray)
return self._apply_ufunc(numpy.add, x1, x2)
return self.apply_ufunc(numpy.add, "__call__", (x1, x2))

def logical_and(
self,
Expand All @@ -1215,7 +1216,7 @@ def logical_and(
maybe_out: ArrayLike | None = None,
) -> TypeTracerArray:
assert not isinstance(x1, PlaceholderArray)
return self._apply_ufunc(numpy.logical_and, x1, x2)
return self.apply_ufunc(numpy.logical_and, "__call__", (x1, x2))

def logical_or(
self,
Expand All @@ -1225,21 +1226,21 @@ def logical_or(
) -> TypeTracerArray:
assert not isinstance(x1, PlaceholderArray)
assert not isinstance(x2, PlaceholderArray)
return self._apply_ufunc(numpy.logical_or, x1, x2)
return self.apply_ufunc(numpy.logical_or, "__call__", (x1, x2))

def logical_not(
self, x: ArrayLike, maybe_out: ArrayLike | None = None
) -> TypeTracerArray:
assert not isinstance(x, PlaceholderArray)
return self._apply_ufunc(numpy.logical_not, x)
return self.apply_ufunc(numpy.logical_not, "__call__", (x,))

def sqrt(self, x: ArrayLike, maybe_out: ArrayLike | None = None) -> TypeTracerArray:
assert not isinstance(x, PlaceholderArray)
return self._apply_ufunc(numpy.sqrt, x)
return self.apply_ufunc(numpy.sqrt, "__call__", (x,))

def exp(self, x: ArrayLike, maybe_out: ArrayLike | None = None) -> TypeTracerArray:
assert not isinstance(x, PlaceholderArray)
return self._apply_ufunc(numpy.exp, x)
return self.apply_ufunc(numpy.exp, "__call__", (x,))

def divide(
self,
Expand All @@ -1249,7 +1250,7 @@ def divide(
) -> TypeTracerArray:
assert not isinstance(x1, PlaceholderArray)
assert not isinstance(x2, PlaceholderArray)
return self._apply_ufunc(numpy.divide, x1, x2)
return self.apply_ufunc(numpy.divide, "__call__", (x1, x2))

############################ almost-ufuncs

Expand Down
2 changes: 1 addition & 1 deletion src/awkward/_slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def prepare_advanced_indexing(items, backend: Backend):

# Then broadcast the index items
nplike = backend.index_nplike
broadcasted = nplike.broadcast_arrays(*broadcastable)
broadcasted = nplike.broadcast_arrays(*[nplike.asarray(x) for x in broadcastable])

# And re-assemble the index with the broadcasted items
prepared = []
Expand Down

0 comments on commit 97cf35f

Please sign in to comment.