-
Notifications
You must be signed in to change notification settings - Fork 616
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] JAX JIT errors-out when using a tracer-based PRNG key #6054
Comments
@trbromley do you have more details on this?
This PR might be relevant, which adds a seed kwarg to qjit: PennyLaneAI/catalyst#936 |
When I run this: import jax
import pennylane as qml
def circuit(key, param):
dev = qml.device("lightning.qubit", wires=2, shots=10, seed=key)
@qml.qnode(dev, interface="jax")
def my_circuit():
qml.RX(param, wires=0)
qml.CNOT(wires=[0, 1])
return qml.sample(qml.PauliZ(0))
return my_circuit()
key = jax.random.PRNGKey(1967) I get this error: >>> qml.qjit(circuit)(key, 0.5)
TypeError: SeedSequence expects int or sequence of ints for entropy not Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=1/0)>
Oh interesting, so the seed should live with |
The PR was motivated by a technical problem: we have tests in the catalyst code base which are flaky, but no qjit-compatible way of specifying a seed. So passing it through the |
Ah yes I guess this is to be expected, since JAX seeds are not supported by MLIR/LLVM |
I'm reviving this issue. The latest version of the JAX interface docs (Randomness section) shows an example using See the example here:
However when you remove Full traceback below:
This error still occurs when the measurement is The error can be avoided by using a fixed seed within the jitted circuit, or commenting out all of the gates that take parameters, which is of course not ideal. The error was brought up in this Forum thread. |
There's a way to pass an integer key as an argument by using
|
Expected behavior
It is possible to JAX JIT a function where
qml.device
hasseed
set to be a tracer for the PRNG key.Actual behavior
An error is raised in many cases. I'm not exactly sure when, but it appears to be if there are other tracers active - like with the
param
variable below.There is also an error when using
qml.qjit
andlightning.qubit
, which doesn't expect a JAX tracer for theseed
argument.Additional information
Although there is a draft PR to fix this problem, it has gone stale and I decided it would be more useful to track the problem in this issue.
Source code
Tracebacks
System information
Existing GitHub issues
The text was updated successfully, but these errors were encountered: