From 98ed1c1ab75b01bf7573da091f114c0c84114201 Mon Sep 17 00:00:00 2001 From: Angus Hollands Date: Wed, 1 Nov 2023 11:09:02 +0000 Subject: [PATCH] fix: strictly pass arrays into `broadcast_arrays` --- src/awkward/_nplikes/array_module.py | 5 ++--- src/awkward/_nplikes/typetracer.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/awkward/_nplikes/array_module.py b/src/awkward/_nplikes/array_module.py index 8f7295c089..25c3df20b8 100644 --- a/src/awkward/_nplikes/array_module.py +++ b/src/awkward/_nplikes/array_module.py @@ -224,14 +224,13 @@ def apply_ufunc( ) -> ArrayLike | tuple[ArrayLike]: # Convert np.generic to scalar arrays resolved_args = [ - self.asarray(arg, dtype=arg.dtype) if hasattr(arg, "dtype") else arg + self.asarray(arg, dtype=arg.dtype if hasattr(arg, "dtype") else None) for arg in args ] broadcasted_args = self.broadcast_arrays(*resolved_args) # Choose the broadcasted argument if it wasn't a Python scalar non_generic_value_promoted_args = [ - y if hasattr(x, "ndim") else x - for x, y in zip(resolved_args, broadcasted_args) + y if hasattr(x, "ndim") else x for x, y in zip(args, broadcasted_args) ] # Allow other nplikes to replace implementation impl = self.prepare_ufunc(ufunc) diff --git a/src/awkward/_nplikes/typetracer.py b/src/awkward/_nplikes/typetracer.py index b1f9e338a4..3834b3569f 100644 --- a/src/awkward/_nplikes/typetracer.py +++ b/src/awkward/_nplikes/typetracer.py @@ -583,7 +583,7 @@ def apply_ufunc( args = [x.content if isinstance(x, MaybeNone) else x for x in args] # Convert np.generic to scalar arrays resolved_args = [ - self.asarray(arg, dtype=arg.dtype) if hasattr(arg, "dtype") else arg + self.asarray(arg, dtype=arg.dtype if hasattr(arg, "dtype") else None) for arg in args ] # Broadcast all inputs together @@ -591,8 +591,7 @@ def apply_ufunc( broadcasted_shape = broadcasted_args[0].shape # Choose the broadcasted argument if it wasn't a Python scalar non_generic_value_promoted_args = [ - y if hasattr(x, "ndim") else x - for x, y in zip(resolved_args, broadcasted_args) + y if hasattr(x, "ndim") else x for x, y in zip(args, broadcasted_args) ] # Build proxy (empty) arrays proxy_args = [