Skip to content

Commit

Permalink
Merge pull request #95 from e3nn/so3
Browse files Browse the repository at this point in the history
Add division and negation for SO3Signal.
  • Loading branch information
ameya98 authored Dec 15, 2024
2 parents 7ed1831 + 4017473 commit 6f6319d
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion e3nn_jax/_src/so3grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,21 @@ def __mul__(self, other: Union[float, "SO3Signal"]) -> "SO3Signal":

return SO3Signal(self.s2_signals * other)

def __truediv__(self, other: float) -> "SO3Signal":
def __rmul__(self, other: float) -> "SO3Signal":
return self * other

def __neg__(self) -> "SO3Signal":
return self * -1

def __truediv__(self, other: Union[float, "SO3Signal"]) -> "SO3Signal":
if isinstance(other, SO3Signal):
if self.shape != other.shape:
raise ValueError(
f"Shapes of the two signals do not match: {self.shape} != {other.shape}"
)

return self.replace_values(self.grid_values / other.grid_values)

return self * (1 / other)

def apply(self, func: Callable[..., jnp.ndarray]) -> "SO3Signal":
Expand Down

0 comments on commit 6f6319d

Please sign in to comment.