You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Jax plugin configuration error: Exception when calling jax_plugins.xla_cuda12.initialize()
Traceback (most recent call last):
File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 623, in discover_pjrt_plugins
plugin_module.initialize()
File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax_plugins/xla_cuda12/__init__.py", line 83, in initialize
xla_client.register_custom_call_handler(
File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jaxlib/xla_client.py", line 633, in register_custom_call_handler
handler(name, fn, xla_platform_name, api_version)
TypeError: register_custom_call_target(): incompatible functionarguments. The following argument types are supported:
1. register_custom_call_target(c_api: capsule, fn_name: str, fn: capsule, xla_platform_name: str, api_version: int = 0) -> None
Invoked with types: PyCapsule, bytes, PyCapsule, str, int
Traceback (most recent call last):
File "/home/mila/d/darshan.patil/research/NSRL/test.py", line 106, in<module>gym_sync_step()
File "/home/mila/d/darshan.patil/research/NSRL/test.py", line 56, in gym_sync_step
run_actor_loop(100, (handle, states))
File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/pjit.py", line 305, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/pjit.py", line 182, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/core.py", line 2789, inbindreturn self.bind_with_trace(top_trace, args, params)
File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/core.py", line 391, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/core.py", line 879, in process_primitive
return primitive.impl(*tracers, **params)
File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/pjit.py", line 1552, in _pjit_call_impl
return xc._xla.pjit(
File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/pjit.py", line 1534, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/pjit.py", line 1464, in _pjit_call_impl_python
inline=inline, lowering_parameters=mlir.LoweringParameters()).compile()
File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2378, in compile
executable = UnloadedMeshExecutable.from_hlo(
File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2914, in from_hlo
xla_executable = _cached_compilation(
File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2726, in _cached_compilation
xla_executable = compiler.compile_or_get_cached(
File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/compiler.py", line 333, in compile_or_get_cached
return _compile_and_write_cache(
File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/compiler.py", line 504, in _compile_and_write_cache
executable = backend_compile(
File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper
return func(*args, **kwargs)
File "/network/scratch/d/darshan.patil/envs/conda/temp/lib/python3.10/site-packages/jax/_src/compiler.py", line 238, in backend_compile
return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: UNIMPLEMENTED: No registered implementation for custom call to AtariGymEnvPool_140505772519360_send_gpu for platform CUDA
Expected behavior
The example code to run without issues
Screenshots
If applicable, add screenshots to help explain your problem.
Describe the bug
Using XLA interface crashes program.
To Reproduce
Run the example code in envpool/examples/xla_step.py as
This results in the following error:
Expected behavior
The example code to run without issues
Screenshots
If applicable, add screenshots to help explain your problem.
System info
This was the extent of my setup:
Above code prints:
Additional context
Add any other context about the problem here.
Reason and Possible fixes
N/A
Checklist
The text was updated successfully, but these errors were encountered: