Skip to content

Commit

Permalink
fix: support older NumPY
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 committed Oct 28, 2023
1 parent 414f69a commit 6c708c6
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 26 deletions.
29 changes: 20 additions & 9 deletions src/awkward/_nplikes/array_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math

import numpy
import packaging.version

from awkward._nplikes.numpylike import (
ArrayLike,
Expand All @@ -18,6 +19,9 @@
from awkward._typing import Any, Final, Literal

np = NumpyMetadata.instance()
NUMPY_HAS_NEP_50 = packaging.version.Version(
numpy.__version__
) >= packaging.version.Version("1.24")


class ArrayModuleNumpyLike(NumpyLike):
Expand Down Expand Up @@ -157,15 +161,22 @@ def apply_ufunc(
args: list[Any],
kwargs: dict[str, Any] | None = None,
) -> ArrayLike | tuple[ArrayLike]:
# Determine input argument dtypes
input_arg_dtypes = [getattr(obj, "dtype", type(obj)) for obj in args]
# Resolve these for the given ufunc
arg_dtypes = tuple(input_arg_dtypes + [None] * ufunc.nout)
resolved_dtypes = ufunc.resolve_dtypes(arg_dtypes)
# Interpret the arguments under these dtypes
resolved_args = [
self.asarray(arg, dtype=dtype) for arg, dtype in zip(args, resolved_dtypes)
]
if NUMPY_HAS_NEP_50:
# Determine input argument dtypes
input_arg_dtypes = [getattr(obj, "dtype", type(obj)) for obj in args]
# Resolve these for the given ufunc
arg_dtypes = tuple(input_arg_dtypes + [None] * ufunc.nout)
resolved_dtypes = ufunc.resolve_dtypes(arg_dtypes)
# Interpret the arguments under these dtypes
resolved_args = [
self.asarray(arg, dtype=dtype)
for arg, dtype in zip(args, resolved_dtypes)
]
else:
resolved_args = [
self.asarray(arg, dtype=arg.dtype) if hasattr(arg, "dtype") else arg
for arg in args
]
# Broadcast these resolved arguments
broadcasted_args = self.broadcast_arrays(*resolved_args)
# Allow other nplikes to replace implementation
Expand Down
58 changes: 41 additions & 17 deletions src/awkward/_nplikes/typetracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Callable

import numpy
import packaging.version

import awkward as ak
from awkward._nplikes.dispatch import register_nplike
Expand All @@ -30,6 +31,9 @@
)

np = NumpyMetadata.instance()
NUMPY_HAS_NEP_50 = packaging.version.Version(
numpy.__version__
) >= packaging.version.Version("1.24")


def is_unknown_length(array: Any) -> bool:
Expand Down Expand Up @@ -516,26 +520,46 @@ def apply_ufunc(

# Unwrap options, assume they don't occur
args = [x.content if isinstance(x, MaybeNone) else x for x in args]
# Determine input argument dtypes
input_arg_dtypes = [getattr(obj, "dtype", type(obj)) for obj in args]
# Resolve these for the given ufunc
arg_dtypes = tuple(input_arg_dtypes + [None] * ufunc.nout)
resolved_dtypes = ufunc.resolve_dtypes(arg_dtypes)
# Interpret the arguments under these dtypes
resolved_args = [
self.asarray(arg, dtype=dtype) for arg, dtype in zip(args, resolved_dtypes)
]
# Broadcast these resolved arguments
broadcasted_args = self.broadcast_arrays(*resolved_args)
result_dtypes = resolved_dtypes[ufunc.nin :]
if NUMPY_HAS_NEP_50:
# Determine input argument dtypes
input_arg_dtypes = [getattr(obj, "dtype", type(obj)) for obj in args]
# Resolve these for the given ufunc
arg_dtypes = tuple(input_arg_dtypes + [None] * ufunc.nout)
resolved_dtypes = ufunc.resolve_dtypes(arg_dtypes)
# Interpret the arguments under these dtypes
resolved_args = [
self.asarray(arg, dtype=dtype)
for arg, dtype in zip(args, resolved_dtypes)
]
# Broadcast these resolved arguments
broadcasted_args = self.broadcast_arrays(*resolved_args)
broadcasted_shape = broadcasted_args[0].shape
result_dtypes = resolved_dtypes[ufunc.nin :]
else:
array_like_args = [
self.asarray(arg, dtype=arg.dtype)
for arg in args
if hasattr(arg, "dtype")
]
broadcasted_args = self.broadcast_arrays(*array_like_args)
broadcasted_shape = broadcasted_args[0].shape

numpy_args = [
(numpy.empty(0, dtype=x.dtype) if hasattr(x, "dtype") else x)
for x in args
]
numpy_result = ufunc(*numpy_args, **kwargs)
if ufunc.nout == 1:
result_dtypes = [numpy_result.dtype]
else:
result_dtypes = [x.dtype for x in numpy_result]

if len(result_dtypes) == 1:
return TypeTracerArray._new(
result_dtypes[0], shape=broadcasted_args[0].shape
)
return TypeTracerArray._new(result_dtypes[0], shape=broadcasted_shape)
else:
return (
TypeTracerArray._new(dtype, shape=b.shape)
for dtype, b in zip(result_dtypes, broadcasted_args)
TypeTracerArray._new(dtype, shape=broadcasted_shape)
for dtype in result_dtypes
)

def _axis_is_valid(self, axis: int, ndim: int) -> bool:
Expand Down

0 comments on commit 6c708c6

Please sign in to comment.