diff --git a/e3nn_jax/_src/so3grid.py b/e3nn_jax/_src/so3grid.py index 6c94c77..f02c45e 100644 --- a/e3nn_jax/_src/so3grid.py +++ b/e3nn_jax/_src/so3grid.py @@ -125,7 +125,21 @@ def __mul__(self, other: Union[float, "SO3Signal"]) -> "SO3Signal": return SO3Signal(self.s2_signals * other) - def __truediv__(self, other: float) -> "SO3Signal": + def __rmul__(self, other: float) -> "SO3Signal": + return self * other + + def __neg__(self) -> "SO3Signal": + return self * -1 + + def __truediv__(self, other: Union[float, "SO3Signal"]) -> "SO3Signal": + if isinstance(other, SO3Signal): + if self.shape != other.shape: + raise ValueError( + f"Shapes of the two signals do not match: {self.shape} != {other.shape}" + ) + + return self.replace_values(self.grid_values / other.grid_values) + return self * (1 / other) def apply(self, func: Callable[..., jnp.ndarray]) -> "SO3Signal":