Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Namespace-aware xarray.ufuncs #9776

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

slevang
Copy link
Contributor

@slevang slevang commented Nov 13, 2024

Re-implement the old xarray.ufuncs module to allow generic ufunc handling for array types that don't implement __array_ufunc__:

import jax.numpy as jnp
import numpy as np
import xarray as xr
import xarray.ufuncs as xu

x = xr.DataArray(jnp.asarray([1, 2, 3]))
print(type(xu.sin(x).data))
print(type(np.sin(x).data))

# <class 'jaxlib.xla_extension.ArrayImpl'>
# <class 'numpy.ndarray'>

elif isinstance(obj, array_type("dask")):
_walk_array_namespaces(obj._meta, namespaces)
else:
namespace = getattr(obj, "__array_namespace__", None)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we ever prioritize dispatching with np.func via __array_ufunc__ (if it exists) over the library's __array_namespace__().func?

)
func = getattr(np, self._name)

return xr.apply_ufunc(func, *args, dask="parallelized", **kwargs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there ever a reason to use dask's ufuncs with dask="allowed" instead of the appropriate _meta array's namespace and dask="parallelized"? With jax for example, which doesn't have __array_ufunc__, this ends up converting to numpy. So it would have to be special cased.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In user code using xr.apply_ufunc there is - dask='allowed' can be used to rechunk along a core dimension e.g. by applying a dask reduction ufunc along that dimension. Not sure if that's relevant here though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are all elementwise so no core dimensions



# Auto generate from the public numpy ufuncs
np_ufuncs = {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can hard code these if preferred?


# Auto generate from the public numpy ufuncs
np_ufuncs = {name for name in dir(np) if isinstance(getattr(np, name), np.ufunc)}
excluded_ufuncs = {"divmod", "frexp", "isnat", "matmul", "modf", "vecdot"}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the ones that didn't immediately work. There are also other ufunc like things that aren't technically np.ufunc subclasses that we could add. I saw angle and iscomplex were special cased before.

@slevang slevang marked this pull request as ready for review November 13, 2024 14:17
@TomNicholas TomNicholas added topic-arrays related to flexible array support array API standard Support for the Python array API standard labels Nov 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array API standard Support for the Python array API standard topic-arrays related to flexible array support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Compatibility with the Array API standard
2 participants