diff --git a/e3nn_jax/_src/s2grid.py b/e3nn_jax/_src/s2grid.py index 1904cf5..2ab90d2 100644 --- a/e3nn_jax/_src/s2grid.py +++ b/e3nn_jax/_src/s2grid.py @@ -912,9 +912,9 @@ def _from_s2grid_s2fft( """An S2FFT powered version of e3nn_jax.from_s2grid.""" lmax = irreps.lmax expected_grid_resolution = get_s2fft_grid_resolution(lmax) - if sig.shape != expected_grid_resolution: + if sig.grid_resolution != expected_grid_resolution: raise ValueError( - f"Input signal shape {sig.shape} does not match the required shape {expected_grid_resolution}" + f"Input signal resolution {sig.grid_resolution} does not match the required resolution {expected_grid_resolution}." ) with jax.ensure_compile_time_eval():