Skip to content

Commit

Permalink
Merge pull request #89 from e3nn/so3
Browse files Browse the repository at this point in the history
Minor fixes for SO3Signal
  • Loading branch information
ameya98 authored Dec 2, 2024
2 parents acd916d + 398d835 commit f9ac41f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
9 changes: 6 additions & 3 deletions e3nn_jax/_src/so3grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,18 @@ def __truediv__(self, other: float) -> "SO3Signal":
def integrate_over_angles(self) -> SphericalSignal:
"""Integrate the signal over the angles in the axis-angle parametrization."""
# Account for angle-dependency in Haar measure.
grid_values = self.s2_signals.grid_values * (1 - jnp.cos(self.grid_theta))[..., None, None]
grid_values = (
self.s2_signals.grid_values
* (1 - jnp.cos(self.grid_theta))[..., None, None]
)

# Trapezoidal rule for integration.
delta_theta = self.grid_theta[1] - self.grid_theta[0]
return self.s2_signals.replace_values(
grid_values=jnp.sum(grid_values, axis=-3) * delta_theta
)

def integrate(self) -> SphericalSignal:
def integrate(self) -> float:
"""Numerically integrate the signal over SO(3)."""
# Integrate over angles.
s2_signal_integrated = self.integrate_over_angles()
Expand All @@ -153,7 +156,7 @@ def integrate(self) -> SphericalSignal:
integral = integral / (8 * jnp.pi**2)
return integral

def sample(self, rng: jax.random.PRNGKey):
def sample(self, rng: jax.random.PRNGKey) -> jnp.ndarray:
"""Sample a random rotation from SO(3) using the given probability distribution."""
# Integrate over angles.
s2_signal_integrated = self.integrate_over_angles()
Expand Down
1 change: 0 additions & 1 deletion tests/_src/so3grid_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,3 @@ def test_division_scalar():
sig2 = sig1 / 2.7
integral2 = sig2.integrate()
assert jnp.isclose(integral2, integral1 / 2.7)

0 comments on commit f9ac41f

Please sign in to comment.