Skip to content

Commit

Permalink
Merge pull request #91 from e3nn/so3
Browse files Browse the repository at this point in the history
Add argmax for signals on SO3
  • Loading branch information
ameya98 authored Dec 3, 2024
2 parents 941afa1 + 8be4994 commit c47ace2
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 16 deletions.
64 changes: 48 additions & 16 deletions e3nn_jax/_src/so3grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def res_beta(self) -> int:
def res_alpha(self) -> int:
return self.s2_signals.res_alpha

@property
def grid_values(self) -> jnp.ndarray:
return self.s2_signals.grid_values

@property
def res_theta(self) -> int:
return self.s2_signals.shape[-3]
Expand Down Expand Up @@ -124,13 +128,43 @@ def __mul__(self, other: Union[float, "SO3Signal"]) -> "SO3Signal":
def __truediv__(self, other: float) -> "SO3Signal":
return self * (1 / other)

def vmap_over_batch_dims(
self, func: Callable[..., jnp.ndarray]
) -> Callable[..., jnp.ndarray]:
"""Apply a function to the signal while preserving the batch dimensions."""
for _ in range(len(self.batch_dims)):
func = jax.vmap(func)
return func

def argmax(self) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]:
"""Find the rotation (and corresponding grid indices) with the maximum value of the signal."""
# Get flattened argmax
flat_index = jnp.argmax(self.grid_values.reshape(*self.shape[:-3], -1), axis=-1)

# Convert flat index back to indices for theta, beta, alpha
theta_idx, beta_idx, alpha_idx = jnp.unravel_index(flat_index, self.shape[-3:])

# Compute axis.
axis = self.s2_signals.grid_vectors[..., beta_idx, alpha_idx, :]
assert axis.shape == (*self.batch_dims, 3)

# Compute angle.
angle = self.grid_theta[theta_idx]
assert angle.shape == (*self.batch_dims,)

Rs = self.vmap_over_batch_dims(e3nn.axis_angle_to_matrix)(axis, angle)
assert Rs.shape == (*self.batch_dims, 3, 3)

return Rs, (theta_idx, beta_idx, alpha_idx)

def replace_values(self, grid_values: jnp.ndarray) -> "SO3Signal":
"""Replace the values of the signal with the given grid_values."""
return SO3Signal(self.s2_signals.replace_values(grid_values))

def integrate_over_angles(self) -> SphericalSignal:
"""Integrate the signal over the angles in the axis-angle parametrization."""
# Account for angle-dependency in Haar measure.
grid_values = (
self.s2_signals.grid_values
* (1 - jnp.cos(self.grid_theta))[..., None, None]
)
grid_values = self.grid_values * (1 - jnp.cos(self.grid_theta))[..., None, None]

# Trapezoidal rule for integration.
delta_theta = self.grid_theta[1] - self.grid_theta[0]
Expand All @@ -141,15 +175,15 @@ def integrate_over_angles(self) -> SphericalSignal:
def integrate(self) -> float:
"""Numerically integrate the signal over SO(3)."""
# Integrate over angles.
s2_signal_integrated = self.integrate_over_angles()
assert s2_signal_integrated.shape == (
sig_integrated = self.integrate_over_angles()
assert sig_integrated.shape == (
*self.batch_dims,
self.res_beta,
self.res_alpha,
)

# Integrate over axes using S2 quadrature.
integral = s2_signal_integrated.integrate().array.squeeze(-1)
integral = sig_integrated.integrate().array.squeeze(-1)
assert integral.shape == self.batch_dims

# Factor of 8pi^2 from the Haar measure.
Expand All @@ -159,22 +193,22 @@ def integrate(self) -> float:
def sample(self, rng: jax.random.PRNGKey) -> jnp.ndarray:
"""Sample a random rotation from SO(3) using the given probability distribution."""
# Integrate over angles.
s2_signal_integrated = self.integrate_over_angles()
assert s2_signal_integrated.shape == (
sig_integrated = self.integrate_over_angles()
assert sig_integrated.shape == (
*self.batch_dims,
self.res_beta,
self.res_alpha,
)

# Sample the axis from the S2 signal (integrated over angles).
axis_rng, rng = jax.random.split(rng)
beta_idx, alpha_idx = s2_signal_integrated.sample(axis_rng)
axis = s2_signal_integrated.grid_vectors[..., beta_idx, alpha_idx, :]
beta_idx, alpha_idx = sig_integrated.sample(axis_rng)
axis = sig_integrated.grid_vectors[..., beta_idx, alpha_idx, :]
assert axis.shape == (*self.batch_dims, 3)

# Choose the angle from the distribution conditioned on the axis.
angle_rng, rng = jax.random.split(rng)
theta_probs = self.s2_signals.grid_values[..., beta_idx, alpha_idx]
theta_probs = self.grid_values[..., beta_idx, alpha_idx]
assert theta_probs.shape == (*self.batch_dims, self.res_theta)

# Avoid log(0) by replacing 0 with a small value.
Expand All @@ -185,8 +219,6 @@ def sample(self, rng: jax.random.PRNGKey) -> jnp.ndarray:
angle = jnp.linspace(0, 2 * jnp.pi, self.res_theta)[theta_idx]
assert angle.shape == (*self.batch_dims,)

axis_angle_to_matrix = e3nn.axis_angle_to_matrix
for _ in range(len(self.batch_dims)):
axis_angle_to_matrix = jax.vmap(axis_angle_to_matrix)
Rs = axis_angle_to_matrix(axis, angle)
Rs = self.vmap_over_batch_dims(e3nn.axis_angle_to_matrix)(axis, angle)
assert Rs.shape == (*self.batch_dims, 3, 3)
return Rs
23 changes: 23 additions & 0 deletions tests/_src/so3grid_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,26 @@ def test_division_scalar():
sig2 = sig1 / 2.7
integral2 = sig2.integrate()
assert jnp.isclose(integral2, integral1 / 2.7)


@pytest.mark.parametrize("seed", [0, 1, 2])
def test_argmax(seed: int):
rng = jax.random.PRNGKey(seed)
F = jax.random.normal(rng, (3, 3))

func = lambda R: jnp.exp(jnp.trace(F.T @ R))
sig = SO3Signal.from_function(
func,
res_beta=50,
res_alpha=50,
res_theta=50,
quadrature="gausslegendre",
)

U, S, VT = jnp.linalg.svd(F)
R_argmax_expected = (
U @ jnp.diag(jnp.asarray([1.0, 1.0, jnp.linalg.det(U @ VT)])) @ VT
)
R_argmax, _ = sig.argmax()

assert jnp.allclose(func(R_argmax), func(R_argmax_expected), rtol=1e-2)

0 comments on commit c47ace2

Please sign in to comment.