-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for qml.StatePrep
on lightning.qubit
with Catalyst
#1160
Comments
Hi @joeybarreto, unfortunately this is a known bug in Catalyst at the moment; see #1065 for more details. In short, this is because without Instead, when We are working to fix this, for now, here is a workaround: Nq = 2
init_state = np.array([1, 0, 0, 0])
dev = qml.device('lightning.qubit', wires=Nq)
@qml.qnode(dev)
def qnode_test(angles, init_state):
qml.BasisState.compute_decomposition(init_state, wires=range(Nq))
qml.RY(angles[0], wires=0)
qml.RY(angles[1], wires=1)
return qml.expval(qml.PauliZ(0))
@qjit
def jgrad_qjit(angles):
g = grad(qnode_test, argnums=0)
return g(angles, init_state) >>> angles = jnp.array([0.1,0.2])
>>> jax.grad(qnode_test)(angles, init_state=init_state)
Array([0.09983342, 0. ], dtype=float64)
>>> jgrad_qjit(angles)
Array([ 9.98334166e-02, -5.55111512e-17], dtype=float64) Note that I have swapped |
Thanks, yeah I had seen #1065 but wasn't 100% sure if this was the same issue. In reality my initial states are pretty complicated and not just something like I do have a follow-up question though, regarding the execution time of the two compiled functions above (whether or not Nq = 18
def test(angles):
for kk in range(5):
for ii in range(Nq):
qml.RY(angles[ii], wires=ii)
for ii in range(0,Nq,2):
qml.CNOT(wires=[ii % Nq, (ii+1) % Nq])
return qml.expval(qml.PauliZ(0))
qnode_test = qml.QNode(test,
qml.device('lightning.qubit', wires=Nq),
interface='jax',
diff_method='best')
@qjit
def jgrad_qjit(angles):
g = grad(qnode_test)
return g(angles)
jgrad = jax.jit(jax.grad(qnode_test))
angles = np.array([0.1]*Nq) After compilation, one call to
|
Ah I see! You could try this approach: Nq = 2
init_state = np.array([1, 0, 0, 0])
dev = qml.device('lightning.qubit', wires=Nq)
@qml.qnode(dev)
def qnode_test(angles):
qml.StatePrep.compute_decomposition(init_state, wires=range(Nq))
qml.RY(angles[0], wires=0)
qml.RY(angles[1], wires=1)
return qml.expval(qml.PauliZ(0))
@qjit
def jgrad_qjit(angles):
g = grad(qnode_test, argnums=0)
return g(angles) Note a couple of things here:
Both of these are important, as it means that the state decomposition will happen in Python -- JAX will not try to compile it, which we have noticed is significantly costly. |
Thank you for the suggestion, I just tried that on my minimal 14 qubit example (which involves a lot of project code not shown above), after ~5 minutes compiling the gradient function has yet to finish. (Also, I make sure to only pass in initial states as numpy arrays and to provide them during partial completions before creating my qnodes, so I don't expect that JAX is tracing them). However, the benchmarking plot I shared above suggests a deeper issue. Note that in that example, I do not use any initial state prep, I am just comparing two different ways of jitting the gradient of the qnode. I find that Catalyst (via |
Hi @joeybarreto, thanks for sharing that benchmark! I think the difference comes from an unfortunate default value for the |
Thanks for pointing that out @dime10 ! Where is the default value specified? It seems like qnode_test = qml.QNode(test,
qml.device('lightning.qubit', wires=Nq),
interface='jax',
diff_method='best') |
Sorry, default may not have been the best term, what I meant is the value determined to be used when catalyst/frontend/catalyst/jax_primitives.py Line 573 in 934726f
We'll be sure to update this soon for better default performance. |
In the code below, I create a
test
circuit and then compute its gradient in two ways. The first way just appliesjax.jit(jax.grad(...))
to a partial completion of the qnode after an initial state has been supplied. The second uses the Catalyst@qjit
decorator andgrad
function instead.jgrad
succeeds, butjgrad_qjit
fails with the errorDifferentiableCompileError: StatePrep is non-differentiable on 'lightning.qubit' device
. I'm not sure whether this is a bug or expected behavior, but ifqml.StatePrep
works on the lightning backend without using Catalyst, I'm not sure why it would fail here. How hard would it be to add support for arbitrary state prep when using Catalyst?The full error message is
qml.about():
The text was updated successfully, but these errors were encountered: