Discrepancy in output Irreps of e3nn.haiku.Linear with and without reinitializing IrrepsArray #75
-
I encountered an interesting behavior when working with IrrepsArray and the linear layer from the e3nn.haiku module. Specifically, I noticed that the output irreps (y.irreps) differ based on whether I reinitialize the IrrepsArray after extending it with zeros. Below is an example of the code that shows the discrepancy: import jax
import jax.numpy as jnp
import haiku as hk
import e3nn_jax as e3nn
@hk.without_apply_rng
@hk.transform
def linear(x):
x = x.extend_with_zeros("2x0e+2x1e+2x1o")
# Uncommenting the next line changes the output irreps
#x = e3nn.IrrepsArray(x.irreps, x.array)
return e3nn.haiku.Linear(irreps_out="6x0e+2x1e", biases=True)(x)
x = e3nn.IrrepsArray("2x0e+2x1o", jnp.ones((4, 8)))
params = linear.init(jax.random.PRNGKey(0), x)
y = linear.apply(params, x)
print(y.irreps) Without the line Is this behavior expected? Any insights or explanations regarding this behavior would be appreciated. Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Sorry about the inconvenience. Yes it's expected, though I understand that it can be confusing. The reason is that IrrepsArray has an attribute to tag is a "chunk" is known to be strictly zero. When you uncomment that line you drop that attribute which changes the behavior of Linear. Linear removes the zero chunks before determining the reachable output irreps. I introduced this mechanism because originally Linear would sometimes be asked to produce irreps that are not present in the input, to avoid unnecessary computations I introduced that tag such that the output caries that information with it. Linear has a force argument that allows to force the desired output irreps even if they are unreadable. |
Beta Was this translation helpful? Give feedback.
Sorry about the inconvenience. Yes it's expected, though I understand that it can be confusing.
The reason is that IrrepsArray has an attribute to tag is a "chunk" is known to be strictly zero.
When you uncomment that line you drop that attribute which changes the behavior of Linear. Linear removes the zero chunks before determining the reachable output irreps.
I introduced this mechanism because originally Linear would sometimes be asked to produce irreps that are not present in the input, to avoid unnecessary computations I introduced that tag such that the output caries that information with it.
Later I realized that a better solution was to avoid Linear to output those zero chunks and…