From 97cf35fccb5da7a0f164e98262ba17467267d282 Mon Sep 17 00:00:00 2001 From: Angus Hollands Date: Tue, 24 Oct 2023 18:00:10 +0100 Subject: [PATCH] feat!: move ufunc preparation to nplike (#2767) --- .pre-commit-config.yaml | 2 +- src/awkward/_backends/backend.py | 12 +---- src/awkward/_backends/jax.py | 7 +-- src/awkward/_connect/numpy.py | 9 +--- src/awkward/_nplikes/array_module.py | 27 ++++++++++ src/awkward/_nplikes/jax.py | 7 ++- src/awkward/_nplikes/numpylike.py | 24 +++++++++ src/awkward/_nplikes/typetracer.py | 79 ++++++++++++++-------------- src/awkward/_slicing.py | 2 +- 9 files changed, 104 insertions(+), 65 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 569c621083..b17142d4f0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -48,7 +48,7 @@ repos: rev: v2.2.6 hooks: - id: codespell - args: ["-L", "ue,subjet,parms,fo,numer,thre"] + args: ["-L", "ue,subjet,parms,fo,numer,thre,nin,nout"] - repo: local hooks: diff --git a/src/awkward/_backends/backend.py b/src/awkward/_backends/backend.py index 5df8283181..f7a927c4ce 100644 --- a/src/awkward/_backends/backend.py +++ b/src/awkward/_backends/backend.py @@ -6,9 +6,9 @@ import awkward as ak from awkward._kernels import KernelError from awkward._nplikes.numpy import Numpy -from awkward._nplikes.numpylike import ArrayLike, NumpyLike, NumpyMetadata +from awkward._nplikes.numpylike import NumpyLike, NumpyMetadata from awkward._singleton import PublicSingleton -from awkward._typing import Callable, Protocol, Tuple, TypeAlias, TypeVar, Unpack +from awkward._typing import Callable, Tuple, TypeAlias, TypeVar, Unpack np = NumpyMetadata.instance() numpy = Numpy.instance() @@ -19,11 +19,6 @@ KernelType: TypeAlias = "Callable[..., KernelError | None]" -class UfuncLike(Protocol): - def __call__(self, *args: ArrayLike, **kwargs) -> ArrayLike: - ... - - class Backend(PublicSingleton, ABC): name: str @@ -43,9 +38,6 @@ def __getitem__(self, key: KernelKeyType) -> KernelType: def prepare_reducer(self, reducer: ak._reducers.Reducer) -> ak._reducers.Reducer: return reducer - def prepare_ufunc(self, ufunc: UfuncLike) -> UfuncLike: - return ufunc - def format_kernel_error( self, error: KernelError, diff --git a/src/awkward/_backends/jax.py b/src/awkward/_backends/jax.py index 1a825d4a9a..ef63dca017 100644 --- a/src/awkward/_backends/jax.py +++ b/src/awkward/_backends/jax.py @@ -4,7 +4,7 @@ import awkward_cpp import awkward as ak -from awkward._backends.backend import Backend, KernelKeyType, UfuncLike +from awkward._backends.backend import Backend, KernelKeyType from awkward._backends.dispatch import register_backend from awkward._kernels import JaxKernel from awkward._nplikes.jax import Jax @@ -43,8 +43,3 @@ def prepare_reducer(self, reducer: ak._reducers.Reducer) -> ak._reducers.Reducer from awkward._connect.jax import get_jax_reducer return get_jax_reducer(reducer) - - def prepare_ufunc(self, ufunc: UfuncLike) -> UfuncLike: - from awkward._connect.jax import get_jax_ufunc - - return get_jax_ufunc(ufunc) diff --git a/src/awkward/_connect/numpy.py b/src/awkward/_connect/numpy.py index 282b3fef6a..004ca3ca44 100644 --- a/src/awkward/_connect/numpy.py +++ b/src/awkward/_connect/numpy.py @@ -385,13 +385,8 @@ def action(inputs, **ignore): parameters_intersect, (c._parameters for c in contents) ) - args = [x.data if isinstance(x, NumpyArray) else x for x in inputs] - - # Give backend a chance to change the ufunc implementation - impl = backend.prepare_ufunc(ufunc) - - # Invoke ufunc - result = impl(*args, **kwargs) + input_args = [x.data if isinstance(x, NumpyArray) else x for x in inputs] + result = backend.nplike.apply_ufunc(ufunc, method, input_args, kwargs) if isinstance(result, tuple): return tuple( diff --git a/src/awkward/_nplikes/array_module.py b/src/awkward/_nplikes/array_module.py index 291fec4623..5bead34287 100644 --- a/src/awkward/_nplikes/array_module.py +++ b/src/awkward/_nplikes/array_module.py @@ -10,6 +10,7 @@ IndexType, NumpyLike, NumpyMetadata, + UfuncLike, UniqueAllResult, ) from awkward._nplikes.placeholder import PlaceholderArray @@ -22,6 +23,9 @@ class ArrayModuleNumpyLike(NumpyLike): known_data: Final = True + def prepare_ufunc(self, ufunc: UfuncLike) -> UfuncLike: + return ufunc + ############################ array creation def asarray( @@ -146,6 +150,29 @@ def searchsorted( ############################ manipulation + def apply_ufunc( + self, + ufunc: UfuncLike, + method: str, + args: list[Any], + kwargs: dict[str, Any] | None = None, + ) -> ArrayLike | tuple[ArrayLike]: + # Determine input argument dtypes + input_arg_dtypes = [getattr(obj, "dtype", type(obj)) for obj in args] + # Resolve these for the given ufunc + arg_dtypes = tuple(input_arg_dtypes + [None] * ufunc.nout) + resolved_dtypes = ufunc.resolve_dtypes(arg_dtypes) + # Interpret the arguments under these dtypes + resolved_args = [ + self.asarray(arg, dtype=dtype) for arg, dtype in zip(args, resolved_dtypes) + ] + # Broadcast these resolved arguments + broadcasted_args = self.broadcast_arrays(*resolved_args) + # Allow other nplikes to replace implementation + impl = self.prepare_ufunc(ufunc) + # Compute the result + return impl(*broadcasted_args, **kwargs) + def broadcast_arrays(self, *arrays: ArrayLike) -> list[ArrayLike]: assert not any(isinstance(x, PlaceholderArray) for x in arrays) return self._module.broadcast_arrays(*arrays) diff --git a/src/awkward/_nplikes/jax.py b/src/awkward/_nplikes/jax.py index 8af5480aca..d1aeeb9ced 100644 --- a/src/awkward/_nplikes/jax.py +++ b/src/awkward/_nplikes/jax.py @@ -4,7 +4,7 @@ import awkward as ak from awkward._nplikes.array_module import ArrayModuleNumpyLike from awkward._nplikes.dispatch import register_nplike -from awkward._nplikes.numpylike import ArrayLike +from awkward._nplikes.numpylike import ArrayLike, UfuncLike from awkward._nplikes.shape import ShapeItem from awkward._typing import Final @@ -18,6 +18,11 @@ def __init__(self): jax = ak.jax.import_jax() self._module = jax.numpy + def prepare_ufunc(self, ufunc: UfuncLike) -> UfuncLike: + from awkward._connect.jax import get_jax_ufunc + + return get_jax_ufunc(ufunc) + @property def ma(self): raise ValueError( diff --git a/src/awkward/_nplikes/numpylike.py b/src/awkward/_nplikes/numpylike.py index b2bbd29f5f..c77f44011e 100644 --- a/src/awkward/_nplikes/numpylike.py +++ b/src/awkward/_nplikes/numpylike.py @@ -242,9 +242,33 @@ def issubdtype(self): NumpyMetadata.timedelta64 = numpy.timedelta64 +class UfuncLike(Protocol): + nargs: int + nin: int + nout: int + + def resolve_dtypes( + self, dtypes: tuple[numpy.dtype | type, ...] + ) -> tuple[numpy.dtype, ...]: + ... + + def __call__(self, *args: ArrayLike, **kwargs) -> ArrayLike: + ... + + class NumpyLike(PublicSingleton, Protocol): ############################ Awkward features + @abstractmethod + def apply_ufunc( + self, + ufunc: UfuncLike, + method: str, + args: list[Any], + kwargs: dict[str, Any] | None = None, + ) -> ArrayLike | tuple[ArrayLike]: + ... + @property @abstractmethod def supports_structured_dtypes(self) -> bool: diff --git a/src/awkward/_nplikes/typetracer.py b/src/awkward/_nplikes/typetracer.py index 9f4c75bbcb..51fb4b5972 100644 --- a/src/awkward/_nplikes/typetracer.py +++ b/src/awkward/_nplikes/typetracer.py @@ -13,6 +13,7 @@ IndexType, NumpyLike, NumpyMetadata, + UfuncLike, UniqueAllResult, ) from awkward._nplikes.placeholder import PlaceholderArray @@ -387,8 +388,7 @@ def __getitem__( try_touch_data(item) try_touch_data(self) - if is_unknown_scalar(item): - item = self.nplike.promote_scalar(item) + item = self.nplike.asarray(item) # If this is the first advanced index, insert the location if not advanced_shapes: @@ -415,7 +415,7 @@ def __getitem__( try_touch_data(item) try_touch_data(self) - item = self.nplike.promote_scalar(item) + item = self.nplike.asarray(item) if is_unknown_length(dimension_length) or is_unknown_integer(item): continue @@ -469,7 +469,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): if len(kwargs) > 0: raise ValueError("TypeTracerArray does not support kwargs for ufuncs") - return self.nplike._apply_ufunc(ufunc, *inputs) + return self.nplike.apply_ufunc(ufunc, method, inputs, kwargs) def __bool__(self) -> bool: raise RuntimeError("cannot realise an unknown value") @@ -504,26 +504,39 @@ class TypeTracer(NumpyLike): is_eager: Final = True supports_structured_dtypes: Final = True - def _apply_ufunc(self, ufunc, *inputs): - for x in inputs: - assert not isinstance(x, PlaceholderArray) + def apply_ufunc( + self, + ufunc: UfuncLike, + method: str, + args: list[Any], + kwargs: dict[str, Any] | None = None, + ) -> TypeTracerArray | tuple[TypeTracerArray]: + for x in args: try_touch_data(x) - inputs = [x.content if isinstance(x, MaybeNone) else x for x in inputs] - - broadcasted = self.broadcast_arrays(*inputs) - placeholders = [numpy.empty(0, x.dtype) for x in broadcasted] - - result = ufunc(*placeholders) - if isinstance(result, numpy.ndarray): - return TypeTracerArray._new(result.dtype, shape=broadcasted[0].shape) - elif isinstance(result, tuple): - return ( - TypeTracerArray._new(x.dtype, shape=b.shape) - for x, b in zip(result, broadcasted) + # Unwrap options, assume they don't occur + args = [x.content if isinstance(x, MaybeNone) else x for x in args] + # Determine input argument dtypes + input_arg_dtypes = [getattr(obj, "dtype", type(obj)) for obj in args] + # Resolve these for the given ufunc + arg_dtypes = tuple(input_arg_dtypes + [None] * ufunc.nout) + resolved_dtypes = ufunc.resolve_dtypes(arg_dtypes) + # Interpret the arguments under these dtypes + resolved_args = [ + self.asarray(arg, dtype=dtype) for arg, dtype in zip(args, resolved_dtypes) + ] + # Broadcast these resolved arguments + broadcasted_args = self.broadcast_arrays(*resolved_args) + result_dtypes = resolved_dtypes[ufunc.nin :] + if len(result_dtypes) == 1: + return TypeTracerArray._new( + result_dtypes[0], shape=broadcasted_args[0].shape ) else: - raise TypeError + return ( + TypeTracerArray._new(dtype, shape=b.shape) + for dtype, b in zip(result_dtypes, broadcasted_args) + ) def _axis_is_valid(self, axis: int, ndim: int) -> bool: if axis < 0: @@ -801,18 +814,6 @@ def searchsorted( return TypeTracerArray._new(x.dtype, (values.size,)) ############################ manipulation - - def promote_scalar(self, obj) -> TypeTracerArray: - assert not isinstance(obj, PlaceholderArray) - if is_unknown_scalar(obj): - return obj - elif isinstance(obj, (Number, bool)): - # TODO: statically define these types for all nplikes - as_array = numpy.asarray(obj) - return TypeTracerArray._new(as_array.dtype, ()) - else: - raise TypeError(f"expected scalar type, received {obj}") - def shape_item_as_index(self, x1: ShapeItem) -> IndexType: if x1 is unknown_length: return TypeTracerArray._new(np.int64, shape=()) @@ -1206,7 +1207,7 @@ def add( maybe_out: ArrayLike | None = None, ) -> TypeTracerArray: assert not isinstance(x1, PlaceholderArray) - return self._apply_ufunc(numpy.add, x1, x2) + return self.apply_ufunc(numpy.add, "__call__", (x1, x2)) def logical_and( self, @@ -1215,7 +1216,7 @@ def logical_and( maybe_out: ArrayLike | None = None, ) -> TypeTracerArray: assert not isinstance(x1, PlaceholderArray) - return self._apply_ufunc(numpy.logical_and, x1, x2) + return self.apply_ufunc(numpy.logical_and, "__call__", (x1, x2)) def logical_or( self, @@ -1225,21 +1226,21 @@ def logical_or( ) -> TypeTracerArray: assert not isinstance(x1, PlaceholderArray) assert not isinstance(x2, PlaceholderArray) - return self._apply_ufunc(numpy.logical_or, x1, x2) + return self.apply_ufunc(numpy.logical_or, "__call__", (x1, x2)) def logical_not( self, x: ArrayLike, maybe_out: ArrayLike | None = None ) -> TypeTracerArray: assert not isinstance(x, PlaceholderArray) - return self._apply_ufunc(numpy.logical_not, x) + return self.apply_ufunc(numpy.logical_not, "__call__", (x,)) def sqrt(self, x: ArrayLike, maybe_out: ArrayLike | None = None) -> TypeTracerArray: assert not isinstance(x, PlaceholderArray) - return self._apply_ufunc(numpy.sqrt, x) + return self.apply_ufunc(numpy.sqrt, "__call__", (x,)) def exp(self, x: ArrayLike, maybe_out: ArrayLike | None = None) -> TypeTracerArray: assert not isinstance(x, PlaceholderArray) - return self._apply_ufunc(numpy.exp, x) + return self.apply_ufunc(numpy.exp, "__call__", (x,)) def divide( self, @@ -1249,7 +1250,7 @@ def divide( ) -> TypeTracerArray: assert not isinstance(x1, PlaceholderArray) assert not isinstance(x2, PlaceholderArray) - return self._apply_ufunc(numpy.divide, x1, x2) + return self.apply_ufunc(numpy.divide, "__call__", (x1, x2)) ############################ almost-ufuncs diff --git a/src/awkward/_slicing.py b/src/awkward/_slicing.py index e102246275..5ac4ae1b45 100644 --- a/src/awkward/_slicing.py +++ b/src/awkward/_slicing.py @@ -137,7 +137,7 @@ def prepare_advanced_indexing(items, backend: Backend): # Then broadcast the index items nplike = backend.index_nplike - broadcasted = nplike.broadcast_arrays(*broadcastable) + broadcasted = nplike.broadcast_arrays(*[nplike.asarray(x) for x in broadcastable]) # And re-assemble the index with the broadcasted items prepared = []