From 08a7e8187e8a9dc6716b3e4ab049b6de91b27ab4 Mon Sep 17 00:00:00 2001 From: Ameya Daigavane Date: Thu, 5 Dec 2024 15:06:33 -0500 Subject: [PATCH] Add apply() for SO3Signal. --- e3nn_jax/_src/so3grid.py | 4 ++++ tests/_src/so3grid_test.py | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/e3nn_jax/_src/so3grid.py b/e3nn_jax/_src/so3grid.py index ef5b8a7..6c94c77 100644 --- a/e3nn_jax/_src/so3grid.py +++ b/e3nn_jax/_src/so3grid.py @@ -128,6 +128,10 @@ def __mul__(self, other: Union[float, "SO3Signal"]) -> "SO3Signal": def __truediv__(self, other: float) -> "SO3Signal": return self * (1 / other) + def apply(self, func: Callable[..., jnp.ndarray]) -> "SO3Signal": + """Apply a pointwise function to the signal.""" + return SO3Signal(self.s2_signals.apply(func)) + def vmap_over_batch_dims( self, func: Callable[..., jnp.ndarray] ) -> Callable[..., jnp.ndarray]: diff --git a/tests/_src/so3grid_test.py b/tests/_src/so3grid_test.py index 6a8732f..f973377 100644 --- a/tests/_src/so3grid_test.py +++ b/tests/_src/so3grid_test.py @@ -119,3 +119,22 @@ def test_argmax(seed: int): R_argmax, _ = sig.argmax() assert jnp.allclose(func(R_argmax), func(R_argmax_expected), rtol=1e-2) + + +def test_apply(): + sig = SO3Signal.from_function( + lambda R: jnp.trace(R @ R), + res_beta=40, + res_alpha=39, + res_theta=40, + quadrature="gausslegendre", + ) + sig_applied = sig.apply(jnp.exp) + sig_expected = SO3Signal.from_function( + lambda R: jnp.exp(jnp.trace(R @ R)), + res_beta=40, + res_alpha=39, + res_theta=40, + quadrature="gausslegendre", + ) + assert jnp.allclose(sig_applied.grid_values, sig_expected.grid_values)