Skip to content

Commit

Permalink
fix: strictly pass arrays into broadcast_arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 committed Nov 1, 2023
1 parent a7b8020 commit 98ed1c1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
5 changes: 2 additions & 3 deletions src/awkward/_nplikes/array_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,14 +224,13 @@ def apply_ufunc(
) -> ArrayLike | tuple[ArrayLike]:
# Convert np.generic to scalar arrays
resolved_args = [
self.asarray(arg, dtype=arg.dtype) if hasattr(arg, "dtype") else arg
self.asarray(arg, dtype=arg.dtype if hasattr(arg, "dtype") else None)
for arg in args
]
broadcasted_args = self.broadcast_arrays(*resolved_args)
# Choose the broadcasted argument if it wasn't a Python scalar
non_generic_value_promoted_args = [
y if hasattr(x, "ndim") else x
for x, y in zip(resolved_args, broadcasted_args)
y if hasattr(x, "ndim") else x for x, y in zip(args, broadcasted_args)
]
# Allow other nplikes to replace implementation
impl = self.prepare_ufunc(ufunc)
Expand Down
5 changes: 2 additions & 3 deletions src/awkward/_nplikes/typetracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,16 +583,15 @@ def apply_ufunc(
args = [x.content if isinstance(x, MaybeNone) else x for x in args]
# Convert np.generic to scalar arrays
resolved_args = [
self.asarray(arg, dtype=arg.dtype) if hasattr(arg, "dtype") else arg
self.asarray(arg, dtype=arg.dtype if hasattr(arg, "dtype") else None)
for arg in args
]
# Broadcast all inputs together
broadcasted_args = self.broadcast_arrays(*resolved_args)
broadcasted_shape = broadcasted_args[0].shape
# Choose the broadcasted argument if it wasn't a Python scalar
non_generic_value_promoted_args = [
y if hasattr(x, "ndim") else x
for x, y in zip(resolved_args, broadcasted_args)
y if hasattr(x, "ndim") else x for x, y in zip(args, broadcasted_args)
]
# Build proxy (empty) arrays
proxy_args = [
Expand Down

0 comments on commit 98ed1c1

Please sign in to comment.