You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "C:\code\vqeac\mve.py", line 38, in<module>
gradient = jax.grad(circuit_2)(h2.hf_state, theta)
TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int64. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True.
The issue is that BasisState is fundamentally a discrete operation. The inputs can either be zero or one, corresponding to applying an x gate on that wire or not. We fundamentally cannot differentiate a parameter that is not continuous.
If you want to differentiate with respect to only the differentiable theta, you can do:
You could also using convert the hf_state to a full state vector and use StatePrep/ MottonenStatePrep instead. While not always fully differentiable, you should have better luck with it then BasisState.
Expected behavior
The two snippet below should behave identically
Actual behavior
Passing a
state
param cause an errorAdditional information
Comment/uncomment the 3rd/4th LOC from the last line in the source code. One should be able to replicate this issue
Source code
Tracebacks
System information
Existing GitHub issues
The text was updated successfully, but these errors were encountered: