From 98a7e57b9b33ce049e1efe9f9969422fcb77a81e Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Wed, 24 Jul 2024 18:05:40 +0200 Subject: [PATCH 01/54] add rbm --- src/qml_benchmarks/model_utils.py | 86 ++++++---- .../models/restricted_boltzmann_machine.py | 153 ++++++++++++++++++ 2 files changed, 204 insertions(+), 35 deletions(-) create mode 100644 src/qml_benchmarks/models/restricted_boltzmann_machine.py diff --git a/src/qml_benchmarks/model_utils.py b/src/qml_benchmarks/model_utils.py index 312118e2..ea60d235 100644 --- a/src/qml_benchmarks/model_utils.py +++ b/src/qml_benchmarks/model_utils.py @@ -24,14 +24,14 @@ import jax.numpy as jnp from sklearn.exceptions import ConvergenceWarning from sklearn.utils import gen_batches - +import inspect def train( model, loss_fn, optimizer, X, y, random_key_generator, convergence_interval=200 ): """ Trains a model using an optimizer and a loss function via gradient descent. We assume that the loss function - is of the form `loss(params, X, y)` and that the trainable parameters are stored in model.params_ as a dictionary + is of the form `loss(params, X, y, key)` and that the trainable parameters are stored in model.params_ as a dictionary of jnp.arrays. The optimizer should be an Optax optimizer (e.g. optax.adam). `model` must have an attribute `learning_rate` to set the initial learning rate for the gradient descent. @@ -43,10 +43,10 @@ def train( Args: model (class): Classifier class object to train. Trainable parameters must be stored in model.params_. - loss_fn (Callable): Loss function to be minimised. Must be of the form loss_fn(params, X, y). + loss_fn (Callable): Loss function to be minimised. Must be of the form loss_fn(params, X, y, key). optimizer (optax optimizer): Optax optimizer (e.g. optax.adam). X (array): Input data array of shape (n_samples, n_features) - y (array): Array of shape (n_samples) containing the labels. + y (array, optional): Array of shape (n_samples) containing the labels. random_key_generator (jax.random.PRNGKey): JAX key generator object for pseudo-randomness generation. convergence_interval (int, optional): Number of optimization steps over which to decide convergence. Larger values give a higher confidence that the model has converged but may increase training time. @@ -55,13 +55,21 @@ def train( params (dict): The new parameters after training has completed. """ - if not model.batch_size / model.max_vmap % 1 == 0: - raise Exception("Batch size must be multiple of max_vmap.") + if model.max_vmap is not None: + if not model.batch_size / model.max_vmap % 1 == 0: + raise Exception("Batch size must be multiple of max_vmap.") + + # wrap a key around the function if it doesn't have one + if "key" not in inspect.signature(loss_fn).parameters: + def loss_fn_wrapped(params, x, y, key): + return loss_fn(params, x, y) + else: + loss_fn_wrapped = loss_fn params = model.params_ opt = optimizer(learning_rate=model.learning_rate) opt_state = opt.init(params) - grad_fn = jax.grad(loss_fn) + grad_fn = jax.grad(loss_fn_wrapped) # jitting through the chunked_grad function can take a long time, # so we jit here and chunk after @@ -70,12 +78,12 @@ def train( # note: assumes that the loss function is a sample mean of # some function over the input data set - chunked_grad_fn = chunk_grad(grad_fn, model.max_vmap) - chunked_loss_fn = chunk_loss(loss_fn, model.max_vmap) + chunked_grad_fn = chunk_grad(grad_fn, model.max_vmap) if model.max_vmap is not None else grad_fn + chunked_loss_fn = chunk_loss(loss_fn_wrapped, model.max_vmap) if model.max_vmap is not None else loss_fn_wrapped - def update(params, opt_state, x, y): - grads = chunked_grad_fn(params, x, y) - loss_val = chunked_loss_fn(params, x, y) + def update(params, opt_state, x, y, key): + grads = chunked_grad_fn(params, x, y, key) + loss_val = chunked_loss_fn(params, x, y, key) updates, opt_state = opt.update(grads, opt_state) params = optax.apply_updates(params, updates) return params, opt_state, loss_val @@ -85,8 +93,9 @@ def update(params, opt_state, x, y): start = time.time() for step in range(model.max_steps): key = random_key_generator() - X_batch, y_batch = get_batch(X, y, key, batch_size=model.batch_size) - params, opt_state, loss_val = update(params, opt_state, X_batch, y_batch) + key1, key2 = jax.random.split(key, 2) + X_batch, y_batch = get_batch(X, y, key1, batch_size=model.batch_size) + params, opt_state, loss_val = update(params, opt_state, X_batch, y_batch, key2) loss_history.append(loss_val) logging.debug(f"{step} - loss: {loss_val}") @@ -95,27 +104,28 @@ def update(params, opt_state, x, y): break # decide convergence - if step > 2 * convergence_interval: - # get means of last two intervals and standard deviation of last interval - average1 = np.mean(loss_history[-convergence_interval:]) - average2 = np.mean( - loss_history[-2 * convergence_interval : -convergence_interval] - ) - std1 = np.std(loss_history[-convergence_interval:]) - # if the difference in averages is small compared to the statistical fluctuations, stop training. - if np.abs(average2 - average1) <= std1 / np.sqrt(convergence_interval) / 2: - logging.info( - f"Model {model.__class__.__name__} converged after {step} steps." + if convergence_interval is not None: + if step > 2 * convergence_interval: + # get means of last two intervals and standard deviation of last interval + average1 = np.mean(loss_history[-convergence_interval:]) + average2 = np.mean( + loss_history[-2 * convergence_interval : -convergence_interval] ) - converged = True - break + std1 = np.std(loss_history[-convergence_interval:]) + # if the difference in averages is small compared to the statistical fluctuations, stop training. + if np.abs(average2 - average1) <= std1 / np.sqrt(convergence_interval) / 2: + logging.info( + f"Model {model.__class__.__name__} converged after {step} steps." + ) + converged = True + break end = time.time() loss_history = np.array(loss_history) model.loss_history_ = loss_history / np.max(np.abs(loss_history)) model.training_time_ = end - start - if not converged: + if not converged and convergence_interval is not None: print("Loss did not converge:", loss_history) raise ConvergenceWarning( f"Model {model.__class__.__name__} has not converged after the maximum number of {model.max_steps} steps." @@ -142,7 +152,13 @@ def get_batch(X, y, rnd_key, batch_size=32): rnd_indices = jax.random.choice( key=rnd_key, a=all_indices, shape=(batch_size,), replace=True ) - return X[rnd_indices], y[rnd_indices] + + if y is not None: + return X[rnd_indices], y[rnd_indices] + else: + return X[rnd_indices], None + + def get_from_dict(dict, key_list): @@ -238,7 +254,7 @@ def chunk_grad(grad_fn, max_vmap): """ Convert a `jax.grad` function to an equivalent version that evaluated in chunks of size max_vmap. - `grad_fn` should be of the form `jax.grad(fn(params, X, y), argnums=0)`, where `params` is a + `grad_fn` should be of the form `jax.grad(fn(params, X, y, key), argnums=0)`, where `params` is a dictionary of `jnp.arrays`, `X, y` are `jnp.arrays` with the same-size leading axis, and `grad_fn` is a function that is vectorised along these axes (i.e. `in_axes = (None,0,0)`). @@ -253,9 +269,9 @@ def chunk_grad(grad_fn, max_vmap): chunked version of the function """ - def chunked_grad(params, X, y): + def chunked_grad(params, X, y, key): batch_slices = list(gen_batches(len(X), max_vmap)) - grads = [grad_fn(params, X[slice], y[slice]) for slice in batch_slices] + grads = [grad_fn(params, X[slice], y[slice], key) for slice in batch_slices] grad_dict = {} for key_list in get_nested_keys(params): set_in_dict( @@ -272,7 +288,7 @@ def chunked_grad(params, X, y): def chunk_loss(loss_fn, max_vmap): """ - Converts a loss function of the form `loss_fn(params, array1, array2)` to an equivalent version that + Converts a loss function of the form `loss_fn(params, array1, array2, key)` to an equivalent version that evaluates `loss_fn` in chunks of size max_vmap. `loss_fn` should batch evaluate along the leading axis of `array1, array2` (i.e. `in_axes = (None,0,0)`). @@ -284,10 +300,10 @@ def chunk_loss(loss_fn, max_vmap): chunked version of the function """ - def chunked_loss(params, X, y): + def chunked_loss(params, X, y, key): batch_slices = list(gen_batches(len(X), max_vmap)) res = jnp.array( - [loss_fn(params, *[X[slice], y[slice]]) for slice in batch_slices] + [loss_fn(params, *[X[slice], y[slice]], key) for slice in batch_slices] ) return jnp.mean(res) diff --git a/src/qml_benchmarks/models/restricted_boltzmann_machine.py b/src/qml_benchmarks/models/restricted_boltzmann_machine.py new file mode 100644 index 00000000..c3aaacdf --- /dev/null +++ b/src/qml_benchmarks/models/restricted_boltzmann_machine.py @@ -0,0 +1,153 @@ +import numpy as np +import jax +import jax.numpy as jnp +from qml_benchmarks.model_utils import train +import optax +import copy + +class RestrictedBoltzmannMachine(): + """ + A restricted Boltzmann machine generative model. The model is trained with the k-contrastive divergence (CD-k) + algorithm. + Args: + n_hidden (int): The number of hidden neurons + learning_rate (float): The learning rate for the CD-k updates + cdiv_steps (int): The number of gibbs sampling steps used in contrastive divergence + jit (bool): Whether to use just-in-time complilation + batch_size (int): Size of batches used for computing parameter updates + max_steps (int): Maximum number of training steps. + reg (float): The L2 regularisation strength (larger implies stronger) + convergence_interval (int or None): The number of loss values to consider to decide convergence. + If None, training runs until the maximum number of steps. Recommoneded to set to None since + CD-k does not follow the gradient of a fucntion. + random_state (int): Seed used for pseudorandom number generation. + + """ + + def __init__(self, n_hidden, learning_rate=0.001, cdiv_steps=1, jit=True, batch_size=32, + max_steps=200, reg=0.0, convergence_interval=None, random_state=42): + + self.n_hidden = n_hidden + self.learning_rate = learning_rate + self.random_state = random_state + self.rng = np.random.default_rng(random_state) + self.jit = jit + self.batch_size = batch_size + self.max_steps = max_steps + self.reg = reg + self.convergence_interval = convergence_interval + self.cdiv_steps = cdiv_steps + self.vmap = True + self.max_vmap = None + + # data depended attributes + self.params_ = None + self.n_visible_ = None + + self.gibbs_step = jax.jit(self.gibbs_step) + + def generate_key(self): + return jax.random.PRNGKey(self.rng.integers(1000000)) + + def energy(self, params, x, h): + """ + The RBM energy function + Args: + params: parameter dictionay of weights and biases + x: visible configuration + h: hidden configuration + Returns: + energy (float): The energy + """ + return -x.T @ params['W'] @ h - params['a'].T @ x - params['b'].T @ h + + def initialize(self, n_features): + self.n_visible_ = n_features + W = jax.random.normal(self.generate_key(), shape=(self.n_visible_, self.n_hidden)) / jnp.sqrt(self.n_visible_) + a = jax.random.normal(self.generate_key(), shape=(self.n_visible_,)) / jnp.sqrt(self.n_visible_) + b = jax.random.normal(self.generate_key(), shape=(self.n_hidden,)) / jnp.sqrt(self.n_visible_) + self.params_ = {'a': a, 'b': b, 'W': W} + + def gibbs_step(self, args, i): + """ + Perform one Gibbs steps. The format is such that it can be used with jax.lax.scan for fast compilation. + """ + params = args[0] + key = args[1] + x = args[2] + key1, key2, key3 = jax.random.split(key, 3) + # get hidden units probs + prob_h = jax.nn.sigmoid(x.T @ params['W'] + params['b']) + h = jnp.array(jax.random.bernoulli(key1, p=prob_h), dtype=int) + # get visible units probs + prob_x = jax.nn.sigmoid(params['W'] @ h + params['a']) + x_new = jnp.array(jax.random.bernoulli(key2, p=prob_x), dtype=int) + return [params, key3, x_new], [x, h] + + def gibbs_sample(self, params, x_init, n_samples, key): + """ + Sample a chain of visible and hidden configurations from a starting visible configuration x_init + """ + carry = [params, key, x_init] + carry, configs = jax.lax.scan(self.gibbs_step, carry, jnp.arange(n_samples)) + return configs + + def sample_visible(self, n_samples): + """ + sample only the visible units starting from a random configuration. + """ + key = self.generate_key() + x_init = jnp.array(jax.random.bernoulli(key, p=0.5, shape=(self.n_visible_,)), dtype=int) + samples = self.gibbs_sample(self.params_, x_init, n_samples, self.generate_key()) + return samples[0] + + def fit(self, X): + """ + Fit the parameters using contrastive divergence + """ + self.initialize(X_train.shape[-1]) + + # batch the relevant functions + batched_gibbs_sample = jax.vmap(self.gibbs_sample, in_axes=(None, 0, None, 0)) + batched_energy = jax.vmap(self.energy, in_axes=(None, 0, 0)) + + def c_div_loss(params, X, y, key): + """ + contrastive divergence loss + Args: + params (dict): parameter dictionary + X (array): batch of training examples + y (array): not used; should be set to None when training + key: jax PRNG key + """ + keys = jax.random.split(key, X.shape[0]) + + # we do not take the gradient wrt the sampling, so decouple the param dict here + params_copy = copy.deepcopy(params) + for key in params_copy.keys(): + params_copy[key] = jax.lax.stop_gradient(params_copy[key]) + + configs = batched_gibbs_sample(params_copy, X, self.cdiv_steps + 1, keys) + x0 = configs[0][:, 0, :] + h0 = configs[1][:, 0, :] + x1 = configs[0][:, -1, :] + h1 = configs[1][:, -1, :] + + # taking the gradient of this loss is equivalent to the CD-k update + loss = batched_energy(params, x0, h0) - batched_energy(params, x1, h1) + + return jnp.mean(loss) + self.reg * jnp.sqrt(jnp.sum(params['W'] ** 2)) + + c_div_loss = jax.jit(c_div_loss) if self.jit else c_div_loss + + self.params_ = train(self, c_div_loss, optax.sgd, X, None, self.generate_key, + convergence_interval=self.convergence_interval) + + + + + + + + + From 2f63bb6b6f93d4d97c1fda935d5342aab063cbd1 Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Wed, 24 Jul 2024 18:25:58 +0200 Subject: [PATCH 02/54] add rbm --- src/qml_benchmarks/models/restricted_boltzmann_machine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qml_benchmarks/models/restricted_boltzmann_machine.py b/src/qml_benchmarks/models/restricted_boltzmann_machine.py index c3aaacdf..2a0de21d 100644 --- a/src/qml_benchmarks/models/restricted_boltzmann_machine.py +++ b/src/qml_benchmarks/models/restricted_boltzmann_machine.py @@ -92,7 +92,7 @@ def gibbs_sample(self, params, x_init, n_samples, key): carry, configs = jax.lax.scan(self.gibbs_step, carry, jnp.arange(n_samples)) return configs - def sample_visible(self, n_samples): + def sample(self, n_samples): """ sample only the visible units starting from a random configuration. """ From 056a9d4357e6412e74613a54dfb606e93efa51f4 Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Wed, 24 Jul 2024 18:29:19 +0200 Subject: [PATCH 03/54] add rbm --- src/qml_benchmarks/models/restricted_boltzmann_machine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qml_benchmarks/models/restricted_boltzmann_machine.py b/src/qml_benchmarks/models/restricted_boltzmann_machine.py index 2a0de21d..64240de7 100644 --- a/src/qml_benchmarks/models/restricted_boltzmann_machine.py +++ b/src/qml_benchmarks/models/restricted_boltzmann_machine.py @@ -105,7 +105,7 @@ def fit(self, X): """ Fit the parameters using contrastive divergence """ - self.initialize(X_train.shape[-1]) + self.initialize(X.shape[-1]) # batch the relevant functions batched_gibbs_sample = jax.vmap(self.gibbs_sample, in_axes=(None, 0, None, 0)) From 24aa47e001aa258e53dec502e7f7a716f168d238 Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Wed, 24 Jul 2024 18:30:47 +0200 Subject: [PATCH 04/54] add rbm --- src/qml_benchmarks/models/restricted_boltzmann_machine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/qml_benchmarks/models/restricted_boltzmann_machine.py b/src/qml_benchmarks/models/restricted_boltzmann_machine.py index 64240de7..568183b3 100644 --- a/src/qml_benchmarks/models/restricted_boltzmann_machine.py +++ b/src/qml_benchmarks/models/restricted_boltzmann_machine.py @@ -106,6 +106,7 @@ def fit(self, X): Fit the parameters using contrastive divergence """ self.initialize(X.shape[-1]) + X = jnp.array(X, dtype=int) # batch the relevant functions batched_gibbs_sample = jax.vmap(self.gibbs_sample, in_axes=(None, 0, None, 0)) From 134c2ad3cf869d532fb0e322c6f46a78ac2fb580 Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Thu, 25 Jul 2024 15:49:41 +0200 Subject: [PATCH 05/54] ebm model --- .../models/energy_based_model.py | 175 ++++++++++++++++++ 1 file changed, 175 insertions(+) create mode 100644 src/qml_benchmarks/models/energy_based_model.py diff --git a/src/qml_benchmarks/models/energy_based_model.py b/src/qml_benchmarks/models/energy_based_model.py new file mode 100644 index 00000000..493b2ac5 --- /dev/null +++ b/src/qml_benchmarks/models/energy_based_model.py @@ -0,0 +1,175 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Energy-based models for generative modeling.""" + +import numpy as np +import jax +import jax.numpy as jnp +from qml_benchmarks.model_utils import train +import optax +import copy +import flax.linen as nn + +class MLP(nn.Module): + """ + Simple multilayer perceptron neural network used for the energy model. + """ + @nn.compact + def __call__(self, x): + x = nn.Dense(16)(x) + x = nn.tanh(x) + x = nn.Dense(1)(x) + x = nn.tanh(x) + return x + +class EnergyBasedModel(): + """ + Energy-based model for generative learning. + The model takes as input energy model written as a flax neural network and uses k contrastive divergence + to fit the parameters. + + Args: + learning_rate (float): The learning rate for the CD-k updates + cdiv_steps (int): The number of sampling steps used in contrastive divergence + jit (bool): Whether to use just-in-time complilation + batch_size (int): Size of batches used for computing parameter updates + max_steps (int): Maximum number of training steps. + convergence_interval (int or None): The number of loss values to consider to decide convergence. + If None, training runs until the maximum number of steps. Recommoneded to set to None since + CD-k does not follow the gradient of a fucntion. + random_state (int): Seed used for pseudorandom number generation. + """ + + def __init__(self, energy_model=MLP, learning_rate=0.001, cdiv_steps=1, jit=True, batch_size=32, + max_steps=200, convergence_interval=None, random_state=42): + self.learning_rate = learning_rate + self.random_state = random_state + self.rng = np.random.default_rng(random_state) + self.jit = jit + self.batch_size = batch_size + self.max_steps = max_steps + self.convergence_interval = convergence_interval + self.cdiv_steps = cdiv_steps + self.vmap = True + self.max_vmap = None + + # data depended attributes + self.params_ = None + self.n_visible_ = None + + self.mcmc_step = jax.jit(self.mcmc_step) + self.energy_model = energy_model() + + def generate_key(self): + return jax.random.PRNGKey(self.rng.integers(1000000)) + + def energy(self, params, x): + """ + The energy function for the model for a given configuration x. + + Args: + x: The configuration to calculate the energy for. + Returns: + energy (float): The energy. + """ + return self.energy_model.apply(params, x) + + def initialize(self, n_features): + self.n_visible_ = n_features + x = jax.random.normal(self.generate_key(), shape=(1, n_features)) + self.params_ = self.energy_model.init(self.generate_key(), x) + + def mcmc_step(self, args, i): + """ + Perform one metropolis hastings steps. + The format is such that it can be used with jax.lax.scan for fast compilation. + """ + params = args[0] + key = args[1] + x = args[2] + key1, key2 = jax.random.split(key, 2) + flip_idx = jax.random.choice(key1, jnp.arange(self.n_visible_)) + flip_config = jnp.zeros(self.n_visible_, dtype=int) + flip_config = flip_config.at[flip_idx].set(1) + x_flip = jnp.array((x + flip_config) % 2) + en = self.energy(params, x) + en_flip = self.energy(params, x_flip) + accept_ratio = jnp.exp(-en_flip) / jnp.exp(-en) + accept = jnp.array(jax.random.bernoulli(key2, accept_ratio), dtype=int)[0] + x_new = accept * x_flip + (1 - accept) * x + return [params, key2, x_new], x + + def mcmc_sample(self, params, x_init, n_samples, key): + """ + Sample a chain of configurations from a starting configuration x_init + """ + carry = [params, key, x_init] + carry, configs = jax.lax.scan(self.mcmc_step, carry, jnp.arange(n_samples)) + return configs + + def langevin_sample(self, params, x_init, n_samples, key): + pass + + def sample(self, n_samples): + """ + sample configurations starting from a random configuration. + """ + key = self.generate_key() + x_init = jnp.array(jax.random.bernoulli(key, p=0.5, shape=(self.n_visible_,)), dtype=int) + samples = self.mcmc_sample(self.params_, x_init, n_samples, self.generate_key()) + return jnp.array(samples) + + def fit(self, X): + """ + Fit the parameters using contrastive divergence + """ + self.initialize(X.shape[-1]) + X = jnp.array(X, dtype=int) + + # batch the relevant functions + batched_mcmc_sample = jax.vmap(self.mcmc_sample, in_axes=(None, 0, None, 0)) + batched_energy = jax.vmap(self.energy, in_axes=(None, 0)) + + def c_div_loss(params, X, y, key): + """ + contrastive divergence loss + Args: + params (dict): parameter dictionary + X (array): batch of training examples + y (array): not used; should be set to None when training + key: jax PRNG key + """ + keys = jax.random.split(key, X.shape[0]) + + # we do not take the gradient wrt the sampling, so decouple the param dict here + params_copy = copy.deepcopy(params) + for key in params_copy.keys(): + params_copy[key] = jax.lax.stop_gradient(params_copy[key]) + + configs = batched_mcmc_sample(params_copy, X, self.cdiv_steps + 1, keys) + x0 = configs[:, 0] + x1 = configs[:, -1] + + # taking the gradient of this loss is equivalent to the CD-k update + loss = batched_energy(params, x0) - batched_energy(params, x1) + + return jnp.mean(loss) + + c_div_loss = jax.jit(c_div_loss) if self.jit else c_div_loss + + self.params_ = train(self, c_div_loss, optax.adam, X, None, self.generate_key, + convergence_interval=self.convergence_interval) + + From 59c4147d704bd20c8aebdb74361ed03505095feb Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Thu, 25 Jul 2024 17:16:35 +0200 Subject: [PATCH 06/54] ebm model --- .../models/energy_based_model.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/qml_benchmarks/models/energy_based_model.py b/src/qml_benchmarks/models/energy_based_model.py index 493b2ac5..ad068625 100644 --- a/src/qml_benchmarks/models/energy_based_model.py +++ b/src/qml_benchmarks/models/energy_based_model.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Energy-based models for generative modeling.""" - import numpy as np import jax import jax.numpy as jnp @@ -23,15 +21,14 @@ import flax.linen as nn class MLP(nn.Module): - """ - Simple multilayer perceptron neural network used for the energy model. - """ + "multilayer perceptron in flax" @nn.compact def __call__(self, x): - x = nn.Dense(16)(x) + x = nn.Dense(8)(x) x = nn.tanh(x) - x = nn.Dense(1)(x) + x = nn.Dense(4)(x) x = nn.tanh(x) + x = nn.Dense(1)(x) return x class EnergyBasedModel(): @@ -54,6 +51,7 @@ class EnergyBasedModel(): def __init__(self, energy_model=MLP, learning_rate=0.001, cdiv_steps=1, jit=True, batch_size=32, max_steps=200, convergence_interval=None, random_state=42): + self.energy_model = energy_model() self.learning_rate = learning_rate self.random_state = random_state self.rng = np.random.default_rng(random_state) @@ -70,7 +68,6 @@ def __init__(self, energy_model=MLP, learning_rate=0.001, cdiv_steps=1, jit=True self.n_visible_ = None self.mcmc_step = jax.jit(self.mcmc_step) - self.energy_model = energy_model() def generate_key(self): return jax.random.PRNGKey(self.rng.integers(1000000)) @@ -104,8 +101,8 @@ def mcmc_step(self, args, i): flip_config = jnp.zeros(self.n_visible_, dtype=int) flip_config = flip_config.at[flip_idx].set(1) x_flip = jnp.array((x + flip_config) % 2) - en = self.energy(params, x) - en_flip = self.energy(params, x_flip) + en = self.energy(params, jnp.expand_dims(x, 0))[0] + en_flip = self.energy(params, jnp.expand_dims(x_flip, 0))[0] accept_ratio = jnp.exp(-en_flip) / jnp.exp(-en) accept = jnp.array(jax.random.bernoulli(key2, accept_ratio), dtype=int)[0] x_new = accept * x_flip + (1 - accept) * x @@ -135,12 +132,13 @@ def fit(self, X): """ Fit the parameters using contrastive divergence """ - self.initialize(X.shape[-1]) + self.initialize(X.shape[1]) X = jnp.array(X, dtype=int) # batch the relevant functions batched_mcmc_sample = jax.vmap(self.mcmc_sample, in_axes=(None, 0, None, 0)) - batched_energy = jax.vmap(self.energy, in_axes=(None, 0)) + + # batched_energy = jax.vmap(self.energy, in_axes=(None, 0)) def c_div_loss(params, X, y, key): """ @@ -163,7 +161,7 @@ def c_div_loss(params, X, y, key): x1 = configs[:, -1] # taking the gradient of this loss is equivalent to the CD-k update - loss = batched_energy(params, x0) - batched_energy(params, x1) + loss = self.energy(params, x0) - self.energy(params, x1) return jnp.mean(loss) From 366a65e610b1115f855ee23e718e222889a4cef2 Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Thu, 25 Jul 2024 17:27:24 +0200 Subject: [PATCH 07/54] ebm model --- src/qml_benchmarks/models/energy_based_model.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/qml_benchmarks/models/energy_based_model.py b/src/qml_benchmarks/models/energy_based_model.py index ad068625..ef066c1b 100644 --- a/src/qml_benchmarks/models/energy_based_model.py +++ b/src/qml_benchmarks/models/energy_based_model.py @@ -138,8 +138,6 @@ def fit(self, X): # batch the relevant functions batched_mcmc_sample = jax.vmap(self.mcmc_sample, in_axes=(None, 0, None, 0)) - # batched_energy = jax.vmap(self.energy, in_axes=(None, 0)) - def c_div_loss(params, X, y, key): """ contrastive divergence loss From 717d4472d069037576df00522a561fe3964556f3 Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Thu, 25 Jul 2024 17:33:04 +0200 Subject: [PATCH 08/54] update --- src/qml_benchmarks/models/restricted_boltzmann_machine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qml_benchmarks/models/restricted_boltzmann_machine.py b/src/qml_benchmarks/models/restricted_boltzmann_machine.py index 568183b3..9ba1d546 100644 --- a/src/qml_benchmarks/models/restricted_boltzmann_machine.py +++ b/src/qml_benchmarks/models/restricted_boltzmann_machine.py @@ -99,7 +99,7 @@ def sample(self, n_samples): key = self.generate_key() x_init = jnp.array(jax.random.bernoulli(key, p=0.5, shape=(self.n_visible_,)), dtype=int) samples = self.gibbs_sample(self.params_, x_init, n_samples, self.generate_key()) - return samples[0] + return jnp.array(samples[0]) def fit(self, X): """ From 8bfbbb920d17c1c916d35a69d63d0771673c01a7 Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Thu, 25 Jul 2024 17:43:42 +0200 Subject: [PATCH 09/54] update --- src/qml_benchmarks/models/energy_based_model.py | 2 +- src/qml_benchmarks/models/restricted_boltzmann_machine.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/qml_benchmarks/models/energy_based_model.py b/src/qml_benchmarks/models/energy_based_model.py index ef066c1b..2a1f3f79 100644 --- a/src/qml_benchmarks/models/energy_based_model.py +++ b/src/qml_benchmarks/models/energy_based_model.py @@ -67,7 +67,7 @@ def __init__(self, energy_model=MLP, learning_rate=0.001, cdiv_steps=1, jit=True self.params_ = None self.n_visible_ = None - self.mcmc_step = jax.jit(self.mcmc_step) + self.mcmc_step = jax.jit(self.mcmc_step) if jit else self.mcmc_step def generate_key(self): return jax.random.PRNGKey(self.rng.integers(1000000)) diff --git a/src/qml_benchmarks/models/restricted_boltzmann_machine.py b/src/qml_benchmarks/models/restricted_boltzmann_machine.py index 9ba1d546..5ac9d7ec 100644 --- a/src/qml_benchmarks/models/restricted_boltzmann_machine.py +++ b/src/qml_benchmarks/models/restricted_boltzmann_machine.py @@ -44,7 +44,7 @@ def __init__(self, n_hidden, learning_rate=0.001, cdiv_steps=1, jit=True, batch_ self.params_ = None self.n_visible_ = None - self.gibbs_step = jax.jit(self.gibbs_step) + self.gibbs_step = jax.jit(self.gibbs_step) if jit else self.gibbs_step def generate_key(self): return jax.random.PRNGKey(self.rng.integers(1000000)) From 2c1057a4f661d40d5d762d572f32f57c8f4466ad Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Fri, 26 Jul 2024 09:52:21 +0200 Subject: [PATCH 10/54] add tqdm --- src/qml_benchmarks/model_utils.py | 59 ++++++++++++++++--------------- 1 file changed, 31 insertions(+), 28 deletions(-) diff --git a/src/qml_benchmarks/model_utils.py b/src/qml_benchmarks/model_utils.py index ea60d235..f46bec96 100644 --- a/src/qml_benchmarks/model_utils.py +++ b/src/qml_benchmarks/model_utils.py @@ -25,6 +25,7 @@ from sklearn.exceptions import ConvergenceWarning from sklearn.utils import gen_batches import inspect +from tqdm import tqdm def train( model, loss_fn, optimizer, X, y, random_key_generator, convergence_interval=200 @@ -91,38 +92,40 @@ def update(params, opt_state, x, y, key): loss_history = [] converged = False start = time.time() - for step in range(model.max_steps): - key = random_key_generator() - key1, key2 = jax.random.split(key, 2) - X_batch, y_batch = get_batch(X, y, key1, batch_size=model.batch_size) - params, opt_state, loss_val = update(params, opt_state, X_batch, y_batch, key2) - loss_history.append(loss_val) - logging.debug(f"{step} - loss: {loss_val}") - - if np.isnan(loss_val): - logging.info(f"nan encountered. Training aborted.") - break - - # decide convergence - if convergence_interval is not None: - if step > 2 * convergence_interval: - # get means of last two intervals and standard deviation of last interval - average1 = np.mean(loss_history[-convergence_interval:]) - average2 = np.mean( - loss_history[-2 * convergence_interval : -convergence_interval] - ) - std1 = np.std(loss_history[-convergence_interval:]) - # if the difference in averages is small compared to the statistical fluctuations, stop training. - if np.abs(average2 - average1) <= std1 / np.sqrt(convergence_interval) / 2: - logging.info( - f"Model {model.__class__.__name__} converged after {step} steps." + with tqdm(total=model.max_steps, desc="Training Progress") as pbar: + for step in range(model.max_steps): + key = random_key_generator() + key1, key2 = jax.random.split(key, 2) + X_batch, y_batch = get_batch(X, y, key1, batch_size=model.batch_size) + params, opt_state, loss_val = update(params, opt_state, X_batch, y_batch, key2) + loss_history.append(loss_val) + logging.debug(f"{step} - loss: {loss_val}") + pbar.update(1) + + if np.isnan(loss_val): + logging.info(f"nan encountered. Training aborted.") + break + + # decide convergence + if convergence_interval is not None: + if step > 2 * convergence_interval: + # get means of last two intervals and standard deviation of last interval + average1 = np.mean(loss_history[-convergence_interval:]) + average2 = np.mean( + loss_history[-2 * convergence_interval : -convergence_interval] ) - converged = True - break + std1 = np.std(loss_history[-convergence_interval:]) + # if the difference in averages is small compared to the statistical fluctuations, stop training. + if np.abs(average2 - average1) <= std1 / np.sqrt(convergence_interval) / 2: + logging.info( + f"Model {model.__class__.__name__} converged after {step} steps." + ) + converged = True + break end = time.time() loss_history = np.array(loss_history) - model.loss_history_ = loss_history / np.max(np.abs(loss_history)) + model.loss_history_ = loss_history / np.max(np.abs(loss_history)) model.training_time_ = end - start if not converged and convergence_interval is not None: From 9ae22f4bb858ba9cb681c9b211e73c3c721847a1 Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Fri, 26 Jul 2024 16:05:18 +0200 Subject: [PATCH 11/54] convergence --- src/qml_benchmarks/models/energy_based_model.py | 5 ++--- src/qml_benchmarks/models/restricted_boltzmann_machine.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/qml_benchmarks/models/energy_based_model.py b/src/qml_benchmarks/models/energy_based_model.py index 2a1f3f79..318f82b5 100644 --- a/src/qml_benchmarks/models/energy_based_model.py +++ b/src/qml_benchmarks/models/energy_based_model.py @@ -44,13 +44,12 @@ class EnergyBasedModel(): batch_size (int): Size of batches used for computing parameter updates max_steps (int): Maximum number of training steps. convergence_interval (int or None): The number of loss values to consider to decide convergence. - If None, training runs until the maximum number of steps. Recommoneded to set to None since - CD-k does not follow the gradient of a fucntion. + If None, training runs until the maximum number of steps. random_state (int): Seed used for pseudorandom number generation. """ def __init__(self, energy_model=MLP, learning_rate=0.001, cdiv_steps=1, jit=True, batch_size=32, - max_steps=200, convergence_interval=None, random_state=42): + max_steps=200, convergence_interval=200, random_state=42): self.energy_model = energy_model() self.learning_rate = learning_rate self.random_state = random_state diff --git a/src/qml_benchmarks/models/restricted_boltzmann_machine.py b/src/qml_benchmarks/models/restricted_boltzmann_machine.py index 5ac9d7ec..bbd5398a 100644 --- a/src/qml_benchmarks/models/restricted_boltzmann_machine.py +++ b/src/qml_benchmarks/models/restricted_boltzmann_machine.py @@ -18,14 +18,13 @@ class RestrictedBoltzmannMachine(): max_steps (int): Maximum number of training steps. reg (float): The L2 regularisation strength (larger implies stronger) convergence_interval (int or None): The number of loss values to consider to decide convergence. - If None, training runs until the maximum number of steps. Recommoneded to set to None since - CD-k does not follow the gradient of a fucntion. + If None, training runs until the maximum number of steps. random_state (int): Seed used for pseudorandom number generation. """ def __init__(self, n_hidden, learning_rate=0.001, cdiv_steps=1, jit=True, batch_size=32, - max_steps=200, reg=0.0, convergence_interval=None, random_state=42): + max_steps=200, reg=0.0, convergence_interval=200, random_state=42): self.n_hidden = n_hidden self.learning_rate = learning_rate From 0c89b3dff3183c5c2a0e4e3b156948f0697c127f Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Tue, 30 Jul 2024 12:39:40 +0200 Subject: [PATCH 12/54] identical batching --- src/qml_benchmarks/model_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/qml_benchmarks/model_utils.py b/src/qml_benchmarks/model_utils.py index f46bec96..3598a222 100644 --- a/src/qml_benchmarks/model_utils.py +++ b/src/qml_benchmarks/model_utils.py @@ -94,10 +94,10 @@ def update(params, opt_state, x, y, key): start = time.time() with tqdm(total=model.max_steps, desc="Training Progress") as pbar: for step in range(model.max_steps): - key = random_key_generator() - key1, key2 = jax.random.split(key, 2) - X_batch, y_batch = get_batch(X, y, key1, batch_size=model.batch_size) - params, opt_state, loss_val = update(params, opt_state, X_batch, y_batch, key2) + key_batch = random_key_generator() + key_loss = jax.random.split(key_batch, 1) + X_batch, y_batch = get_batch(X, y, key_batch, batch_size=model.batch_size) + params, opt_state, loss_val = update(params, opt_state, X_batch, y_batch, key_loss) loss_history.append(loss_val) logging.debug(f"{step} - loss: {loss_val}") pbar.update(1) From 1d8c9a35d669109f59ae6a5961c5a70b884bebcd Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Wed, 31 Jul 2024 15:33:57 +0200 Subject: [PATCH 13/54] fix key split --- src/qml_benchmarks/model_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qml_benchmarks/model_utils.py b/src/qml_benchmarks/model_utils.py index 3598a222..29e9b4a4 100644 --- a/src/qml_benchmarks/model_utils.py +++ b/src/qml_benchmarks/model_utils.py @@ -95,7 +95,7 @@ def update(params, opt_state, x, y, key): with tqdm(total=model.max_steps, desc="Training Progress") as pbar: for step in range(model.max_steps): key_batch = random_key_generator() - key_loss = jax.random.split(key_batch, 1) + key_loss = jax.random.split(key_batch, 1)[0] X_batch, y_batch = get_batch(X, y, key_batch, batch_size=model.batch_size) params, opt_state, loss_val = update(params, opt_state, X_batch, y_batch, key_loss) loss_history.append(loss_val) From 24d3ed2b73ff432dcd8cff66d45e5e11392bdb7f Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Mon, 29 Jul 2024 15:55:13 +0000 Subject: [PATCH 14/54] Canot cast array to int with jax >= 0.4.30. --- src/qml_benchmarks/models/iqp_kernel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qml_benchmarks/models/iqp_kernel.py b/src/qml_benchmarks/models/iqp_kernel.py index e8f5853d..d2ceae00 100644 --- a/src/qml_benchmarks/models/iqp_kernel.py +++ b/src/qml_benchmarks/models/iqp_kernel.py @@ -177,7 +177,7 @@ def fit(self, X, y): self.svm.random_state = int( jax.random.randint( self.generate_key(), shape=(1,), minval=0, maxval=1000000 - ) + )[0] ) self.initialize(X.shape[1], np.unique(y)) From 37f43c604db6631bf083d00746f6342850c42403 Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Mon, 29 Jul 2024 16:16:43 +0000 Subject: [PATCH 15/54] self.svm.random_state = self.rng.integers --- src/qml_benchmarks/models/iqp_kernel.py | 6 +----- src/qml_benchmarks/models/separable.py | 6 +----- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/src/qml_benchmarks/models/iqp_kernel.py b/src/qml_benchmarks/models/iqp_kernel.py index d2ceae00..c87eee62 100644 --- a/src/qml_benchmarks/models/iqp_kernel.py +++ b/src/qml_benchmarks/models/iqp_kernel.py @@ -174,11 +174,7 @@ def fit(self, X, y): y (np.ndarray): Labels of shape (n_samples,) """ - self.svm.random_state = int( - jax.random.randint( - self.generate_key(), shape=(1,), minval=0, maxval=1000000 - )[0] - ) + self.svm.random_state = self.rng.integers(100000) self.initialize(X.shape[1], np.unique(y)) diff --git a/src/qml_benchmarks/models/separable.py b/src/qml_benchmarks/models/separable.py index 56e852f0..1f61f91f 100644 --- a/src/qml_benchmarks/models/separable.py +++ b/src/qml_benchmarks/models/separable.py @@ -378,11 +378,7 @@ def fit(self, X, y): y (np.ndarray): Labels of shape (n_samples,) """ - self.svm.random_state = int( - jax.random.randint( - self.generate_key(), shape=(1,), minval=0, maxval=1000000 - ) - ) + self.svm.random_state = self.rng.integers(100000) self.initialize(X.shape[1], np.unique(y)) From 0073ebb7dc57e86f3a757339f408b80f24cab096 Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Tue, 30 Jul 2024 10:23:54 +0200 Subject: [PATCH 16/54] correct conv interval --- src/qml_benchmarks/models/quanvolutional_neural_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qml_benchmarks/models/quanvolutional_neural_network.py b/src/qml_benchmarks/models/quanvolutional_neural_network.py index d08d4f07..e0740ca8 100644 --- a/src/qml_benchmarks/models/quanvolutional_neural_network.py +++ b/src/qml_benchmarks/models/quanvolutional_neural_network.py @@ -77,7 +77,7 @@ def __init__( jit=True, learning_rate=0.001, max_steps=10000, - convergence_interval=10e-4, + convergence_interval=200, batch_size=32, random_state=42, scaling=1.0, From 2ab65d137722ca97d26d71ca20a681597da8af86 Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Tue, 30 Jul 2024 10:38:38 +0200 Subject: [PATCH 17/54] fix height/width --- src/qml_benchmarks/data/bars_and_stripes.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/qml_benchmarks/data/bars_and_stripes.py b/src/qml_benchmarks/data/bars_and_stripes.py index 1aaf109a..d9e2ae50 100644 --- a/src/qml_benchmarks/data/bars_and_stripes.py +++ b/src/qml_benchmarks/data/bars_and_stripes.py @@ -23,19 +23,21 @@ def generate_bars_and_stripes(n_samples, height, width, noise_std): height (int): number of pixels for image height width (int): number of pixels for image width noise_std (float): standard deviation of Gaussian noise added to the pixels + Returns: + (array): data labels. -1 corresponds to a bar, +1 to a stripe. """ X = np.ones([n_samples, 1, height, width]) * -1 y = np.zeros([n_samples]) for i in range(len(X)): if np.random.rand() > 0.5: - rows = np.where(np.random.rand(width) > 0.5)[0] + rows = np.where(np.random.rand(height) > 0.5)[0] X[i, 0, rows, :] = 1.0 - y[i] = -1 + y[i] = +1 else: - columns = np.where(np.random.rand(height) > 0.5)[0] + columns = np.where(np.random.rand(width) > 0.5)[0] X[i, 0, :, columns] = 1.0 - y[i] = +1 + y[i] = -1 X[i, 0] = X[i, 0] + np.random.normal(0, noise_std, size=X[i, 0].shape) return X, y From 1a3d81d9c9989ebdb8457336932809d66ff78232 Mon Sep 17 00:00:00 2001 From: Shahnawaz Ahmed Date: Tue, 13 Aug 2024 17:00:22 +0200 Subject: [PATCH 18/54] Added spin blobs dataset --- src/qml_benchmarks/data/spin_blobs.py | 112 ++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 src/qml_benchmarks/data/spin_blobs.py diff --git a/src/qml_benchmarks/data/spin_blobs.py b/src/qml_benchmarks/data/spin_blobs.py new file mode 100644 index 00000000..028b5f25 --- /dev/null +++ b/src/qml_benchmarks/data/spin_blobs.py @@ -0,0 +1,112 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generate a dataset of spin configurations as blobs.""" + +import numpy as np + + +class RandomSpinBlobs: + """Generate spin configurations with high probabilites for certain spins. + + The dataset is generated by creating random spin samples close to a few + chosen peaks forming a set of blobs. One of the peaks is chosen randomly + and then by flipping some of the spins from the chosen configuration, new + spins are added so that they are a certain hamming distance away from the + chosen peak chosen from a binomial distribution with a certain noise. + + Args: + N (int): The number of spins. + num_blobs (int): The number of blobs or peak probabilities. + peak_probabilities (list[float], optional): + The probability of each spin to be selected. If not specified, + the probabilities are distributed uniformly. + peak_spins (list[np.array], optional): + The peak spin configurations, selected randomly by default. + noise_probability (float, optional): + The probability with which a binomial distribution is sampled to + add noise to the spin configurations. + """ + + def __init__( + self, + N: int, + num_blobs: int, + peak_probabilities: list[float] = None, + peak_spins: list[np.array] = None, + noise_probability: float = 0.001, + ) -> None: + self.N = N + self.num_blobs = num_blobs + + if peak_probabilities is None: + peak_probabilities = [1 / num_blobs for i in range(num_blobs)] + + if len(peak_probabilities) != num_blobs: + msg = f"Specify probabilities for all {num_blobs} blobs." + raise ValueError(msg) + + if peak_spins is not None and len(peak_spins) != num_blobs: + msg = f"The number of peak spins should be the same as blobs." + raise ValueError(msg) + + if peak_spins is None: + # generate some random peak spin configs + # we flip each bit with 50% prob + spin_configs = [] + while len(spin_configs) < num_blobs: + config = list((-1) ** np.random.binomial([1] * N, [0.5] * N)) + if config not in spin_configs: + spin_configs.append(config) + peak_spins = spin_configs + + self.peak_spins = peak_spins + self.peak_probabilities = peak_probabilities + self.noise_probability = noise_probability + + def sample(self, num_samples: int, return_labels=False) -> np.array: + """Generate a given number of samples. + + Args: + num_samples (int): Number of samples to generate. + return_labels (bool, optional): + Whether to return labels for each sample. Defaults to False. + + Returns: + np.array: A (num_samples, N) array of spin configurations. + """ + samples = [] + labels = [] + + for _ in range(num_samples): + # Choose a random peak + label = np.random.choice(self.num_blobs, p=self.peak_probabilities) + labels.append(label) + peak_spin = self.peak_spins[label] + + # Randomly choose a Hamming distance from the sampler + dist = np.random.binomial(self.N, self.noise_probability) + + # Flip 'dist' number of bits randomly in the peak spin configuration + indices_to_flip = np.random.choice(self.N, dist, replace=False) + sampled_spin = np.copy(peak_spin) + for index in indices_to_flip: + sampled_spin[index] *= -1 + + samples.append(sampled_spin) + + if return_labels: + return np.array(samples), np.array(labels) + else: + return np.array(samples) \ No newline at end of file From 2c55999c9cc269d2c70eb8dc9d46a574b4f39056 Mon Sep 17 00:00:00 2001 From: Shahnawaz Ahmed Date: Tue, 13 Aug 2024 19:11:57 +0200 Subject: [PATCH 19/54] Updated spin blobs dataset --- src/qml_benchmarks/data/spin_blobs.py | 126 ++++++++++++++++++++++---- 1 file changed, 110 insertions(+), 16 deletions(-) diff --git a/src/qml_benchmarks/data/spin_blobs.py b/src/qml_benchmarks/data/spin_blobs.py index 028b5f25..d5a6842a 100644 --- a/src/qml_benchmarks/data/spin_blobs.py +++ b/src/qml_benchmarks/data/spin_blobs.py @@ -21,22 +21,31 @@ class RandomSpinBlobs: """Generate spin configurations with high probabilites for certain spins. The dataset is generated by creating random spin samples close to a few - chosen peaks forming a set of blobs. One of the peaks is chosen randomly - and then by flipping some of the spins from the chosen configuration, new - spins are added so that they are a certain hamming distance away from the - chosen peak chosen from a binomial distribution with a certain noise. + chosen `peak_spin` configurations of dimension `N` with each spin having + the possible values 0 or 1. We can vary the `peak_probabilities` parameter + to create data with different modes, where some samples will have higher + probabilities allowing us to study the effects of imbalance in the data. + + Samples are generated by selecting one of the peak spin configurations + distributed according `peak_probabilities`, and then by flipping some of the + spins. The number of spins that are flipped each time, is drawn from a + Binomial distribution bin(`N`, `p`) where `p=1` will flip all the spins + and `p=0` will not flip any spins therefore creating very narrow distributions + around the peak spins. Args: N (int): The number of spins. - num_blobs (int): The number of blobs or peak probabilities. + num_blobs (int): + The number of blobs or peak probabilities. peak_probabilities (list[float], optional): The probability of each spin to be selected. If not specified, the probabilities are distributed uniformly. peak_spins (list[np.array], optional): - The peak spin configurations, selected randomly by default. - noise_probability (float, optional): - The probability with which a binomial distribution is sampled to - add noise to the spin configurations. + The peak spin configurations. Selected randomly by default. + p (float, optional): + The value of the parameter `p` in a Binomial distribution specifying + the number of spins that are flipped each time during sampling. + Defaults to 0.01. """ def __init__( @@ -45,7 +54,7 @@ def __init__( num_blobs: int, peak_probabilities: list[float] = None, peak_spins: list[np.array] = None, - noise_probability: float = 0.001, + p: float = 0.01, ) -> None: self.N = N self.num_blobs = num_blobs @@ -73,7 +82,7 @@ def __init__( self.peak_spins = peak_spins self.peak_probabilities = peak_probabilities - self.noise_probability = noise_probability + self.p = p def sample(self, num_samples: int, return_labels=False) -> np.array: """Generate a given number of samples. @@ -96,17 +105,102 @@ def sample(self, num_samples: int, return_labels=False) -> np.array: peak_spin = self.peak_spins[label] # Randomly choose a Hamming distance from the sampler - dist = np.random.binomial(self.N, self.noise_probability) + num_bits_to_flip = np.random.binomial(self.N, self.p) - # Flip 'dist' number of bits randomly in the peak spin configuration - indices_to_flip = np.random.choice(self.N, dist, replace=False) + # Flip bits randomly in the peak spin configuration + indices_to_flip = np.random.choice(self.N, num_bits_to_flip, replace=False) sampled_spin = np.copy(peak_spin) for index in indices_to_flip: sampled_spin[index] *= -1 samples.append(sampled_spin) + samples = (np.array(samples) + 1) / 2 + if return_labels: - return np.array(samples), np.array(labels) + return samples, np.array(labels) else: - return np.array(samples) \ No newline at end of file + return samples + + +def generate_8blobs( + num_samples: int, + p: float = 0.01, +): + """Generate 4x4 spin samples with 8 selected high-probability configurations + + Example + ------- + import matplotlib.pyplot as plt + from qml_benchmarks.data.spin_blobs import generate_8blobs + X, y = generate_8blobs(100) + num_samples = 20 + interval = len(X) // num_samples + + fig, axes = plt.subplots(1, num_samples, figsize=(20, 4)) + for i in range(num_samples): + axes[i].imshow(X[i*interval].reshape((4, 4))) + axes[i].axis('off') + plt.show() + + Args: + num_samples (int): The number of samples to generate. + p (float, optional): + The value of the parameter p in a Binomial distribution bin(N, p) + that determines how many spins are flipped during each sampling step + after choosing one of the peak configurations. Defaults to 0.01. + + Returns: + np.ndarray: A (num_samples, 16) array of spin configurations. + """ + np.random.seed(66) + N: int = 16 + num_blobs: int = 8 + + # generate a specific set + config1 = np.array( + [[1, 1, -1, -1], [1, 1, -1, -1], [-1, -1, -1, -1], [-1, -1, -1, -1]] + ) + config2 = np.array( + [[-1, -1, 1, 1], [-1, -1, 1, 1], [-1, -1, -1, -1], [-1, -1, -1, -1]] + ) + config3 = np.array( + [[-1, -1, -1, -1], [-1, -1, -1, -1], [1, 1, -1, -1], [1, 1, -1, -1]] + ) + config4 = np.array( + [[-1, -1, -1, -1], [-1, -1, -1, -1], [-1, -1, 1, 1], [-1, -1, 1, 1]] + ) + config5 = np.array( + [[-1, -1, -1, -1], [-1, 1, 1, -1], [-1, 1, 1, -1], [-1, -1, -1, -1]] + ) + config6 = np.array( + [[1, -1, -1, -1], [-1, 1, -1, -1], [-1, -1, 1, -1], [-1, -1, -1, 1]] + ) + config7 = np.array( + [[-1, -1, -1, 1], [-1, -1, 1, -1], [-1, 1, -1, -1], [1, -1, -1, -1]] + ) + config8 = np.array( + [[1, -1, -1, 1], [-1, -1, -1, -1], [-1, -1, -1, -1], [1, -1, -1, 1]] + ) + + peak_spins = [ + np.reshape(config1, -1), + np.reshape(config2, -1), + np.reshape(config3, -1), + np.reshape(config4, -1), + np.reshape(config5, -1), + np.reshape(config6, -1), + np.reshape(config7, -1), + np.reshape(config8, -1), + ] + sampler = RandomSpinBlobs( + N=N, + num_blobs=num_blobs, + peak_spins=peak_spins, + p=p, + ) + + X, y = sampler.sample(num_samples=num_samples, return_labels=True) + X = X.reshape(-1, N) + + return X, y From f238768a25e82d6e66e63da976736234ea965342 Mon Sep 17 00:00:00 2001 From: Shahnawaz Ahmed Date: Wed, 14 Aug 2024 21:18:13 +0200 Subject: [PATCH 20/54] Added generative models and hyperparameter search Added models for generative learning and scripts to perform hyperparameter search on them. --- generative_models/benchmarks/README.md | 65 ++++ generative_models/benchmarks/generate_data.py | 31 ++ .../benchmarks/run_hyperparameter_search.py | 214 +++++++++++++ src/qml_benchmarks/data/__init__.py | 2 +- src/qml_benchmarks/hyperparam_search_utils.py | 15 +- src/qml_benchmarks/hyperparameter_settings.py | 5 + src/qml_benchmarks/models/__init__.py | 84 ++--- src/qml_benchmarks/models/base.py | 293 ++++++++++++++++++ .../models/energy_based_model.py | 221 +++++-------- 9 files changed, 748 insertions(+), 182 deletions(-) create mode 100644 generative_models/benchmarks/README.md create mode 100644 generative_models/benchmarks/generate_data.py create mode 100644 generative_models/benchmarks/run_hyperparameter_search.py create mode 100644 src/qml_benchmarks/models/base.py diff --git a/generative_models/benchmarks/README.md b/generative_models/benchmarks/README.md new file mode 100644 index 00000000..040529af --- /dev/null +++ b/generative_models/benchmarks/README.md @@ -0,0 +1,65 @@ +# Benchmarking for generative model + +The scripts in this package can help set up experiments to evaluate generative +models on custom datasets. The models and datasets are defined in the main +package. + +## Datasets + +The `qml_benchmarks.data` module provides generating functions to create datasets +A generating function can be used like this: + +```python +from qml_benchmarks.data import generate_8blobs +X, y = generate_8blobs(n_samples=200) +``` + +The scipt in this folder will generate a simple spin blob dataset. + +## Running hyperparameter optimization + +A hyperparameter search for any model and dataset can be run with the script +in this folder as: + +``` +python run_hyperparameter_search.py --model-name "RBM" --dataset-path "spin_blobs/8blobs_train.csv" +``` + +where `spin_blobs/8blobs_train.csv` is a CSV file containing the training data +such that each column is a feature. + +Unless otherwise specified, the hyperparameter grid is loaded from +`qml_benchmarks/hyperparameter_settings.py`. One can override the default +grid of hyperparameters by specifying the hyperparameter list, +where the datatype is inferred from the default values. +For example, for the `RBM` we can run: + +``` +python run_hyperparameter_search.py \ + --model-name RBM \ + --dataset-path "spin_blobs/8blobs_train.csv" \ + --learning_rate 0.1 0.01 \ + --clean True +``` + +which runs a search for the grid: + +``` +{'learning_rate': [0.1, 0.01], } +``` + +The script creates two CSV files that contains the detailed results of hyperparameter search and the best +hyperparameters obtained in the search. These files are similar to the ones stored in the `paper/results` +folder. + +The best hyperparameters can be loaded into a model and used to score the classifier. + +You can check the various options for the script using: + +``` +python run_hyperparameter_search --help +``` + +## Feedback + +Please help us improve this repository and report problems by opening an issue or pull request. diff --git a/generative_models/benchmarks/generate_data.py b/generative_models/benchmarks/generate_data.py new file mode 100644 index 00000000..9f34465f --- /dev/null +++ b/generative_models/benchmarks/generate_data.py @@ -0,0 +1,31 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generate 8blobs dataset.""" + +import os +import numpy as np +from qml_benchmarks.data import generate_8blobs + + +if __name__ == "__main__": + os.makedirs("spin_blobs", exist_ok=True) + path_train = "spin_blobs/8blobs_train.csv" + path_test = "spin_blobs/8blobs_test.csv" + + X, y = generate_8blobs(num_samples=5000) + np.savetxt(path_train, X, delimiter=",") + + X, y = generate_8blobs(num_samples=1000) + np.savetxt(path_test, X, delimiter=",") diff --git a/generative_models/benchmarks/run_hyperparameter_search.py b/generative_models/benchmarks/run_hyperparameter_search.py new file mode 100644 index 00000000..c01f096a --- /dev/null +++ b/generative_models/benchmarks/run_hyperparameter_search.py @@ -0,0 +1,214 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run hyperparameter search and store results with a command-line script.""" + +import numpy as np +import sys +import os +import time +import argparse +import logging +logging.getLogger().setLevel(logging.INFO) +from importlib import import_module +import pandas as pd +from pathlib import Path +import matplotlib.pyplot as plt +from sklearn.model_selection import GridSearchCV +from qml_benchmarks.hyperparam_search_utils import read_data, construct_hyperparameter_grid +from qml_benchmarks.hyperparameter_settings import hyper_parameter_settings + +np.random.seed(42) + +logging.info('cpu count:' + str(os.cpu_count())) + + +if __name__ == "__main__": + # Create an argument parser + parser = argparse.ArgumentParser(description="Run experiments with hyperparameter search.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument( + "--model-name", + help="Model to run", + ) + + parser.add_argument( + "--dataset-path", + help="Path to the dataset", + ) + + parser.add_argument( + "--results-path", default=".", help="Path to store the experiment results" + ) + + parser.add_argument( + "--clean", + help="True or False. Remove previous results if it exists", + dest="clean", + default=False, + type=bool, + ) + + parser.add_argument( + "--hyperparameter-scoring", + type=list, + nargs="+", + default=["accuracy", "roc_auc"], + help="Scoring for hyperparameter search.", + ) + + parser.add_argument( + "--hyperparameter-refit", + type=str, + default="accuracy", + help="Refit scoring for hyperparameter search.", + ) + + parser.add_argument( + "--plot-loss", + help="True or False. Plot loss history for single fit", + dest="plot_loss", + default=False, + type=bool, + ) + + parser.add_argument( + "--n-jobs", type=int, default=-1, help="Number of parallel threads to run" + ) + + # Parse the arguments along with any extra arguments that might be model specific + args, unknown_args = parser.parse_known_args() + + if any(arg is None for arg in [args.model_name, + args.dataset_path]): + msg = "\n================================================================================" + msg += "\nA model from qml.benchmarks.model and dataset path are required. E.g., \n \n" + msg += "python run_hyperparameter_search \ \n--model RBM \ \n--dataset-path train.csv\n" + msg += "\nCheck all arguments for the script with \n" + msg += "python run_hyperparameter_search --help\n" + msg += "================================================================================" + raise ValueError(msg) + + # Add model specific arguments to override the default hyperparameter grid + hyperparam_grid = construct_hyperparameter_grid( + hyper_parameter_settings, args.model_name + ) + for hyperparam in hyperparam_grid: + hp_type = type(hyperparam_grid[hyperparam][0]) + parser.add_argument(f'--{hyperparam}', + type=hp_type, + nargs="+", + default=hyperparam_grid[hyperparam], + help=f'{hyperparam} grid values for {args.model_name}') + + args = parser.parse_args(unknown_args, namespace=args) + + for hyperparam in hyperparam_grid: + override = getattr(args, hyperparam) + if override is not None: + hyperparam_grid[hyperparam] = override + logging.info( + "Running hyperparameter search experiment with the following settings\n" + ) + logging.info(args.model_name) + logging.info(args.dataset_path) + logging.info(" ".join(args.hyperparameter_scoring)) + logging.info(args.hyperparameter_refit) + logging.info("Hyperparam grid:"+" ".join([(str(key)+str(":")+str(hyperparam_grid[key])) for key in hyperparam_grid.keys()])) + + experiment_path = args.results_path + results_path = os.path.join(experiment_path, "results") + + if not os.path.exists(results_path): + os.makedirs(results_path) + + ################################################################### + # Get the model, dataset and search methods from the arguments + ################################################################### + model = getattr( + import_module("qml_benchmarks.models"), + args.model_name + ) + model_name = model.__name__ + + # Run the experiments save the results + train_dataset_filename = os.path.join(args.dataset_path) + X, y = read_data(train_dataset_filename) + + dataset_path_obj = Path(args.dataset_path) + results_filename_stem = " ".join( + [model.__name__ + "_" + dataset_path_obj.stem + + "_GridSearchCV"]) + + # If we have already run this experiment then continue + if os.path.isfile(os.path.join(results_path, results_filename_stem + ".csv")): + if args.clean is False: + msg = "\n=================================================================================" + msg += "\nResults exist in " + os.path.join(results_path, results_filename_stem + ".csv") + msg += "\nSpecify --clean True to override results or new --results-path" + msg += "\n=================================================================================" + logging.warning(msg) + sys.exit(msg) + else: + logging.warning("Cleaning existing results for ", os.path.join(results_path, results_filename_stem + ".csv")) + + + ########################################################################### + # Single fit to check everything works + ########################################################################### + model = model() + a = time.time() + model.fit(X, y) + b = time.time() + acc_train = model.score(X, y) + logging.info(" ".join( + [model_name, + "Dataset path", + args.dataset_path, + "Train acc:", + str(acc_train), + "Time single run", + str(b - a)]) + ) + if hasattr(model, "loss_history_"): + if args.plot_loss: + plt.plot(model.loss_history_) + plt.xlabel("Iterations") + plt.ylabel("Loss") + plt.show() + + if hasattr(model, "n_qubits_"): + logging.info(" ".join(["Num qubits", f"{model.n_qubits_}"])) + + ########################################################################### + # Hyperparameter search + ########################################################################### + gs = GridSearchCV(estimator=model, param_grid=hyperparam_grid, + refit=args.hyperparameter_refit, + verbose=3, + n_jobs=-1).fit( + X, y + ) + logging.info("Best hyperparams") + logging.info(gs.best_params_) + + df = pd.DataFrame.from_dict(gs.cv_results_) + df.to_csv(os.path.join(results_path, results_filename_stem + ".csv")) + + best_df = pd.DataFrame(list(gs.best_params_.items()), columns=['hyperparameter', 'best_value']) + + # Save best hyperparameters to a CSV file + best_df.to_csv(os.path.join(results_path, + results_filename_stem + '-best-hyperparameters.csv'), index=False) \ No newline at end of file diff --git a/src/qml_benchmarks/data/__init__.py b/src/qml_benchmarks/data/__init__.py index 654ac570..db748a56 100644 --- a/src/qml_benchmarks/data/__init__.py +++ b/src/qml_benchmarks/data/__init__.py @@ -19,4 +19,4 @@ from qml_benchmarks.data.hyperplanes import generate_hyperplanes_parity from qml_benchmarks.data.linearly_separable import generate_linearly_separable from qml_benchmarks.data.two_curves import generate_two_curves - \ No newline at end of file +from qml_benchmarks.data.spin_blobs import generate_8blobs diff --git a/src/qml_benchmarks/hyperparam_search_utils.py b/src/qml_benchmarks/hyperparam_search_utils.py index 27757d80..0d9adbaf 100644 --- a/src/qml_benchmarks/hyperparam_search_utils.py +++ b/src/qml_benchmarks/hyperparam_search_utils.py @@ -19,19 +19,24 @@ import pandas as pd -def read_data(path): +def read_data(path, labels=True): """Read data from a csv file where each row is a data sample. - The columns are the input features and the last column specifies a label. + The columns are the input features. If labels=True, the last feature is understood to be the label corresponding + to that sample. - Return a 2-d array of inputs and an array of labels, X,y. + Return a 2-d array of inputs and an array of labels (if labels=True) Args: path (str): path to data """ # The data is stored on a CSV file with the last column being the label data = pd.read_csv(path, header=None) - X = data.iloc[:, :-1].values - y = data.iloc[:, -1].values + if labels: + X = data.iloc[:, :-1].values + y = data.iloc[:, -1].values + else: + X = data.iloc[:, :].values + y = None return X, y diff --git a/src/qml_benchmarks/hyperparameter_settings.py b/src/qml_benchmarks/hyperparameter_settings.py index 58a6923b..9fb8dce0 100644 --- a/src/qml_benchmarks/hyperparameter_settings.py +++ b/src/qml_benchmarks/hyperparameter_settings.py @@ -191,4 +191,9 @@ "alpha": {"type": "list", "dtype": "float", "val": [0.01, 0.001, 0.0001]}, }, "Perceptron": {"eta0": {"type": "list", "dtype": "float", "val": [0.1, 1, 10]}}, + "RBM": {"learning_rate": {"type": "list", "dtype": "float", "val": [0.01, 0.1, 1.]}}, + "DeepEBM": {"hidden_layers": {"type": "list", + "dtype": "tuple", + "val": ["(100,)", "(10, 10, 10, 10)", "(50, 10, 5)"], + }} } diff --git a/src/qml_benchmarks/models/__init__.py b/src/qml_benchmarks/models/__init__.py index 2cabcec6..4e225db8 100644 --- a/src/qml_benchmarks/models/__init__.py +++ b/src/qml_benchmarks/models/__init__.py @@ -20,8 +20,8 @@ ) from qml_benchmarks.models.data_reuploading import ( DataReuploadingClassifier, - DataReuploadingClassifierNoScaling, DataReuploadingClassifierNoCost, + DataReuploadingClassifierNoScaling, DataReuploadingClassifierNoTrainableEmbedding, DataReuploadingClassifierSeparable, ) @@ -30,27 +30,29 @@ DressedQuantumCircuitClassifierOnlyNN, DressedQuantumCircuitClassifierSeparable, ) - from qml_benchmarks.models.iqp_kernel import IQPKernelClassifier from qml_benchmarks.models.iqp_variational import IQPVariationalClassifier from qml_benchmarks.models.projected_quantum_kernel import ProjectedQuantumKernel from qml_benchmarks.models.quantum_boltzmann_machine import ( QuantumBoltzmannMachine, - QuantumBoltzmannMachineSeparable + QuantumBoltzmannMachineSeparable, ) from qml_benchmarks.models.quantum_kitchen_sinks import QuantumKitchenSinks from qml_benchmarks.models.quantum_metric_learning import QuantumMetricLearner -from qml_benchmarks.models.quanvolutional_neural_network import QuanvolutionalNeuralNetwork +from qml_benchmarks.models.quanvolutional_neural_network import ( + QuanvolutionalNeuralNetwork, +) from qml_benchmarks.models.separable import ( - SeparableVariationalClassifier, SeparableKernelClassifier, + SeparableVariationalClassifier, ) +from qml_benchmarks.models.energy_based_model import DeepEBM, RBM from qml_benchmarks.models.tree_tensor import TreeTensorClassifier from qml_benchmarks.models.vanilla_qnn import VanillaQNN from qml_benchmarks.models.weinet import WeiNet -from sklearn.svm import SVC as SVC_base from sklearn.neural_network import MLPClassifier as MLP +from sklearn.svm import SVC as SVC_base __all__ = [ "CircuitCentricClassifier", @@ -78,35 +80,37 @@ "WeiNet", "MLPClassifier", "SVC", + "DeepEBM", + "RBM", ] class MLPClassifier(MLP): def __init__( - self, - hidden_layer_sizes=(100, 100), - activation="relu", - solver="adam", - alpha=0.0001, - batch_size="auto", - learning_rate="constant", - learning_rate_init=0.001, - power_t=0.5, - max_iter=3000, - shuffle=True, - random_state=None, - tol=1e-4, - verbose=False, - warm_start=False, - momentum=0.9, - nesterovs_momentum=True, - early_stopping=False, - validation_fraction=0.1, - beta_1=0.9, - beta_2=0.999, - epsilon=1e-8, - n_iter_no_change=10, - max_fun=15000, + self, + hidden_layer_sizes=(100, 100), + activation="relu", + solver="adam", + alpha=0.0001, + batch_size="auto", + learning_rate="constant", + learning_rate_init=0.001, + power_t=0.5, + max_iter=3000, + shuffle=True, + random_state=None, + tol=1e-4, + verbose=False, + warm_start=False, + momentum=0.9, + nesterovs_momentum=True, + early_stopping=False, + validation_fraction=0.1, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-8, + n_iter_no_change=10, + max_fun=15000, ): super().__init__( hidden_layer_sizes=hidden_layer_sizes, @@ -137,16 +141,16 @@ def __init__( class SVC(SVC_base): def __init__( - self, - C=1.0, - degree=3, - gamma="scale", - coef0=0.0, - shrinking=True, - probability=False, - tol=0.001, - max_iter=-1, - random_state=None, + self, + C=1.0, + degree=3, + gamma="scale", + coef0=0.0, + shrinking=True, + probability=False, + tol=0.001, + max_iter=-1, + random_state=None, ): super().__init__( C=C, diff --git a/src/qml_benchmarks/models/base.py b/src/qml_benchmarks/models/base.py new file mode 100644 index 00000000..8fd3f33c --- /dev/null +++ b/src/qml_benchmarks/models/base.py @@ -0,0 +1,293 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base classes for models.""" + +import copy +from abc import abstractmethod + +import jax +import jax.numpy as jnp +import numpy as np +import optax +from qml_benchmarks.model_utils import train +from sklearn.base import BaseEstimator + + +class BaseGenerator(BaseEstimator): + """ + A base class for generative models. + + We use Scikit-learn's `BaseEstimator` so that we can take advantage of + Scikit-learn's hyperparameter search algorithms like `GridSearchCV` or + `RandomizedSearchCV` for hyperparameter tuning. + + Args: + dim: + The dimensionality of the samples, e.g., for spins it could be an + integer or a tuple specifying a grid. + """ + + def __init__(self, dim: int or tuple[int]) -> None: + self.dim = dim + + @abstractmethod + def initialize(self, x: any = None): + """ + Initialize the model and create the model parameters. + + Args: + x: An example data or dimensionality of the model parameters. + """ + # self.dim = x.shape[1:] + pass + + @abstractmethod + def sample(self, num_samples: int) -> any: + """ + Sample from the model. + + Args: + num_samples: The number of samples to generate. + """ + pass + + +class EnergyBasedModel(BaseGenerator): + """ + A base class for energy-based generative models with common functionalities. + + We use Scikit-learn's `BaseEstimator` so that we can take advantage of + Scikit-learn's hyperparameter search algorithms like `GridSearchCV` or + `RandomizedSearchCV` for hyperparameter tuning. + + The parameters of the model are stored in the `params_` attribute. The + model hyperparameters are explicitly passed to the constructor. See the + `BaseEstimator` documentation for more details. + + References: + Teh, Yee Whye, Max Welling, Simon Osindero, and Geoffrey E. Hinton. + "Energy-Based Models for Sparse Overcomplete Representations." + Journal of Machine Learning Research, vol. 4, 2003, pp. 1235-1260. + """ + + def __init__( + self, + dim: int = None, + learning_rate=0.001, + batch_size=32, + max_steps=10000, + cdiv_steps=100, + convergence_interval=None, + random_state=42, + jit=True, + ) -> None: + self.learning_rate = learning_rate + self.jit = jit + self.batch_size = batch_size + self.max_steps = max_steps + self.cdiv_steps = cdiv_steps + self.convergence_interval = convergence_interval + self.vmap = True + self.max_vmap = None + self.random_state = random_state + self.rng = np.random.default_rng(random_state) + + # data depended attributes + self.params_: dict[str : jnp.array] = None + self.dim = dim # initialized to None + + # Train depended attributes that the function train in self.fit() sets. + # It is not the best practice to set attributes hidden in that function + # Since it is not clear unless someone knows what the train function does. + # Therefore we add it here for clarity. + self.history_: list[float] = None + self.training_time_: float = None + + self.mcmc_step = jax.jit(self.mcmc_step) + self.batched_mcmc_sample = jax.vmap( + self.mcmc_sample, in_axes=(None, 0, None, 0) + ) + + def generate_key(self): + return jax.random.PRNGKey(self.rng.integers(1000000)) + + @abstractmethod + def energy(self, params: dict, x: any) -> float: + """ + The energy function for the model for a given configuration x. + + This function should be implemented by the subclass and is also + responsible for initializing the parameters of the model, if necessary. + + Args: + x: The configuration to calculate the energy for. + Returns: + energy (float): The energy. + """ + pass + + # TODO: this can be made even more efficient with Numpyro MCMC + # see qgml.mcmc for a simple example + def mcmc_step(self, args, i): + """ + Perform one metropolis hastings steps. + The format is such that it can be used with jax.lax.scan for fast compilation. + """ + params, key, x = args + key1, key2 = jax.random.split(key, 2) + flip_idx = jax.random.choice(key1, jnp.arange(self.dim)) + flip_config = jnp.zeros(self.dim, dtype=int) + flip_config = flip_config.at[flip_idx].set(1) + x_flip = jnp.array((x + flip_config) % 2) + + en = self.energy(params, jnp.expand_dims(x, 0))[0] + + en_flip = self.energy(params, jnp.expand_dims(x_flip, 0))[0] + accept_ratio = jnp.exp(-en_flip) / jnp.exp(-en) + accept = jnp.array(jax.random.bernoulli(key2, accept_ratio), dtype=int)[0] + x_new = accept * x_flip + (1 - accept) * x + return [params, key2, x_new], x + + def mcmc_sample(self, params, x_init, num_mcmc_steps, key): + """ + Sample a chain of configurations from a starting configuration x_init + """ + carry = [params, key, x_init] + carry, configs = jax.lax.scan(self.mcmc_step, carry, jnp.arange(num_mcmc_steps)) + return configs + + def langevin_sample(self, params, x_init, n_samples, key): + pass + + def sample(self, num_samples, num_mcmc_steps=1000): + """ + sample configurations starting from a random configuration. + """ + if self.params_ is None: + raise ValueError( + "Model not initialized. Call model.initialize first with" + "example data sample." + ) + keys = jax.random.split(self.generate_key(), num_samples) + + x_init = jnp.array( + jax.random.bernoulli( + self.generate_key(), p=0.5, shape=(num_samples, self.dim) + ), + dtype=int, + ) + configs = self.batched_mcmc_sample(self.params_, x_init, num_mcmc_steps, keys) + x1 = configs[:, -1] + return x1 + + def contrastive_divergence_loss(self, params, X, y, key): + """ + Contrastive divergence loss function. + Args: + X (array): batch of training examples + y (array): not used; should be set to None when training + key: jax PRNG key + """ + keys = jax.random.split(key, X.shape[0]) + + # we do not take the gradient wrt the sampling, so decouple the param dict here + params_copy = copy.deepcopy(params) + for key in params_copy.keys(): + params_copy[key] = jax.lax.stop_gradient(params_copy[key]) + + configs = self.batched_mcmc_sample(params_copy, X, self.cdiv_steps, keys) + x0 = configs[:, 0] + x1 = configs[:, -1] + + # taking the gradient of this loss is equivalent to the CD-k update + loss = self.energy(params, x0) - self.energy(params, x1) + + return jnp.mean(loss) + + def fit(self, X: jnp.array, y: any = None) -> None: + """ + Fit the parameters and update self.params_. + """ + self.initialize(X) + c_div_loss = ( + jax.jit(self.contrastive_divergence_loss) + if self.jit + else self.contrastive_divergence_loss + ) + + self.params_ = train( + self, + c_div_loss, + optax.adam, + X, + None, + self.generate_key, + convergence_interval=self.convergence_interval, + ) + + def score(self, X, y=None): + """Score the model on the given data. + + Higher is better. + """ + if self.params_ is None: + self.initialize(X.shape[1]) + + c_div_loss = ( + jax.jit(self.contrastive_divergence_loss) + if self.jit + else self.contrastive_divergence_loss + ) + + return 1 - c_div_loss(self.params_, X, y, self.generate_key()) + + +class SimpleEnergyModel(EnergyBasedModel): + """A simple energy-based generative model. + + Example: + -------- + model = SimpleEnergyModel() + + # Generate random 2D data of 0, 1 + X = np.random.randint(0, 2, size=(100, 2)) + + # Initialize and calculate the energy of the model with the given data + model.initialize(X) + print(model.energy(model.params_, X)) + + # Fit the model to the data + model.fit(X) + + # Generate 100 samples from the model + samples = model.sample(100) + + # Score the model on the generated data + print(model.score(X)) + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.params_ = None # Data-dependent attributes + + def initialize(self, x: any = None): + key = self.generate_key() + self.dim = x.shape[1] + initializer = jax.nn.initializers.he_uniform() + self.params_ = {"weights": initializer(key, (x.shape[1], 1), jnp.float32)} + + def energy(self, params, x): + # Define the energy function here as the dot product of the parameters. + return jnp.dot(x, params["weights"]) diff --git a/src/qml_benchmarks/models/energy_based_model.py b/src/qml_benchmarks/models/energy_based_model.py index 318f82b5..058ab35c 100644 --- a/src/qml_benchmarks/models/energy_based_model.py +++ b/src/qml_benchmarks/models/energy_based_model.py @@ -12,159 +12,108 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np -import jax -import jax.numpy as jnp -from qml_benchmarks.model_utils import train -import optax -import copy import flax.linen as nn +from qml_benchmarks.models.base import EnergyBasedModel, BaseGenerator +from sklearn.neural_network import BernoulliRBM +from joblib import Parallel, delayed +import numpy as np + class MLP(nn.Module): - "multilayer perceptron in flax" + "Multilayer perceptron." + # Create a MLP with hidden layers and neurons specfied as a list of integers. + hidden_layers: list[int] + @nn.compact def __call__(self, x): - x = nn.Dense(8)(x) - x = nn.tanh(x) - x = nn.Dense(4)(x) - x = nn.tanh(x) + for dim in self.hidden_layers: + x = nn.Dense(dim)(x) + x = nn.tanh(x) x = nn.Dense(1)(x) return x -class EnergyBasedModel(): + +class DeepEBM(EnergyBasedModel): """ - Energy-based model for generative learning. - The model takes as input energy model written as a flax neural network and uses k contrastive divergence - to fit the parameters. + Energy-based model with the energy function is a neural network. Args: - learning_rate (float): The learning rate for the CD-k updates - cdiv_steps (int): The number of sampling steps used in contrastive divergence - jit (bool): Whether to use just-in-time complilation - batch_size (int): Size of batches used for computing parameter updates - max_steps (int): Maximum number of training steps. - convergence_interval (int or None): The number of loss values to consider to decide convergence. - If None, training runs until the maximum number of steps. - random_state (int): Seed used for pseudorandom number generation. + hidden_layers (list[int]): + The number of hidden layers and neurons in the MLP layers. """ - def __init__(self, energy_model=MLP, learning_rate=0.001, cdiv_steps=1, jit=True, batch_size=32, - max_steps=200, convergence_interval=200, random_state=42): - self.energy_model = energy_model() - self.learning_rate = learning_rate - self.random_state = random_state - self.rng = np.random.default_rng(random_state) - self.jit = jit - self.batch_size = batch_size - self.max_steps = max_steps - self.convergence_interval = convergence_interval - self.cdiv_steps = cdiv_steps - self.vmap = True - self.max_vmap = None - - # data depended attributes - self.params_ = None - self.n_visible_ = None - - self.mcmc_step = jax.jit(self.mcmc_step) if jit else self.mcmc_step - - def generate_key(self): - return jax.random.PRNGKey(self.rng.integers(1000000)) - - def energy(self, params, x): - """ - The energy function for the model for a given configuration x. + def __init__(self, hidden_layers=[8, 4], **base_kwargs): + super().__init__(**base_kwargs) + self.hidden_layers = hidden_layers + self.model = MLP(hidden_layers=hidden_layers) - Args: - x: The configuration to calculate the energy for. - Returns: - energy (float): The energy. - """ - return self.energy_model.apply(params, x) + def initialize(self, x): + dim = x.shape[1] + if not isinstance(dim, int): + raise NotImplementedError( + "The model is not yet implemented for data" + "with arbitrary dimensions. `dim` must be an integer." + ) - def initialize(self, n_features): - self.n_visible_ = n_features - x = jax.random.normal(self.generate_key(), shape=(1, n_features)) - self.params_ = self.energy_model.init(self.generate_key(), x) + self.dim = dim + self.params_ = self.model.init(self.generate_key(), x) - def mcmc_step(self, args, i): - """ - Perform one metropolis hastings steps. - The format is such that it can be used with jax.lax.scan for fast compilation. - """ - params = args[0] - key = args[1] - x = args[2] - key1, key2 = jax.random.split(key, 2) - flip_idx = jax.random.choice(key1, jnp.arange(self.n_visible_)) - flip_config = jnp.zeros(self.n_visible_, dtype=int) - flip_config = flip_config.at[flip_idx].set(1) - x_flip = jnp.array((x + flip_config) % 2) - en = self.energy(params, jnp.expand_dims(x, 0))[0] - en_flip = self.energy(params, jnp.expand_dims(x_flip, 0))[0] - accept_ratio = jnp.exp(-en_flip) / jnp.exp(-en) - accept = jnp.array(jax.random.bernoulli(key2, accept_ratio), dtype=int)[0] - x_new = accept * x_flip + (1 - accept) * x - return [params, key2, x_new], x - - def mcmc_sample(self, params, x_init, n_samples, key): - """ - Sample a chain of configurations from a starting configuration x_init + def energy(self, params, x): + return self.model.apply(params, x) + + +class RBM(BernoulliRBM, BaseGenerator): + def __init__( + self, + n_components=256, + learning_rate=0.1, + batch_size=10, + n_iter=10, + verbose=0, + random_state=None, + ): + super().__init__( + n_components=n_components, + learning_rate=learning_rate, + batch_size=batch_size, + n_iter=n_iter, + verbose=verbose, + random_state=random_state, + ) + + def initialize(self, x: any = None): + self.fit(x[:1, ...]) + if len(x.shape) > 2: + raise ValueError("Input data must be 2D") + self.dim = x.shape[1] + + # Gibbs sampling: + def _sample(self, num_steps=1000): """ - carry = [params, key, x_init] - carry, configs = jax.lax.scan(self.mcmc_step, carry, jnp.arange(n_samples)) - return configs - - def langevin_sample(self, params, x_init, n_samples, key): - pass + Sample the model for given number of steps. - def sample(self, n_samples): - """ - sample configurations starting from a random configuration. - """ - key = self.generate_key() - x_init = jnp.array(jax.random.bernoulli(key, p=0.5, shape=(self.n_visible_,)), dtype=int) - samples = self.mcmc_sample(self.params_, x_init, n_samples, self.generate_key()) - return jnp.array(samples) + Args: + num_steps (int): Number of Gibbs sample steps - def fit(self, X): - """ - Fit the parameters using contrastive divergence + Returns: + np.array: The samples at the given temperature. """ - self.initialize(X.shape[1]) - X = jnp.array(X, dtype=int) - - # batch the relevant functions - batched_mcmc_sample = jax.vmap(self.mcmc_sample, in_axes=(None, 0, None, 0)) - - def c_div_loss(params, X, y, key): - """ - contrastive divergence loss - Args: - params (dict): parameter dictionary - X (array): batch of training examples - y (array): not used; should be set to None when training - key: jax PRNG key - """ - keys = jax.random.split(key, X.shape[0]) - - # we do not take the gradient wrt the sampling, so decouple the param dict here - params_copy = copy.deepcopy(params) - for key in params_copy.keys(): - params_copy[key] = jax.lax.stop_gradient(params_copy[key]) - - configs = batched_mcmc_sample(params_copy, X, self.cdiv_steps + 1, keys) - x0 = configs[:, 0] - x1 = configs[:, -1] - - # taking the gradient of this loss is equivalent to the CD-k update - loss = self.energy(params, x0) - self.energy(params, x1) - - return jnp.mean(loss) - - c_div_loss = jax.jit(c_div_loss) if self.jit else c_div_loss - - self.params_ = train(self, c_div_loss, optax.adam, X, None, self.generate_key, - convergence_interval=self.convergence_interval) - - + if self.dim is None: + raise ValueError("Model must be initialized before sampling") + v = np.random.choice( + [0, 1], size=(self.dim,) + ) # Assuming `N` is `self.n_components` + for _ in range(num_steps): + v = self.gibbs(v) # Assuming `gibbs` is an instance method + return v + + def sample(self, num_samples: int, num_steps: int = 1000) -> np.ndarray: + # Parallelize the sampling process + samples_t = Parallel(n_jobs=-1)( + delayed(self._sample)(num_steps=num_steps) for _ in range(num_samples) + ) + samples_t = np.array(samples_t) + return samples_t + + def score(self, X: np.ndarray, y: np.ndarray) -> float: + return np.mean(super().score_samples(X)) From 10e158617fa93f911f7f7f7e2c2e812ba938d021 Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Tue, 20 Aug 2024 14:12:09 +0100 Subject: [PATCH 21/54] gen hp search --- README.md | 28 ++++-- scripts/run_hyperparameter_search.py | 95 +++++++++++-------- src/qml_benchmarks/hyperparameter_settings.py | 2 +- src/qml_benchmarks/models/__init__.py | 2 +- .../models/energy_based_model.py | 2 +- .../models/restricted_boltzmann_machine.py | 2 +- 6 files changed, 79 insertions(+), 52 deletions(-) diff --git a/README.md b/README.md index a66f780e..2da78ae6 100644 --- a/README.md +++ b/README.md @@ -39,12 +39,12 @@ Dependencies of this package can be installed in your environment by running pip install -r requirements.txt ``` -## Adding a custom model +## Adding a custom classifier We use the [Scikit-learn API](https://scikit-learn.org/stable/developers/develop.html) to create models and perform hyperparameter search. -A minimal template for a new quantum model is as follows, and can be stored +A minimal template for a new quantum classifier is as follows, and can be stored in `qml_benchmarks/models/my_model.py`: ```python @@ -146,9 +146,20 @@ model.fit(X_train, y_train) print(model.score(X_test, y_test)) ``` +## Adding a custom generative model + +TO DO + +mention: +- need score such that greater is better +- data should be 0/1 valued +- inheritance etc + ## Datasets -The `qml_benchmarks.data` module provides generating functions to create datasets for binary classification. +The `qml_benchmarks.data` module provides generating functions to create datasets for binary classification and +generative learning. + A generating function can be used like this: ```python @@ -158,7 +169,7 @@ X, y = generate_two_curves(n_samples=200, n_features=4, degree=3, noise=0.1, off ``` Note that some datasets might have different return data structures, for example if the train/test split -is performed by the generating function. +is performed by the generating function. If the dataset does not include labels, `y = None` is returned. The original datasets used in the paper can be generated by running the scripts in the `paper/benchmarks` folder, such as: @@ -176,11 +187,12 @@ generate results for a hyperparameter search for any model and dataset. The scri can be run as ``` -python run_hyperparameter_search.py --classifier-name "DataReuploadingClassifier" --dataset-path "my_dataset.csv" +python run_hyperparameter_search.py --model "DataReuploadingClassifier" --dataset-path "my_dataset.csv" ``` -where `my_dataset.csv` is a CSV file containing the training data such that each column is a feature -and the last column is the target. +where`my_dataset.csv` is a CSV file containing the training data. For classification problems, each column should +correspond to an feature and the last column to the target. For generative learning, each row +should correspond to a binary string that specifies a unique datapoint. Unless otherwise specified, the hyperparameter grid is loaded from `qml_benchmarks/hyperparameter_settings.py`. One can override the default grid of hyperparameters by specifying the hyperparameter list, @@ -189,7 +201,7 @@ For example, for the `DataReuploadingClassifier` we can run: ``` python run_hyperparameter_search.py \ - --classifier-name DataReuploadingClassifier \ + --model DataReuploadingClassifier \ --dataset-path "my_dataset.csv" \ --n_layers 1 2 \ --observable_type "single" "full"\ diff --git a/scripts/run_hyperparameter_search.py b/scripts/run_hyperparameter_search.py index fdd64cc6..e356df9c 100644 --- a/scripts/run_hyperparameter_search.py +++ b/scripts/run_hyperparameter_search.py @@ -20,28 +20,33 @@ import time import argparse import logging + logging.getLogger().setLevel(logging.INFO) from importlib import import_module import pandas as pd from pathlib import Path import matplotlib.pyplot as plt from sklearn.model_selection import GridSearchCV +from sklearn.metrics import make_scorer +from qml_benchmarks.models.base import BaseGenerator from qml_benchmarks.hyperparam_search_utils import read_data, construct_hyperparameter_grid from qml_benchmarks.hyperparameter_settings import hyper_parameter_settings np.random.seed(42) -logging.info('cpu count:' + str(os.cpu_count())) +def custom_scorer(estimator, X, y=None): + return estimator.score(X, y) +logging.info('cpu count:' + str(os.cpu_count())) if __name__ == "__main__": # Create an argument parser parser = argparse.ArgumentParser(description="Run experiments with hyperparameter search.", - formatter_class=argparse.ArgumentDefaultsHelpFormatter) + formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( - "--classifier-name", - help="Classifier to run", + "--model", + help="Model to run", ) parser.add_argument( @@ -91,27 +96,28 @@ # Parse the arguments along with any extra arguments that might be model specific args, unknown_args = parser.parse_known_args() - if any(arg is None for arg in [args.classifier_name, + if any(arg is None for arg in [args.model, args.dataset_path]): msg = "\n================================================================================" - msg += "\nA classifier from qml.benchmarks.model and dataset path are required. E.g., \n \n" - msg += "python run_hyperparameter_search \ \n--classifier DataReuploadingClassifier \ \n--dataset-path train.csv\n" + msg += "\nA model from qml.benchmarks.models and dataset path are required. E.g., \n \n" + msg += "python run_hyperparameter_search \ \n--model DataReuploadingClassifier \ \n--dataset-path train.csv\n" msg += "\nCheck all arguments for the script with \n" msg += "python run_hyperparameter_search --help\n" msg += "================================================================================" raise ValueError(msg) - + # Add model specific arguments to override the default hyperparameter grid hyperparam_grid = construct_hyperparameter_grid( - hyper_parameter_settings, args.classifier_name + hyper_parameter_settings, args.model ) + for hyperparam in hyperparam_grid: hp_type = type(hyperparam_grid[hyperparam][0]) parser.add_argument(f'--{hyperparam}', type=hp_type, nargs="+", default=hyperparam_grid[hyperparam], - help=f'{hyperparam} grid values for {args.classifier_name}') + help=f'{hyperparam} grid values for {args.model}') args = parser.parse_args(unknown_args, namespace=args) @@ -122,11 +128,12 @@ logging.info( "Running hyperparameter search experiment with the following settings\n" ) - logging.info(args.classifier_name) + logging.info(args.model) logging.info(args.dataset_path) logging.info(" ".join(args.hyperparameter_scoring)) logging.info(args.hyperparameter_refit) - logging.info("Hyperparam grid:"+" ".join([(str(key)+str(":")+str(hyperparam_grid[key])) for key in hyperparam_grid.keys()])) + logging.info("Hyperparam grid:" + " ".join( + [(str(key) + str(":") + str(hyperparam_grid[key])) for key in hyperparam_grid.keys()])) experiment_path = args.results_path results_path = os.path.join(experiment_path, "results") @@ -135,22 +142,26 @@ os.makedirs(results_path) ################################################################### - # Get the classifier, dataset and search methods from the arguments + # Get the model, dataset and search methods from the arguments ################################################################### - Classifier = getattr( + Model = getattr( import_module("qml_benchmarks.models"), - args.classifier_name + args.model ) - classifier_name = Classifier.__name__ + model_name = Model.__name__ + + is_generative = isinstance(Model(), BaseGenerator) # Run the experiments save the results train_dataset_filename = os.path.join(args.dataset_path) - X, y = read_data(train_dataset_filename) + X, y = read_data(train_dataset_filename, labels=not is_generative) + + X = (X+1)//2 dataset_path_obj = Path(args.dataset_path) results_filename_stem = " ".join( - [Classifier.__name__ + "_" + dataset_path_obj.stem - + "_GridSearchCV"]) + [Model.__name__ + "_" + dataset_path_obj.stem + + "_GridSearchCV"]) # If we have already run this experiment then continue if os.path.isfile(os.path.join(results_path, results_filename_stem + ".csv")): @@ -162,44 +173,48 @@ logging.warning(msg) sys.exit(msg) else: - logging.warning("Cleaning existing results for ", os.path.join(results_path, results_filename_stem + ".csv")) - + logging.warning("Cleaning existing results for ", + os.path.join(results_path, results_filename_stem + ".csv")) ########################################################################### # Single fit to check everything works ########################################################################### - classifier = Classifier() + model = Model() a = time.time() - classifier.fit(X, y) + model.fit(X, y) b = time.time() - acc_train = classifier.score(X, y) + default_score = model.score(X, y) logging.info(" ".join( - [classifier_name, - "Dataset path", - args.dataset_path, - "Train acc:", - str(acc_train), - "Time single run", - str(b - a)]) + [model_name, + "Dataset path", + args.dataset_path, + "Train score:", + str(default_score), + "Time single run", + str(b - a)]) ) - if hasattr(classifier, "loss_history_"): + if hasattr(model, "loss_history_"): if args.plot_loss: - plt.plot(classifier.loss_history_) + plt.plot(model.loss_history_) plt.xlabel("Iterations") plt.ylabel("Loss") plt.show() - if hasattr(classifier, "n_qubits_"): - logging.info(" ".join(["Num qubits", f"{classifier.n_qubits_}"])) + if hasattr(model, "n_qubits_"): + logging.info(" ".join(["Num qubits", f"{model.n_qubits_}"])) ########################################################################### # Hyperparameter search ########################################################################### - gs = GridSearchCV(estimator=classifier, param_grid=hyperparam_grid, - scoring=args.hyperparameter_scoring, - refit=args.hyperparameter_refit, - verbose=3, - n_jobs=-1).fit( + + scorer = args.hyperparameter_scoring if not is_generative else custom_scorer + refit = args.hyperparameter_refit if not is_generative else False + + gs = GridSearchCV(estimator=model, param_grid=hyperparam_grid, + scoring=scorer, + refit=refit, + verbose=3, + n_jobs=args.n_jobs).fit( X, y ) logging.info("Best hyperparams") diff --git a/src/qml_benchmarks/hyperparameter_settings.py b/src/qml_benchmarks/hyperparameter_settings.py index 9fb8dce0..10cd26b4 100644 --- a/src/qml_benchmarks/hyperparameter_settings.py +++ b/src/qml_benchmarks/hyperparameter_settings.py @@ -191,7 +191,7 @@ "alpha": {"type": "list", "dtype": "float", "val": [0.01, 0.001, 0.0001]}, }, "Perceptron": {"eta0": {"type": "list", "dtype": "float", "val": [0.1, 1, 10]}}, - "RBM": {"learning_rate": {"type": "list", "dtype": "float", "val": [0.01, 0.1, 1.]}}, + "RestrictedBoltzmannMachine": {"learning_rate": {"type": "list", "dtype": "float", "val": [0.01, 0.1, 1.]}}, "DeepEBM": {"hidden_layers": {"type": "list", "dtype": "tuple", "val": ["(100,)", "(10, 10, 10, 10)", "(50, 10, 5)"], diff --git a/src/qml_benchmarks/models/__init__.py b/src/qml_benchmarks/models/__init__.py index 4e225db8..f6374d2e 100644 --- a/src/qml_benchmarks/models/__init__.py +++ b/src/qml_benchmarks/models/__init__.py @@ -46,7 +46,7 @@ SeparableKernelClassifier, SeparableVariationalClassifier, ) -from qml_benchmarks.models.energy_based_model import DeepEBM, RBM +from qml_benchmarks.models.energy_based_model import DeepEBM, RestrictedBoltzmannMachine from qml_benchmarks.models.tree_tensor import TreeTensorClassifier from qml_benchmarks.models.vanilla_qnn import VanillaQNN from qml_benchmarks.models.weinet import WeiNet diff --git a/src/qml_benchmarks/models/energy_based_model.py b/src/qml_benchmarks/models/energy_based_model.py index 058ab35c..a7e63604 100644 --- a/src/qml_benchmarks/models/energy_based_model.py +++ b/src/qml_benchmarks/models/energy_based_model.py @@ -62,7 +62,7 @@ def energy(self, params, x): return self.model.apply(params, x) -class RBM(BernoulliRBM, BaseGenerator): +class RestrictedBoltzmannMachine(BernoulliRBM, BaseGenerator): def __init__( self, n_components=256, diff --git a/src/qml_benchmarks/models/restricted_boltzmann_machine.py b/src/qml_benchmarks/models/restricted_boltzmann_machine.py index bbd5398a..6e12c08a 100644 --- a/src/qml_benchmarks/models/restricted_boltzmann_machine.py +++ b/src/qml_benchmarks/models/restricted_boltzmann_machine.py @@ -5,7 +5,7 @@ import optax import copy -class RestrictedBoltzmannMachine(): +class RestrictedBoltzmannMachineOld(): """ A restricted Boltzmann machine generative model. The model is trained with the k-contrastive divergence (CD-k) algorithm. From 63d00cd52cbffaf74f140bc9d663123684ca63ac Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Tue, 20 Aug 2024 14:19:53 +0100 Subject: [PATCH 22/54] remove typo --- scripts/run_hyperparameter_search.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/scripts/run_hyperparameter_search.py b/scripts/run_hyperparameter_search.py index e356df9c..1278fd0d 100644 --- a/scripts/run_hyperparameter_search.py +++ b/scripts/run_hyperparameter_search.py @@ -156,8 +156,6 @@ def custom_scorer(estimator, X, y=None): train_dataset_filename = os.path.join(args.dataset_path) X, y = read_data(train_dataset_filename, labels=not is_generative) - X = (X+1)//2 - dataset_path_obj = Path(args.dataset_path) results_filename_stem = " ".join( [Model.__name__ + "_" + dataset_path_obj.stem From f467478a82d6b3cfb1a8a2428e602c89f397eee7 Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Tue, 20 Aug 2024 15:48:02 +0100 Subject: [PATCH 23/54] update scores --- src/qml_benchmarks/model_utils.py | 72 +++++++++++++++++++ .../models/energy_based_model.py | 18 ++++- 2 files changed, 87 insertions(+), 3 deletions(-) diff --git a/src/qml_benchmarks/model_utils.py b/src/qml_benchmarks/model_utils.py index 29e9b4a4..322d642f 100644 --- a/src/qml_benchmarks/model_utils.py +++ b/src/qml_benchmarks/model_utils.py @@ -311,3 +311,75 @@ def chunked_loss(params, X, y, key): return jnp.mean(res) return chunked_loss + +def mmd_loss(ground_truth: np.ndarray, model_samples: np.ndarray, sigma: float) -> float: + """Calculates an unbiased estimate of the Maximum Mean Discrepancy (MMD) loss from samples + see https://jmlr.org/papers/volume13/gretton12a/gretton12a.pdf for more info + + Args: + ground_truth (np.ndarray): Samples from the ground truth distribution. + model_samples (np.ndarray): Samples from the test model. + sigma (float): Sigma parameter, the width of the kernel. + + Returns: + float: The value of the MMD loss. + """ + + n = len(ground_truth) + m = len(model_samples) + ground_truth = jnp.array(ground_truth) + model_samples = jnp.array(model_samples) + + # K_pp + K_pp = jnp.zeros((ground_truth.shape[0], ground_truth.shape[0])) + def body_fun(i, val): + def inner_body_fun(j, inner_val): + return inner_val.at[i, j].set(gaussian_kernel(sigma, ground_truth[i], ground_truth[j])) + return jax.lax.fori_loop(0, ground_truth.shape[0], inner_body_fun, val) + K_pp = jax.lax.fori_loop(0, ground_truth.shape[0], body_fun, K_pp) + sum_pp = jnp.sum(K_pp) - n + + # K_pq + K_pq = jnp.zeros((ground_truth.shape[0], model_samples.shape[0])) + def body_fun(i, val): + def inner_body_fun(j, inner_val): + return inner_val.at[i, j].set(gaussian_kernel(sigma, ground_truth[i], model_samples[j])) + return jax.lax.fori_loop(0, model_samples.shape[0], inner_body_fun, val) + K_pq = jax.lax.fori_loop(0, ground_truth.shape[0], body_fun, K_pq) + sum_pq = jnp.sum(K_pq) + + # K_qq + K_qq = jnp.zeros((model_samples.shape[0], model_samples.shape[0])) + def body_fun(i, val): + def inner_body_fun(j, inner_val): + return inner_val.at[i, j].set(gaussian_kernel(sigma, model_samples[i], model_samples[j])) + return jax.lax.fori_loop(0, model_samples.shape[0], inner_body_fun, val) + K_qq = jax.lax.fori_loop(0, model_samples.shape[0], body_fun, K_qq) + sum_qq = jnp.sum(K_qq) - m + + return 1/n/(n-1) * sum_pp - 2/n/m * sum_pq + 1/m/(m-1) * sum_qq + +def gaussian_kernel(sigma: float, x: np.ndarray, y: np.ndarray) -> float: + """Calculates the value for the gaussian kernel between two vectors x, y + + Args: + sigma (float): sigma parameter, the width of the kernel + x (np.ndarray): one of the vectors + y (np.ndarray): the other vector + + Returns: + float: Result value of the gaussian kernel + """ + return jnp.exp(-((x-y)**2).sum()/2/sigma) + +def median_heuristic(X): + """ + Computes an estimate of the median heuristic used to decide the bandwidth of the RBF kernels; see + https://arxiv.org/abs/1707.07269 + :param X (array): Dataset of interest + :return (float): median heuristic estimate + """ + m = len(X) + X = np.array(X) + med = np.median([np.sqrt(np.sum((X[i] - X[j]) ** 2)) for i in range(m) for j in range(m)]) + return med \ No newline at end of file diff --git a/src/qml_benchmarks/models/energy_based_model.py b/src/qml_benchmarks/models/energy_based_model.py index a7e63604..bba36e11 100644 --- a/src/qml_benchmarks/models/energy_based_model.py +++ b/src/qml_benchmarks/models/energy_based_model.py @@ -16,6 +16,7 @@ from qml_benchmarks.models.base import EnergyBasedModel, BaseGenerator from sklearn.neural_network import BernoulliRBM from joblib import Parallel, delayed +from qml_benchmarks.model_utils import mmd_loss, median_heuristic import numpy as np @@ -42,10 +43,11 @@ class DeepEBM(EnergyBasedModel): The number of hidden layers and neurons in the MLP layers. """ - def __init__(self, hidden_layers=[8, 4], **base_kwargs): + def __init__(self, hidden_layers=[8, 4], mmd_kwargs = {'n_samples': 1000, 'sigma': 1.0}, **base_kwargs): super().__init__(**base_kwargs) self.hidden_layers = hidden_layers self.model = MLP(hidden_layers=hidden_layers) + self.mmd_kwargs = mmd_kwargs def initialize(self, x): dim = x.shape[1] @@ -61,6 +63,9 @@ def initialize(self, x): def energy(self, params, x): return self.model.apply(params, x) + def score(self, X: np.ndarray, y: np.ndarray) -> float: + return float(-mmd_loss(X, self.sample(self.mmd_kwargs['n_samples']), self.mmd_kwargs['sigma'])) + class RestrictedBoltzmannMachine(BernoulliRBM, BaseGenerator): def __init__( @@ -71,6 +76,8 @@ def __init__( n_iter=10, verbose=0, random_state=None, + score_fn='pseudolikelihood', + mmd_kwargs ={'n_samples': 1000, 'sigma': 1.0} ): super().__init__( n_components=n_components, @@ -80,6 +87,8 @@ def __init__( verbose=verbose, random_state=random_state, ) + self.score_fn = score_fn + self.mmd_kwargs = mmd_kwargs def initialize(self, x: any = None): self.fit(x[:1, ...]) @@ -112,8 +121,11 @@ def sample(self, num_samples: int, num_steps: int = 1000) -> np.ndarray: samples_t = Parallel(n_jobs=-1)( delayed(self._sample)(num_steps=num_steps) for _ in range(num_samples) ) - samples_t = np.array(samples_t) + samples_t = np.array(samples_t, dtype=int) return samples_t def score(self, X: np.ndarray, y: np.ndarray) -> float: - return np.mean(super().score_samples(X)) + if self.score_fn == 'pseudolikelihood': + return float(np.mean(super().score_samples(X))) + elif self.score_fn == 'mmd': + return float(-mmd_loss(X, self.sample(self.mmd_kwargs['n_samples']), self.mmd_kwargs['sigma'])) \ No newline at end of file From 461f9ce374ffce5760a97998b151a9e5a508c6ec Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Fri, 23 Aug 2024 11:47:41 +0200 Subject: [PATCH 24/54] updates --- README.md | 2 +- scripts/run_hyperparameter_search.py | 3 ++- scripts/score_with_best_hyperparameters.py | 6 +++++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 2da78ae6..b18e4ddf 100644 --- a/README.md +++ b/README.md @@ -192,7 +192,7 @@ python run_hyperparameter_search.py --model "DataReuploadingClassifier" --datase where`my_dataset.csv` is a CSV file containing the training data. For classification problems, each column should correspond to an feature and the last column to the target. For generative learning, each row -should correspond to a binary string that specifies a unique datapoint. +should correspond to a binary string that specifies a unique data sample. Unless otherwise specified, the hyperparameter grid is loaded from `qml_benchmarks/hyperparameter_settings.py`. One can override the default grid of hyperparameters by specifying the hyperparameter list, diff --git a/scripts/run_hyperparameter_search.py b/scripts/run_hyperparameter_search.py index 1278fd0d..035913cd 100644 --- a/scripts/run_hyperparameter_search.py +++ b/scripts/run_hyperparameter_search.py @@ -151,10 +151,11 @@ def custom_scorer(estimator, X, y=None): model_name = Model.__name__ is_generative = isinstance(Model(), BaseGenerator) + use_labels = False if is_generative else True # Run the experiments save the results train_dataset_filename = os.path.join(args.dataset_path) - X, y = read_data(train_dataset_filename, labels=not is_generative) + X, y = read_data(train_dataset_filename, labels=use_labels) dataset_path_obj = Path(args.dataset_path) results_filename_stem = " ".join( diff --git a/scripts/score_with_best_hyperparameters.py b/scripts/score_with_best_hyperparameters.py index 47cc8e08..281424bf 100644 --- a/scripts/score_with_best_hyperparameters.py +++ b/scripts/score_with_best_hyperparameters.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Score a model using the best hyperparameters, using a command-line script.""" +""" +Score a model using the best hyperparameters, using a command-line script. +Note this is only compatible with classifier models. +""" + import numpy as np import sys From ca9071bcc1de3e133ce60c42d07fa9c9b8ce607e Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Fri, 23 Aug 2024 17:40:17 +0200 Subject: [PATCH 25/54] fix init --- .../models/energy_based_model.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/qml_benchmarks/models/energy_based_model.py b/src/qml_benchmarks/models/energy_based_model.py index bba36e11..f68fe33d 100644 --- a/src/qml_benchmarks/models/energy_based_model.py +++ b/src/qml_benchmarks/models/energy_based_model.py @@ -76,7 +76,7 @@ def __init__( n_iter=10, verbose=0, random_state=None, - score_fn='pseudolikelihood', + score_fn='mmd', mmd_kwargs ={'n_samples': 1000, 'sigma': 1.0} ): super().__init__( @@ -90,11 +90,14 @@ def __init__( self.score_fn = score_fn self.mmd_kwargs = mmd_kwargs - def initialize(self, x: any = None): - self.fit(x[:1, ...]) - if len(x.shape) > 2: + def initialize(self, X: any = None): + if len(X.shape) > 2: raise ValueError("Input data must be 2D") - self.dim = x.shape[1] + self.dim = X.shape[1] + + def fit(self, X, y=None): + self.initialize(X) + super().fit(X, y) # Gibbs sampling: def _sample(self, num_steps=1000): @@ -128,4 +131,7 @@ def score(self, X: np.ndarray, y: np.ndarray) -> float: if self.score_fn == 'pseudolikelihood': return float(np.mean(super().score_samples(X))) elif self.score_fn == 'mmd': - return float(-mmd_loss(X, self.sample(self.mmd_kwargs['n_samples']), self.mmd_kwargs['sigma'])) \ No newline at end of file + sigma = self.mmd_kwargs['sigma'] + sigmas = [sigma] if isinstance(sigma, (int, float)) else sigma + score = np.mean([mmd_loss(X, self.sample(self.mmd_kwargs['n_samples']), sigma) for sigma in sigmas]) + return float(-score) \ No newline at end of file From 16615973ec587e54f71778e64331b88e3081fb0d Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Fri, 23 Aug 2024 17:40:37 +0200 Subject: [PATCH 26/54] rename RBM --- src/qml_benchmarks/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qml_benchmarks/models/__init__.py b/src/qml_benchmarks/models/__init__.py index f6374d2e..28f57d05 100644 --- a/src/qml_benchmarks/models/__init__.py +++ b/src/qml_benchmarks/models/__init__.py @@ -81,7 +81,7 @@ "MLPClassifier", "SVC", "DeepEBM", - "RBM", + "RestrictedBoltzmannMachine", ] From b14f7daf2e6c4e6708b03f1407487e112bc96596 Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Fri, 23 Aug 2024 17:46:21 +0200 Subject: [PATCH 27/54] fix init --- src/qml_benchmarks/models/energy_based_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qml_benchmarks/models/energy_based_model.py b/src/qml_benchmarks/models/energy_based_model.py index f68fe33d..ec8c6ff4 100644 --- a/src/qml_benchmarks/models/energy_based_model.py +++ b/src/qml_benchmarks/models/energy_based_model.py @@ -76,7 +76,7 @@ def __init__( n_iter=10, verbose=0, random_state=None, - score_fn='mmd', + score_fn='pseudolikelihood', mmd_kwargs ={'n_samples': 1000, 'sigma': 1.0} ): super().__init__( From 79b5cdbce2e230540eb85704434789957760f004 Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Fri, 23 Aug 2024 17:52:26 +0200 Subject: [PATCH 28/54] fix init --- src/qml_benchmarks/models/energy_based_model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/qml_benchmarks/models/energy_based_model.py b/src/qml_benchmarks/models/energy_based_model.py index ec8c6ff4..b52edc8a 100644 --- a/src/qml_benchmarks/models/energy_based_model.py +++ b/src/qml_benchmarks/models/energy_based_model.py @@ -77,7 +77,7 @@ def __init__( verbose=0, random_state=None, score_fn='pseudolikelihood', - mmd_kwargs ={'n_samples': 1000, 'sigma': 1.0} + mmd_kwargs ={'n_samples': 1000, 'n_steps': 1000, 'sigma': 1.0} ): super().__init__( n_components=n_components, @@ -133,5 +133,6 @@ def score(self, X: np.ndarray, y: np.ndarray) -> float: elif self.score_fn == 'mmd': sigma = self.mmd_kwargs['sigma'] sigmas = [sigma] if isinstance(sigma, (int, float)) else sigma - score = np.mean([mmd_loss(X, self.sample(self.mmd_kwargs['n_samples']), sigma) for sigma in sigmas]) + score = np.mean([mmd_loss(X, self.sample(self.mmd_kwargs['n_samples'], + self.mmd_kwargs['n_steps']), sigma) for sigma in sigmas]) return float(-score) \ No newline at end of file From 635f04b4f5d3183269ef46fa5feb79a2796e2f90 Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Fri, 23 Aug 2024 18:10:07 +0200 Subject: [PATCH 29/54] update mmd score --- src/qml_benchmarks/models/energy_based_model.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/qml_benchmarks/models/energy_based_model.py b/src/qml_benchmarks/models/energy_based_model.py index b52edc8a..5e94e627 100644 --- a/src/qml_benchmarks/models/energy_based_model.py +++ b/src/qml_benchmarks/models/energy_based_model.py @@ -43,7 +43,9 @@ class DeepEBM(EnergyBasedModel): The number of hidden layers and neurons in the MLP layers. """ - def __init__(self, hidden_layers=[8, 4], mmd_kwargs = {'n_samples': 1000, 'sigma': 1.0}, **base_kwargs): + def __init__(self, hidden_layers=[8, 4], + mmd_kwargs = {'n_samples': 1000, 'n_steps':1000, 'sigma': 1.0}, + **base_kwargs): super().__init__(**base_kwargs) self.hidden_layers = hidden_layers self.model = MLP(hidden_layers=hidden_layers) @@ -64,7 +66,11 @@ def energy(self, params, x): return self.model.apply(params, x) def score(self, X: np.ndarray, y: np.ndarray) -> float: - return float(-mmd_loss(X, self.sample(self.mmd_kwargs['n_samples']), self.mmd_kwargs['sigma'])) + sigma = self.mmd_kwargs['sigma'] + sigmas = [sigma] if isinstance(sigma, (int, float)) else sigma + score = np.mean([mmd_loss(X, self.sample(self.mmd_kwargs['n_samples'], + self.mmd_kwargs['n_steps']), sigma) for sigma in sigmas]) + return float(-score) class RestrictedBoltzmannMachine(BernoulliRBM, BaseGenerator): From 21f866af30d9f2db157ac9fbbab0a943445b648f Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Tue, 27 Aug 2024 11:54:40 +0200 Subject: [PATCH 30/54] explicit kwargs --- .../models/energy_based_model.py | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/qml_benchmarks/models/energy_based_model.py b/src/qml_benchmarks/models/energy_based_model.py index 5e94e627..93f204b7 100644 --- a/src/qml_benchmarks/models/energy_based_model.py +++ b/src/qml_benchmarks/models/energy_based_model.py @@ -43,10 +43,26 @@ class DeepEBM(EnergyBasedModel): The number of hidden layers and neurons in the MLP layers. """ - def __init__(self, hidden_layers=[8, 4], - mmd_kwargs = {'n_samples': 1000, 'n_steps':1000, 'sigma': 1.0}, - **base_kwargs): - super().__init__(**base_kwargs) + def __init__(self, + learning_rate=0.001, + batch_size=32, + max_steps=10000, + cdiv_steps=100, + convergence_interval=None, + random_state=42, + jit=True, + hidden_layers=[8, 4], + mmd_kwargs = {'n_samples': 1000, 'n_steps':1000, 'sigma': 1.0}): + super().__init__( + dim=None, + learning_rate=learning_rate, + batch_size=batch_size, + max_steps=max_steps, + cdiv_steps=cdiv_steps, + convergence_interval=convergence_interval, + random_state=random_state, + jit=jit + ) self.hidden_layers = hidden_layers self.model = MLP(hidden_layers=hidden_layers) self.mmd_kwargs = mmd_kwargs From 045382f745846676f7273b0af3dfc2c0266a0c5a Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Wed, 28 Aug 2024 11:58:14 +0200 Subject: [PATCH 31/54] fix int error --- src/qml_benchmarks/models/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qml_benchmarks/models/base.py b/src/qml_benchmarks/models/base.py index 8fd3f33c..2440c9b7 100644 --- a/src/qml_benchmarks/models/base.py +++ b/src/qml_benchmarks/models/base.py @@ -148,7 +148,7 @@ def mcmc_step(self, args, i): params, key, x = args key1, key2 = jax.random.split(key, 2) flip_idx = jax.random.choice(key1, jnp.arange(self.dim)) - flip_config = jnp.zeros(self.dim, dtype=int) + flip_config = jnp.zeros(self.dim, dtype='int32') flip_config = flip_config.at[flip_idx].set(1) x_flip = jnp.array((x + flip_config) % 2) From cc88641405a824b5cf7c144dd9d6f545e025de8f Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Wed, 28 Aug 2024 14:28:49 +0200 Subject: [PATCH 32/54] num_mcmc_steps -> num_steps --- src/qml_benchmarks/models/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/qml_benchmarks/models/base.py b/src/qml_benchmarks/models/base.py index 2440c9b7..3330ca9e 100644 --- a/src/qml_benchmarks/models/base.py +++ b/src/qml_benchmarks/models/base.py @@ -171,7 +171,7 @@ def mcmc_sample(self, params, x_init, num_mcmc_steps, key): def langevin_sample(self, params, x_init, n_samples, key): pass - def sample(self, num_samples, num_mcmc_steps=1000): + def sample(self, num_samples, num_steps=1000): """ sample configurations starting from a random configuration. """ @@ -188,7 +188,7 @@ def sample(self, num_samples, num_mcmc_steps=1000): ), dtype=int, ) - configs = self.batched_mcmc_sample(self.params_, x_init, num_mcmc_steps, keys) + configs = self.batched_mcmc_sample(self.params_, x_init, num_steps, keys) x1 = configs[:, -1] return x1 From 9e074c571ad753688b1d38fad9934be430be6079 Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Mon, 30 Sep 2024 16:14:11 +0200 Subject: [PATCH 33/54] fix cv step --- src/qml_benchmarks/models/energy_based_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qml_benchmarks/models/energy_based_model.py b/src/qml_benchmarks/models/energy_based_model.py index 93f204b7..610879ad 100644 --- a/src/qml_benchmarks/models/energy_based_model.py +++ b/src/qml_benchmarks/models/energy_based_model.py @@ -64,7 +64,6 @@ def __init__(self, jit=jit ) self.hidden_layers = hidden_layers - self.model = MLP(hidden_layers=hidden_layers) self.mmd_kwargs = mmd_kwargs def initialize(self, x): @@ -76,6 +75,7 @@ def initialize(self, x): ) self.dim = dim + self.model = MLP(hidden_layers=self.hidden_layers) self.params_ = self.model.init(self.generate_key(), x) def energy(self, params, x): From 36013ee3d0293b3b1db27f175ea8e2bbecdb900e Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Mon, 30 Sep 2024 16:15:34 +0200 Subject: [PATCH 34/54] fix cv steps --- src/qml_benchmarks/models/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/qml_benchmarks/models/base.py b/src/qml_benchmarks/models/base.py index 3330ca9e..c3f1dc37 100644 --- a/src/qml_benchmarks/models/base.py +++ b/src/qml_benchmarks/models/base.py @@ -158,7 +158,7 @@ def mcmc_step(self, args, i): accept_ratio = jnp.exp(-en_flip) / jnp.exp(-en) accept = jnp.array(jax.random.bernoulli(key2, accept_ratio), dtype=int)[0] x_new = accept * x_flip + (1 - accept) * x - return [params, key2, x_new], x + return [params, key2, x_new], x_new def mcmc_sample(self, params, x_init, num_mcmc_steps, key): """ @@ -208,7 +208,7 @@ def contrastive_divergence_loss(self, params, X, y, key): params_copy[key] = jax.lax.stop_gradient(params_copy[key]) configs = self.batched_mcmc_sample(params_copy, X, self.cdiv_steps, keys) - x0 = configs[:, 0] + x0 = X x1 = configs[:, -1] # taking the gradient of this loss is equivalent to the CD-k update From c4f6d766d71a4e72470db2a681687754b51b630b Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Thu, 3 Oct 2024 17:49:35 +0200 Subject: [PATCH 35/54] chunk sampling --- src/qml_benchmarks/models/base.py | 17 +++++++++++++---- src/qml_benchmarks/models/energy_based_model.py | 2 +- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/qml_benchmarks/models/base.py b/src/qml_benchmarks/models/base.py index c3f1dc37..597d5aa6 100644 --- a/src/qml_benchmarks/models/base.py +++ b/src/qml_benchmarks/models/base.py @@ -171,7 +171,7 @@ def mcmc_sample(self, params, x_init, num_mcmc_steps, key): def langevin_sample(self, params, x_init, n_samples, key): pass - def sample(self, num_samples, num_steps=1000): + def sample(self, num_samples, num_steps=1000, max_chunk_size=100): """ sample configurations starting from a random configuration. """ @@ -188,9 +188,18 @@ def sample(self, num_samples, num_steps=1000): ), dtype=int, ) - configs = self.batched_mcmc_sample(self.params_, x_init, num_steps, keys) - x1 = configs[:, -1] - return x1 + + # chunk the sampling, otherwise the vmap can blow the memory + num_chunks = num_steps//max_chunk_size + 1 + x_init = jnp.array_split(x_init, num_chunks) + keys = jnp.array_split(keys, num_chunks) + configs = [] + for elem in zip(x_init, keys): + new_configs = self.batched_mcmc_sample(self.params_, elem[0], num_steps, elem[1]) + configs.append(new_configs[:,-1]) + + configs = jnp.concatenate(configs) + return configs def contrastive_divergence_loss(self, params, X, y, key): """ diff --git a/src/qml_benchmarks/models/energy_based_model.py b/src/qml_benchmarks/models/energy_based_model.py index 610879ad..6a09650d 100644 --- a/src/qml_benchmarks/models/energy_based_model.py +++ b/src/qml_benchmarks/models/energy_based_model.py @@ -81,7 +81,7 @@ def initialize(self, x): def energy(self, params, x): return self.model.apply(params, x) - def score(self, X: np.ndarray, y: np.ndarray) -> float: + def score(self, X: np.ndarray, y=None) -> float: sigma = self.mmd_kwargs['sigma'] sigmas = [sigma] if isinstance(sigma, (int, float)) else sigma score = np.mean([mmd_loss(X, self.sample(self.mmd_kwargs['n_samples'], From 6a6ca4dbd660fae43fc32c288a945c2f1b157bbd Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Wed, 9 Oct 2024 17:45:24 +0200 Subject: [PATCH 36/54] docstring update --- src/qml_benchmarks/models/base.py | 145 ++++++++---------- .../models/energy_based_model.py | 87 +++++++++-- 2 files changed, 139 insertions(+), 93 deletions(-) diff --git a/src/qml_benchmarks/models/base.py b/src/qml_benchmarks/models/base.py index 597d5aa6..0a9cc24f 100644 --- a/src/qml_benchmarks/models/base.py +++ b/src/qml_benchmarks/models/base.py @@ -34,12 +34,10 @@ class BaseGenerator(BaseEstimator): `RandomizedSearchCV` for hyperparameter tuning. Args: - dim: - The dimensionality of the samples, e.g., for spins it could be an - integer or a tuple specifying a grid. + dim (int): dimension of the data (i.e. the number of features) """ - def __init__(self, dim: int or tuple[int]) -> None: + def __init__(self, dim: int) -> None: self.dim = dim @abstractmethod @@ -48,9 +46,8 @@ def initialize(self, x: any = None): Initialize the model and create the model parameters. Args: - x: An example data or dimensionality of the model parameters. + x: batch of data to use to initialize the model """ - # self.dim = x.shape[1:] pass @abstractmethod @@ -65,9 +62,12 @@ def sample(self, num_samples: int) -> any: class EnergyBasedModel(BaseGenerator): - """ + r""" A base class for energy-based generative models with common functionalities. + The class implements MCMC sampling via the energy function. This is used to sample from the model and to train + the model via k-contrastive divergence (see eqn (3) of arXiv:2101.03288). + We use Scikit-learn's `BaseEstimator` so that we can take advantage of Scikit-learn's hyperparameter search algorithms like `GridSearchCV` or `RandomizedSearchCV` for hyperparameter tuning. @@ -77,9 +77,19 @@ class EnergyBasedModel(BaseGenerator): `BaseEstimator` documentation for more details. References: - Teh, Yee Whye, Max Welling, Simon Osindero, and Geoffrey E. Hinton. - "Energy-Based Models for Sparse Overcomplete Representations." - Journal of Machine Learning Research, vol. 4, 2003, pp. 1235-1260. + Yang Song, Diederik P. Kingma + "How to Train Your Energy-Based Models" + arXiv:2101.03288 + + Args: + dim (int): dimension of the data (i.e. number of features) + cdiv_steps (int): number of mcmc steps to perform to estimate the constrastive divergence loss (default 1) + convergence_interval (int): The number of loss values to consider to decide convergence. + max_steps (int): Maximum number of training steps. A warning will be raised if training did not converge. + learning_rate (float): Initial learning rate for training. + batch_size (int): Size of batches used for computing parameter updates. + jit (bool): Whether to use just in time compilation. + random_state (int): Seed used for pseudorandom number generation. """ def __init__( @@ -88,7 +98,7 @@ def __init__( learning_rate=0.001, batch_size=32, max_steps=10000, - cdiv_steps=100, + cdiv_steps=1, convergence_interval=None, random_state=42, jit=True, @@ -108,14 +118,12 @@ def __init__( self.params_: dict[str : jnp.array] = None self.dim = dim # initialized to None - # Train depended attributes that the function train in self.fit() sets. - # It is not the best practice to set attributes hidden in that function - # Since it is not clear unless someone knows what the train function does. - # Therefore we add it here for clarity. + # Train dependent attributes that the function train in self.fit() sets. self.history_: list[float] = None self.training_time_: float = None - self.mcmc_step = jax.jit(self.mcmc_step) + # jax transformations of class functions + self.mcmc_step = jax.jit(self.mcmc_step) if self.jit else self.mcmc_step self.batched_mcmc_sample = jax.vmap( self.mcmc_sample, in_axes=(None, 0, None, 0) ) @@ -126,20 +134,17 @@ def generate_key(self): @abstractmethod def energy(self, params: dict, x: any) -> float: """ - The energy function for the model for a given configuration x. - - This function should be implemented by the subclass and is also - responsible for initializing the parameters of the model, if necessary. + The energy function for the model for a batch of configurations x. + This function should be implemented by the subclass. Args: - x: The configuration to calculate the energy for. + params: model parameters that determine the energy function. + x: batch of configurations of shape (n_batch, dim) for which to calculate the energy Returns: - energy (float): The energy. + energy (Array): Array of energies of shape (n_batch,) """ pass - # TODO: this can be made even more efficient with Numpyro MCMC - # see qgml.mcmc for a simple example def mcmc_step(self, args, i): """ Perform one metropolis hastings steps. @@ -147,14 +152,16 @@ def mcmc_step(self, args, i): """ params, key, x = args key1, key2 = jax.random.split(key, 2) + + # flip a random bit flip_idx = jax.random.choice(key1, jnp.arange(self.dim)) flip_config = jnp.zeros(self.dim, dtype='int32') flip_config = flip_config.at[flip_idx].set(1) x_flip = jnp.array((x + flip_config) % 2) en = self.energy(params, jnp.expand_dims(x, 0))[0] - en_flip = self.energy(params, jnp.expand_dims(x_flip, 0))[0] + accept_ratio = jnp.exp(-en_flip) / jnp.exp(-en) accept = jnp.array(jax.random.bernoulli(key2, accept_ratio), dtype=int)[0] x_new = accept * x_flip + (1 - accept) * x @@ -168,12 +175,15 @@ def mcmc_sample(self, params, x_init, num_mcmc_steps, key): carry, configs = jax.lax.scan(self.mcmc_step, carry, jnp.arange(num_mcmc_steps)) return configs - def langevin_sample(self, params, x_init, n_samples, key): - pass - def sample(self, num_samples, num_steps=1000, max_chunk_size=100): """ - sample configurations starting from a random configuration. + Sample configurations starting from a random configuration. + Each sample is generated by sampling a random configuration and perforning a number of mcmc updates. + Args: + num_samples (int): number of samples to draw + num_steps (int): number of mcmc steps before drawing a sample + max_chunk_size (int): maximum number of samples to vmap the sampling for at a time (large values + use significant memory) """ if self.params_ is None: raise ValueError( @@ -203,11 +213,11 @@ def sample(self, num_samples, num_steps=1000, max_chunk_size=100): def contrastive_divergence_loss(self, params, X, y, key): """ - Contrastive divergence loss function. + Implementation of the standard contrastive divergence loss function (see eqn 3 of arXiv:2101.03288). Args: X (array): batch of training examples y (array): not used; should be set to None when training - key: jax PRNG key + key: JAX PRNG key used for MCMC sampling """ keys = jax.random.split(key, X.shape[0]) @@ -246,57 +256,30 @@ def fit(self, X: jnp.array, y: any = None) -> None: convergence_interval=self.convergence_interval, ) - def score(self, X, y=None): - """Score the model on the given data. - - Higher is better. + @abstractmethod + def score(self, X, y=None) -> any: """ - if self.params_ is None: - self.initialize(X.shape[1]) + Score function to be used with hyperparameter optimization (larger score => better) - c_div_loss = ( - jax.jit(self.contrastive_divergence_loss) - if self.jit - else self.contrastive_divergence_loss - ) - - return 1 - c_div_loss(self.params_, X, y, self.generate_key()) - - -class SimpleEnergyModel(EnergyBasedModel): - """A simple energy-based generative model. - - Example: - -------- - model = SimpleEnergyModel() - - # Generate random 2D data of 0, 1 - X = np.random.randint(0, 2, size=(100, 2)) - - # Initialize and calculate the energy of the model with the given data - model.initialize(X) - print(model.energy(model.params_, X)) - - # Fit the model to the data - model.fit(X) - - # Generate 100 samples from the model - samples = model.sample(100) - - # Score the model on the generated data - print(model.score(X)) - """ + Args: + X: Dataset to calculate score for + y: labels (set to None for generative models to interface with sklearn functionality) + """ + pass - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.params_ = None # Data-dependent attributes + # def score(self, X, y=None): + # """Score the model on the given data. + # + # Higher is better. + # """ + # if self.params_ is None: + # self.initialize(X.shape[1]) + # + # c_div_loss = ( + # jax.jit(self.contrastive_divergence_loss) + # if self.jit + # else self.contrastive_divergence_loss + # ) + # + # return 1 - c_div_loss(self.params_, X, y, self.generate_key()) - def initialize(self, x: any = None): - key = self.generate_key() - self.dim = x.shape[1] - initializer = jax.nn.initializers.he_uniform() - self.params_ = {"weights": initializer(key, (x.shape[1], 1), jnp.float32)} - - def energy(self, params, x): - # Define the energy function here as the dot product of the parameters. - return jnp.dot(x, params["weights"]) diff --git a/src/qml_benchmarks/models/energy_based_model.py b/src/qml_benchmarks/models/energy_based_model.py index 6a09650d..8aa93526 100644 --- a/src/qml_benchmarks/models/energy_based_model.py +++ b/src/qml_benchmarks/models/energy_based_model.py @@ -21,7 +21,7 @@ class MLP(nn.Module): - "Multilayer perceptron." + "Multilayer perceptron implemented in flax" # Create a MLP with hidden layers and neurons specfied as a list of integers. hidden_layers: list[int] @@ -36,18 +36,31 @@ def __call__(self, x): class DeepEBM(EnergyBasedModel): """ - Energy-based model with the energy function is a neural network. + Energy-based model which uses a fully connected multi-layer perceptron neural network as its energy function. + The model is trained via k-contrastive divergence. + The score function corresponds to the (negative of the) maximum mean discrepancy distance. Args: + learning_rate (float): Initial learning rate for training. + batch_size (int): Size of batches used for computing parameter updates. + max_steps (int): Maximum number of training steps. A warning will be raised if training did not converge. + cdiv_steps (int): number of mcmc steps to perform to estimate the constrastive divergence loss (default 1) + convergence_interval (int): The number of loss values to consider to decide convergence. + jit (bool): Whether to use just in time compilation. + random_state (int): Seed used for pseudorandom number generation. hidden_layers (list[int]): - The number of hidden layers and neurons in the MLP layers. + The number of hidden layers and neurons in the MLP layers. e.g. [8,4] uses a three layer network where the + first layers maps to 8 neurons, the second to 4, and the last layer to 1 neuron. + mmd_kwargs (dict): arguments used for the maximum mean discrepancy score. n_samples and n_steps are the args + sent to self.sample when sampling configurations used for evaluation. sigma is the bandwidth of the + maximum mean discrepancy. """ def __init__(self, learning_rate=0.001, batch_size=32, max_steps=10000, - cdiv_steps=100, + cdiv_steps=1, convergence_interval=None, random_state=42, jit=True, @@ -79,9 +92,22 @@ def initialize(self, x): self.params_ = self.model.init(self.generate_key(), x) def energy(self, params, x): + """ + energy function + Args: + params: parameters of the neural network to be passed to flax + x: batch of configurations of shape (n_batch, dim) + returns: + batch of energy values + """ return self.model.apply(params, x) def score(self, X: np.ndarray, y=None) -> float: + """ + Maximum mean discrepancy score function + Args: + X (Array): batch of test samples to evalute the model against. + """ sigma = self.mmd_kwargs['sigma'] sigmas = [sigma] if isinstance(sigma, (int, float)) else sigma score = np.mean([mmd_loss(X, self.sample(self.mmd_kwargs['n_samples'], @@ -90,14 +116,31 @@ def score(self, X: np.ndarray, y=None) -> float: class RestrictedBoltzmannMachine(BernoulliRBM, BaseGenerator): + """ + Implementation of a restricted Boltzmann machine. The model wraps the scikit-learn BernoulliRBM class and is + trained via constrastive divergence. + Args: + n_components (int): Number of hidden units in the RBM + learning_rate (float): learning rate for training + batch_size (int): batch size for training + n_iter (int): number of epochs of training + verbose (int): verbosity level + random_state (int): random seed used for reproducibility + score_fn (str): determinies the score function used in hyperparameter optimization. If 'pseudolikelihood' + sklearn's pseudolikelihood function is used, if 'mmd' the (negative of) the maximum mean discrepancy is used. + mmd_kwargs (dict): arguments used for the maximum mean discrepancy score. n_samples and n_steps are the args + sent to self.sample when sampling configurations used for evaluation. sigma is the bandwidth of the + maximum mean discrepancy. + + """ def __init__( self, n_components=256, - learning_rate=0.1, + learning_rate=0.0001, batch_size=10, n_iter=10, verbose=0, - random_state=None, + random_state=42, score_fn='pseudolikelihood', mmd_kwargs ={'n_samples': 1000, 'n_steps': 1000, 'sigma': 1.0} ): @@ -111,6 +154,7 @@ def __init__( ) self.score_fn = score_fn self.mmd_kwargs = mmd_kwargs + self.rng = np.random.default_rng(random_state) def initialize(self, X: any = None): if len(X.shape) > 2: @@ -118,38 +162,57 @@ def initialize(self, X: any = None): self.dim = X.shape[1] def fit(self, X, y=None): + """ + fit the model using k-contrastive divergence. simply wraps the sklearn fit function. + Args: + X (np.array): training data + y: not used; set to None to interface with sklearn correctly. + """ self.initialize(X) super().fit(X, y) # Gibbs sampling: def _sample(self, num_steps=1000): """ - Sample the model for given number of steps. + Sample the model for given number of steps via the .gibbs method of sklean's RBM. The initial configuration + is sampled randomly. Args: num_steps (int): Number of Gibbs sample steps Returns: - np.array: The samples at the given temperature. + np.array: The sampled configurations """ if self.dim is None: raise ValueError("Model must be initialized before sampling") - v = np.random.choice( + v = self.rng.choice( [0, 1], size=(self.dim,) ) # Assuming `N` is `self.n_components` for _ in range(num_steps): v = self.gibbs(v) # Assuming `gibbs` is an instance method return v - def sample(self, num_samples: int, num_steps: int = 1000) -> np.ndarray: - # Parallelize the sampling process + def sample(self, num_samples: int, num_steps: int = 1000, n_jobs=-1) -> np.ndarray: + """ + Sample the model. Each sample is generated by sampling a random configuration and performing a number of + Gibbs sampling steps. We use joblib to parallelize the sampling. + Args: + num_samples (int): number of samples to return + num_steps (int): number of Gibbs sampling steps for each sample + n_jobs (int): number of parallel jobs to be sent via joblib. By default, uses all avaliable cores. + """ samples_t = Parallel(n_jobs=-1)( delayed(self._sample)(num_steps=num_steps) for _ in range(num_samples) ) samples_t = np.array(samples_t, dtype=int) return samples_t - def score(self, X: np.ndarray, y: np.ndarray) -> float: + def score(self, X: np.ndarray, y: np.ndarray=None) -> float: + """ + Score function for hyperparameter optimization. + Args: + X (Array): batch of test samples to evalute the model against. + """ if self.score_fn == 'pseudolikelihood': return float(np.mean(super().score_samples(X))) elif self.score_fn == 'mmd': From c81b0b1f20a9a4abd026db999b3b8b6b8df569ac Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Wed, 9 Oct 2024 18:45:33 +0200 Subject: [PATCH 37/54] add ising data --- src/qml_benchmarks/data/ising.py | 202 +++++++++++++++++++++++++++++++ 1 file changed, 202 insertions(+) create mode 100644 src/qml_benchmarks/data/ising.py diff --git a/src/qml_benchmarks/data/ising.py b/src/qml_benchmarks/data/ising.py new file mode 100644 index 00000000..f6d191dc --- /dev/null +++ b/src/qml_benchmarks/data/ising.py @@ -0,0 +1,202 @@ +"""Ising spin simulation for a classical 2D Ising model +""" + +import jax +from jax import numpy as jnp +from numpy import ndarray +from numpyro.infer import MCMC +from jax import random +from collections import namedtuple +from numpyro.infer.mcmc import MCMCKernel +from qgml.data import SpinConfigurationGeneratorBase +from tqdm.auto import tqdm + +def create_isotropic_interaction_matrix(grid_size: int): + """Create an interaction matrix for a 2D isotropic square lattice.""" + J = jnp.zeros((grid_size * grid_size, grid_size * grid_size)) + + for i in range(grid_size): + for j in range(grid_size): + # Spin index in the grid + idx = i * grid_size + j + + # Calculate the indices of the neighbors + right_idx = i * grid_size + (j + 1) % grid_size + left_idx = i * grid_size + (j - 1) % grid_size + bottom_idx = ((i + 1) % grid_size) * grid_size + j + top_idx = ((i - 1) % grid_size) * grid_size + j + + # Set the interactions, ensuring each pair is only added once + J = J.at[idx, right_idx].set(1) + J = J.at[idx, left_idx].set(1) + J = J.at[idx, bottom_idx].set(1) + J = J.at[idx, top_idx].set(1) + return J + + +@jax.jit +def energy(s, J, b, J_sparse=None): + """Calculate the Ising energy. For sparse Hamiltonians, it is recommneded to supply a list of nonzero indices of + J to speed up the calculation. + Args: + s: spin configuration + J: interaction matrix + b: bias term + J_sparse: list of nonzero indices of J. + """ + if J_sparse is not None: + return -jnp.einsum( + "i,i,i->", s[J_sparse[0]], s[J_sparse[1]], J[J_sparse] + ) / 2.0 - jnp.dot(s, b) + else: + return -jnp.einsum("i,j,ij->", s, s, J) / 2.0 - jnp.dot(s, b) + + +def initialize_spins(rng_key, num_spins, num_chains): + if num_chains == 1: + spins = random.bernoulli(rng_key, 0.5, (num_spins,)) + else: + spins = random.bernoulli( + rng_key, + 0.5, + ( + num_chains, + num_spins, + ), + ) + return spins * 2 - 1 + + +MHState = namedtuple("MHState", ["spins", "rng_key"]) + +class MetropolisHastings(MCMCKernel): + """An implementation of MCMC using Numpyro, see example in + https://num.pyro.ai/en/stable/mcmc.html + """ + + sample_field = "spins" + + def init( + self, + rng_key: random.PRNGKey, + num_warmup: int, + init_params: jnp.array, + *args, + **kwargs + ): + """Initialize the state of the model.""" + return MHState(init_params, rng_key) + + def sample(self, state: jnp.array, model_args, model_kwargs): + """Sample from the model via Metropolis Hastings MCMC""" + spins, rng_key = state + num_spins = spins.size + + def mh_step(i, val): + spins, rng_key = val + rng_key, subkey = random.split(rng_key) + flip_index = random.randint(subkey, (), 0, num_spins) + spins_proposal = spins.at[flip_index].set(-spins[flip_index]) + + current_energy = energy( + spins, model_kwargs["J"], model_kwargs["b"], model_kwargs["J_sparse"] + ) + proposed_energy = energy( + spins_proposal, + model_kwargs["J"], + model_kwargs["b"], + model_kwargs["J_sparse"], + ) + delta_energy = proposed_energy - current_energy + accept_prob = jnp.exp(-delta_energy / model_kwargs["T"]) + + rng_key, subkey = random.split(rng_key) + accept = random.uniform(subkey) < accept_prob + spins = jnp.where(accept, spins_proposal, spins) + return spins, rng_key + + spins, rng_key = jax.lax.fori_loop(0, num_spins, mh_step, (spins, rng_key)) + return MHState(spins, rng_key) + + +# Define the Ising model class +class IsingSpins(SpinConfigurationGeneratorBase): + """ + class object used to generate datasets + ArgsL + N (int): Number of spins + J (np.array): interaction matrix + b (np.array): bias terms + T (float): temperature + sparse (bool): If true, J is converted to a sparse representation (faster for sparse Hamiltonians) + compute_partition_fn: Whether to compute the partition function + """ + def __init__( + self, N: int, J: jnp.array, b: jnp.array, T: float, sparse=False, compute_partition_fn=False + ) -> None: + super().__init__(N) + self.kernel = MetropolisHastings() + self.J = J + self.T = T + self.b = b + self.J_sparse = jnp.nonzero(J) if sparse else None + + if compute_partition_fn: + Z = 0 + for i in tqdm(range(2**self.N), desc="Computing partition function"): + lattice = (-1) ** jnp.array(jnp.unravel_index(i, [2] * self.N)) + en = energy(lattice, self.J, self.b, self.J_sparse) + Z += jnp.exp(-en / T) + self.Z = Z + + def sample( + self, num_samples: int, num_chains=1, thinning=1, num_warmup=1000, key=42 + ) -> jnp.array: + + """ + Generate samples. + Args: + num_samples (int): total number of samples to generate per chain + num_chains (int): number of chains + thinning (int): how much to thin the sampling. e.g. if thinning = 10 a sample will be drawn ater each + 10 steps of mcmc sampling. Larger numbers result in more unbiased samples. + num_warmup (int): number of mcmc 'burn in' steps to perform before collecting any samples. + key (int): random seed used to initialize sampling. + """ + rng_key = random.PRNGKey(key) + init_spins = initialize_spins(rng_key, self.N, num_chains) + mcmc = MCMC( + self.kernel, + num_warmup=num_warmup, + thinning=thinning, + num_samples=num_samples, + num_chains=num_chains, + ) + mcmc.run( + rng_key, + init_params=init_spins, + J=self.J, + b=self.b, + T=self.T, + J_sparse=self.J_sparse, + ) + samples = mcmc.get_samples() + return samples.reshape((-1, self.N)) + + def probability(self, spin_configuration: ndarray) -> float: + return ( + jnp.exp(-energy(spin_configuration, self.J, self.b, self.J_sparse) / self.T) + / self.Z + ) + +def generate_isometric_ising( + num_samples: int = 100, T: float = 2.5, grid_size: int = 4 +) -> (ndarray, None): + num_spins = grid_size * grid_size + num_chains = 2 + num_steps = 1000 + J = create_isotropic_interaction_matrix(grid_size) + model = IsingSpins(num_spins, J, b=1.0, T=T) + # Plot the magnetization and energy trajectories for a single T + samples = model.sample(num_samples*num_steps, num_chains=num_chains, num_warmup=10000, key=0) + return samples[-num_samples:], None From a43c901b03ee4520763e24094430bb3fe62ef2bc Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Thu, 10 Oct 2024 16:17:28 +0200 Subject: [PATCH 38/54] clean up --- requirements.txt | 1 + src/qml_benchmarks/data/ising.py | 124 ++++++++++++++++---------- src/qml_benchmarks/data/spin_blobs.py | 75 ++++++++++++---- 3 files changed, 139 insertions(+), 61 deletions(-) diff --git a/requirements.txt b/requirements.txt index 9cc6c785..cdc7791a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ pyyaml~=6.0 pennyLane~=0.34 scipy~=1.11 pandas~=2.2 +numpyro~=0.14.0 \ No newline at end of file diff --git a/src/qml_benchmarks/data/ising.py b/src/qml_benchmarks/data/ising.py index f6d191dc..e35818ca 100644 --- a/src/qml_benchmarks/data/ising.py +++ b/src/qml_benchmarks/data/ising.py @@ -8,32 +8,8 @@ from jax import random from collections import namedtuple from numpyro.infer.mcmc import MCMCKernel -from qgml.data import SpinConfigurationGeneratorBase from tqdm.auto import tqdm -def create_isotropic_interaction_matrix(grid_size: int): - """Create an interaction matrix for a 2D isotropic square lattice.""" - J = jnp.zeros((grid_size * grid_size, grid_size * grid_size)) - - for i in range(grid_size): - for j in range(grid_size): - # Spin index in the grid - idx = i * grid_size + j - - # Calculate the indices of the neighbors - right_idx = i * grid_size + (j + 1) % grid_size - left_idx = i * grid_size + (j - 1) % grid_size - bottom_idx = ((i + 1) % grid_size) * grid_size + j - top_idx = ((i - 1) % grid_size) * grid_size + j - - # Set the interactions, ensuring each pair is only added once - J = J.at[idx, right_idx].set(1) - J = J.at[idx, left_idx].set(1) - J = J.at[idx, bottom_idx].set(1) - J = J.at[idx, top_idx].set(1) - return J - - @jax.jit def energy(s, J, b, J_sparse=None): """Calculate the Ising energy. For sparse Hamiltonians, it is recommneded to supply a list of nonzero indices of @@ -51,7 +27,6 @@ def energy(s, J, b, J_sparse=None): else: return -jnp.einsum("i,j,ij->", s, s, J) / 2.0 - jnp.dot(s, b) - def initialize_spins(rng_key, num_spins, num_chains): if num_chains == 1: spins = random.bernoulli(rng_key, 0.5, (num_spins,)) @@ -119,11 +94,19 @@ def mh_step(i, val): return MHState(spins, rng_key) -# Define the Ising model class -class IsingSpins(SpinConfigurationGeneratorBase): - """ - class object used to generate datasets - ArgsL +class IsingSpins: + r""" + class object used to generate datasets by sampling an ising distrbution of a specified interaction + matrix. The distribution is sampled via markov chain Monte Carlo via the Metrolopis Hastings + algorithm. + + In the case of perfect sampling, a spin configuration s is sampled with probabability + :math:`p(s)=exp(-H(s)/T)`, where the energy :math:`H(s)=\sum_{i\neq j}s_i s_i J_{ij}+\sum_i b_i s_i` + corresponds to an ising Hamiltonian and configurations s are :math:`\pm1` valued. + + The final sampled configurations are converted from a :math:`\pm1` representation to to a binary + representation via x = (s+1)//2. + N (int): Number of spins J (np.array): interaction matrix b (np.array): bias terms @@ -134,14 +117,15 @@ class object used to generate datasets def __init__( self, N: int, J: jnp.array, b: jnp.array, T: float, sparse=False, compute_partition_fn=False ) -> None: - super().__init__(N) + + self.N = N self.kernel = MetropolisHastings() self.J = J self.T = T self.b = b self.J_sparse = jnp.nonzero(J) if sparse else None - if compute_partition_fn: + if compute_partition_fn: Z = 0 for i in tqdm(range(2**self.N), desc="Computing partition function"): lattice = (-1) ** jnp.array(jnp.unravel_index(i, [2] * self.N)) @@ -181,22 +165,70 @@ def sample( J_sparse=self.J_sparse, ) samples = mcmc.get_samples() - return samples.reshape((-1, self.N)) + samples.reshape((-1, self.N)) + return (samples+1)//2 + + def probability(self, x: ndarray) -> float: + """ + compute the probability of a binary configuration x + Args: + x: binary configuration array + Returns: + (float): the probability of sampling x according to the ising distribution + """ + + if not(hasattr(self, 'Z')): + raise Exception('probability requires partition fuction to have been computed') - def probability(self, spin_configuration: ndarray) -> float: return ( - jnp.exp(-energy(spin_configuration, self.J, self.b, self.J_sparse) / self.T) + jnp.exp(-energy(x, self.J, self.b, self.J_sparse) / self.T) / self.Z ) -def generate_isometric_ising( - num_samples: int = 100, T: float = 2.5, grid_size: int = 4 -) -> (ndarray, None): - num_spins = grid_size * grid_size - num_chains = 2 - num_steps = 1000 - J = create_isotropic_interaction_matrix(grid_size) - model = IsingSpins(num_spins, J, b=1.0, T=T) - # Plot the magnetization and energy trajectories for a single T - samples = model.sample(num_samples*num_steps, num_chains=num_chains, num_warmup=10000, key=0) - return samples[-num_samples:], None +def generate_ising(N: int, + num_samples: int, + J: jnp.array, + b: jnp.array, + T: float, + sparse=False, + num_chains=1, + thinning=1, + num_warmup=1000, + key=42): + r""" + Generating function for ising datasets. + + The dataset is generated by sampling an ising distrbution of a specified interaction + matrix. The distribution is sampled via markov chain Monte Carlo via the Metrolopis Hastings + algorithm. + + In the case of perfect sampling, a spin configuration s is sampled with probabability + :math:`p(s)=exp(-H(s)/T)`, where the energy :math:`H(s)=\sum_{i\neq j}s_i s_i J_{ij}+\sum_i b_i s_i` + corresponds to an ising Hamiltonian and configurations s are :math:`\pm1` valued. + + The final sampled configurations are converted from a :math:`\pm1` representation to to a binary + representation via x = (s+1)//2. + + Note that in order to use parallelization, the number of avaliable cores has to be specified explicitly + to numpyro. i.e. the line `numpyro.set_host_device_count(num_cores)` should appear before running the + generator, where num_cores is the number of avaliable CPU cores you want to use. + + N (int): Number of spins + num_samples (int): total number of samples to generate per chain + J (np.array): interaction matrix of shape (N,N) + b (np.array): bias array of shape (N,) + T (float): temperature + num_chains (int): number of chains, defaults to 1. + thinning (int): how much to thin the sampling. e.g. if thinning = 10 a sample will be drawn after each + 10 steps of mcmc sampling. Larger numbers result in more unbiased samples. + num_warmup (int): number of mcmc 'burn in' steps to perform before collecting any samples. + key (int): random seed used to initialize sampling. + sparse (bool): If true, J is converted to a sparse representation (faster for sparse Hamiltonians) + + Returns: + Array of data samples, and Nonetype object (since there are no labels) + """ + + sampler = IsingSpins(N, J, b, T, sparse=sparse, compute_partition_fn=False) + samples = sampler.sample(num_samples, num_chains=num_chains, thinning=thinning, num_warmup=num_warmup, key=key) + return samples, None diff --git a/src/qml_benchmarks/data/spin_blobs.py b/src/qml_benchmarks/data/spin_blobs.py index d5a6842a..b6c36faa 100644 --- a/src/qml_benchmarks/data/spin_blobs.py +++ b/src/qml_benchmarks/data/spin_blobs.py @@ -16,27 +16,23 @@ import numpy as np - class RandomSpinBlobs: - """Generate spin configurations with high probabilites for certain spins. - - The dataset is generated by creating random spin samples close to a few - chosen `peak_spin` configurations of dimension `N` with each spin having - the possible values 0 or 1. We can vary the `peak_probabilities` parameter - to create data with different modes, where some samples will have higher - probabilities allowing us to study the effects of imbalance in the data. + """ + Class object used to generate spin blob datasets: a binary analog of the + 'gaussian blobs' dataset, in which bitstrings are sampled close in Hamming + distance to a set of specified configurations. - Samples are generated by selecting one of the peak spin configurations - distributed according `peak_probabilities`, and then by flipping some of the - spins. The number of spins that are flipped each time, is drawn from a - Binomial distribution bin(`N`, `p`) where `p=1` will flip all the spins - and `p=0` will not flip any spins therefore creating very narrow distributions - around the peak spins. + The dataset is generated by specifying a list of configurations (peak spins) + that mark the centre of the 'blobs'. Data points are sampled by chosing one of + the peak spins (with probabilities specified by peak probabilities), and then + flipping some of the bits. Each bit is flipped with probability specified by + p, so that (for small p) datapoints are close in Hamming distance to one of + the peak probabilities. Args: N (int): The number of spins. num_blobs (int): - The number of blobs or peak probabilities. + The number of blobs. peak_probabilities (list[float], optional): The probability of each spin to be selected. If not specified, the probabilities are distributed uniformly. @@ -56,6 +52,7 @@ def __init__( peak_spins: list[np.array] = None, p: float = 0.01, ) -> None: + self.N = N self.num_blobs = num_blobs @@ -122,6 +119,54 @@ def sample(self, num_samples: int, return_labels=False) -> np.array: else: return samples +def generate_spin_blobs(N: int, num_blobs: int, num_samples:int, peak_probabilities: list[float] = None, peak_spins: list[np.array] = None, + p: float = 0.01): + + """ + Generator function for spin blob datasets: a binary analog of the + 'gaussian blobs' dataset, in which bitstrings are sampled close in Hamming + distance to a set of specified configurations. + + The dataset is generated by specifying a list of configurations (peak spins) + that mark the centre of the 'blobs'. Data points are sampled by chosing one of + the peak spins (with probabilities specified by peak probabilities), and then + flipping some of the bits. Each bit is flipped with probability specified by + p, so that (for small p) datapoints are close in Hamming distance to one of + the peak probabilities. + + Args: + N (int): The number of spins. + num_blobs (int): + The number of blobs. + num_samples (int): The number of samples to generate. + peak_probabilities (list[float], optional): + The probability of each spin to be selected. If not specified, + the probabilities are distributed uniformly. + peak_spins (list[np.array], optional): + The peak spin configurations. Selected randomly by default. + p (float, optional): + The value of the parameter `p` in a Binomial distribution specifying + the number of spins that are flipped each time during sampling. + Defaults to 0.01. + + Returns: + tuple(np.ndarray): Dataset array and label array specifying the peak spin + that was used to sample each datapoint. + """ + + sampler = RandomSpinBlobs( + N=N, + num_blobs=num_blobs, + peak_probabilities=peak_probabilities, + peak_spins=peak_spins, + p=p, + ) + + X, y = sampler.sample(num_samples=num_samples, return_labels=True) + X = X.reshape(-1, N) + + return X, y + def generate_8blobs( num_samples: int, From fe511fb0e0b9ba6b8e3bdf014b61616600ffd161 Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Thu, 10 Oct 2024 17:06:42 +0200 Subject: [PATCH 39/54] update --- README.md | 97 ++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 85 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index b18e4ddf..2793ff41 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Benchmarking for quantum machine learning models This repository contains tools to compare the performance of near-term quantum machine learning (QML) -as well as standard classical machine learning models on supervised learning tasks. +as well as standard classical machine learning models on supervised and generative learning tasks. It is based on pipelines using [Pennylane](https://pennylane.ai/) for the simulation of quantum circuits, [JAX](https://jax.readthedocs.io/en/latest/index.html) for training, @@ -61,18 +61,23 @@ class MyModel(BaseEstimator, ClassifierMixin): # reproducibility is ensured by creating a numpy PRNG and using it for all # subsequent random functions. - self._random_state = random_state - self._rng = np.random.default_rng(random_state) + self.random_state = random_state + self.rng = np.random.default_rng(random_state) # define data-dependent attributes self.params_ = None self.n_qubits_ = None + + def initialize(self, args): + """ + initialize the model if necessary + """ + # ... your code here ... def fit(self, X, y): """Fit the model to data X and labels y. Add your custom training loop here and store the trained model parameters in `self.params_`. - Set the data-dependent attributes, such as `self.n_qubits_`. Args: X (array_like): Data of shape (n_samples, n_features) @@ -146,19 +151,86 @@ model.fit(X_train, y_train) print(model.score(X_test, y_test)) ``` + ## Adding a custom generative model -TO DO +The minimal template for a new generative model closely follows that of the classifier models. +Labels are set to `None` throughout to maintain sci-kit learn functionality. + +```python +import numpy as np + +from sklearn.base import BaseEstimator + + +class MyModel(BaseEstimator): + def __init__(self, hyperparam1="some_value", random_state=42): + + # store hyperparameters as attributes + self.hyperparam1 = hyperparam1 + + # reproducibility is ensured by creating a numpy PRNG and using it for all + # subsequent random functions. + self.random_state = random_state + self.rng = np.random.default_rng(random_state) + + # define data-dependent attributes + self.params_ = None + self.n_qubits_ = None + + def initialize(self, args): + """ + initialize the model if necessary + """ + # ... your code here ... + + def fit(self, X, y=None): + """Fit the model to data X. + + Add your custom training loop here and store the trained model parameters in `self.params_`. + + Args: + X (array_like): Data of shape (n_samples, n_features) + y (array_like): not used (no labels) + """ + # ... your code here ... + + def sample(self, num_samples): + """sample from the generative model + + Args: + num_samples (int): number of points to sample + + Returns: + array_like: sampled points + """ + # ... your code here ... + + return samples + + def score(self, X, y=None): + """A optional custom score function to be used with hyperparameter optimization + Args: + X (array_like): Data of shape (n_samples, n_features) + y: unused (no labels for generative models) + + Returns: + (float): score for the dataset X + """ + # ... your code here ... + return score +``` -mention: -- need score such that greater is better -- data should be 0/1 valued -- inheritance etc +If the model samples binary data, it is recommended to construct models that sample binary strings (rather than $\pm1$ valued strings) +to align with the datasets designed for generative models. The repository currently contains two classical generative models: +a restricted Boltzmann machine and a simple energy based model (called DeepEBM) that uses a multi-layer perceptron as its energy function. +Energy based models with more structure can easily be constructed by replacing the multilayer perception neural network by +any other differentiable network written in flax. ## Datasets The `qml_benchmarks.data` module provides generating functions to create datasets for binary classification and -generative learning. +generative learning. A generating function can be used like this: @@ -183,7 +255,8 @@ This will create a new folder in `paper/benchmarks` containing the datasets. ## Running hyperparameter optimization In the folder `scripts` we provide an example that can be used to -generate results for a hyperparameter search for any model and dataset. The script +generate results for a hyperparameter search for any model and dataset. The script functions +for both classifier and generative models. The script can be run as ``` @@ -191,7 +264,7 @@ python run_hyperparameter_search.py --model "DataReuploadingClassifier" --datase ``` where`my_dataset.csv` is a CSV file containing the training data. For classification problems, each column should -correspond to an feature and the last column to the target. For generative learning, each row +correspond to a feature and the last column to the target. For generative learning, each row should correspond to a binary string that specifies a unique data sample. Unless otherwise specified, the hyperparameter grid is loaded from `qml_benchmarks/hyperparameter_settings.py`. From 5ebdee596b76d58821060efb4c3aa2fbc66f4193 Mon Sep 17 00:00:00 2001 From: josephbowles <54283511+josephbowles@users.noreply.github.com> Date: Thu, 10 Oct 2024 17:10:39 +0200 Subject: [PATCH 40/54] Delete generative_models/benchmarks directory --- generative_models/benchmarks/README.md | 65 ------ generative_models/benchmarks/generate_data.py | 31 --- .../benchmarks/run_hyperparameter_search.py | 214 ------------------ 3 files changed, 310 deletions(-) delete mode 100644 generative_models/benchmarks/README.md delete mode 100644 generative_models/benchmarks/generate_data.py delete mode 100644 generative_models/benchmarks/run_hyperparameter_search.py diff --git a/generative_models/benchmarks/README.md b/generative_models/benchmarks/README.md deleted file mode 100644 index 040529af..00000000 --- a/generative_models/benchmarks/README.md +++ /dev/null @@ -1,65 +0,0 @@ -# Benchmarking for generative model - -The scripts in this package can help set up experiments to evaluate generative -models on custom datasets. The models and datasets are defined in the main -package. - -## Datasets - -The `qml_benchmarks.data` module provides generating functions to create datasets -A generating function can be used like this: - -```python -from qml_benchmarks.data import generate_8blobs -X, y = generate_8blobs(n_samples=200) -``` - -The scipt in this folder will generate a simple spin blob dataset. - -## Running hyperparameter optimization - -A hyperparameter search for any model and dataset can be run with the script -in this folder as: - -``` -python run_hyperparameter_search.py --model-name "RBM" --dataset-path "spin_blobs/8blobs_train.csv" -``` - -where `spin_blobs/8blobs_train.csv` is a CSV file containing the training data -such that each column is a feature. - -Unless otherwise specified, the hyperparameter grid is loaded from -`qml_benchmarks/hyperparameter_settings.py`. One can override the default -grid of hyperparameters by specifying the hyperparameter list, -where the datatype is inferred from the default values. -For example, for the `RBM` we can run: - -``` -python run_hyperparameter_search.py \ - --model-name RBM \ - --dataset-path "spin_blobs/8blobs_train.csv" \ - --learning_rate 0.1 0.01 \ - --clean True -``` - -which runs a search for the grid: - -``` -{'learning_rate': [0.1, 0.01], } -``` - -The script creates two CSV files that contains the detailed results of hyperparameter search and the best -hyperparameters obtained in the search. These files are similar to the ones stored in the `paper/results` -folder. - -The best hyperparameters can be loaded into a model and used to score the classifier. - -You can check the various options for the script using: - -``` -python run_hyperparameter_search --help -``` - -## Feedback - -Please help us improve this repository and report problems by opening an issue or pull request. diff --git a/generative_models/benchmarks/generate_data.py b/generative_models/benchmarks/generate_data.py deleted file mode 100644 index 9f34465f..00000000 --- a/generative_models/benchmarks/generate_data.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2024 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Generate 8blobs dataset.""" - -import os -import numpy as np -from qml_benchmarks.data import generate_8blobs - - -if __name__ == "__main__": - os.makedirs("spin_blobs", exist_ok=True) - path_train = "spin_blobs/8blobs_train.csv" - path_test = "spin_blobs/8blobs_test.csv" - - X, y = generate_8blobs(num_samples=5000) - np.savetxt(path_train, X, delimiter=",") - - X, y = generate_8blobs(num_samples=1000) - np.savetxt(path_test, X, delimiter=",") diff --git a/generative_models/benchmarks/run_hyperparameter_search.py b/generative_models/benchmarks/run_hyperparameter_search.py deleted file mode 100644 index c01f096a..00000000 --- a/generative_models/benchmarks/run_hyperparameter_search.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright 2024 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Run hyperparameter search and store results with a command-line script.""" - -import numpy as np -import sys -import os -import time -import argparse -import logging -logging.getLogger().setLevel(logging.INFO) -from importlib import import_module -import pandas as pd -from pathlib import Path -import matplotlib.pyplot as plt -from sklearn.model_selection import GridSearchCV -from qml_benchmarks.hyperparam_search_utils import read_data, construct_hyperparameter_grid -from qml_benchmarks.hyperparameter_settings import hyper_parameter_settings - -np.random.seed(42) - -logging.info('cpu count:' + str(os.cpu_count())) - - -if __name__ == "__main__": - # Create an argument parser - parser = argparse.ArgumentParser(description="Run experiments with hyperparameter search.", - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - - parser.add_argument( - "--model-name", - help="Model to run", - ) - - parser.add_argument( - "--dataset-path", - help="Path to the dataset", - ) - - parser.add_argument( - "--results-path", default=".", help="Path to store the experiment results" - ) - - parser.add_argument( - "--clean", - help="True or False. Remove previous results if it exists", - dest="clean", - default=False, - type=bool, - ) - - parser.add_argument( - "--hyperparameter-scoring", - type=list, - nargs="+", - default=["accuracy", "roc_auc"], - help="Scoring for hyperparameter search.", - ) - - parser.add_argument( - "--hyperparameter-refit", - type=str, - default="accuracy", - help="Refit scoring for hyperparameter search.", - ) - - parser.add_argument( - "--plot-loss", - help="True or False. Plot loss history for single fit", - dest="plot_loss", - default=False, - type=bool, - ) - - parser.add_argument( - "--n-jobs", type=int, default=-1, help="Number of parallel threads to run" - ) - - # Parse the arguments along with any extra arguments that might be model specific - args, unknown_args = parser.parse_known_args() - - if any(arg is None for arg in [args.model_name, - args.dataset_path]): - msg = "\n================================================================================" - msg += "\nA model from qml.benchmarks.model and dataset path are required. E.g., \n \n" - msg += "python run_hyperparameter_search \ \n--model RBM \ \n--dataset-path train.csv\n" - msg += "\nCheck all arguments for the script with \n" - msg += "python run_hyperparameter_search --help\n" - msg += "================================================================================" - raise ValueError(msg) - - # Add model specific arguments to override the default hyperparameter grid - hyperparam_grid = construct_hyperparameter_grid( - hyper_parameter_settings, args.model_name - ) - for hyperparam in hyperparam_grid: - hp_type = type(hyperparam_grid[hyperparam][0]) - parser.add_argument(f'--{hyperparam}', - type=hp_type, - nargs="+", - default=hyperparam_grid[hyperparam], - help=f'{hyperparam} grid values for {args.model_name}') - - args = parser.parse_args(unknown_args, namespace=args) - - for hyperparam in hyperparam_grid: - override = getattr(args, hyperparam) - if override is not None: - hyperparam_grid[hyperparam] = override - logging.info( - "Running hyperparameter search experiment with the following settings\n" - ) - logging.info(args.model_name) - logging.info(args.dataset_path) - logging.info(" ".join(args.hyperparameter_scoring)) - logging.info(args.hyperparameter_refit) - logging.info("Hyperparam grid:"+" ".join([(str(key)+str(":")+str(hyperparam_grid[key])) for key in hyperparam_grid.keys()])) - - experiment_path = args.results_path - results_path = os.path.join(experiment_path, "results") - - if not os.path.exists(results_path): - os.makedirs(results_path) - - ################################################################### - # Get the model, dataset and search methods from the arguments - ################################################################### - model = getattr( - import_module("qml_benchmarks.models"), - args.model_name - ) - model_name = model.__name__ - - # Run the experiments save the results - train_dataset_filename = os.path.join(args.dataset_path) - X, y = read_data(train_dataset_filename) - - dataset_path_obj = Path(args.dataset_path) - results_filename_stem = " ".join( - [model.__name__ + "_" + dataset_path_obj.stem - + "_GridSearchCV"]) - - # If we have already run this experiment then continue - if os.path.isfile(os.path.join(results_path, results_filename_stem + ".csv")): - if args.clean is False: - msg = "\n=================================================================================" - msg += "\nResults exist in " + os.path.join(results_path, results_filename_stem + ".csv") - msg += "\nSpecify --clean True to override results or new --results-path" - msg += "\n=================================================================================" - logging.warning(msg) - sys.exit(msg) - else: - logging.warning("Cleaning existing results for ", os.path.join(results_path, results_filename_stem + ".csv")) - - - ########################################################################### - # Single fit to check everything works - ########################################################################### - model = model() - a = time.time() - model.fit(X, y) - b = time.time() - acc_train = model.score(X, y) - logging.info(" ".join( - [model_name, - "Dataset path", - args.dataset_path, - "Train acc:", - str(acc_train), - "Time single run", - str(b - a)]) - ) - if hasattr(model, "loss_history_"): - if args.plot_loss: - plt.plot(model.loss_history_) - plt.xlabel("Iterations") - plt.ylabel("Loss") - plt.show() - - if hasattr(model, "n_qubits_"): - logging.info(" ".join(["Num qubits", f"{model.n_qubits_}"])) - - ########################################################################### - # Hyperparameter search - ########################################################################### - gs = GridSearchCV(estimator=model, param_grid=hyperparam_grid, - refit=args.hyperparameter_refit, - verbose=3, - n_jobs=-1).fit( - X, y - ) - logging.info("Best hyperparams") - logging.info(gs.best_params_) - - df = pd.DataFrame.from_dict(gs.cv_results_) - df.to_csv(os.path.join(results_path, results_filename_stem + ".csv")) - - best_df = pd.DataFrame(list(gs.best_params_.items()), columns=['hyperparameter', 'best_value']) - - # Save best hyperparameters to a CSV file - best_df.to_csv(os.path.join(results_path, - results_filename_stem + '-best-hyperparameters.csv'), index=False) \ No newline at end of file From da5d68170c9938daf0a92bb099a5cbcee60f8df4 Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Thu, 10 Oct 2024 17:28:38 +0200 Subject: [PATCH 41/54] update --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 2793ff41..1610ae5c 100644 --- a/README.md +++ b/README.md @@ -265,7 +265,8 @@ python run_hyperparameter_search.py --model "DataReuploadingClassifier" --datase where`my_dataset.csv` is a CSV file containing the training data. For classification problems, each column should correspond to a feature and the last column to the target. For generative learning, each row -should correspond to a binary string that specifies a unique data sample. +should correspond to a binary string that specifies a unique data sample, and the model should implement a `score` +method. Unless otherwise specified, the hyperparameter grid is loaded from `qml_benchmarks/hyperparameter_settings.py`. One can override the default grid of hyperparameters by specifying the hyperparameter list, From 0b7d23978c51f5251acee6cc0c4290bd4deb1084 Mon Sep 17 00:00:00 2001 From: josephbowles <54283511+josephbowles@users.noreply.github.com> Date: Fri, 11 Oct 2024 10:27:17 +0200 Subject: [PATCH 42/54] Delete src/qml_benchmarks/models/restricted_boltzmann_machine.py --- .../models/restricted_boltzmann_machine.py | 153 ------------------ 1 file changed, 153 deletions(-) delete mode 100644 src/qml_benchmarks/models/restricted_boltzmann_machine.py diff --git a/src/qml_benchmarks/models/restricted_boltzmann_machine.py b/src/qml_benchmarks/models/restricted_boltzmann_machine.py deleted file mode 100644 index 6e12c08a..00000000 --- a/src/qml_benchmarks/models/restricted_boltzmann_machine.py +++ /dev/null @@ -1,153 +0,0 @@ -import numpy as np -import jax -import jax.numpy as jnp -from qml_benchmarks.model_utils import train -import optax -import copy - -class RestrictedBoltzmannMachineOld(): - """ - A restricted Boltzmann machine generative model. The model is trained with the k-contrastive divergence (CD-k) - algorithm. - Args: - n_hidden (int): The number of hidden neurons - learning_rate (float): The learning rate for the CD-k updates - cdiv_steps (int): The number of gibbs sampling steps used in contrastive divergence - jit (bool): Whether to use just-in-time complilation - batch_size (int): Size of batches used for computing parameter updates - max_steps (int): Maximum number of training steps. - reg (float): The L2 regularisation strength (larger implies stronger) - convergence_interval (int or None): The number of loss values to consider to decide convergence. - If None, training runs until the maximum number of steps. - random_state (int): Seed used for pseudorandom number generation. - - """ - - def __init__(self, n_hidden, learning_rate=0.001, cdiv_steps=1, jit=True, batch_size=32, - max_steps=200, reg=0.0, convergence_interval=200, random_state=42): - - self.n_hidden = n_hidden - self.learning_rate = learning_rate - self.random_state = random_state - self.rng = np.random.default_rng(random_state) - self.jit = jit - self.batch_size = batch_size - self.max_steps = max_steps - self.reg = reg - self.convergence_interval = convergence_interval - self.cdiv_steps = cdiv_steps - self.vmap = True - self.max_vmap = None - - # data depended attributes - self.params_ = None - self.n_visible_ = None - - self.gibbs_step = jax.jit(self.gibbs_step) if jit else self.gibbs_step - - def generate_key(self): - return jax.random.PRNGKey(self.rng.integers(1000000)) - - def energy(self, params, x, h): - """ - The RBM energy function - Args: - params: parameter dictionay of weights and biases - x: visible configuration - h: hidden configuration - Returns: - energy (float): The energy - """ - return -x.T @ params['W'] @ h - params['a'].T @ x - params['b'].T @ h - - def initialize(self, n_features): - self.n_visible_ = n_features - W = jax.random.normal(self.generate_key(), shape=(self.n_visible_, self.n_hidden)) / jnp.sqrt(self.n_visible_) - a = jax.random.normal(self.generate_key(), shape=(self.n_visible_,)) / jnp.sqrt(self.n_visible_) - b = jax.random.normal(self.generate_key(), shape=(self.n_hidden,)) / jnp.sqrt(self.n_visible_) - self.params_ = {'a': a, 'b': b, 'W': W} - - def gibbs_step(self, args, i): - """ - Perform one Gibbs steps. The format is such that it can be used with jax.lax.scan for fast compilation. - """ - params = args[0] - key = args[1] - x = args[2] - key1, key2, key3 = jax.random.split(key, 3) - # get hidden units probs - prob_h = jax.nn.sigmoid(x.T @ params['W'] + params['b']) - h = jnp.array(jax.random.bernoulli(key1, p=prob_h), dtype=int) - # get visible units probs - prob_x = jax.nn.sigmoid(params['W'] @ h + params['a']) - x_new = jnp.array(jax.random.bernoulli(key2, p=prob_x), dtype=int) - return [params, key3, x_new], [x, h] - - def gibbs_sample(self, params, x_init, n_samples, key): - """ - Sample a chain of visible and hidden configurations from a starting visible configuration x_init - """ - carry = [params, key, x_init] - carry, configs = jax.lax.scan(self.gibbs_step, carry, jnp.arange(n_samples)) - return configs - - def sample(self, n_samples): - """ - sample only the visible units starting from a random configuration. - """ - key = self.generate_key() - x_init = jnp.array(jax.random.bernoulli(key, p=0.5, shape=(self.n_visible_,)), dtype=int) - samples = self.gibbs_sample(self.params_, x_init, n_samples, self.generate_key()) - return jnp.array(samples[0]) - - def fit(self, X): - """ - Fit the parameters using contrastive divergence - """ - self.initialize(X.shape[-1]) - X = jnp.array(X, dtype=int) - - # batch the relevant functions - batched_gibbs_sample = jax.vmap(self.gibbs_sample, in_axes=(None, 0, None, 0)) - batched_energy = jax.vmap(self.energy, in_axes=(None, 0, 0)) - - def c_div_loss(params, X, y, key): - """ - contrastive divergence loss - Args: - params (dict): parameter dictionary - X (array): batch of training examples - y (array): not used; should be set to None when training - key: jax PRNG key - """ - keys = jax.random.split(key, X.shape[0]) - - # we do not take the gradient wrt the sampling, so decouple the param dict here - params_copy = copy.deepcopy(params) - for key in params_copy.keys(): - params_copy[key] = jax.lax.stop_gradient(params_copy[key]) - - configs = batched_gibbs_sample(params_copy, X, self.cdiv_steps + 1, keys) - x0 = configs[0][:, 0, :] - h0 = configs[1][:, 0, :] - x1 = configs[0][:, -1, :] - h1 = configs[1][:, -1, :] - - # taking the gradient of this loss is equivalent to the CD-k update - loss = batched_energy(params, x0, h0) - batched_energy(params, x1, h1) - - return jnp.mean(loss) + self.reg * jnp.sqrt(jnp.sum(params['W'] ** 2)) - - c_div_loss = jax.jit(c_div_loss) if self.jit else c_div_loss - - self.params_ = train(self, c_div_loss, optax.sgd, X, None, self.generate_key, - convergence_interval=self.convergence_interval) - - - - - - - - - From bd863c5326c88a3bd08592f74a772b355c590d55 Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Tue, 15 Oct 2024 12:21:37 +0200 Subject: [PATCH 43/54] remove_joblib --- src/qml_benchmarks/models/energy_based_model.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/qml_benchmarks/models/energy_based_model.py b/src/qml_benchmarks/models/energy_based_model.py index 8aa93526..26ed3361 100644 --- a/src/qml_benchmarks/models/energy_based_model.py +++ b/src/qml_benchmarks/models/energy_based_model.py @@ -15,7 +15,6 @@ import flax.linen as nn from qml_benchmarks.models.base import EnergyBasedModel, BaseGenerator from sklearn.neural_network import BernoulliRBM -from joblib import Parallel, delayed from qml_benchmarks.model_utils import mmd_loss, median_heuristic import numpy as np @@ -172,7 +171,7 @@ def fit(self, X, y=None): super().fit(X, y) # Gibbs sampling: - def _sample(self, num_steps=1000): + def _sample(self, init_config, num_steps=1000): """ Sample the model for given number of steps via the .gibbs method of sklean's RBM. The initial configuration is sampled randomly. @@ -185,9 +184,7 @@ def _sample(self, num_steps=1000): """ if self.dim is None: raise ValueError("Model must be initialized before sampling") - v = self.rng.choice( - [0, 1], size=(self.dim,) - ) # Assuming `N` is `self.n_components` + v = init_config for _ in range(num_steps): v = self.gibbs(v) # Assuming `gibbs` is an instance method return v @@ -201,9 +198,8 @@ def sample(self, num_samples: int, num_steps: int = 1000, n_jobs=-1) -> np.ndarr num_steps (int): number of Gibbs sampling steps for each sample n_jobs (int): number of parallel jobs to be sent via joblib. By default, uses all avaliable cores. """ - samples_t = Parallel(n_jobs=-1)( - delayed(self._sample)(num_steps=num_steps) for _ in range(num_samples) - ) + init_configs = [self.rng.choice([0, 1], size=(self.dim,)) for __ in range(num_samples)] + samples_t = [self._sample(init_config, num_steps=num_steps) for init_config in init_configs] samples_t = np.array(samples_t, dtype=int) return samples_t From 3b2d230ba379c12cb0859f8b9cbe4f3830a492ca Mon Sep 17 00:00:00 2001 From: josephbowles <54283511+josephbowles@users.noreply.github.com> Date: Tue, 15 Oct 2024 12:23:55 +0200 Subject: [PATCH 44/54] Update README.md Co-authored-by: Maria Schuld --- README.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 1610ae5c..cbee29f7 100644 --- a/README.md +++ b/README.md @@ -222,10 +222,9 @@ class MyModel(BaseEstimator): ``` If the model samples binary data, it is recommended to construct models that sample binary strings (rather than $\pm1$ valued strings) -to align with the datasets designed for generative models. The repository currently contains two classical generative models: -a restricted Boltzmann machine and a simple energy based model (called DeepEBM) that uses a multi-layer perceptron as its energy function. -Energy based models with more structure can easily be constructed by replacing the multilayer perception neural network by -any other differentiable network written in flax. +to align with the datasets designed for generative models. +Energy based models can easily be constructed by replacing the multilayer perceptron neural network in `DeepEBM` by +any other differentiable network written in `flax`. ## Datasets From 28b70bae8b8dc9251c9896c2e93d00078c4c5ded Mon Sep 17 00:00:00 2001 From: josephbowles <54283511+josephbowles@users.noreply.github.com> Date: Tue, 15 Oct 2024 12:24:09 +0200 Subject: [PATCH 45/54] Update scripts/score_with_best_hyperparameters.py Co-authored-by: Maria Schuld --- scripts/score_with_best_hyperparameters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/score_with_best_hyperparameters.py b/scripts/score_with_best_hyperparameters.py index 281424bf..90e42a63 100644 --- a/scripts/score_with_best_hyperparameters.py +++ b/scripts/score_with_best_hyperparameters.py @@ -14,7 +14,7 @@ """ Score a model using the best hyperparameters, using a command-line script. -Note this is only compatible with classifier models. +Note this is only compatible with supervised models for classification. """ From d230ddd048ef364620adc8a3d2f517c00ef8efe2 Mon Sep 17 00:00:00 2001 From: josephbowles <54283511+josephbowles@users.noreply.github.com> Date: Tue, 15 Oct 2024 12:24:18 +0200 Subject: [PATCH 46/54] Update src/qml_benchmarks/data/ising.py Co-authored-by: Maria Schuld --- src/qml_benchmarks/data/ising.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qml_benchmarks/data/ising.py b/src/qml_benchmarks/data/ising.py index e35818ca..6f3d1906 100644 --- a/src/qml_benchmarks/data/ising.py +++ b/src/qml_benchmarks/data/ising.py @@ -196,7 +196,7 @@ def generate_ising(N: int, num_warmup=1000, key=42): r""" - Generating function for ising datasets. + Generating function for Ising datasets. The dataset is generated by sampling an ising distrbution of a specified interaction matrix. The distribution is sampled via markov chain Monte Carlo via the Metrolopis Hastings From 517ac98667d1d73e9bb66f97287c4d0a20b37c62 Mon Sep 17 00:00:00 2001 From: josephbowles <54283511+josephbowles@users.noreply.github.com> Date: Tue, 15 Oct 2024 12:24:26 +0200 Subject: [PATCH 47/54] Update src/qml_benchmarks/data/ising.py Co-authored-by: Maria Schuld --- src/qml_benchmarks/data/ising.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qml_benchmarks/data/ising.py b/src/qml_benchmarks/data/ising.py index 6f3d1906..3dc547df 100644 --- a/src/qml_benchmarks/data/ising.py +++ b/src/qml_benchmarks/data/ising.py @@ -199,7 +199,7 @@ def generate_ising(N: int, Generating function for Ising datasets. The dataset is generated by sampling an ising distrbution of a specified interaction - matrix. The distribution is sampled via markov chain Monte Carlo via the Metrolopis Hastings + matrix. The distribution is sampled via Markov Chain Monte Carlo via the Metrolopis Hastings algorithm. In the case of perfect sampling, a spin configuration s is sampled with probabability From 7677b415bd768ee66455ed2248b04f4255907447 Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Tue, 15 Oct 2024 15:08:19 +0200 Subject: [PATCH 48/54] cleanup --- src/qml_benchmarks/models/base.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/qml_benchmarks/models/base.py b/src/qml_benchmarks/models/base.py index 0a9cc24f..4a7eb7fe 100644 --- a/src/qml_benchmarks/models/base.py +++ b/src/qml_benchmarks/models/base.py @@ -266,20 +266,3 @@ def score(self, X, y=None) -> any: y: labels (set to None for generative models to interface with sklearn functionality) """ pass - - # def score(self, X, y=None): - # """Score the model on the given data. - # - # Higher is better. - # """ - # if self.params_ is None: - # self.initialize(X.shape[1]) - # - # c_div_loss = ( - # jax.jit(self.contrastive_divergence_loss) - # if self.jit - # else self.contrastive_divergence_loss - # ) - # - # return 1 - c_div_loss(self.params_, X, y, self.generate_key()) - From 1789506c8668a145d097b08a2520d7692255db0e Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Tue, 15 Oct 2024 15:10:01 +0200 Subject: [PATCH 49/54] add new datasets --- src/qml_benchmarks/data/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/qml_benchmarks/data/__init__.py b/src/qml_benchmarks/data/__init__.py index db748a56..04bd1df5 100644 --- a/src/qml_benchmarks/data/__init__.py +++ b/src/qml_benchmarks/data/__init__.py @@ -19,4 +19,5 @@ from qml_benchmarks.data.hyperplanes import generate_hyperplanes_parity from qml_benchmarks.data.linearly_separable import generate_linearly_separable from qml_benchmarks.data.two_curves import generate_two_curves -from qml_benchmarks.data.spin_blobs import generate_8blobs +from qml_benchmarks.data.spin_blobs import generate_spin_blobs, generate_8blobs +from qml_benchmarks.data.ising import generate_ising From 3528863f4a38d4398089dfd1968a096a384df253 Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Tue, 15 Oct 2024 15:21:03 +0200 Subject: [PATCH 50/54] black formatting --- src/qml_benchmarks/data/ising.py | 60 +++++++++----- src/qml_benchmarks/data/spin_blobs.py | 12 ++- src/qml_benchmarks/model_utils.py | 64 +++++++++++---- src/qml_benchmarks/models/base.py | 10 ++- .../models/energy_based_model.py | 80 +++++++++++++------ 5 files changed, 156 insertions(+), 70 deletions(-) diff --git a/src/qml_benchmarks/data/ising.py b/src/qml_benchmarks/data/ising.py index 3dc547df..1c47003b 100644 --- a/src/qml_benchmarks/data/ising.py +++ b/src/qml_benchmarks/data/ising.py @@ -10,6 +10,7 @@ from numpyro.infer.mcmc import MCMCKernel from tqdm.auto import tqdm + @jax.jit def energy(s, J, b, J_sparse=None): """Calculate the Ising energy. For sparse Hamiltonians, it is recommneded to supply a list of nonzero indices of @@ -27,6 +28,7 @@ def energy(s, J, b, J_sparse=None): else: return -jnp.einsum("i,j,ij->", s, s, J) / 2.0 - jnp.dot(s, b) + def initialize_spins(rng_key, num_spins, num_chains): if num_chains == 1: spins = random.bernoulli(rng_key, 0.5, (num_spins,)) @@ -44,6 +46,7 @@ def initialize_spins(rng_key, num_spins, num_chains): MHState = namedtuple("MHState", ["spins", "rng_key"]) + class MetropolisHastings(MCMCKernel): """An implementation of MCMC using Numpyro, see example in https://num.pyro.ai/en/stable/mcmc.html @@ -114,10 +117,16 @@ class object used to generate datasets by sampling an ising distrbution of a spe sparse (bool): If true, J is converted to a sparse representation (faster for sparse Hamiltonians) compute_partition_fn: Whether to compute the partition function """ + def __init__( - self, N: int, J: jnp.array, b: jnp.array, T: float, sparse=False, compute_partition_fn=False + self, + N: int, + J: jnp.array, + b: jnp.array, + T: float, + sparse=False, + compute_partition_fn=False, ) -> None: - self.N = N self.kernel = MetropolisHastings() self.J = J @@ -136,7 +145,6 @@ def __init__( def sample( self, num_samples: int, num_chains=1, thinning=1, num_warmup=1000, key=42 ) -> jnp.array: - """ Generate samples. Args: @@ -166,7 +174,7 @@ def sample( ) samples = mcmc.get_samples() samples.reshape((-1, self.N)) - return (samples+1)//2 + return (samples + 1) // 2 def probability(self, x: ndarray) -> float: """ @@ -177,24 +185,26 @@ def probability(self, x: ndarray) -> float: (float): the probability of sampling x according to the ising distribution """ - if not(hasattr(self, 'Z')): - raise Exception('probability requires partition fuction to have been computed') - - return ( - jnp.exp(-energy(x, self.J, self.b, self.J_sparse) / self.T) - / self.Z - ) + if not (hasattr(self, "Z")): + raise Exception( + "probability requires partition fuction to have been computed" + ) -def generate_ising(N: int, - num_samples: int, - J: jnp.array, - b: jnp.array, - T: float, - sparse=False, - num_chains=1, - thinning=1, - num_warmup=1000, - key=42): + return jnp.exp(-energy(x, self.J, self.b, self.J_sparse) / self.T) / self.Z + + +def generate_ising( + N: int, + num_samples: int, + J: jnp.array, + b: jnp.array, + T: float, + sparse=False, + num_chains=1, + thinning=1, + num_warmup=1000, + key=42, +): r""" Generating function for Ising datasets. @@ -230,5 +240,11 @@ def generate_ising(N: int, """ sampler = IsingSpins(N, J, b, T, sparse=sparse, compute_partition_fn=False) - samples = sampler.sample(num_samples, num_chains=num_chains, thinning=thinning, num_warmup=num_warmup, key=key) + samples = sampler.sample( + num_samples, + num_chains=num_chains, + thinning=thinning, + num_warmup=num_warmup, + key=key, + ) return samples, None diff --git a/src/qml_benchmarks/data/spin_blobs.py b/src/qml_benchmarks/data/spin_blobs.py index b6c36faa..c7308546 100644 --- a/src/qml_benchmarks/data/spin_blobs.py +++ b/src/qml_benchmarks/data/spin_blobs.py @@ -16,6 +16,7 @@ import numpy as np + class RandomSpinBlobs: """ Class object used to generate spin blob datasets: a binary analog of the @@ -52,7 +53,6 @@ def __init__( peak_spins: list[np.array] = None, p: float = 0.01, ) -> None: - self.N = N self.num_blobs = num_blobs @@ -119,9 +119,15 @@ def sample(self, num_samples: int, return_labels=False) -> np.array: else: return samples -def generate_spin_blobs(N: int, num_blobs: int, num_samples:int, peak_probabilities: list[float] = None, peak_spins: list[np.array] = None, - p: float = 0.01): +def generate_spin_blobs( + N: int, + num_blobs: int, + num_samples: int, + peak_probabilities: list[float] = None, + peak_spins: list[np.array] = None, + p: float = 0.01, +): """ Generator function for spin blob datasets: a binary analog of the 'gaussian blobs' dataset, in which bitstrings are sampled close in Hamming diff --git a/src/qml_benchmarks/model_utils.py b/src/qml_benchmarks/model_utils.py index 322d642f..c1b97576 100644 --- a/src/qml_benchmarks/model_utils.py +++ b/src/qml_benchmarks/model_utils.py @@ -27,6 +27,7 @@ import inspect from tqdm import tqdm + def train( model, loss_fn, optimizer, X, y, random_key_generator, convergence_interval=200 ): @@ -62,8 +63,10 @@ def train( # wrap a key around the function if it doesn't have one if "key" not in inspect.signature(loss_fn).parameters: + def loss_fn_wrapped(params, x, y, key): return loss_fn(params, x, y) + else: loss_fn_wrapped = loss_fn @@ -79,8 +82,14 @@ def loss_fn_wrapped(params, x, y, key): # note: assumes that the loss function is a sample mean of # some function over the input data set - chunked_grad_fn = chunk_grad(grad_fn, model.max_vmap) if model.max_vmap is not None else grad_fn - chunked_loss_fn = chunk_loss(loss_fn_wrapped, model.max_vmap) if model.max_vmap is not None else loss_fn_wrapped + chunked_grad_fn = ( + chunk_grad(grad_fn, model.max_vmap) if model.max_vmap is not None else grad_fn + ) + chunked_loss_fn = ( + chunk_loss(loss_fn_wrapped, model.max_vmap) + if model.max_vmap is not None + else loss_fn_wrapped + ) def update(params, opt_state, x, y, key): grads = chunked_grad_fn(params, x, y, key) @@ -97,7 +106,9 @@ def update(params, opt_state, x, y, key): key_batch = random_key_generator() key_loss = jax.random.split(key_batch, 1)[0] X_batch, y_batch = get_batch(X, y, key_batch, batch_size=model.batch_size) - params, opt_state, loss_val = update(params, opt_state, X_batch, y_batch, key_loss) + params, opt_state, loss_val = update( + params, opt_state, X_batch, y_batch, key_loss + ) loss_history.append(loss_val) logging.debug(f"{step} - loss: {loss_val}") pbar.update(1) @@ -116,7 +127,10 @@ def update(params, opt_state, x, y, key): ) std1 = np.std(loss_history[-convergence_interval:]) # if the difference in averages is small compared to the statistical fluctuations, stop training. - if np.abs(average2 - average1) <= std1 / np.sqrt(convergence_interval) / 2: + if ( + np.abs(average2 - average1) + <= std1 / np.sqrt(convergence_interval) / 2 + ): logging.info( f"Model {model.__class__.__name__} converged after {step} steps." ) @@ -125,7 +139,7 @@ def update(params, opt_state, x, y, key): end = time.time() loss_history = np.array(loss_history) - model.loss_history_ = loss_history / np.max(np.abs(loss_history)) + model.loss_history_ = loss_history / np.max(np.abs(loss_history)) model.training_time_ = end - start if not converged and convergence_interval is not None: @@ -162,8 +176,6 @@ def get_batch(X, y, rnd_key, batch_size=32): return X[rnd_indices], None - - def get_from_dict(dict, key_list): """ Access a value from a nested dictionary. @@ -312,7 +324,10 @@ def chunked_loss(params, X, y, key): return chunked_loss -def mmd_loss(ground_truth: np.ndarray, model_samples: np.ndarray, sigma: float) -> float: + +def mmd_loss( + ground_truth: np.ndarray, model_samples: np.ndarray, sigma: float +) -> float: """Calculates an unbiased estimate of the Maximum Mean Discrepancy (MMD) loss from samples see https://jmlr.org/papers/volume13/gretton12a/gretton12a.pdf for more info @@ -332,32 +347,48 @@ def mmd_loss(ground_truth: np.ndarray, model_samples: np.ndarray, sigma: float) # K_pp K_pp = jnp.zeros((ground_truth.shape[0], ground_truth.shape[0])) + def body_fun(i, val): def inner_body_fun(j, inner_val): - return inner_val.at[i, j].set(gaussian_kernel(sigma, ground_truth[i], ground_truth[j])) + return inner_val.at[i, j].set( + gaussian_kernel(sigma, ground_truth[i], ground_truth[j]) + ) + return jax.lax.fori_loop(0, ground_truth.shape[0], inner_body_fun, val) + K_pp = jax.lax.fori_loop(0, ground_truth.shape[0], body_fun, K_pp) sum_pp = jnp.sum(K_pp) - n # K_pq K_pq = jnp.zeros((ground_truth.shape[0], model_samples.shape[0])) + def body_fun(i, val): def inner_body_fun(j, inner_val): - return inner_val.at[i, j].set(gaussian_kernel(sigma, ground_truth[i], model_samples[j])) + return inner_val.at[i, j].set( + gaussian_kernel(sigma, ground_truth[i], model_samples[j]) + ) + return jax.lax.fori_loop(0, model_samples.shape[0], inner_body_fun, val) + K_pq = jax.lax.fori_loop(0, ground_truth.shape[0], body_fun, K_pq) sum_pq = jnp.sum(K_pq) # K_qq K_qq = jnp.zeros((model_samples.shape[0], model_samples.shape[0])) + def body_fun(i, val): def inner_body_fun(j, inner_val): - return inner_val.at[i, j].set(gaussian_kernel(sigma, model_samples[i], model_samples[j])) + return inner_val.at[i, j].set( + gaussian_kernel(sigma, model_samples[i], model_samples[j]) + ) + return jax.lax.fori_loop(0, model_samples.shape[0], inner_body_fun, val) + K_qq = jax.lax.fori_loop(0, model_samples.shape[0], body_fun, K_qq) sum_qq = jnp.sum(K_qq) - m - return 1/n/(n-1) * sum_pp - 2/n/m * sum_pq + 1/m/(m-1) * sum_qq + return 1 / n / (n - 1) * sum_pp - 2 / n / m * sum_pq + 1 / m / (m - 1) * sum_qq + def gaussian_kernel(sigma: float, x: np.ndarray, y: np.ndarray) -> float: """Calculates the value for the gaussian kernel between two vectors x, y @@ -370,7 +401,8 @@ def gaussian_kernel(sigma: float, x: np.ndarray, y: np.ndarray) -> float: Returns: float: Result value of the gaussian kernel """ - return jnp.exp(-((x-y)**2).sum()/2/sigma) + return jnp.exp(-((x - y) ** 2).sum() / 2 / sigma) + def median_heuristic(X): """ @@ -381,5 +413,7 @@ def median_heuristic(X): """ m = len(X) X = np.array(X) - med = np.median([np.sqrt(np.sum((X[i] - X[j]) ** 2)) for i in range(m) for j in range(m)]) - return med \ No newline at end of file + med = np.median( + [np.sqrt(np.sum((X[i] - X[j]) ** 2)) for i in range(m) for j in range(m)] + ) + return med diff --git a/src/qml_benchmarks/models/base.py b/src/qml_benchmarks/models/base.py index 4a7eb7fe..6738ec53 100644 --- a/src/qml_benchmarks/models/base.py +++ b/src/qml_benchmarks/models/base.py @@ -155,7 +155,7 @@ def mcmc_step(self, args, i): # flip a random bit flip_idx = jax.random.choice(key1, jnp.arange(self.dim)) - flip_config = jnp.zeros(self.dim, dtype='int32') + flip_config = jnp.zeros(self.dim, dtype="int32") flip_config = flip_config.at[flip_idx].set(1) x_flip = jnp.array((x + flip_config) % 2) @@ -200,13 +200,15 @@ def sample(self, num_samples, num_steps=1000, max_chunk_size=100): ) # chunk the sampling, otherwise the vmap can blow the memory - num_chunks = num_steps//max_chunk_size + 1 + num_chunks = num_steps // max_chunk_size + 1 x_init = jnp.array_split(x_init, num_chunks) keys = jnp.array_split(keys, num_chunks) configs = [] for elem in zip(x_init, keys): - new_configs = self.batched_mcmc_sample(self.params_, elem[0], num_steps, elem[1]) - configs.append(new_configs[:,-1]) + new_configs = self.batched_mcmc_sample( + self.params_, elem[0], num_steps, elem[1] + ) + configs.append(new_configs[:, -1]) configs = jnp.concatenate(configs) return configs diff --git a/src/qml_benchmarks/models/energy_based_model.py b/src/qml_benchmarks/models/energy_based_model.py index 26ed3361..5ae6be1f 100644 --- a/src/qml_benchmarks/models/energy_based_model.py +++ b/src/qml_benchmarks/models/energy_based_model.py @@ -55,16 +55,18 @@ class DeepEBM(EnergyBasedModel): maximum mean discrepancy. """ - def __init__(self, - learning_rate=0.001, - batch_size=32, - max_steps=10000, - cdiv_steps=1, - convergence_interval=None, - random_state=42, - jit=True, - hidden_layers=[8, 4], - mmd_kwargs = {'n_samples': 1000, 'n_steps':1000, 'sigma': 1.0}): + def __init__( + self, + learning_rate=0.001, + batch_size=32, + max_steps=10000, + cdiv_steps=1, + convergence_interval=None, + random_state=42, + jit=True, + hidden_layers=[8, 4], + mmd_kwargs={"n_samples": 1000, "n_steps": 1000, "sigma": 1.0}, + ): super().__init__( dim=None, learning_rate=learning_rate, @@ -73,7 +75,7 @@ def __init__(self, cdiv_steps=cdiv_steps, convergence_interval=convergence_interval, random_state=random_state, - jit=jit + jit=jit, ) self.hidden_layers = hidden_layers self.mmd_kwargs = mmd_kwargs @@ -107,10 +109,20 @@ def score(self, X: np.ndarray, y=None) -> float: Args: X (Array): batch of test samples to evalute the model against. """ - sigma = self.mmd_kwargs['sigma'] + sigma = self.mmd_kwargs["sigma"] sigmas = [sigma] if isinstance(sigma, (int, float)) else sigma - score = np.mean([mmd_loss(X, self.sample(self.mmd_kwargs['n_samples'], - self.mmd_kwargs['n_steps']), sigma) for sigma in sigmas]) + score = np.mean( + [ + mmd_loss( + X, + self.sample( + self.mmd_kwargs["n_samples"], self.mmd_kwargs["n_steps"] + ), + sigma, + ) + for sigma in sigmas + ] + ) return float(-score) @@ -132,6 +144,7 @@ class RestrictedBoltzmannMachine(BernoulliRBM, BaseGenerator): maximum mean discrepancy. """ + def __init__( self, n_components=256, @@ -140,8 +153,8 @@ def __init__( n_iter=10, verbose=0, random_state=42, - score_fn='pseudolikelihood', - mmd_kwargs ={'n_samples': 1000, 'n_steps': 1000, 'sigma': 1.0} + score_fn="pseudolikelihood", + mmd_kwargs={"n_samples": 1000, "n_steps": 1000, "sigma": 1.0}, ): super().__init__( n_components=n_components, @@ -198,22 +211,37 @@ def sample(self, num_samples: int, num_steps: int = 1000, n_jobs=-1) -> np.ndarr num_steps (int): number of Gibbs sampling steps for each sample n_jobs (int): number of parallel jobs to be sent via joblib. By default, uses all avaliable cores. """ - init_configs = [self.rng.choice([0, 1], size=(self.dim,)) for __ in range(num_samples)] - samples_t = [self._sample(init_config, num_steps=num_steps) for init_config in init_configs] + init_configs = [ + self.rng.choice([0, 1], size=(self.dim,)) for __ in range(num_samples) + ] + samples_t = [ + self._sample(init_config, num_steps=num_steps) + for init_config in init_configs + ] samples_t = np.array(samples_t, dtype=int) return samples_t - def score(self, X: np.ndarray, y: np.ndarray=None) -> float: + def score(self, X: np.ndarray, y: np.ndarray = None) -> float: """ Score function for hyperparameter optimization. Args: X (Array): batch of test samples to evalute the model against. - """ - if self.score_fn == 'pseudolikelihood': + """ + if self.score_fn == "pseudolikelihood": return float(np.mean(super().score_samples(X))) - elif self.score_fn == 'mmd': - sigma = self.mmd_kwargs['sigma'] + elif self.score_fn == "mmd": + sigma = self.mmd_kwargs["sigma"] sigmas = [sigma] if isinstance(sigma, (int, float)) else sigma - score = np.mean([mmd_loss(X, self.sample(self.mmd_kwargs['n_samples'], - self.mmd_kwargs['n_steps']), sigma) for sigma in sigmas]) - return float(-score) \ No newline at end of file + score = np.mean( + [ + mmd_loss( + X, + self.sample( + self.mmd_kwargs["n_samples"], self.mmd_kwargs["n_steps"] + ), + sigma, + ) + for sigma in sigmas + ] + ) + return float(-score) From 4cba0beda081e293b637b544e83f4f793a2e56c8 Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Tue, 15 Oct 2024 16:32:25 +0200 Subject: [PATCH 51/54] remove joblib args --- src/qml_benchmarks/models/energy_based_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/qml_benchmarks/models/energy_based_model.py b/src/qml_benchmarks/models/energy_based_model.py index 5ae6be1f..5fe7480e 100644 --- a/src/qml_benchmarks/models/energy_based_model.py +++ b/src/qml_benchmarks/models/energy_based_model.py @@ -202,14 +202,13 @@ def _sample(self, init_config, num_steps=1000): v = self.gibbs(v) # Assuming `gibbs` is an instance method return v - def sample(self, num_samples: int, num_steps: int = 1000, n_jobs=-1) -> np.ndarray: + def sample(self, num_samples: int, num_steps: int = 1000) -> np.ndarray: """ Sample the model. Each sample is generated by sampling a random configuration and performing a number of Gibbs sampling steps. We use joblib to parallelize the sampling. Args: num_samples (int): number of samples to return num_steps (int): number of Gibbs sampling steps for each sample - n_jobs (int): number of parallel jobs to be sent via joblib. By default, uses all avaliable cores. """ init_configs = [ self.rng.choice([0, 1], size=(self.dim,)) for __ in range(num_samples) From 516c08010bb1084a7290ffa720e0a8bbada9d2df Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Wed, 16 Oct 2024 17:55:23 +0200 Subject: [PATCH 52/54] rbm batch sample --- src/qml_benchmarks/models/energy_based_model.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/qml_benchmarks/models/energy_based_model.py b/src/qml_benchmarks/models/energy_based_model.py index 5fe7480e..94a29759 100644 --- a/src/qml_benchmarks/models/energy_based_model.py +++ b/src/qml_benchmarks/models/energy_based_model.py @@ -184,7 +184,7 @@ def fit(self, X, y=None): super().fit(X, y) # Gibbs sampling: - def _sample(self, init_config, num_steps=1000): + def _sample(self, init_configs, num_steps=1000): """ Sample the model for given number of steps via the .gibbs method of sklean's RBM. The initial configuration is sampled randomly. @@ -197,7 +197,7 @@ def _sample(self, init_config, num_steps=1000): """ if self.dim is None: raise ValueError("Model must be initialized before sampling") - v = init_config + v = init_configs for _ in range(num_steps): v = self.gibbs(v) # Assuming `gibbs` is an instance method return v @@ -210,13 +210,8 @@ def sample(self, num_samples: int, num_steps: int = 1000) -> np.ndarray: num_samples (int): number of samples to return num_steps (int): number of Gibbs sampling steps for each sample """ - init_configs = [ - self.rng.choice([0, 1], size=(self.dim,)) for __ in range(num_samples) - ] - samples_t = [ - self._sample(init_config, num_steps=num_steps) - for init_config in init_configs - ] + init_configs = self.rng.choice([0, 1], size=(num_samples, self.dim,)) + samples_t = self._sample(init_configs, num_steps=num_steps) samples_t = np.array(samples_t, dtype=int) return samples_t From 7f826e09ba0e694b92208f86d7043db768b433a3 Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Thu, 17 Oct 2024 14:53:52 +0200 Subject: [PATCH 53/54] increase chunk sample size --- src/qml_benchmarks/models/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qml_benchmarks/models/base.py b/src/qml_benchmarks/models/base.py index 6738ec53..d017b0b9 100644 --- a/src/qml_benchmarks/models/base.py +++ b/src/qml_benchmarks/models/base.py @@ -175,7 +175,7 @@ def mcmc_sample(self, params, x_init, num_mcmc_steps, key): carry, configs = jax.lax.scan(self.mcmc_step, carry, jnp.arange(num_mcmc_steps)) return configs - def sample(self, num_samples, num_steps=1000, max_chunk_size=100): + def sample(self, num_samples, num_steps=1000, max_chunk_size=1000): """ Sample configurations starting from a random configuration. Each sample is generated by sampling a random configuration and perforning a number of mcmc updates. From 4ac5ec5bd392c60b1e7a59c1238fa2cc57eaf944 Mon Sep 17 00:00:00 2001 From: Joseph Bowles Date: Sun, 20 Oct 2024 13:31:22 +0200 Subject: [PATCH 54/54] smaller chunks --- src/qml_benchmarks/models/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qml_benchmarks/models/base.py b/src/qml_benchmarks/models/base.py index d017b0b9..6738ec53 100644 --- a/src/qml_benchmarks/models/base.py +++ b/src/qml_benchmarks/models/base.py @@ -175,7 +175,7 @@ def mcmc_sample(self, params, x_init, num_mcmc_steps, key): carry, configs = jax.lax.scan(self.mcmc_step, carry, jnp.arange(num_mcmc_steps)) return configs - def sample(self, num_samples, num_steps=1000, max_chunk_size=1000): + def sample(self, num_samples, num_steps=1000, max_chunk_size=100): """ Sample configurations starting from a random configuration. Each sample is generated by sampling a random configuration and perforning a number of mcmc updates.