Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a how-to for catalyst-compiling "Symmetry-invariant quantum machine learning force fields" #1222

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions demonstrations/tutorial_eqnn_force_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@
import matplotlib.pyplot as plt
import sklearn

######################################################################
# To speed up the computation, we also import catalyst, a jit compiler for PennyLane quantum programs.
paul0403 marked this conversation as resolved.
Show resolved Hide resolved
import catalyst

######################################################################
# Let us construct Pauli matrices, which are used to build the Hamiltonian.
X = np.array([[0, 1], [1, 0]])
Expand Down Expand Up @@ -301,10 +305,13 @@ def noise_layer(epsilon, wires):
#################################


dev = qml.device("default.qubit", wires=num_qubits)
######################################################################
# To speed up the computation, we will be using catalyst to compile our quantum program, and we will be
paul0403 marked this conversation as resolved.
Show resolved Hide resolved
# running our program on the lightning backend instead of the default qubit backend.
paul0403 marked this conversation as resolved.
Show resolved Hide resolved
dev = qml.device("lightning.qubit", wires=num_qubits)


@qml.qnode(dev, interface="jax")
@qml.qnode(dev)
def vqlm(data, params):

weights = params["params"]["weights"]
Expand Down Expand Up @@ -396,25 +403,27 @@ def vqlm(data, params):
)

#################################
# We will know define the cost function and how to train the model using Jax. We will use the mean-square-error loss function.
# To speed up the computation, we use the decorator ``@jax.jit`` to do just-in-time compilation for this execution. This means the first execution will typically take a little longer with the
# benefit that all following executions will be significantly faster, see the `Jax docs on jitting <https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html>`_.
# We will now define the cost function and how to train the model using Jax. We will use the mean-square-error loss function.
# To speed up the computation, we use the decorator ``@catalyst.qjit`` to do just-in-time compilation for this execution. This means the first execution will typically take a little longer with the
paul0403 marked this conversation as resolved.
Show resolved Hide resolved
# benefit that all following executions will be significantly faster, see the `Catalyst documentation <https://docs.pennylane.ai/projects/catalyst/en/stable/index.html>`_.
paul0403 marked this conversation as resolved.
Show resolved Hide resolved

#################################
from jax.example_libraries import optimizers

# We vectorize the model over the data points
vec_vqlm = jax.vmap(vqlm, (0, None), 0)
vec_vqlm = catalyst.vmap(
vqlm,
in_axes=(0, {"params": {"alphas": None, "epsilon": None, "weights": None}}),
out_axes=0,
)


# Mean-squared-error loss function
@jax.jit
def mse_loss(predictions, targets):
return jnp.mean(0.5 * (predictions - targets) ** 2)


# Make prediction and compute the loss
@jax.jit
def cost(weights, loss_data):
data, E_target, F_target = loss_data
E_pred = vec_vqlm(data, weights)
Expand All @@ -424,17 +433,19 @@ def cost(weights, loss_data):


# Perform one training step
@jax.jit
# This function will be repeatedly called, so we qjit it to exploit the saved runtime from many runs.
@catalyst.qjit
def train_step(step_i, opt_state, loss_data):

net_params = get_params(opt_state)
loss, grads = jax.value_and_grad(cost, argnums=0)(net_params, loss_data)

loss = cost(net_params, loss_data)
grads = catalyst.grad(cost, method="fd", h=1e-13, argnums=0)(net_params, loss_data)
Copy link
Author

@paul0403 paul0403 Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When changing to catalyst, it was discovered that QubitUnitary cannot be differentiated:

catalyst.utils.exceptions.DifferentiableCompileError: QubitUnitary is non-differentiable on 'lightning.qubit' device

To make the demo work, I had to manually change gradient method to finite difference. This causes significant performance degradation.

Possible paths forward:

  1. Find another demo to convert
  2. Still convert this demo, but either (a) make lightning work with differentiating through qubit unitary, or (b) decompose qubit unitary into "standard" gateset, although unsure whether this decomposition itself can be jitted or not, or (c) decrease the batch size and the number of traning steps

return loss, opt_update(step_i, grads, opt_state)


# Return prediction and loss at inference times, e.g. for testing
@jax.jit
# This function is also repeatedly called, so qjit it.
@catalyst.qjit
def inference(loss_data, opt_state):

data, E_target, F_target = loss_data
Expand Down
Loading