Skip to content

Commit

Permalink
sh
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Jun 26, 2022
1 parent f4e750f commit 2a8dfde
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 12 deletions.
3 changes: 2 additions & 1 deletion ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]
### Added
- add the `sh` function that does not `IrrepsData` as input/output
- `legendre` algorithm to compute spherical harmonics
- add flag `algorithm` to specify the algorithm to use for computing spherical harmonics
- add flag `algorithm` to specify the algorithm to use for computing spherical harmonics, use `legendre` for large L.
- `experimental.voxel_convolution`: add optional dynamic steps (not static for jit)

## [0.6.1] - 2022-06-09
Expand Down
4 changes: 3 additions & 1 deletion e3nn_jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from ._so3 import clebsch_gordan, wigner_D, generators
from ._instruction import Instruction
from ._irreps import Irrep, MulIrrep, Irreps, IrrepsData
from ._spherical_harmonics import spherical_harmonics, set_default_spherical_harmonics_algorithm
from ._spherical_harmonics import spherical_harmonics, sh, set_default_spherical_harmonics_algorithm, legendre
from ._soft_one_hot_linspace import sus, soft_one_hot_linspace
from ._linear import FunctionalLinear, Linear
from ._core_tensor_product import FunctionalTensorProduct
Expand Down Expand Up @@ -95,7 +95,9 @@
"Irreps",
"IrrepsData",
"spherical_harmonics",
"sh",
"set_default_spherical_harmonics_algorithm",
"legendre",
"sus",
"soft_one_hot_linspace",
"FunctionalLinear",
Expand Down
82 changes: 74 additions & 8 deletions e3nn_jax/_spherical_harmonics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import fractions
import math
from functools import partial
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Sequence, Tuple, Union

import jax
import jax.numpy as jnp
Expand All @@ -13,16 +13,70 @@
from e3nn_jax import Irreps, IrrepsData, clebsch_gordan
from e3nn_jax.util.sympy import sqrtQarray_to_sympy

DEFAULT_SPHERICAL_HARMONICS_ALGORITHM = ("recursive", "sparse", "custom_vjp")
DEFAULT_SPHERICAL_HARMONICS_ALGORITHM = None


def set_default_spherical_harmonics_algorithm(algorithm: Tuple[str]):
global DEFAULT_SPHERICAL_HARMONICS_ALGORITHM
DEFAULT_SPHERICAL_HARMONICS_ALGORITHM = algorithm


def sh(
irreps_out: Union[Irreps, int, Sequence[int]],
input: jnp.ndarray,
normalize: bool,
normalization: str = "integral",
*,
algorithm: Tuple[str] = None,
) -> jnp.ndarray:
r"""Spherical harmonics
.. image:: https://user-images.githubusercontent.com/333780/79220728-dbe82c00-7e54-11ea-82c7-b3acbd9b2246.gif
| Polynomials defined on the 3d space :math:`Y^l: \mathbb{R}^3 \longrightarrow \mathbb{R}^{2l+1}`
| Usually restricted on the sphere (with ``normalize=True``) :math:`Y^l: S^2 \longrightarrow \mathbb{R}^{2l+1}`
| who satisfies the following properties:
* are polynomials of the cartesian coordinates ``x, y, z``
* is equivariant :math:`Y^l(R x) = D^l(R) Y^l(x)`
* are orthogonal :math:`\int_{S^2} Y^l_m(x) Y^j_n(x) dx = \text{cste} \; \delta_{lj} \delta_{mn}`
The value of the constant depends on the choice of normalization.
It obeys the following property:
.. math::
Y^{l+1}_i(x) &= \text{cste}(l) \; & C_{ijk} Y^l_j(x) x_k
\partial_k Y^{l+1}_i(x) &= \text{cste}(l) \; (l+1) & C_{ijk} Y^l_j(x)
Where :math:`C` are the `clebsch_gordan`.
.. note::
This function match with this table of standard real spherical harmonics from Wikipedia_
when ``normalize=True``, ``normalization='integral'`` and is called with the argument in the order ``y,z,x``
(instead of ``x,y,z``).
.. _Wikipedia: https://en.wikipedia.org/wiki/Table_of_spherical_harmonics#Real_spherical_harmonics
Args:
irreps_out (`Irreps` or int or Sequence[int]): the output irreps
input (`jnp.ndarray`): cartesian coordinates
normalize (bool): if True, the polynomials are restricted to the sphere
normalization (str): normalization of the constant :math:`\text{cste}`. Default is 'integral'
algorithm (Tuple[str]): algorithm to use for the computation. (legendre|recursive, dense|sparse, [custom_vjp])
Returns:
`jnp.ndarray`: polynomials of the spherical harmonics
"""
input = IrrepsData.from_contiguous("1e", input)
return spherical_harmonics(irreps_out, input, normalize, normalization, algorithm=algorithm).contiguous


def spherical_harmonics(
irreps_out: Union[Irreps, int],
irreps_out: Union[Irreps, int, Sequence[int]],
input: Union[IrrepsData, jnp.ndarray],
normalize: bool,
normalization: str = "integral",
Expand Down Expand Up @@ -69,25 +123,37 @@ def spherical_harmonics(
algorithm (Tuple[str]): algorithm to use for the computation. (legendre|recursive, dense|sparse, [custom_vjp])
Returns:
`jnp.ndarray`: polynomials of the spherical harmonics
`IrrepsData`: polynomials of the spherical harmonics
"""
assert normalization in ["integral", "component", "norm"]

if algorithm is None:
algorithm = DEFAULT_SPHERICAL_HARMONICS_ALGORITHM
assert all(keyword in ["legendre", "recursive", "dense", "sparse", "custom_vjp"] for keyword in algorithm)

if isinstance(irreps_out, int):
l = irreps_out
assert isinstance(input, IrrepsData)
[(mul, ir)] = input.irreps
irreps_out = Irreps([(1, (l, ir.p**l))])

if all(isinstance(l, int) for l in irreps_out):
assert isinstance(input, IrrepsData)
[(mul, ir)] = input.irreps
irreps_out = Irreps([(1, (l, ir.p**l)) for l in irreps_out])

irreps_out = Irreps(irreps_out)

assert all([l % 2 == 1 or p == 1 for _, (l, p) in irreps_out])
assert len(set([p for _, (l, p) in irreps_out if l % 2 == 1])) <= 1

if algorithm is None:
if DEFAULT_SPHERICAL_HARMONICS_ALGORITHM is None:
if irreps_out.lmax <= 8:
algorithm = ("recursive", "sparse", "custom_vjp")
else:
algorithm = ("legendre", "sparse", "custom_vjp")
else:
algorithm = DEFAULT_SPHERICAL_HARMONICS_ALGORITHM

assert all(keyword in ["legendre", "recursive", "dense", "sparse", "custom_vjp"] for keyword in algorithm)

if isinstance(input, IrrepsData):
[(mul, ir)] = input.irreps
assert mul == 1
Expand Down
7 changes: 5 additions & 2 deletions tests/spherical_harmonics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def test_check_grads(keys, irreps, normalization):
@pytest.mark.parametrize("l", range(7 + 1))
def test_normalize(keys, l):
x = jax.random.normal(keys[0], (10, 3))
y1 = e3nn.spherical_harmonics([l], x, normalize=True).contiguous * jnp.linalg.norm(x, axis=1, keepdims=True) ** l
y2 = e3nn.spherical_harmonics([l], x, normalize=False).contiguous
y1 = (
e3nn.spherical_harmonics(e3nn.Irreps([l]), x, normalize=True).contiguous
* jnp.linalg.norm(x, axis=1, keepdims=True) ** l
)
y2 = e3nn.spherical_harmonics(e3nn.Irreps([l]), x, normalize=False).contiguous
np.testing.assert_allclose(y1, y2, atol=1e-6, rtol=1e-5)

0 comments on commit 2a8dfde

Please sign in to comment.