Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Do not merge] Implement non-jax versions of IQP models #6

Open
wants to merge 100 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
100 commits
Select commit Hold shift + click to select a range
89826b6
implemented branching logic for different use cases
mariaschuld Feb 28, 2024
ac10674
some minor corrections
mariaschuld Feb 28, 2024
8ff542d
simplify if logic a little
mariaschuld Feb 28, 2024
c57a184
Update src/qml_benchmarks/models/iqp_kernel.py
mariaschuld Feb 29, 2024
bf67629
Add makefile. Run make format. Add catalyst qjit as an option for lig…
vincentmr Feb 28, 2024
c39a4f8
Add qjit in iqp_variational.py module.
vincentmr Feb 29, 2024
98dcf67
Fix iqp_var solver.
vincentmr Feb 29, 2024
c7af8b6
faster qjitting
josephbowles Mar 26, 2024
d357dff
catalyst training
josephbowles Mar 27, 2024
27b4d9d
catalyst port attempt
josephbowles Mar 28, 2024
92b921d
add catalyst support
josephbowles Mar 28, 2024
868a62c
attempt catalyst port
josephbowles Mar 28, 2024
af61f65
add code
josephbowles Apr 9, 2024
43ddf74
test
Apr 10, 2024
56fc153
update
josephbowles Apr 10, 2024
46e6eb5
update
josephbowles Apr 10, 2024
81a614d
update
josephbowles Apr 10, 2024
66169c7
update
josephbowles Apr 10, 2024
5c01770
v1
josephbowles Apr 12, 2024
1596669
perf ind additions
josephbowles Apr 12, 2024
30ffcbb
v1
josephbowles Apr 12, 2024
d8bd09e
v1
josephbowles Apr 12, 2024
7a65c86
new code
josephbowles Apr 15, 2024
eec5fa5
add installs
josephbowles Apr 15, 2024
c18820f
add installs
josephbowles Apr 15, 2024
0f32a6c
add qmetric
josephbowles Apr 15, 2024
ae27b0e
update
josephbowles Apr 15, 2024
2a254be
result
josephbowles Apr 15, 2024
5cf8f2b
add info
josephbowles Apr 15, 2024
881dc94
add info
josephbowles Apr 15, 2024
7dfd26e
update tests
josephbowles Apr 15, 2024
f452d2b
add slurm code
josephbowles Apr 16, 2024
d081626
add slurm code
josephbowles Apr 16, 2024
f513c67
add slurm code
josephbowles Apr 16, 2024
74967ef
add slurm code
josephbowles Apr 16, 2024
9462980
add slurm code
josephbowles Apr 16, 2024
07b49a9
update code
josephbowles Apr 16, 2024
ce6b0ff
delete
josephbowles Apr 16, 2024
fe0a9ce
update
josephbowles Apr 16, 2024
b2bc779
update
josephbowles Apr 22, 2024
595d3ef
update
josephbowles Apr 22, 2024
62ac41e
update
josephbowles Apr 22, 2024
a94578e
update
josephbowles Apr 22, 2024
4a78cf0
update
josephbowles Apr 22, 2024
2698fc5
update
josephbowles Apr 22, 2024
2690831
update
josephbowles Apr 22, 2024
8513a93
.
josephbowles Apr 22, 2024
e813a7f
add use jax
josephbowles Apr 22, 2024
282d90b
.
josephbowles Apr 24, 2024
2cba022
update
Apr 24, 2024
45c4d72
update
Apr 24, 2024
2cb5917
prototype of Slurm job w/ Podman works
balewski Apr 25, 2024
feaf667
double dream works
balewski Apr 26, 2024
8aad113
ok
balewski Apr 26, 2024
63ce6dd
working sbatch script
josephbowles Apr 29, 2024
375b141
update jax version
josephbowles Apr 30, 2024
7b40ba1
working sbatch
josephbowles Apr 30, 2024
5c3ccb2
add perf attributes
josephbowles Apr 30, 2024
c380d43
reads data but crash in Python scalars
balewski Apr 30, 2024
1de6e38
cleanup
balewski May 1, 2024
a76e44a
data gen file
josephbowles May 2, 2024
15f4998
results
josephbowles May 2, 2024
34ca1cf
update
josephbowles May 2, 2024
7f52b53
result
josephbowles May 6, 2024
07b91a9
working sbatch
josephbowles May 6, 2024
241c7ed
rename
josephbowles May 6, 2024
f6c4884
rename
josephbowles May 6, 2024
03cfc6b
cleanup
josephbowles May 6, 2024
e678c96
cleanup
josephbowles May 6, 2024
9cf5550
cleanup
josephbowles May 6, 2024
49b0f80
cleanup
josephbowles May 6, 2024
b2f9bf0
update
josephbowles May 6, 2024
a6b85f4
update
josephbowles May 6, 2024
5fe8150
cleanup
josephbowles May 6, 2024
ada44af
rename
josephbowles May 6, 2024
d0a951e
cleanup
josephbowles May 6, 2024
adf4e33
update
josephbowles May 7, 2024
92fc797
working with catalyst
josephbowles May 29, 2024
3108ca1
model cleanup
josephbowles May 30, 2024
04670ba
catalyst support
josephbowles Jun 4, 2024
fc72c00
qjit update
josephbowles Jun 4, 2024
5016ec4
fix results dir
josephbowles Jun 5, 2024
4b333b1
profile time
josephbowles Jun 17, 2024
1899a0c
.
josephbowles Jun 17, 2024
3a1baa8
.
josephbowles Jun 17, 2024
5e26427
profile
josephbowles Jun 17, 2024
628292a
results
josephbowles Jun 18, 2024
c69895b
results
josephbowles Jun 19, 2024
e5dbfb6
results
josephbowles Jun 19, 2024
0328ca8
results
josephbowles Jun 19, 2024
6bc2c73
no vmap train
josephbowles Jun 19, 2024
362a7fe
results
josephbowles Jun 19, 2024
bea29fa
lax batching
josephbowles Jun 19, 2024
10553dd
update
josephbowles Jul 2, 2024
fd2efba
kernel profiling
josephbowles Jul 2, 2024
3f01d33
results
josephbowles Jul 3, 2024
8231bd0
results
josephbowles Jul 3, 2024
670ecec
results
josephbowles Jul 3, 2024
9c535e5
results
josephbowles Jul 3, 2024
b848381
update
josephbowles Aug 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions src/qml_benchmarks/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import optax
import jax
import jax.numpy as jnp
from pennylane import numpy as pnp
from sklearn.exceptions import ConvergenceWarning
from sklearn.utils import gen_batches

Expand Down Expand Up @@ -124,6 +125,63 @@ def update(params, opt_state, x, y):
return params


def train_without_jax(
model,
loss_fn,
optimizer,
X,
y,
random_key_generator,
convergence_interval=200
):
"""Trains a model using an optimizer and a loss function, using PennyLane's autograd interface.
"""

params = list(model.params_.values())
opt = optimizer(stepsize=model.learning_rate)

loss_history = []
converged = False
start = time.time()
for step in range(model.max_steps):
key = random_key_generator()
X_batch, y_batch = get_batch_without_jax(X, y, key, batch_size=model.batch_size)
X_batch = pnp.array(X_batch, requires_grad=False)
y_batch = pnp.array(y_batch, requires_grad=False)
loss_val = loss_fn(*params, X_batch, y_batch)
params = opt.step(loss_fn, *params, X_batch, y_batch)[:len(params)]
loss_history.append(loss_val)

logging.debug(f"{step} - loss: {loss_val}")

if np.isnan(loss_val):
logging.info(f"nan encountered. Training aborted.")
break

if step > 2 * convergence_interval:
average1 = np.mean(loss_history[-convergence_interval:])
average2 = np.mean(loss_history[-2 * convergence_interval:-convergence_interval])
std1 = np.std(loss_history[-convergence_interval:])
if np.abs(average2 - average1) <= std1 / np.sqrt(convergence_interval) / 2:
logging.info(f"Model {model.__class__.__name__} converged after {step} steps.")
converged = True
break

end = time.time()
loss_history = np.array(loss_history)
model.loss_history_ = loss_history / np.max(np.abs(loss_history))
model.training_time_ = end - start

if not converged:
raise ConvergenceWarning(
f"Model {model.__class__.__name__} has not converged after the maximum number of {model.max_steps} steps.")

for i, key in enumerate(model.params_.keys()):
model.params_[key] = params[i]

return model.params_


def get_batch(X, y, rnd_key, batch_size=32):
"""
A generator to get random batches of the data (X, y)
Expand All @@ -145,6 +203,25 @@ def get_batch(X, y, rnd_key, batch_size=32):
return X[rnd_indices], y[rnd_indices]


def get_batch_without_jax(X, y, rnd_key, batch_size=32):
"""
A generator to get random batches of the data (X, y)

Args:
X (array[float]): Input data with shape (n_samples, n_features).
y (array[float]): Target labels with shape (n_samples,)
rnd_key: A jax random key object
batch_size (int): Number of elements in batch

Returns:
array[float]: A batch of input data shape (batch_size, n_features)
array[float]: A batch of target labels shaped (batch_size,)
"""
all_indices = list(range(len(X)))
rnd_indices = np.random.choice(all_indices, size=(batch_size,), replace=True)
return X[rnd_indices], y[rnd_indices]


def get_from_dict(dict, key_list):
"""
Access a value from a nested dictionary.
Expand Down Expand Up @@ -292,3 +369,54 @@ def chunked_loss(params, X, y):
return jnp.mean(res)

return chunked_loss


####### LOSS UTILS WITHOUT JAX

def l2_loss(pred, y):
"""
The square loss function. 0.5 is there to match optax.l2_loss.
"""
return 0.5 * (pred - y) ** 2


def softmax(x, axis=-1):
"""
copied from JAX: https://jax.readthedocs.io/en/latest/_modules/jax/_src/nn/functions.html#softmax
"""
x_max = pnp.max(x, axis, keepdims=True)
unnormalized = pnp.exp(x - x_max)
result = unnormalized / pnp.sum(unnormalized, axis, keepdims=True)
return result


def one_hot(a, num_classes=2):
"""
convert an array to a one hot encoded array.
Taken from https://stackoverflow.com/questions/29831489/convert-array-of-indices-to-one-hot-encoded-array-in-numpy
"""
b = pnp.zeros((a.size, num_classes))
b[pnp.arange(a.size), a] = 1
return b


def log_softmax(x, axis=-1):
"""
taken from jax.nn.log_softmax:
https://jax.readthedocs.io/en/latest/_modules/jax/_src/nn/functions.html#log_softmax
"""
x_arr = pnp.asarray(x)
x_max = pnp.max(x_arr, axis, keepdims=True)
x_max = pnp.array(x_max, requires_grad=False)
shifted = x_arr - x_max
shifted_logsumexp = pnp.log(
pnp.sum(pnp.exp(shifted), axis, keepdims=True))
result = shifted - shifted_logsumexp
return result


def softmax_cross_entropy(logits, labels):
"""taken from optax source:
https: // github.com / google - deepmind / optax / blob / master / optax / losses / _classification.py # L78%23L103
"""
return -pnp.sum(labels * log_softmax(logits, axis=-1), axis=-1)
46 changes: 31 additions & 15 deletions src/qml_benchmarks/models/iqp_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import pennylane as qml
import numpy as np
import jax
import jax.numpy as jnp
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.svm import SVC
from sklearn.preprocessing import MinMaxScaler
Expand All @@ -31,12 +30,14 @@ def __init__(
svm=SVC(kernel="precomputed", probability=True),
repeats=2,
C=1.0,
use_jax=False,
vmap=False,
jit=False,
random_state=42,
scaling=1.0,
max_vmap=250,
dev_type="default.qubit.jax",
qnode_kwargs={"interface": "jax-jit", "diff_method": None},
dev_type="default.qubit",
mariaschuld marked this conversation as resolved.
Show resolved Hide resolved
qnode_kwargs={},
):
r"""
Kernel version of the classifier from https://arxiv.org/pdf/1804.11326v2.pdf.
Expand All @@ -58,17 +59,21 @@ def __init__(
svm (sklearn.svm.SVC): scikit-learn SVM class object used to fit the model from the kernel matrix
repeats (int): number of times the IQP structure is repeated in the embedding circuit.
C (float): regularization parameter for SVC. Lower values imply stronger regularization.
use_jax (bool): Whether to use jax. If False, no jitting and vmapping is performed either
jit (bool): Whether to use just in time compilation.
dev_type (str): string specifying the pennylane device type; e.g. 'default.qubit'.
vmap (bool): Whether to use jax.vmap.
max_vmap (int or None): The maximum size of a chunk to vectorise over. Lower values use less memory.
must divide batch_size.
dev_type (str): string specifying the pennylane device type; e.g. 'default.qubit'.
qnode_kwargs (str): the key word arguments passed to the circuit qnode.
scaling (float): Factor by which to scale the input data.
random_state (int): seed used for reproducibility.
"""
# attributes that do not depend on data
self.repeats = repeats
self.C = C
self.use_jax = use_jax
self.vmap = vmap
self.jit = jit
self.max_vmap = max_vmap
self.svm = svm
Expand All @@ -86,7 +91,9 @@ def __init__(
self.circuit = None

def generate_key(self):
return jax.random.PRNGKey(self.rng.integers(1000000))
if self.use_jax:
return jax.random.PRNGKey(self.rng.integers(1000000))
return self.rng.integers(1000000)

def construct_circuit(self):
dev = qml.device(self.dev_type, wires=self.n_qubits_)
Expand Down Expand Up @@ -115,7 +122,7 @@ def circuit(x):

self.circuit = circuit

if self.jit:
if self.use_jax and self.jit:
circuit = jax.jit(circuit)
return circuit

Expand All @@ -132,15 +139,19 @@ def precompute_kernel(self, X1, X2):
dim2 = len(X2)

# concatenate all pairs of vectors
Z = jnp.array(
Z = np.array(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using pure numpy because we're not differentiating through the construction of the kernel matrix...

[np.concatenate((X1[i], X2[j])) for i in range(dim1) for j in range(dim2)]
)

circuit = self.construct_circuit()
self.batched_circuit = chunk_vmapped_fn(
jax.vmap(circuit, 0), start=0, max_vmap=self.max_vmap
)
kernel_values = self.batched_circuit(Z)[:, 0]

if self.use_jax and self.vmap:
self.batched_circuit = chunk_vmapped_fn(
jax.vmap(circuit, 0), start=0, max_vmap=self.max_vmap
)
kernel_values = self.batched_circuit(Z)[:, 0]
else:
kernel_values = np.array([circuit(z)[0] for z in Z])

# reshape the values into the kernel matrix
kernel_matrix = np.reshape(kernel_values, (dim1, dim2))
Expand Down Expand Up @@ -174,11 +185,14 @@ def fit(self, X, y):
y (np.ndarray): Labels of shape (n_samples,)
"""

self.svm.random_state = int(
jax.random.randint(
self.generate_key(), shape=(1,), minval=0, maxval=1000000
if self.use_jax:
self.svm.random_state = int(
jax.random.randint(
self.generate_key(), shape=(1,), minval=0, maxval=1000000
)
)
)
else:
self.svm.random_state = self.generate_key()

self.initialize(X.shape[1], np.unique(y))

Expand Down Expand Up @@ -244,3 +258,5 @@ def transform(self, X, preprocess=True):
X = self.scaler.transform(X)

return X * self.scaling


Loading