-
Notifications
You must be signed in to change notification settings - Fork 617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Qiskit devices don't support vmap parameter broadcasting #5240
Comments
Thanks for opening this issue @lauracappelli . We'll try and get a fix in for the next release coming out on March 5th. The problematic line of code is this:
This works for In the short-term, we just update the above line to be:
Though in the longer term, we should rethink |
@albi3ro would this bug also impact lightning.qubit? |
@josh146 Yes it does. Minimal non-working example:
|
…ng (#5286) Fixes #5240 [sc-57137] [sc-57848] Fixes #5289 Basically when we set `vectorized=True` inside the `pure_callback` call, we assumed that the device natively supports broadcasting. And then we only tested with devices that did indeed natively support parameter broadcasting. This problem was made worse by the fact that our `vmap` tests included a `Hamiltonian` expectation value, which caused up to skip many of the test cases that we really should have been testing. So I got rid of the `Hamiltonian` from the test so we could actually test more situations. I also added more `lightning.qubit` tests to the test configuration. That forced one or two other changes. The major problem with `jax.vmap` is that it adds in a parameter broadcasting dimension *after* we have already handled all of our preprocessing and breaking up parameter broadcasting. --------- Co-authored-by: Josh Izaac <[email protected]> Co-authored-by: Nathan Killoran <[email protected]> Co-authored-by: Matthew Silverman <[email protected]> Co-authored-by: Astral Cai <[email protected]> Co-authored-by: Mikhail Andrenkov <[email protected]> Co-authored-by: Korbinian Kottmann <[email protected]> Co-authored-by: Thomas R. Bromley <[email protected]> Co-authored-by: Isaac De Vlugt <[email protected]> Co-authored-by: Isaac De Vlugt <[email protected]> Co-authored-by: Pietropaolo Frisoni <[email protected]> Co-authored-by: soranjh <[email protected]> Co-authored-by: Mudit Pandey <[email protected]>
Expected behavior
I'm trying to use the Qiskit plugin in a neural network defined in JAX. One of the network layers is a quantum circuit written in Pennylane and called with the vmap function. I have posted a simplified version of my code (useful for reproducibility purposes) in the Xanadu Discussion Forum at this link. I expect that the circuit is called for each element of the input and the result contains the values obtained for all the calls.
Actual behavior
The initialization of the training state fails with the error:
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: CircuitError: "Invalid param type <class 'list'> for gate ry."
Additional information
As written by Christina in the forum mentioned above, your handling of vmap currently assumes that the device natively supports parameter broadcasting, which is only true is a limited subsection of devices.
A more minimal example of the problem is:
Source code
Tracebacks
System information
Existing GitHub issues
The text was updated successfully, but these errors were encountered: