Releases: e3nn/e3nn-jax
2023-06-22
Highlight
e3nn.utils.vmap
allow to overcome the security to drop .zero_flags
for the case of vmap
.
Consider an irreps array with the 0o entry set to None, its zero_flags
attribute will be (False, True)
:
x = e3nn.from_chunks("0e + 0o", [jnp.ones((100, 1, 1)), None], (100,))
x.zero_flags # (False, True)
Now if we vmap
a function using jax.vmap
, the internal function will not get the info that the 0o entry is actually zero.
jax.vmap(e3nn.norm)(x).zero_flags # (False, False)
This is a security because not all transformations conserve the validity of zero_flags
, take for instance:
jax.tree_util.tree_map(lambda x: x + 1.0, x).zero_flags # (False, False)
However for the case of vmap
, vectorization does preserves the validity of zero_flags
, in this case we can allow it to propagate in-out the vectorized function:
e3nn.utils.vmap(e3nn.norm)(x).zero_flags # (False, True)
Changelog
Added
e3nn.utils.vmap
to propagatezero_flags
in the vectorized function.
Changed
- Simplify the tetris examples
Fixed
- Example of what is fixed: assume
x.ndim = 2
, allowx[:, None]
but preventx[:, :, None]
andx[..., None]
2023-06-21
Highlight
There are two main changes in this release: Linear
and IrrepsArray
.
Change in Linear
Now the classes e3nn.flax.Linear
and e3nn.haiku.Linear
by default discard the output irrep that are not reachable from the input.
linear = Linear("2x0e + 1o + 2e")
x = e3nn.normal("0e + 1o") # an input without 2e
w = linear.init(jax.random.PRNGKey(0), x)
linear.apply(w, x).irreps
# 2x0e+1x1o, no 2e because it is not reachable from the input
Change in IrrepsArray
Before this release
IrrepsArray
had its data stored twice, both in .array
and in .list
.
x = e3nn.IrrepsArray("0e + 1o", jnp.array([[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0]]))
x.list # a list with the scalars and the vectors in two separate arrays
x.array # an array with all the data contiguous
the motivation was to allow chunks to be stored as None
, to be strictly zero
e3nn.IrrepsArray.from_list("0e + 1o", [jnp.array([[1.0]]), None], ())
this was leading to confusing situations with the jax.tree_util
module
jax.tree_util.tree_leaves(x)
# a list of 3 arrays with repeated data
Because of that, jax.vmap
was not possible with negative axis
jax.vmap(lambda x: x, in_axes=-2)(x) # ERROR!!
And the gradient was only propagated through one of the two attributes
g = jax.grad(lambda x: e3nn.sum(x)["0e"].array.squeeze())(x)
g.array # is zero
g.list # is not zero
In this release
We refactored IrrepsArray
. It has now only .array
as a data attribute. The .list
attribute is gone.
Therefore jax.tree_util.tree_leaves
have just one array.
x = e3nn.IrrepsArray("0e + 1o", jnp.array([[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0]]))
jax.tree_util.tree_leaves(x) # [jnp.array([[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0]])]
Instead, we have a new attribute .zero_flags
which is a list of booleans indicating whether the corresponding chunk is zero or not.
y = e3nn.from_chunks("0e + 1o", [jnp.array([[1.0]]), None], ())
y.chunks # [jnp.array([[1.0]]), None]
y.zero_flags # [False, True]
.chunks
is the new attribute that replaces .list
(now deprecated).
It has a better name because we already have .slice_by_chunk
.
x.chunks # list of the two chunks
x.slice_by_chunk[:1] # get the first chunk
jax.vmap
can be used with negative axis
jax.vmap(lambda x: x, in_axes=-2)(x)
And the gradient behaves as expected
g = jax.grad(lambda x: e3nn.sum(x)["0e"].array.squeeze())(x)
g.array # expected value
g.chunks # expected value
To avoid any trouble that .zero_flags
might induce in all jax
transformations we drop it when using a transformation.
y = e3nn.from_chunks("0e + 1o", [jnp.array([[1.0]]), None], ())
print(y.zero_flags) # (False, True)
z = jax.jit(lambda x: x)(y)
print(z.zero_flags) # (False, False)
z = jax.tree_util.tree_map(lambda x: x, z)
print(z.zero_flags) # (False, False)
z = jax.vmap(lambda x: x)(z[None, ...])
print(z.zero_flags) # (False, False)
Changelog
Changed
- [BREAKING]
e3nn.flax.Linear
ande3nn.haiku.Linear
now don't output the impossible irreps anymore. To force the output of all irreps, useforce_irreps_out = True
. For instancee3nn.flax.Linear("0e + 1o")("0e")
will now return"0e"
instead of"0e + 1o"
. - [BREAKING]
e3nn.utils.assert_equivariant
has the same signature ase3nn.utils.equivariance_test
- [BREAKING] Move
as_irreps_array
,zeros
andzeros_like
frome3nn.IrrepsArray
toe3nn
- [BREAKING] Move
IrrepsArray.from_list
toe3nn.from_chunks
- [BREAKING] Rename
IrrepsArray.list
intoIrrepsArray.chunks
- [BREAKING] Rename
IrrepsArray.remove_nones
intoIrrepsArray.remove_zero_chunks
e3nn.IrrepsArray
has now only.array
as data attribute.
Added
e3nn.IrrepsArray.rechunk
e3nn.IrrepsArray.zero_flags
a tuple of bools that indicates which chunks are zero
2023-06-19
Highlight
Add set_mul
to Irreps
, note that it's not an in-place operation.
irreps = e3nn.Irreps("0e + 1o")
irreps = irreps.set_mul(2)
# 2x0e+2x1o
Add the option lmax
to Irreps.filter
and IrrepsArray.filter
.
irreps = irreps.filter(lmax=0)
irreps
# 2x0e
e3nn.utils
is now directly accessible as a submodule and has a documentation.
x1 = e3nn.IrrepsArray("1o", jnp.array([1.0, 3.0, 4.0]))
x2 = e3nn.IrrepsArray("1o", jnp.array([0.0, 1.0, 4.0]))
y1, y2 = e3nn.utils.equivariance_test(
e3nn.tensor_product, jax.random.PRNGKey(0), x1, x2
)
# y1 = R x1 otimes R x2
# y2 = R (x1 otimes x2)
Changelog
Changed
- [BREAKING] Renamed
e3nn.util
ine3nn.utils
Added
Irreps.set_mul(int)
to set the multiplicity of all irrepsIrreps.filter(lmax=int)
to filter out irreps withl > lmax
IrrepsArray.filter(lmax=int)
to filter out irreps withl > lmax
IrrepsArray.__radd__
andIrrepsArray.__rsub__
to supportscalar + IrrepsArray
andscalar - IrrepsArray
0 + IrrepsArray
and0 - IrrepsArray
are now always accepted as special cases.- Support for
IrrepsArray / array
- Add
utils
as a submodule
Fixed
e3nn.scatter
operation handle indices withndim > 1
2023-05-10
Highlight
We can now compute the trispectrum for l=4 in about a minute.
irreps = e3nn.Irreps.spherical_harmonics(lmax=4)
q = e3nn.reduced_symmetric_tensor_product_basis(irreps, 4, keep_ir="0e + 0o")
ChangeLog
Added
e3nn.cross
for completeness
Changed
- Optimize
e3nn.reduced_symmetric_tensor_product_basis
, especially for thekeep_ir
argument
2023-05-02
Highlight
Finally implemented the eSCN optimization. That is the optimization of the tensor product with spherical harmonics followed by a linear mix.
old:
sh = e3nn.spherical_harmonics(vector, range(lmax), True)
features = e3nn.tensor_product(features, sh)
linear = e3nn.flax.Linear(irreps_out)
features = linear.apply(w, features)
new: (equivalent but faster)
from e3nn_jax.experimental.linear_shtp import LinearSHTP
conv = LinearSHTP(irreps_out)
features = conv.apply(w, features, vector)
Example of usage:
ChangeLog
Added
LinearSHTP
module implementing the optimized linear mixing of inputs tensor product with spherical harmonicsD_from_axis_angle
to_s2grid
:quadrature="gausslegendre"
by defaultsoft_odd
activation function for odd scalars- more support of arrays implicitely converted into
IrrepsArray
as scalars (i.e. added fewIrrepsArray.as_irreps_array
)
Changed
scalar_activation
simpler to use with default activation functions (a bit like gate)
2023-04-06
x = e3nn.IrrepsArray("0e + 1o", jnp.array([1.0, 1.0, 1.0, 2.0]))
norm_activation(x, [None, jnp.tanh])
1x0e+1x1o [1. 0.8883856 0.8883856 1.7767712]
Changed
e3nn.normalize_function
now uses a deterministic (not pseudorandom) algorithm to compute the normalization factor.
Added
normalize_act
option toe3nn.scalar_activation
ande3nn.gate
. We can now turn the normalization off if we want to.e3nn.norm_activation
as a new activation function.
2023-03-26
This release fix the gradients of xyz_to_angles
: instead of returning NaN on the poles, it now returns zero.
Angles alpha and beta on the sphere:
(x-axis points to the right, y-axis to the top and z-axis comes out of the image)
Gradients of alpha wrt (x, y, z) respectively:
Gradients of beta wrt (x, y, z) respectively:
Fixed
- Fix
NaN
in the gradients ofe3nn.xyz_to_angles
. The gradients are now0
when the input is on the poles.
2023-03-15
e3nn.tensor_product_with_spherical_harmonics
backward propagation is not yet working.
Added
e3nn.dot
: compute the dot product between twoIrrepsArray
per_irrep
argument toe3nn.norm
: compute the norm of each irrep independently ifper_irrep=True
e3nn.tensor_product_with_spherical_harmonics
from https://arxiv.org/pdf/2302.03655.pdf
Changed
__repr__(Irreps())
has been changed from""
to"Irreps()"
Fixed
- spherical harmonics edge case when
output_irreps=Irreps()
2023-02-20
The function e3nn.SphericalSignal.sample
allows to sample points on the sphere according to some arbitrary distribution on a grid.
Added
e3nn.SphericalSignal.sample
to sample a point on the spheree3nn.scatter_max
Changed
- [BREAKING] Removed
e3nn.s2_sum_of_diracs
in favor ofe3nn.s2_dirac
- [BREAKING]
e3nn.grad
now regroups the output by default. It can be disabled withregroup_output=False
Fixed
e3nn.SphericalSignal
arithmetic operationse3nn.Irreps.D_from_angles
computes (again!) the Wigner D matrices using the J matrices for L <= 11. This is faster and more accurate than using the expm.
2022-02-01
import jax.numpy as jnp
import e3nn_jax as e3nn
coeffs = e3nn.IrrepsArray("0e + 1o", jnp.array([1, 2, 0, 0.0]))
signal = e3nn.to_s2grid(coeffs, 50, 69, quadrature="gausslegendre")
import plotly.graph_objects as go
go.Figure([go.Surface(signal.plotly_surface())])
Added
e3nn.SphericalSignal
class to represent signals on the sphereSignal on the Sphere
section in the documentatione3nn.Irreps.D_from_log_coordinates
rotation_angle_from_*
functionse3nn.to_s2point
function
Changed
- Wigner D matrices are computed from the log coordinates which makes 1 instead of 3 calls to
expm
. - [BREAKING]
e3nn.util.assert_output_dtype
renamed toe3nn.util.assert_output_dtype_matches_input_dtype
- [BREAKING] Update
experimental.point_convolution
to use the last changes. - [BREAKING] changed the
e3nn.to_s2grid
ande3nn.from_s2grid
signature and default normalization.
Removed
- [BREAKING] All the
haiku
modules from the main module. They are now in thee3nn.haiku
submodule. - [BREAKING]
e3nn.wigner_D
in favor ofe3nn.Irrep.D_from_*
Fixed
- Removed
jax.jit
decorator toIrreps.D_from_*
that was causing a bug.