Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Adding a hf_state parameter to optimize causes JAX error #5482

Closed
1 task done
minhtriet opened this issue Apr 8, 2024 · 2 comments
Closed
1 task done

[BUG] Adding a hf_state parameter to optimize causes JAX error #5482

minhtriet opened this issue Apr 8, 2024 · 2 comments
Labels
bug 🐛 Something isn't working

Comments

@minhtriet
Copy link
Contributor

minhtriet commented Apr 8, 2024

Expected behavior

The two snippet below should behave identically

@qml.qnode(dev)
def circuit(param):
    qml.BasisState(hf, wires=range(qubits))
    qml.DoubleExcitation(param, wires=[0, 1, 2, 3])
    return qml.expval(H)
@qml.qnode(dev)
def circuit_2(state, param):
    qml.BasisState(state, wires=range(qubits))
    qml.DoubleExcitation(param, wires=[0, 1, 2, 3])
    return qml.expval(H)

Actual behavior

Passing a state param cause an error

Additional information

Comment/uncomment the 3rd/4th LOC from the last line in the source code. One should be able to replicate this issue

Source code

import jax
from jax import numpy as np
import pennylane as qml
import optax

jax.config.update("jax_platform_name", "cpu")
jax.config.update('jax_enable_x64', True)

h2_dataset = qml.data.load("qchem", molname="H2", bondlength=0.742, basis="STO-3G")
h2 = h2_dataset[0]
H, qubits = h2.hamiltonian, len(h2.hamiltonian.wires)


dev = qml.device("lightning.qubit", wires=qubits)

hf = h2.hf_state

@qml.qnode(dev)
def circuit(param):
    qml.BasisState(hf, wires=range(qubits))
    qml.DoubleExcitation(param, wires=[0, 1, 2, 3])
    return qml.expval(H)
@qml.qnode(dev)
def circuit_2(state, param):
    qml.BasisState(state, wires=range(qubits))
    qml.DoubleExcitation(param, wires=[0, 1, 2, 3])
    return qml.expval(H)

max_iterations = 100
conv_tol = 1e-06
opt = optax.sgd(learning_rate=0.4)
theta = np.array(0.)
angle = [theta]
opt_state = opt.init(theta)

for n in range(max_iterations):
    # gradient = jax.grad(circuit)(theta)   <- comment/uncomment this line and the below line, in a XOR fashion
    gradient = jax.grad(circuit_2)(h2.hf_state, theta)
    updates, opt_state = opt.update(gradient, opt_state)
    theta = optax.apply_updates(theta, updates)

Tracebacks

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "C:\code\vqeac\mve.py", line 38, in <module>
    gradient = jax.grad(circuit_2)(h2.hf_state, theta)
TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int64. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True.

System information

Name: PennyLane
Version: 0.35.1
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: c:\code\vqeac\.venv\lib\site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane_Lightning

Platform info:           Windows-10-10.0.19045-SP0
Python version:          3.10.4
Numpy version:           1.26.4
Scipy version:           1.12.0
Installed devices:
- default.clifford (PennyLane-0.35.1)      
- default.gaussian (PennyLane-0.35.1)      
- default.mixed (PennyLane-0.35.1)
- default.qubit (PennyLane-0.35.1)
- default.qubit.autograd (PennyLane-0.35.1)
- default.qubit.jax (PennyLane-0.35.1)
- default.qubit.legacy (PennyLane-0.35.1)
- default.qubit.tf (PennyLane-0.35.1)
- default.qubit.torch (PennyLane-0.35.1)
- default.qutrit (PennyLane-0.35.1)
- null.qubit (PennyLane-0.35.1)
- lightning.qubit (PennyLane_Lightning-0.35.1)

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.
@minhtriet minhtriet added the bug 🐛 Something isn't working label Apr 8, 2024
@minhtriet minhtriet changed the title [BUG] Adding [BUG] Adding a hf_state parameter to optimize causes JAX error Apr 8, 2024
@albi3ro
Copy link
Contributor

albi3ro commented Apr 8, 2024

Thanks for reaching out @minhtriet .

The issue is that BasisState is fundamentally a discrete operation. The inputs can either be zero or one, corresponding to applying an x gate on that wire or not. We fundamentally cannot differentiate a parameter that is not continuous.

If you want to differentiate with respect to only the differentiable theta, you can do:

gradient = jax.grad(circuit_2, argnums=1)(h2.hf_state, theta)

You could also using convert the hf_state to a full state vector and use StatePrep/ MottonenStatePrep instead. While not always fully differentiable, you should have better luck with it then BasisState.

@minhtriet
Copy link
Contributor Author

thank you @albi3ro :)

@minhtriet minhtriet closed this as not planned Won't fix, can't repro, duplicate, stale Apr 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants