Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a how-to for catalyst-compiling "Symmetry-invariant quantum machine learning force fields" #1222

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion demonstrations/tutorial_eqnn_force_field.metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
}
],
"dateOfPublication": "2024-03-12T00:00:00+00:00",
"dateOfLastModification": "2024-11-06T00:00:00+00:00",
"dateOfLastModification": "2024-11-25T00:00:00+00:00",
"categories": [
"Quantum Machine Learning",
"Quantum Chemistry"
Expand Down
39 changes: 25 additions & 14 deletions demonstrations/tutorial_eqnn_force_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@
import matplotlib.pyplot as plt
import sklearn

######################################################################
# To speed up the computation, we also import `Catalyst <https://docs.pennylane.ai/projects/catalyst/en/stable/index.html>`_, a jit compiler for PennyLane quantum programs.
import catalyst

######################################################################
# Let us construct Pauli matrices, which are used to build the Hamiltonian.
X = np.array([[0, 1], [1, 0]])
Expand Down Expand Up @@ -301,10 +305,12 @@ def noise_layer(epsilon, wires):
#################################


dev = qml.device("default.qubit", wires=num_qubits)
######################################################################
# We will be running our program using `lightning.qubit`, our performant state-vector simulator.
dev = qml.device("lightning.qubit", wires=num_qubits)


@qml.qnode(dev, interface="jax")
@qml.qnode(dev)
def vqlm(data, params):

weights = params["params"]["weights"]
Expand Down Expand Up @@ -396,25 +402,27 @@ def vqlm(data, params):
)

#################################
# We will know define the cost function and how to train the model using Jax. We will use the mean-square-error loss function.
# To speed up the computation, we use the decorator ``@jax.jit`` to do just-in-time compilation for this execution. This means the first execution will typically take a little longer with the
# benefit that all following executions will be significantly faster, see the `Jax docs on jitting <https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html>`_.
# We will now define the cost function and how to train the model using Jax. We will use the mean-square-error loss function.
# We use the decorator ``@catalyst.qjit`` to do just-in-time compilation for this execution. This means the first execution will typically take a little longer with the
# benefit that all following executions will be significantly faster (see the `Catalyst documentation <https://docs.pennylane.ai/projects/catalyst/en/stable/index.html>`_).

#################################
from jax.example_libraries import optimizers

# We vectorize the model over the data points
vec_vqlm = jax.vmap(vqlm, (0, None), 0)
vec_vqlm = catalyst.vmap(
vqlm,
in_axes=(0, {"params": {"alphas": None, "epsilon": None, "weights": None}}),
out_axes=0,
)


# Mean-squared-error loss function
@jax.jit
def mse_loss(predictions, targets):
return jnp.mean(0.5 * (predictions - targets) ** 2)


# Make prediction and compute the loss
@jax.jit
def cost(weights, loss_data):
data, E_target, F_target = loss_data
E_pred = vec_vqlm(data, weights)
Expand All @@ -424,17 +432,19 @@ def cost(weights, loss_data):


# Perform one training step
@jax.jit
# This function will be repeatedly called, so we qjit it to exploit the saved runtime from many runs.
@catalyst.qjit
def train_step(step_i, opt_state, loss_data):

net_params = get_params(opt_state)
loss, grads = jax.value_and_grad(cost, argnums=0)(net_params, loss_data)

loss = cost(net_params, loss_data)
grads = catalyst.grad(cost, method="fd", h=1e-13, argnums=0)(net_params, loss_data)
Copy link
Author

@paul0403 paul0403 Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When changing to catalyst, it was discovered that QubitUnitary cannot be differentiated:

catalyst.utils.exceptions.DifferentiableCompileError: QubitUnitary is non-differentiable on 'lightning.qubit' device

To make the demo work, I had to manually change gradient method to finite difference. This causes significant performance degradation.

Possible paths forward:

  1. Find another demo to convert
  2. Still convert this demo, but either (a) make lightning work with differentiating through qubit unitary, or (b) decompose qubit unitary into "standard" gateset, although unsure whether this decomposition itself can be jitted or not, or (c) decrease the batch size and the number of traning steps

return loss, opt_update(step_i, grads, opt_state)


# Return prediction and loss at inference times, e.g. for testing
@jax.jit
# This function is also repeatedly called, so qjit it.
@catalyst.qjit
def inference(loss_data, opt_state):

data, E_target, F_target = loss_data
Expand Down Expand Up @@ -475,11 +485,12 @@ def inference(loss_data, opt_state):
# We train our VQLM using stochastic gradient descent.


num_batches = 5000 # number of optimization steps
batch_size = 256 # number of training data per batch
num_batches = 200 # 5000 # number of optimization steps
batch_size = 5 # 256 # number of training data per batch


for ibatch in range(num_batches):
#print(ibatch)
# select a batch of training points
batch = np.random.choice(np.arange(np.shape(data_train)[0]), batch_size, replace=False)

Expand Down
Loading