jit[just in time] compilation issue with radius_graph function #74
Unanswered
VinaySingh561
asked this question in
Q&A
Replies: 1 comment 2 replies
-
Thanks for the question. You need to specify a static key = jax.random.PRNGKey(0)
pos = jax.random.normal(key, (20, 3))
jax.jit(e3nn.radius_graph, static_argnames=["size"])(pos, r_max=0.8, batch=None, size=10) but right now it seems like the |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi, I am trying to use JIT compilation for the radius graph in my research work. Below is a trivial example demonstrating the problem:
key = jax.random.PRNGKey(0) pos = jax.random.normal(key, (20, 3)) batch = jnp.arange(20) < 10 jax.jit(radius_graph)(pos, 0.8, batch=batch)
And I am getting the following error ::
""""ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
The error occurred while tracing the function radius_graph at /usr/local/lib/python3.10/dist-packages/e3nn_jax/_src/radius_graph.py:9 for jit. This concrete value was not available in Python because it depends on the values of the arguments pos and r_max.""""
Beta Was this translation helpful? Give feedback.
All reactions