Skip to content

Releases: e3nn/e3nn-jax

2023-06-22

22 Jun 22:07
Compare
Choose a tag to compare

Highlight

e3nn.utils.vmap allow to overcome the security to drop .zero_flags for the case of vmap.

Consider an irreps array with the 0o entry set to None, its zero_flags attribute will be (False, True):

x = e3nn.from_chunks("0e + 0o", [jnp.ones((100, 1, 1)), None], (100,))
x.zero_flags  # (False, True)

Now if we vmap a function using jax.vmap, the internal function will not get the info that the 0o entry is actually zero.

jax.vmap(e3nn.norm)(x).zero_flags  # (False, False)

This is a security because not all transformations conserve the validity of zero_flags, take for instance:

jax.tree_util.tree_map(lambda x: x + 1.0, x).zero_flags  # (False, False)

However for the case of vmap, vectorization does preserves the validity of zero_flags, in this case we can allow it to propagate in-out the vectorized function:

e3nn.utils.vmap(e3nn.norm)(x).zero_flags  # (False, True)

Changelog

Added

  • e3nn.utils.vmap to propagate zero_flags in the vectorized function.

Changed

  • Simplify the tetris examples

Fixed

  • Example of what is fixed: assume x.ndim = 2, allow x[:, None] but prevent x[:, :, None] and x[..., None]

2023-06-21

22 Jun 02:47
Compare
Choose a tag to compare

Highlight

There are two main changes in this release: Linear and IrrepsArray.

Change in Linear

Now the classes e3nn.flax.Linear and e3nn.haiku.Linear by default discard the output irrep that are not reachable from the input.

linear = Linear("2x0e + 1o + 2e")
x = e3nn.normal("0e + 1o")  # an input without 2e
w = linear.init(jax.random.PRNGKey(0), x)
linear.apply(w, x).irreps
# 2x0e+1x1o, no 2e because it is not reachable from the input

Change in IrrepsArray

Before this release

IrrepsArray had its data stored twice, both in .array and in .list.

x = e3nn.IrrepsArray("0e + 1o", jnp.array([[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0]]))
x.list  # a list with the scalars and the vectors in two separate arrays
x.array  # an array with all the data contiguous

the motivation was to allow chunks to be stored as None, to be strictly zero

e3nn.IrrepsArray.from_list("0e + 1o", [jnp.array([[1.0]]), None], ())

this was leading to confusing situations with the jax.tree_util module

jax.tree_util.tree_leaves(x)
# a list of 3 arrays with repeated data

Because of that, jax.vmap was not possible with negative axis

jax.vmap(lambda x: x, in_axes=-2)(x)  # ERROR!!

And the gradient was only propagated through one of the two attributes

g = jax.grad(lambda x: e3nn.sum(x)["0e"].array.squeeze())(x)
g.array  # is zero
g.list  # is not zero

In this release

We refactored IrrepsArray. It has now only .array as a data attribute. The .list attribute is gone.
Therefore jax.tree_util.tree_leaves have just one array.

x = e3nn.IrrepsArray("0e + 1o", jnp.array([[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0]]))
jax.tree_util.tree_leaves(x)  # [jnp.array([[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0]])]

Instead, we have a new attribute .zero_flags which is a list of booleans indicating whether the corresponding chunk is zero or not.

y = e3nn.from_chunks("0e + 1o", [jnp.array([[1.0]]), None], ())
y.chunks  # [jnp.array([[1.0]]), None]
y.zero_flags  # [False, True]

.chunks is the new attribute that replaces .list (now deprecated).
It has a better name because we already have .slice_by_chunk.

x.chunks  # list of the two chunks
x.slice_by_chunk[:1]  # get the first chunk

jax.vmap can be used with negative axis

jax.vmap(lambda x: x, in_axes=-2)(x)

And the gradient behaves as expected

g = jax.grad(lambda x: e3nn.sum(x)["0e"].array.squeeze())(x)
g.array  # expected value
g.chunks  # expected value

To avoid any trouble that .zero_flags might induce in all jax transformations we drop it when using a transformation.

y = e3nn.from_chunks("0e + 1o", [jnp.array([[1.0]]), None], ())
print(y.zero_flags)  # (False, True)

z = jax.jit(lambda x: x)(y)
print(z.zero_flags)  # (False, False)

z = jax.tree_util.tree_map(lambda x: x, z)
print(z.zero_flags)  # (False, False)

z = jax.vmap(lambda x: x)(z[None, ...])
print(z.zero_flags)  # (False, False)

Changelog

Changed

  • [BREAKING] e3nn.flax.Linear and e3nn.haiku.Linear now don't output the impossible irreps anymore. To force the output of all irreps, use force_irreps_out = True. For instance e3nn.flax.Linear("0e + 1o")("0e") will now return "0e" instead of "0e + 1o".
  • [BREAKING] e3nn.utils.assert_equivariant has the same signature as e3nn.utils.equivariance_test
  • [BREAKING] Move as_irreps_array, zeros and zeros_like from e3nn.IrrepsArray to e3nn
  • [BREAKING] Move IrrepsArray.from_list to e3nn.from_chunks
  • [BREAKING] Rename IrrepsArray.list into IrrepsArray.chunks
  • [BREAKING] Rename IrrepsArray.remove_nones into IrrepsArray.remove_zero_chunks
  • e3nn.IrrepsArray has now only .array as data attribute.

Added

  • e3nn.IrrepsArray.rechunk
  • e3nn.IrrepsArray.zero_flags a tuple of bools that indicates which chunks are zero

2023-06-19

19 Jun 15:51
Compare
Choose a tag to compare

Highlight

Add set_mul to Irreps, note that it's not an in-place operation.

irreps = e3nn.Irreps("0e + 1o")
irreps = irreps.set_mul(2)
# 2x0e+2x1o

Add the option lmax to Irreps.filter and IrrepsArray.filter.

irreps = irreps.filter(lmax=0)
irreps
# 2x0e

e3nn.utils is now directly accessible as a submodule and has a documentation.

x1 = e3nn.IrrepsArray("1o", jnp.array([1.0, 3.0, 4.0]))
x2 = e3nn.IrrepsArray("1o", jnp.array([0.0, 1.0, 4.0]))
y1, y2 = e3nn.utils.equivariance_test(
    e3nn.tensor_product, jax.random.PRNGKey(0), x1, x2
)
# y1 = R x1 otimes R x2
# y2 = R (x1 otimes x2)

Changelog

Changed

  • [BREAKING] Renamed e3nn.util in e3nn.utils

Added

  • Irreps.set_mul(int) to set the multiplicity of all irreps
  • Irreps.filter(lmax=int) to filter out irreps with l > lmax
  • IrrepsArray.filter(lmax=int) to filter out irreps with l > lmax
  • IrrepsArray.__radd__ and IrrepsArray.__rsub__ to support scalar + IrrepsArray and scalar - IrrepsArray
  • 0 + IrrepsArray and 0 - IrrepsArray are now always accepted as special cases.
  • Support for IrrepsArray / array
  • Add utils as a submodule

Fixed

  • e3nn.scatter operation handle indices with ndim > 1

2023-05-10

10 May 22:45
Compare
Choose a tag to compare

Highlight

We can now compute the trispectrum for l=4 in about a minute.

irreps = e3nn.Irreps.spherical_harmonics(lmax=4)
q = e3nn.reduced_symmetric_tensor_product_basis(irreps, 4, keep_ir="0e + 0o")

ChangeLog

Added

  • e3nn.cross for completeness

Changed

  • Optimize e3nn.reduced_symmetric_tensor_product_basis, especially for the keep_ir argument

2023-05-02

02 May 19:46
Compare
Choose a tag to compare

Highlight

Finally implemented the eSCN optimization. That is the optimization of the tensor product with spherical harmonics followed by a linear mix.

old:

sh = e3nn.spherical_harmonics(vector, range(lmax), True)
features = e3nn.tensor_product(features, sh)
linear = e3nn.flax.Linear(irreps_out)
features = linear.apply(w, features)

new: (equivalent but faster)

from e3nn_jax.experimental.linear_shtp import LinearSHTP
conv = LinearSHTP(irreps_out)
features = conv.apply(w, features, vector)

Example of usage:

ChangeLog

Added

  • LinearSHTP module implementing the optimized linear mixing of inputs tensor product with spherical harmonics
  • D_from_axis_angle
  • to_s2grid: quadrature="gausslegendre" by default
  • soft_odd activation function for odd scalars
  • more support of arrays implicitely converted into IrrepsArray as scalars (i.e. added few IrrepsArray.as_irreps_array)

Changed

  • scalar_activation simpler to use with default activation functions (a bit like gate)

2023-04-06

07 Apr 02:06
Compare
Choose a tag to compare
x = e3nn.IrrepsArray("0e + 1o", jnp.array([1.0, 1.0, 1.0, 2.0]))
norm_activation(x, [None, jnp.tanh])
1x0e+1x1o [1.        0.8883856 0.8883856 1.7767712]

Changed

  • e3nn.normalize_function now uses a deterministic (not pseudorandom) algorithm to compute the normalization factor.

Added

  • normalize_act option to e3nn.scalar_activation and e3nn.gate. We can now turn the normalization off if we want to.
  • e3nn.norm_activation as a new activation function.

2023-03-26

26 Mar 11:04
Compare
Choose a tag to compare

This release fix the gradients of xyz_to_angles: instead of returning NaN on the poles, it now returns zero.

Angles alpha and beta on the sphere:
Screenshot 2023-03-26 at 6 58 48 AM
(x-axis points to the right, y-axis to the top and z-axis comes out of the image)

Gradients of alpha wrt (x, y, z) respectively:
Screenshot 2023-03-26 at 6 59 49 AM

Gradients of beta wrt (x, y, z) respectively:
Screenshot 2023-03-26 at 7 01 12 AM

Fixed

  • Fix NaN in the gradients of e3nn.xyz_to_angles. The gradients are now 0 when the input is on the poles.

2023-03-15

15 Mar 17:27
Compare
Choose a tag to compare

⚠️ e3nn.tensor_product_with_spherical_harmonics backward propagation is not yet working.

Added

  • e3nn.dot: compute the dot product between two IrrepsArray
  • per_irrep argument to e3nn.norm: compute the norm of each irrep independently if per_irrep=True
  • e3nn.tensor_product_with_spherical_harmonics from https://arxiv.org/pdf/2302.03655.pdf

Changed

  • __repr__(Irreps()) has been changed from "" to "Irreps()"

Fixed

  • spherical harmonics edge case when output_irreps=Irreps()

2023-02-20

20 Feb 19:53
Compare
Choose a tag to compare

The function e3nn.SphericalSignal.sample allows to sample points on the sphere according to some arbitrary distribution on a grid.

Screenshot 2023-02-20 at 2 51 37 PM

Added

  • e3nn.SphericalSignal.sample to sample a point on the sphere
  • e3nn.scatter_max

Changed

  • [BREAKING] Removed e3nn.s2_sum_of_diracs in favor of e3nn.s2_dirac
  • [BREAKING] e3nn.grad now regroups the output by default. It can be disabled with regroup_output=False

Fixed

  • e3nn.SphericalSignal arithmetic operations
  • e3nn.Irreps.D_from_angles computes (again!) the Wigner D matrices using the J matrices for L <= 11. This is faster and more accurate than using the expm.

2022-02-01

01 Feb 07:01
Compare
Choose a tag to compare
import jax.numpy as jnp
import e3nn_jax as e3nn
coeffs = e3nn.IrrepsArray("0e + 1o", jnp.array([1, 2, 0, 0.0]))
signal = e3nn.to_s2grid(coeffs, 50, 69, quadrature="gausslegendre")

import plotly.graph_objects as go
go.Figure([go.Surface(signal.plotly_surface())])

Screenshot 2023-02-01 at 2 00 29 AM

Added

  • e3nn.SphericalSignal class to represent signals on the sphere
  • Signal on the Sphere section in the documentation
  • e3nn.Irreps.D_from_log_coordinates
  • rotation_angle_from_* functions
  • e3nn.to_s2point function

Changed

  • Wigner D matrices are computed from the log coordinates which makes 1 instead of 3 calls to expm.
  • [BREAKING] e3nn.util.assert_output_dtype renamed to e3nn.util.assert_output_dtype_matches_input_dtype
  • [BREAKING] Update experimental.point_convolution to use the last changes.
  • [BREAKING] changed the e3nn.to_s2grid and e3nn.from_s2grid signature and default normalization.

Removed

  • [BREAKING] All the haiku modules from the main module. They are now in the e3nn.haiku submodule.
  • [BREAKING] e3nn.wigner_D in favor of e3nn.Irrep.D_from_*

Fixed

  • Removed jax.jit decorator to Irreps.D_from_* that was causing a bug.