Skip to content

Commit

Permalink
black format
Browse files Browse the repository at this point in the history
  • Loading branch information
paul0403 committed Sep 20, 2024
1 parent 442e9d2 commit bbd9221
Showing 1 changed file with 37 additions and 14 deletions.
51 changes: 37 additions & 14 deletions demonstrations/tutorial_eqnn_force_field_catalyst_compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,15 @@
sigmas = jnp.array(np.array([X, Y, Z])) # Vector of Pauli matrices
sigmas_sigmas = jnp.array(
np.array(
[np.kron(X, X), np.kron(Y, Y), np.kron(Z, Z)] # Vector of tensor products of Pauli matrices
[
np.kron(X, X),
np.kron(Y, Y),
np.kron(Z, Z),
] # Vector of tensor products of Pauli matrices
)
)


def singlet(wires):
# Encode a 2-qubit rotation-invariant initial state, i.e., the singlet state.
qml.Hadamard(wires=wires[0])
Expand Down Expand Up @@ -76,7 +81,7 @@ def noise_layer(epsilon, wires):
rep = 2 # Number of repeated vertical encoding

active_atoms = 2 # Number of active atoms
# Here we only have two active atoms since we fixed the oxygen (which becomes non-active) at the origin
# Here we only have two active atoms since we fixed the oxygen (which becomes non-active) at the origin
num_qubits = active_atoms * rep


Expand All @@ -93,22 +98,25 @@ def noise_layer(epsilon, wires):
# If there is no such dependence, `catalyst.for_loop` can still be used.
# Here we showcase both usages.


@qjit
@qml.qnode(dev)
def vqlm_qjit(data, params):
weights = params["params"]["weights"]
alphas = params["params"]["alphas"]
epsilon = params["params"]["epsilon"]

# Initial state
@catalyst.for_loop(0, rep, 1)
def singlet_loop(i):
singlet(wires=jnp.arange(active_atoms)+active_atoms*i)
singlet(wires=jnp.arange(active_atoms) + active_atoms * i)

singlet_loop()
# Initial encoding
for i in range(num_qubits):
equivariant_encoding(
alphas[i, 0], jnp.asarray(data)[i % active_atoms, ...], wires=[i]
)
)
# Reuploading model
for d in range(D):
qml.Barrier()
Expand All @@ -129,13 +137,20 @@ def singlet_loop(i):
)
return qml.expval(Observable)


# vectorizing for batched training with `catalyst.vmap`
vec_vqlm = catalyst.vmap(vqlm_qjit, in_axes=(0, {'params': {'alphas': None, 'epsilon': None, 'weights': None}} ), out_axes=0)
vec_vqlm = catalyst.vmap(
vqlm_qjit,
in_axes=(0, {"params": {"alphas": None, "epsilon": None, "weights": None}}),
out_axes=0,
)


# loss function for cost
def mse_loss(predictions, targets):
return jnp.mean(0.5 * (predictions - targets) ** 2)


# Compile a training step
# many calls so compile = faster!
@qjit
Expand All @@ -149,7 +164,7 @@ def cost(weights, loss_data):

net_params = get_params(opt_state)
loss = cost(net_params, loss_data)
grads = catalyst.grad(cost, method = "fd", h=1e-13, argnums=0)(net_params, loss_data)
grads = catalyst.grad(cost, method="fd", h=1e-13, argnums=0)(net_params, loss_data)
return loss, opt_update(step_i, grads, opt_state)


Expand Down Expand Up @@ -188,11 +203,15 @@ def inference(loss_data, opt_state):
data[:, 1, :] = positions[:, 2, :] - positions[:, 0, :]
positions = data.copy()

forces = forces[:, 1:, :] # Select only the forces on the hydrogen atoms since the oxygen is fixed
forces = forces[
:, 1:, :
] # Select only the forces on the hydrogen atoms since the oxygen is fixed


# Splitting in train-test set
indices_train = np.random.choice(np.arange(shape[0]), size=int(0.8 * shape[0]), replace=False)
indices_train = np.random.choice(
np.arange(shape[0]), size=int(0.8 * shape[0]), replace=False
)
indices_test = np.setdiff1d(np.arange(shape[0]), indices_train)

E_train, E_test = (energy[indices_train, 0], energy[indices_test, 0])
Expand All @@ -215,7 +234,9 @@ def inference(loss_data, opt_state):
np.random.seed(42)
epsilon = jnp.array(np.random.normal(0, 0.001, size=(D, num_qubits)))
epsilon = None # We disable SB for this specific example
epsilon = jax.lax.stop_gradient(epsilon) # comment if we wish to train the SB weights as well.
epsilon = jax.lax.stop_gradient(
epsilon
) # comment if we wish to train the SB weights as well.


opt_init, opt_update, get_params = optimizers.adam(1e-2)
Expand All @@ -224,8 +245,8 @@ def inference(loss_data, opt_state):
running_loss = []


num_batches = 5000 # number of optimization steps
batch_size = 256 # number of training data per batch
num_batches = 5000 # number of optimization steps
batch_size = 256 # number of training data per batch

batch = np.random.choice(np.arange(np.shape(data_train)[0]), batch_size, replace=False)
loss_data = data_train[batch, ...], E_train[batch, ...], F_train[batch, ...]
Expand All @@ -235,7 +256,9 @@ def inference(loss_data, opt_state):
# We call `train_step` and `inference` many times, so the speedup from qjit will be quite significant!
for ibatch in range(num_batches):
# select a batch of training points
batch = np.random.choice(np.arange(np.shape(data_train)[0]), batch_size, replace=False)
batch = np.random.choice(
np.arange(np.shape(data_train)[0]), batch_size, replace=False
)

# preparing the data
loss_data = data_train[batch, ...], E_train[batch, ...], F_train[batch, ...]
Expand All @@ -253,7 +276,7 @@ def inference(loss_data, opt_state):

### plotting ###
fontsize = 12
plt.figure(figsize=(4,4))
plt.figure(figsize=(4, 4))
plt.plot(history_loss[:, 0], "r-", label="training error")
plt.plot(history_loss[:, 1], "b-", label="testing error")

Expand All @@ -265,7 +288,7 @@ def inference(loss_data, opt_state):
plt.show()


plt.figure(figsize=(4,4))
plt.figure(figsize=(4, 4))
plt.title("Energy predictions", fontsize=fontsize)
plt.plot(energy[indices_test], E_pred, "ro", label="Test predictions")
plt.plot(energy[indices_test], energy[indices_test], "k.-", lw=1, label="Exact")
Expand Down

0 comments on commit bbd9221

Please sign in to comment.