diff --git a/README.md b/README.md index a66f780e..cbee29f7 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Benchmarking for quantum machine learning models This repository contains tools to compare the performance of near-term quantum machine learning (QML) -as well as standard classical machine learning models on supervised learning tasks. +as well as standard classical machine learning models on supervised and generative learning tasks. It is based on pipelines using [Pennylane](https://pennylane.ai/) for the simulation of quantum circuits, [JAX](https://jax.readthedocs.io/en/latest/index.html) for training, @@ -39,12 +39,12 @@ Dependencies of this package can be installed in your environment by running pip install -r requirements.txt ``` -## Adding a custom model +## Adding a custom classifier We use the [Scikit-learn API](https://scikit-learn.org/stable/developers/develop.html) to create models and perform hyperparameter search. -A minimal template for a new quantum model is as follows, and can be stored +A minimal template for a new quantum classifier is as follows, and can be stored in `qml_benchmarks/models/my_model.py`: ```python @@ -61,18 +61,23 @@ class MyModel(BaseEstimator, ClassifierMixin): # reproducibility is ensured by creating a numpy PRNG and using it for all # subsequent random functions. - self._random_state = random_state - self._rng = np.random.default_rng(random_state) + self.random_state = random_state + self.rng = np.random.default_rng(random_state) # define data-dependent attributes self.params_ = None self.n_qubits_ = None + + def initialize(self, args): + """ + initialize the model if necessary + """ + # ... your code here ... def fit(self, X, y): """Fit the model to data X and labels y. Add your custom training loop here and store the trained model parameters in `self.params_`. - Set the data-dependent attributes, such as `self.n_qubits_`. Args: X (array_like): Data of shape (n_samples, n_features) @@ -146,9 +151,86 @@ model.fit(X_train, y_train) print(model.score(X_test, y_test)) ``` + +## Adding a custom generative model + +The minimal template for a new generative model closely follows that of the classifier models. +Labels are set to `None` throughout to maintain sci-kit learn functionality. + +```python +import numpy as np + +from sklearn.base import BaseEstimator + + +class MyModel(BaseEstimator): + def __init__(self, hyperparam1="some_value", random_state=42): + + # store hyperparameters as attributes + self.hyperparam1 = hyperparam1 + + # reproducibility is ensured by creating a numpy PRNG and using it for all + # subsequent random functions. + self.random_state = random_state + self.rng = np.random.default_rng(random_state) + + # define data-dependent attributes + self.params_ = None + self.n_qubits_ = None + + def initialize(self, args): + """ + initialize the model if necessary + """ + # ... your code here ... + + def fit(self, X, y=None): + """Fit the model to data X. + + Add your custom training loop here and store the trained model parameters in `self.params_`. + + Args: + X (array_like): Data of shape (n_samples, n_features) + y (array_like): not used (no labels) + """ + # ... your code here ... + + def sample(self, num_samples): + """sample from the generative model + + Args: + num_samples (int): number of points to sample + + Returns: + array_like: sampled points + """ + # ... your code here ... + + return samples + + def score(self, X, y=None): + """A optional custom score function to be used with hyperparameter optimization + Args: + X (array_like): Data of shape (n_samples, n_features) + y: unused (no labels for generative models) + + Returns: + (float): score for the dataset X + """ + # ... your code here ... + return score +``` + +If the model samples binary data, it is recommended to construct models that sample binary strings (rather than $\pm1$ valued strings) +to align with the datasets designed for generative models. +Energy based models can easily be constructed by replacing the multilayer perceptron neural network in `DeepEBM` by +any other differentiable network written in `flax`. + ## Datasets -The `qml_benchmarks.data` module provides generating functions to create datasets for binary classification. +The `qml_benchmarks.data` module provides generating functions to create datasets for binary classification and +generative learning. + A generating function can be used like this: ```python @@ -158,7 +240,7 @@ X, y = generate_two_curves(n_samples=200, n_features=4, degree=3, noise=0.1, off ``` Note that some datasets might have different return data structures, for example if the train/test split -is performed by the generating function. +is performed by the generating function. If the dataset does not include labels, `y = None` is returned. The original datasets used in the paper can be generated by running the scripts in the `paper/benchmarks` folder, such as: @@ -172,15 +254,18 @@ This will create a new folder in `paper/benchmarks` containing the datasets. ## Running hyperparameter optimization In the folder `scripts` we provide an example that can be used to -generate results for a hyperparameter search for any model and dataset. The script +generate results for a hyperparameter search for any model and dataset. The script functions +for both classifier and generative models. The script can be run as ``` -python run_hyperparameter_search.py --classifier-name "DataReuploadingClassifier" --dataset-path "my_dataset.csv" +python run_hyperparameter_search.py --model "DataReuploadingClassifier" --dataset-path "my_dataset.csv" ``` -where `my_dataset.csv` is a CSV file containing the training data such that each column is a feature -and the last column is the target. +where`my_dataset.csv` is a CSV file containing the training data. For classification problems, each column should +correspond to a feature and the last column to the target. For generative learning, each row +should correspond to a binary string that specifies a unique data sample, and the model should implement a `score` +method. Unless otherwise specified, the hyperparameter grid is loaded from `qml_benchmarks/hyperparameter_settings.py`. One can override the default grid of hyperparameters by specifying the hyperparameter list, @@ -189,7 +274,7 @@ For example, for the `DataReuploadingClassifier` we can run: ``` python run_hyperparameter_search.py \ - --classifier-name DataReuploadingClassifier \ + --model DataReuploadingClassifier \ --dataset-path "my_dataset.csv" \ --n_layers 1 2 \ --observable_type "single" "full"\ diff --git a/requirements.txt b/requirements.txt index 9cc6c785..cdc7791a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ pyyaml~=6.0 pennyLane~=0.34 scipy~=1.11 pandas~=2.2 +numpyro~=0.14.0 \ No newline at end of file diff --git a/scripts/run_hyperparameter_search.py b/scripts/run_hyperparameter_search.py index fdd64cc6..035913cd 100644 --- a/scripts/run_hyperparameter_search.py +++ b/scripts/run_hyperparameter_search.py @@ -20,28 +20,33 @@ import time import argparse import logging + logging.getLogger().setLevel(logging.INFO) from importlib import import_module import pandas as pd from pathlib import Path import matplotlib.pyplot as plt from sklearn.model_selection import GridSearchCV +from sklearn.metrics import make_scorer +from qml_benchmarks.models.base import BaseGenerator from qml_benchmarks.hyperparam_search_utils import read_data, construct_hyperparameter_grid from qml_benchmarks.hyperparameter_settings import hyper_parameter_settings np.random.seed(42) -logging.info('cpu count:' + str(os.cpu_count())) +def custom_scorer(estimator, X, y=None): + return estimator.score(X, y) +logging.info('cpu count:' + str(os.cpu_count())) if __name__ == "__main__": # Create an argument parser parser = argparse.ArgumentParser(description="Run experiments with hyperparameter search.", - formatter_class=argparse.ArgumentDefaultsHelpFormatter) + formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( - "--classifier-name", - help="Classifier to run", + "--model", + help="Model to run", ) parser.add_argument( @@ -91,27 +96,28 @@ # Parse the arguments along with any extra arguments that might be model specific args, unknown_args = parser.parse_known_args() - if any(arg is None for arg in [args.classifier_name, + if any(arg is None for arg in [args.model, args.dataset_path]): msg = "\n================================================================================" - msg += "\nA classifier from qml.benchmarks.model and dataset path are required. E.g., \n \n" - msg += "python run_hyperparameter_search \ \n--classifier DataReuploadingClassifier \ \n--dataset-path train.csv\n" + msg += "\nA model from qml.benchmarks.models and dataset path are required. E.g., \n \n" + msg += "python run_hyperparameter_search \ \n--model DataReuploadingClassifier \ \n--dataset-path train.csv\n" msg += "\nCheck all arguments for the script with \n" msg += "python run_hyperparameter_search --help\n" msg += "================================================================================" raise ValueError(msg) - + # Add model specific arguments to override the default hyperparameter grid hyperparam_grid = construct_hyperparameter_grid( - hyper_parameter_settings, args.classifier_name + hyper_parameter_settings, args.model ) + for hyperparam in hyperparam_grid: hp_type = type(hyperparam_grid[hyperparam][0]) parser.add_argument(f'--{hyperparam}', type=hp_type, nargs="+", default=hyperparam_grid[hyperparam], - help=f'{hyperparam} grid values for {args.classifier_name}') + help=f'{hyperparam} grid values for {args.model}') args = parser.parse_args(unknown_args, namespace=args) @@ -122,11 +128,12 @@ logging.info( "Running hyperparameter search experiment with the following settings\n" ) - logging.info(args.classifier_name) + logging.info(args.model) logging.info(args.dataset_path) logging.info(" ".join(args.hyperparameter_scoring)) logging.info(args.hyperparameter_refit) - logging.info("Hyperparam grid:"+" ".join([(str(key)+str(":")+str(hyperparam_grid[key])) for key in hyperparam_grid.keys()])) + logging.info("Hyperparam grid:" + " ".join( + [(str(key) + str(":") + str(hyperparam_grid[key])) for key in hyperparam_grid.keys()])) experiment_path = args.results_path results_path = os.path.join(experiment_path, "results") @@ -135,22 +142,25 @@ os.makedirs(results_path) ################################################################### - # Get the classifier, dataset and search methods from the arguments + # Get the model, dataset and search methods from the arguments ################################################################### - Classifier = getattr( + Model = getattr( import_module("qml_benchmarks.models"), - args.classifier_name + args.model ) - classifier_name = Classifier.__name__ + model_name = Model.__name__ + + is_generative = isinstance(Model(), BaseGenerator) + use_labels = False if is_generative else True # Run the experiments save the results train_dataset_filename = os.path.join(args.dataset_path) - X, y = read_data(train_dataset_filename) + X, y = read_data(train_dataset_filename, labels=use_labels) dataset_path_obj = Path(args.dataset_path) results_filename_stem = " ".join( - [Classifier.__name__ + "_" + dataset_path_obj.stem - + "_GridSearchCV"]) + [Model.__name__ + "_" + dataset_path_obj.stem + + "_GridSearchCV"]) # If we have already run this experiment then continue if os.path.isfile(os.path.join(results_path, results_filename_stem + ".csv")): @@ -162,44 +172,48 @@ logging.warning(msg) sys.exit(msg) else: - logging.warning("Cleaning existing results for ", os.path.join(results_path, results_filename_stem + ".csv")) - + logging.warning("Cleaning existing results for ", + os.path.join(results_path, results_filename_stem + ".csv")) ########################################################################### # Single fit to check everything works ########################################################################### - classifier = Classifier() + model = Model() a = time.time() - classifier.fit(X, y) + model.fit(X, y) b = time.time() - acc_train = classifier.score(X, y) + default_score = model.score(X, y) logging.info(" ".join( - [classifier_name, - "Dataset path", - args.dataset_path, - "Train acc:", - str(acc_train), - "Time single run", - str(b - a)]) + [model_name, + "Dataset path", + args.dataset_path, + "Train score:", + str(default_score), + "Time single run", + str(b - a)]) ) - if hasattr(classifier, "loss_history_"): + if hasattr(model, "loss_history_"): if args.plot_loss: - plt.plot(classifier.loss_history_) + plt.plot(model.loss_history_) plt.xlabel("Iterations") plt.ylabel("Loss") plt.show() - if hasattr(classifier, "n_qubits_"): - logging.info(" ".join(["Num qubits", f"{classifier.n_qubits_}"])) + if hasattr(model, "n_qubits_"): + logging.info(" ".join(["Num qubits", f"{model.n_qubits_}"])) ########################################################################### # Hyperparameter search ########################################################################### - gs = GridSearchCV(estimator=classifier, param_grid=hyperparam_grid, - scoring=args.hyperparameter_scoring, - refit=args.hyperparameter_refit, - verbose=3, - n_jobs=-1).fit( + + scorer = args.hyperparameter_scoring if not is_generative else custom_scorer + refit = args.hyperparameter_refit if not is_generative else False + + gs = GridSearchCV(estimator=model, param_grid=hyperparam_grid, + scoring=scorer, + refit=refit, + verbose=3, + n_jobs=args.n_jobs).fit( X, y ) logging.info("Best hyperparams") diff --git a/scripts/score_with_best_hyperparameters.py b/scripts/score_with_best_hyperparameters.py index 47cc8e08..90e42a63 100644 --- a/scripts/score_with_best_hyperparameters.py +++ b/scripts/score_with_best_hyperparameters.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Score a model using the best hyperparameters, using a command-line script.""" +""" +Score a model using the best hyperparameters, using a command-line script. +Note this is only compatible with supervised models for classification. +""" + import numpy as np import sys diff --git a/src/qml_benchmarks/data/__init__.py b/src/qml_benchmarks/data/__init__.py index 654ac570..04bd1df5 100644 --- a/src/qml_benchmarks/data/__init__.py +++ b/src/qml_benchmarks/data/__init__.py @@ -19,4 +19,5 @@ from qml_benchmarks.data.hyperplanes import generate_hyperplanes_parity from qml_benchmarks.data.linearly_separable import generate_linearly_separable from qml_benchmarks.data.two_curves import generate_two_curves - \ No newline at end of file +from qml_benchmarks.data.spin_blobs import generate_spin_blobs, generate_8blobs +from qml_benchmarks.data.ising import generate_ising diff --git a/src/qml_benchmarks/data/ising.py b/src/qml_benchmarks/data/ising.py new file mode 100644 index 00000000..1c47003b --- /dev/null +++ b/src/qml_benchmarks/data/ising.py @@ -0,0 +1,250 @@ +"""Ising spin simulation for a classical 2D Ising model +""" + +import jax +from jax import numpy as jnp +from numpy import ndarray +from numpyro.infer import MCMC +from jax import random +from collections import namedtuple +from numpyro.infer.mcmc import MCMCKernel +from tqdm.auto import tqdm + + +@jax.jit +def energy(s, J, b, J_sparse=None): + """Calculate the Ising energy. For sparse Hamiltonians, it is recommneded to supply a list of nonzero indices of + J to speed up the calculation. + Args: + s: spin configuration + J: interaction matrix + b: bias term + J_sparse: list of nonzero indices of J. + """ + if J_sparse is not None: + return -jnp.einsum( + "i,i,i->", s[J_sparse[0]], s[J_sparse[1]], J[J_sparse] + ) / 2.0 - jnp.dot(s, b) + else: + return -jnp.einsum("i,j,ij->", s, s, J) / 2.0 - jnp.dot(s, b) + + +def initialize_spins(rng_key, num_spins, num_chains): + if num_chains == 1: + spins = random.bernoulli(rng_key, 0.5, (num_spins,)) + else: + spins = random.bernoulli( + rng_key, + 0.5, + ( + num_chains, + num_spins, + ), + ) + return spins * 2 - 1 + + +MHState = namedtuple("MHState", ["spins", "rng_key"]) + + +class MetropolisHastings(MCMCKernel): + """An implementation of MCMC using Numpyro, see example in + https://num.pyro.ai/en/stable/mcmc.html + """ + + sample_field = "spins" + + def init( + self, + rng_key: random.PRNGKey, + num_warmup: int, + init_params: jnp.array, + *args, + **kwargs + ): + """Initialize the state of the model.""" + return MHState(init_params, rng_key) + + def sample(self, state: jnp.array, model_args, model_kwargs): + """Sample from the model via Metropolis Hastings MCMC""" + spins, rng_key = state + num_spins = spins.size + + def mh_step(i, val): + spins, rng_key = val + rng_key, subkey = random.split(rng_key) + flip_index = random.randint(subkey, (), 0, num_spins) + spins_proposal = spins.at[flip_index].set(-spins[flip_index]) + + current_energy = energy( + spins, model_kwargs["J"], model_kwargs["b"], model_kwargs["J_sparse"] + ) + proposed_energy = energy( + spins_proposal, + model_kwargs["J"], + model_kwargs["b"], + model_kwargs["J_sparse"], + ) + delta_energy = proposed_energy - current_energy + accept_prob = jnp.exp(-delta_energy / model_kwargs["T"]) + + rng_key, subkey = random.split(rng_key) + accept = random.uniform(subkey) < accept_prob + spins = jnp.where(accept, spins_proposal, spins) + return spins, rng_key + + spins, rng_key = jax.lax.fori_loop(0, num_spins, mh_step, (spins, rng_key)) + return MHState(spins, rng_key) + + +class IsingSpins: + r""" + class object used to generate datasets by sampling an ising distrbution of a specified interaction + matrix. The distribution is sampled via markov chain Monte Carlo via the Metrolopis Hastings + algorithm. + + In the case of perfect sampling, a spin configuration s is sampled with probabability + :math:`p(s)=exp(-H(s)/T)`, where the energy :math:`H(s)=\sum_{i\neq j}s_i s_i J_{ij}+\sum_i b_i s_i` + corresponds to an ising Hamiltonian and configurations s are :math:`\pm1` valued. + + The final sampled configurations are converted from a :math:`\pm1` representation to to a binary + representation via x = (s+1)//2. + + N (int): Number of spins + J (np.array): interaction matrix + b (np.array): bias terms + T (float): temperature + sparse (bool): If true, J is converted to a sparse representation (faster for sparse Hamiltonians) + compute_partition_fn: Whether to compute the partition function + """ + + def __init__( + self, + N: int, + J: jnp.array, + b: jnp.array, + T: float, + sparse=False, + compute_partition_fn=False, + ) -> None: + self.N = N + self.kernel = MetropolisHastings() + self.J = J + self.T = T + self.b = b + self.J_sparse = jnp.nonzero(J) if sparse else None + + if compute_partition_fn: + Z = 0 + for i in tqdm(range(2**self.N), desc="Computing partition function"): + lattice = (-1) ** jnp.array(jnp.unravel_index(i, [2] * self.N)) + en = energy(lattice, self.J, self.b, self.J_sparse) + Z += jnp.exp(-en / T) + self.Z = Z + + def sample( + self, num_samples: int, num_chains=1, thinning=1, num_warmup=1000, key=42 + ) -> jnp.array: + """ + Generate samples. + Args: + num_samples (int): total number of samples to generate per chain + num_chains (int): number of chains + thinning (int): how much to thin the sampling. e.g. if thinning = 10 a sample will be drawn ater each + 10 steps of mcmc sampling. Larger numbers result in more unbiased samples. + num_warmup (int): number of mcmc 'burn in' steps to perform before collecting any samples. + key (int): random seed used to initialize sampling. + """ + rng_key = random.PRNGKey(key) + init_spins = initialize_spins(rng_key, self.N, num_chains) + mcmc = MCMC( + self.kernel, + num_warmup=num_warmup, + thinning=thinning, + num_samples=num_samples, + num_chains=num_chains, + ) + mcmc.run( + rng_key, + init_params=init_spins, + J=self.J, + b=self.b, + T=self.T, + J_sparse=self.J_sparse, + ) + samples = mcmc.get_samples() + samples.reshape((-1, self.N)) + return (samples + 1) // 2 + + def probability(self, x: ndarray) -> float: + """ + compute the probability of a binary configuration x + Args: + x: binary configuration array + Returns: + (float): the probability of sampling x according to the ising distribution + """ + + if not (hasattr(self, "Z")): + raise Exception( + "probability requires partition fuction to have been computed" + ) + + return jnp.exp(-energy(x, self.J, self.b, self.J_sparse) / self.T) / self.Z + + +def generate_ising( + N: int, + num_samples: int, + J: jnp.array, + b: jnp.array, + T: float, + sparse=False, + num_chains=1, + thinning=1, + num_warmup=1000, + key=42, +): + r""" + Generating function for Ising datasets. + + The dataset is generated by sampling an ising distrbution of a specified interaction + matrix. The distribution is sampled via Markov Chain Monte Carlo via the Metrolopis Hastings + algorithm. + + In the case of perfect sampling, a spin configuration s is sampled with probabability + :math:`p(s)=exp(-H(s)/T)`, where the energy :math:`H(s)=\sum_{i\neq j}s_i s_i J_{ij}+\sum_i b_i s_i` + corresponds to an ising Hamiltonian and configurations s are :math:`\pm1` valued. + + The final sampled configurations are converted from a :math:`\pm1` representation to to a binary + representation via x = (s+1)//2. + + Note that in order to use parallelization, the number of avaliable cores has to be specified explicitly + to numpyro. i.e. the line `numpyro.set_host_device_count(num_cores)` should appear before running the + generator, where num_cores is the number of avaliable CPU cores you want to use. + + N (int): Number of spins + num_samples (int): total number of samples to generate per chain + J (np.array): interaction matrix of shape (N,N) + b (np.array): bias array of shape (N,) + T (float): temperature + num_chains (int): number of chains, defaults to 1. + thinning (int): how much to thin the sampling. e.g. if thinning = 10 a sample will be drawn after each + 10 steps of mcmc sampling. Larger numbers result in more unbiased samples. + num_warmup (int): number of mcmc 'burn in' steps to perform before collecting any samples. + key (int): random seed used to initialize sampling. + sparse (bool): If true, J is converted to a sparse representation (faster for sparse Hamiltonians) + + Returns: + Array of data samples, and Nonetype object (since there are no labels) + """ + + sampler = IsingSpins(N, J, b, T, sparse=sparse, compute_partition_fn=False) + samples = sampler.sample( + num_samples, + num_chains=num_chains, + thinning=thinning, + num_warmup=num_warmup, + key=key, + ) + return samples, None diff --git a/src/qml_benchmarks/data/spin_blobs.py b/src/qml_benchmarks/data/spin_blobs.py new file mode 100644 index 00000000..c7308546 --- /dev/null +++ b/src/qml_benchmarks/data/spin_blobs.py @@ -0,0 +1,257 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generate a dataset of spin configurations as blobs.""" + +import numpy as np + + +class RandomSpinBlobs: + """ + Class object used to generate spin blob datasets: a binary analog of the + 'gaussian blobs' dataset, in which bitstrings are sampled close in Hamming + distance to a set of specified configurations. + + The dataset is generated by specifying a list of configurations (peak spins) + that mark the centre of the 'blobs'. Data points are sampled by chosing one of + the peak spins (with probabilities specified by peak probabilities), and then + flipping some of the bits. Each bit is flipped with probability specified by + p, so that (for small p) datapoints are close in Hamming distance to one of + the peak probabilities. + + Args: + N (int): The number of spins. + num_blobs (int): + The number of blobs. + peak_probabilities (list[float], optional): + The probability of each spin to be selected. If not specified, + the probabilities are distributed uniformly. + peak_spins (list[np.array], optional): + The peak spin configurations. Selected randomly by default. + p (float, optional): + The value of the parameter `p` in a Binomial distribution specifying + the number of spins that are flipped each time during sampling. + Defaults to 0.01. + """ + + def __init__( + self, + N: int, + num_blobs: int, + peak_probabilities: list[float] = None, + peak_spins: list[np.array] = None, + p: float = 0.01, + ) -> None: + self.N = N + self.num_blobs = num_blobs + + if peak_probabilities is None: + peak_probabilities = [1 / num_blobs for i in range(num_blobs)] + + if len(peak_probabilities) != num_blobs: + msg = f"Specify probabilities for all {num_blobs} blobs." + raise ValueError(msg) + + if peak_spins is not None and len(peak_spins) != num_blobs: + msg = f"The number of peak spins should be the same as blobs." + raise ValueError(msg) + + if peak_spins is None: + # generate some random peak spin configs + # we flip each bit with 50% prob + spin_configs = [] + while len(spin_configs) < num_blobs: + config = list((-1) ** np.random.binomial([1] * N, [0.5] * N)) + if config not in spin_configs: + spin_configs.append(config) + peak_spins = spin_configs + + self.peak_spins = peak_spins + self.peak_probabilities = peak_probabilities + self.p = p + + def sample(self, num_samples: int, return_labels=False) -> np.array: + """Generate a given number of samples. + + Args: + num_samples (int): Number of samples to generate. + return_labels (bool, optional): + Whether to return labels for each sample. Defaults to False. + + Returns: + np.array: A (num_samples, N) array of spin configurations. + """ + samples = [] + labels = [] + + for _ in range(num_samples): + # Choose a random peak + label = np.random.choice(self.num_blobs, p=self.peak_probabilities) + labels.append(label) + peak_spin = self.peak_spins[label] + + # Randomly choose a Hamming distance from the sampler + num_bits_to_flip = np.random.binomial(self.N, self.p) + + # Flip bits randomly in the peak spin configuration + indices_to_flip = np.random.choice(self.N, num_bits_to_flip, replace=False) + sampled_spin = np.copy(peak_spin) + for index in indices_to_flip: + sampled_spin[index] *= -1 + + samples.append(sampled_spin) + + samples = (np.array(samples) + 1) / 2 + + if return_labels: + return samples, np.array(labels) + else: + return samples + + +def generate_spin_blobs( + N: int, + num_blobs: int, + num_samples: int, + peak_probabilities: list[float] = None, + peak_spins: list[np.array] = None, + p: float = 0.01, +): + """ + Generator function for spin blob datasets: a binary analog of the + 'gaussian blobs' dataset, in which bitstrings are sampled close in Hamming + distance to a set of specified configurations. + + The dataset is generated by specifying a list of configurations (peak spins) + that mark the centre of the 'blobs'. Data points are sampled by chosing one of + the peak spins (with probabilities specified by peak probabilities), and then + flipping some of the bits. Each bit is flipped with probability specified by + p, so that (for small p) datapoints are close in Hamming distance to one of + the peak probabilities. + + Args: + N (int): The number of spins. + num_blobs (int): + The number of blobs. + num_samples (int): The number of samples to generate. + peak_probabilities (list[float], optional): + The probability of each spin to be selected. If not specified, + the probabilities are distributed uniformly. + peak_spins (list[np.array], optional): + The peak spin configurations. Selected randomly by default. + p (float, optional): + The value of the parameter `p` in a Binomial distribution specifying + the number of spins that are flipped each time during sampling. + Defaults to 0.01. + + Returns: + tuple(np.ndarray): Dataset array and label array specifying the peak spin + that was used to sample each datapoint. + """ + + sampler = RandomSpinBlobs( + N=N, + num_blobs=num_blobs, + peak_probabilities=peak_probabilities, + peak_spins=peak_spins, + p=p, + ) + + X, y = sampler.sample(num_samples=num_samples, return_labels=True) + X = X.reshape(-1, N) + + return X, y + + +def generate_8blobs( + num_samples: int, + p: float = 0.01, +): + """Generate 4x4 spin samples with 8 selected high-probability configurations + + Example + ------- + import matplotlib.pyplot as plt + from qml_benchmarks.data.spin_blobs import generate_8blobs + X, y = generate_8blobs(100) + num_samples = 20 + interval = len(X) // num_samples + + fig, axes = plt.subplots(1, num_samples, figsize=(20, 4)) + for i in range(num_samples): + axes[i].imshow(X[i*interval].reshape((4, 4))) + axes[i].axis('off') + plt.show() + + Args: + num_samples (int): The number of samples to generate. + p (float, optional): + The value of the parameter p in a Binomial distribution bin(N, p) + that determines how many spins are flipped during each sampling step + after choosing one of the peak configurations. Defaults to 0.01. + + Returns: + np.ndarray: A (num_samples, 16) array of spin configurations. + """ + np.random.seed(66) + N: int = 16 + num_blobs: int = 8 + + # generate a specific set + config1 = np.array( + [[1, 1, -1, -1], [1, 1, -1, -1], [-1, -1, -1, -1], [-1, -1, -1, -1]] + ) + config2 = np.array( + [[-1, -1, 1, 1], [-1, -1, 1, 1], [-1, -1, -1, -1], [-1, -1, -1, -1]] + ) + config3 = np.array( + [[-1, -1, -1, -1], [-1, -1, -1, -1], [1, 1, -1, -1], [1, 1, -1, -1]] + ) + config4 = np.array( + [[-1, -1, -1, -1], [-1, -1, -1, -1], [-1, -1, 1, 1], [-1, -1, 1, 1]] + ) + config5 = np.array( + [[-1, -1, -1, -1], [-1, 1, 1, -1], [-1, 1, 1, -1], [-1, -1, -1, -1]] + ) + config6 = np.array( + [[1, -1, -1, -1], [-1, 1, -1, -1], [-1, -1, 1, -1], [-1, -1, -1, 1]] + ) + config7 = np.array( + [[-1, -1, -1, 1], [-1, -1, 1, -1], [-1, 1, -1, -1], [1, -1, -1, -1]] + ) + config8 = np.array( + [[1, -1, -1, 1], [-1, -1, -1, -1], [-1, -1, -1, -1], [1, -1, -1, 1]] + ) + + peak_spins = [ + np.reshape(config1, -1), + np.reshape(config2, -1), + np.reshape(config3, -1), + np.reshape(config4, -1), + np.reshape(config5, -1), + np.reshape(config6, -1), + np.reshape(config7, -1), + np.reshape(config8, -1), + ] + sampler = RandomSpinBlobs( + N=N, + num_blobs=num_blobs, + peak_spins=peak_spins, + p=p, + ) + + X, y = sampler.sample(num_samples=num_samples, return_labels=True) + X = X.reshape(-1, N) + + return X, y diff --git a/src/qml_benchmarks/hyperparam_search_utils.py b/src/qml_benchmarks/hyperparam_search_utils.py index 27757d80..0d9adbaf 100644 --- a/src/qml_benchmarks/hyperparam_search_utils.py +++ b/src/qml_benchmarks/hyperparam_search_utils.py @@ -19,19 +19,24 @@ import pandas as pd -def read_data(path): +def read_data(path, labels=True): """Read data from a csv file where each row is a data sample. - The columns are the input features and the last column specifies a label. + The columns are the input features. If labels=True, the last feature is understood to be the label corresponding + to that sample. - Return a 2-d array of inputs and an array of labels, X,y. + Return a 2-d array of inputs and an array of labels (if labels=True) Args: path (str): path to data """ # The data is stored on a CSV file with the last column being the label data = pd.read_csv(path, header=None) - X = data.iloc[:, :-1].values - y = data.iloc[:, -1].values + if labels: + X = data.iloc[:, :-1].values + y = data.iloc[:, -1].values + else: + X = data.iloc[:, :].values + y = None return X, y diff --git a/src/qml_benchmarks/hyperparameter_settings.py b/src/qml_benchmarks/hyperparameter_settings.py index 58a6923b..10cd26b4 100644 --- a/src/qml_benchmarks/hyperparameter_settings.py +++ b/src/qml_benchmarks/hyperparameter_settings.py @@ -191,4 +191,9 @@ "alpha": {"type": "list", "dtype": "float", "val": [0.01, 0.001, 0.0001]}, }, "Perceptron": {"eta0": {"type": "list", "dtype": "float", "val": [0.1, 1, 10]}}, + "RestrictedBoltzmannMachine": {"learning_rate": {"type": "list", "dtype": "float", "val": [0.01, 0.1, 1.]}}, + "DeepEBM": {"hidden_layers": {"type": "list", + "dtype": "tuple", + "val": ["(100,)", "(10, 10, 10, 10)", "(50, 10, 5)"], + }} } diff --git a/src/qml_benchmarks/model_utils.py b/src/qml_benchmarks/model_utils.py index 312118e2..c1b97576 100644 --- a/src/qml_benchmarks/model_utils.py +++ b/src/qml_benchmarks/model_utils.py @@ -24,6 +24,8 @@ import jax.numpy as jnp from sklearn.exceptions import ConvergenceWarning from sklearn.utils import gen_batches +import inspect +from tqdm import tqdm def train( @@ -31,7 +33,7 @@ def train( ): """ Trains a model using an optimizer and a loss function via gradient descent. We assume that the loss function - is of the form `loss(params, X, y)` and that the trainable parameters are stored in model.params_ as a dictionary + is of the form `loss(params, X, y, key)` and that the trainable parameters are stored in model.params_ as a dictionary of jnp.arrays. The optimizer should be an Optax optimizer (e.g. optax.adam). `model` must have an attribute `learning_rate` to set the initial learning rate for the gradient descent. @@ -43,10 +45,10 @@ def train( Args: model (class): Classifier class object to train. Trainable parameters must be stored in model.params_. - loss_fn (Callable): Loss function to be minimised. Must be of the form loss_fn(params, X, y). + loss_fn (Callable): Loss function to be minimised. Must be of the form loss_fn(params, X, y, key). optimizer (optax optimizer): Optax optimizer (e.g. optax.adam). X (array): Input data array of shape (n_samples, n_features) - y (array): Array of shape (n_samples) containing the labels. + y (array, optional): Array of shape (n_samples) containing the labels. random_key_generator (jax.random.PRNGKey): JAX key generator object for pseudo-randomness generation. convergence_interval (int, optional): Number of optimization steps over which to decide convergence. Larger values give a higher confidence that the model has converged but may increase training time. @@ -55,13 +57,23 @@ def train( params (dict): The new parameters after training has completed. """ - if not model.batch_size / model.max_vmap % 1 == 0: - raise Exception("Batch size must be multiple of max_vmap.") + if model.max_vmap is not None: + if not model.batch_size / model.max_vmap % 1 == 0: + raise Exception("Batch size must be multiple of max_vmap.") + + # wrap a key around the function if it doesn't have one + if "key" not in inspect.signature(loss_fn).parameters: + + def loss_fn_wrapped(params, x, y, key): + return loss_fn(params, x, y) + + else: + loss_fn_wrapped = loss_fn params = model.params_ opt = optimizer(learning_rate=model.learning_rate) opt_state = opt.init(params) - grad_fn = jax.grad(loss_fn) + grad_fn = jax.grad(loss_fn_wrapped) # jitting through the chunked_grad function can take a long time, # so we jit here and chunk after @@ -70,12 +82,18 @@ def train( # note: assumes that the loss function is a sample mean of # some function over the input data set - chunked_grad_fn = chunk_grad(grad_fn, model.max_vmap) - chunked_loss_fn = chunk_loss(loss_fn, model.max_vmap) + chunked_grad_fn = ( + chunk_grad(grad_fn, model.max_vmap) if model.max_vmap is not None else grad_fn + ) + chunked_loss_fn = ( + chunk_loss(loss_fn_wrapped, model.max_vmap) + if model.max_vmap is not None + else loss_fn_wrapped + ) - def update(params, opt_state, x, y): - grads = chunked_grad_fn(params, x, y) - loss_val = chunked_loss_fn(params, x, y) + def update(params, opt_state, x, y, key): + grads = chunked_grad_fn(params, x, y, key) + loss_val = chunked_loss_fn(params, x, y, key) updates, opt_state = opt.update(grads, opt_state) params = optax.apply_updates(params, updates) return params, opt_state, loss_val @@ -83,39 +101,48 @@ def update(params, opt_state, x, y): loss_history = [] converged = False start = time.time() - for step in range(model.max_steps): - key = random_key_generator() - X_batch, y_batch = get_batch(X, y, key, batch_size=model.batch_size) - params, opt_state, loss_val = update(params, opt_state, X_batch, y_batch) - loss_history.append(loss_val) - logging.debug(f"{step} - loss: {loss_val}") - - if np.isnan(loss_val): - logging.info(f"nan encountered. Training aborted.") - break - - # decide convergence - if step > 2 * convergence_interval: - # get means of last two intervals and standard deviation of last interval - average1 = np.mean(loss_history[-convergence_interval:]) - average2 = np.mean( - loss_history[-2 * convergence_interval : -convergence_interval] + with tqdm(total=model.max_steps, desc="Training Progress") as pbar: + for step in range(model.max_steps): + key_batch = random_key_generator() + key_loss = jax.random.split(key_batch, 1)[0] + X_batch, y_batch = get_batch(X, y, key_batch, batch_size=model.batch_size) + params, opt_state, loss_val = update( + params, opt_state, X_batch, y_batch, key_loss ) - std1 = np.std(loss_history[-convergence_interval:]) - # if the difference in averages is small compared to the statistical fluctuations, stop training. - if np.abs(average2 - average1) <= std1 / np.sqrt(convergence_interval) / 2: - logging.info( - f"Model {model.__class__.__name__} converged after {step} steps." - ) - converged = True + loss_history.append(loss_val) + logging.debug(f"{step} - loss: {loss_val}") + pbar.update(1) + + if np.isnan(loss_val): + logging.info(f"nan encountered. Training aborted.") break + # decide convergence + if convergence_interval is not None: + if step > 2 * convergence_interval: + # get means of last two intervals and standard deviation of last interval + average1 = np.mean(loss_history[-convergence_interval:]) + average2 = np.mean( + loss_history[-2 * convergence_interval : -convergence_interval] + ) + std1 = np.std(loss_history[-convergence_interval:]) + # if the difference in averages is small compared to the statistical fluctuations, stop training. + if ( + np.abs(average2 - average1) + <= std1 / np.sqrt(convergence_interval) / 2 + ): + logging.info( + f"Model {model.__class__.__name__} converged after {step} steps." + ) + converged = True + break + end = time.time() loss_history = np.array(loss_history) model.loss_history_ = loss_history / np.max(np.abs(loss_history)) model.training_time_ = end - start - if not converged: + if not converged and convergence_interval is not None: print("Loss did not converge:", loss_history) raise ConvergenceWarning( f"Model {model.__class__.__name__} has not converged after the maximum number of {model.max_steps} steps." @@ -142,7 +169,11 @@ def get_batch(X, y, rnd_key, batch_size=32): rnd_indices = jax.random.choice( key=rnd_key, a=all_indices, shape=(batch_size,), replace=True ) - return X[rnd_indices], y[rnd_indices] + + if y is not None: + return X[rnd_indices], y[rnd_indices] + else: + return X[rnd_indices], None def get_from_dict(dict, key_list): @@ -238,7 +269,7 @@ def chunk_grad(grad_fn, max_vmap): """ Convert a `jax.grad` function to an equivalent version that evaluated in chunks of size max_vmap. - `grad_fn` should be of the form `jax.grad(fn(params, X, y), argnums=0)`, where `params` is a + `grad_fn` should be of the form `jax.grad(fn(params, X, y, key), argnums=0)`, where `params` is a dictionary of `jnp.arrays`, `X, y` are `jnp.arrays` with the same-size leading axis, and `grad_fn` is a function that is vectorised along these axes (i.e. `in_axes = (None,0,0)`). @@ -253,9 +284,9 @@ def chunk_grad(grad_fn, max_vmap): chunked version of the function """ - def chunked_grad(params, X, y): + def chunked_grad(params, X, y, key): batch_slices = list(gen_batches(len(X), max_vmap)) - grads = [grad_fn(params, X[slice], y[slice]) for slice in batch_slices] + grads = [grad_fn(params, X[slice], y[slice], key) for slice in batch_slices] grad_dict = {} for key_list in get_nested_keys(params): set_in_dict( @@ -272,7 +303,7 @@ def chunked_grad(params, X, y): def chunk_loss(loss_fn, max_vmap): """ - Converts a loss function of the form `loss_fn(params, array1, array2)` to an equivalent version that + Converts a loss function of the form `loss_fn(params, array1, array2, key)` to an equivalent version that evaluates `loss_fn` in chunks of size max_vmap. `loss_fn` should batch evaluate along the leading axis of `array1, array2` (i.e. `in_axes = (None,0,0)`). @@ -284,11 +315,105 @@ def chunk_loss(loss_fn, max_vmap): chunked version of the function """ - def chunked_loss(params, X, y): + def chunked_loss(params, X, y, key): batch_slices = list(gen_batches(len(X), max_vmap)) res = jnp.array( - [loss_fn(params, *[X[slice], y[slice]]) for slice in batch_slices] + [loss_fn(params, *[X[slice], y[slice]], key) for slice in batch_slices] ) return jnp.mean(res) return chunked_loss + + +def mmd_loss( + ground_truth: np.ndarray, model_samples: np.ndarray, sigma: float +) -> float: + """Calculates an unbiased estimate of the Maximum Mean Discrepancy (MMD) loss from samples + see https://jmlr.org/papers/volume13/gretton12a/gretton12a.pdf for more info + + Args: + ground_truth (np.ndarray): Samples from the ground truth distribution. + model_samples (np.ndarray): Samples from the test model. + sigma (float): Sigma parameter, the width of the kernel. + + Returns: + float: The value of the MMD loss. + """ + + n = len(ground_truth) + m = len(model_samples) + ground_truth = jnp.array(ground_truth) + model_samples = jnp.array(model_samples) + + # K_pp + K_pp = jnp.zeros((ground_truth.shape[0], ground_truth.shape[0])) + + def body_fun(i, val): + def inner_body_fun(j, inner_val): + return inner_val.at[i, j].set( + gaussian_kernel(sigma, ground_truth[i], ground_truth[j]) + ) + + return jax.lax.fori_loop(0, ground_truth.shape[0], inner_body_fun, val) + + K_pp = jax.lax.fori_loop(0, ground_truth.shape[0], body_fun, K_pp) + sum_pp = jnp.sum(K_pp) - n + + # K_pq + K_pq = jnp.zeros((ground_truth.shape[0], model_samples.shape[0])) + + def body_fun(i, val): + def inner_body_fun(j, inner_val): + return inner_val.at[i, j].set( + gaussian_kernel(sigma, ground_truth[i], model_samples[j]) + ) + + return jax.lax.fori_loop(0, model_samples.shape[0], inner_body_fun, val) + + K_pq = jax.lax.fori_loop(0, ground_truth.shape[0], body_fun, K_pq) + sum_pq = jnp.sum(K_pq) + + # K_qq + K_qq = jnp.zeros((model_samples.shape[0], model_samples.shape[0])) + + def body_fun(i, val): + def inner_body_fun(j, inner_val): + return inner_val.at[i, j].set( + gaussian_kernel(sigma, model_samples[i], model_samples[j]) + ) + + return jax.lax.fori_loop(0, model_samples.shape[0], inner_body_fun, val) + + K_qq = jax.lax.fori_loop(0, model_samples.shape[0], body_fun, K_qq) + sum_qq = jnp.sum(K_qq) - m + + return 1 / n / (n - 1) * sum_pp - 2 / n / m * sum_pq + 1 / m / (m - 1) * sum_qq + + +def gaussian_kernel(sigma: float, x: np.ndarray, y: np.ndarray) -> float: + """Calculates the value for the gaussian kernel between two vectors x, y + + Args: + sigma (float): sigma parameter, the width of the kernel + x (np.ndarray): one of the vectors + y (np.ndarray): the other vector + + Returns: + float: Result value of the gaussian kernel + """ + return jnp.exp(-((x - y) ** 2).sum() / 2 / sigma) + + +def median_heuristic(X): + """ + Computes an estimate of the median heuristic used to decide the bandwidth of the RBF kernels; see + https://arxiv.org/abs/1707.07269 + :param X (array): Dataset of interest + :return (float): median heuristic estimate + """ + m = len(X) + X = np.array(X) + med = np.median( + [np.sqrt(np.sum((X[i] - X[j]) ** 2)) for i in range(m) for j in range(m)] + ) + return med diff --git a/src/qml_benchmarks/models/__init__.py b/src/qml_benchmarks/models/__init__.py index 2cabcec6..28f57d05 100644 --- a/src/qml_benchmarks/models/__init__.py +++ b/src/qml_benchmarks/models/__init__.py @@ -20,8 +20,8 @@ ) from qml_benchmarks.models.data_reuploading import ( DataReuploadingClassifier, - DataReuploadingClassifierNoScaling, DataReuploadingClassifierNoCost, + DataReuploadingClassifierNoScaling, DataReuploadingClassifierNoTrainableEmbedding, DataReuploadingClassifierSeparable, ) @@ -30,27 +30,29 @@ DressedQuantumCircuitClassifierOnlyNN, DressedQuantumCircuitClassifierSeparable, ) - from qml_benchmarks.models.iqp_kernel import IQPKernelClassifier from qml_benchmarks.models.iqp_variational import IQPVariationalClassifier from qml_benchmarks.models.projected_quantum_kernel import ProjectedQuantumKernel from qml_benchmarks.models.quantum_boltzmann_machine import ( QuantumBoltzmannMachine, - QuantumBoltzmannMachineSeparable + QuantumBoltzmannMachineSeparable, ) from qml_benchmarks.models.quantum_kitchen_sinks import QuantumKitchenSinks from qml_benchmarks.models.quantum_metric_learning import QuantumMetricLearner -from qml_benchmarks.models.quanvolutional_neural_network import QuanvolutionalNeuralNetwork +from qml_benchmarks.models.quanvolutional_neural_network import ( + QuanvolutionalNeuralNetwork, +) from qml_benchmarks.models.separable import ( - SeparableVariationalClassifier, SeparableKernelClassifier, + SeparableVariationalClassifier, ) +from qml_benchmarks.models.energy_based_model import DeepEBM, RestrictedBoltzmannMachine from qml_benchmarks.models.tree_tensor import TreeTensorClassifier from qml_benchmarks.models.vanilla_qnn import VanillaQNN from qml_benchmarks.models.weinet import WeiNet -from sklearn.svm import SVC as SVC_base from sklearn.neural_network import MLPClassifier as MLP +from sklearn.svm import SVC as SVC_base __all__ = [ "CircuitCentricClassifier", @@ -78,35 +80,37 @@ "WeiNet", "MLPClassifier", "SVC", + "DeepEBM", + "RestrictedBoltzmannMachine", ] class MLPClassifier(MLP): def __init__( - self, - hidden_layer_sizes=(100, 100), - activation="relu", - solver="adam", - alpha=0.0001, - batch_size="auto", - learning_rate="constant", - learning_rate_init=0.001, - power_t=0.5, - max_iter=3000, - shuffle=True, - random_state=None, - tol=1e-4, - verbose=False, - warm_start=False, - momentum=0.9, - nesterovs_momentum=True, - early_stopping=False, - validation_fraction=0.1, - beta_1=0.9, - beta_2=0.999, - epsilon=1e-8, - n_iter_no_change=10, - max_fun=15000, + self, + hidden_layer_sizes=(100, 100), + activation="relu", + solver="adam", + alpha=0.0001, + batch_size="auto", + learning_rate="constant", + learning_rate_init=0.001, + power_t=0.5, + max_iter=3000, + shuffle=True, + random_state=None, + tol=1e-4, + verbose=False, + warm_start=False, + momentum=0.9, + nesterovs_momentum=True, + early_stopping=False, + validation_fraction=0.1, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-8, + n_iter_no_change=10, + max_fun=15000, ): super().__init__( hidden_layer_sizes=hidden_layer_sizes, @@ -137,16 +141,16 @@ def __init__( class SVC(SVC_base): def __init__( - self, - C=1.0, - degree=3, - gamma="scale", - coef0=0.0, - shrinking=True, - probability=False, - tol=0.001, - max_iter=-1, - random_state=None, + self, + C=1.0, + degree=3, + gamma="scale", + coef0=0.0, + shrinking=True, + probability=False, + tol=0.001, + max_iter=-1, + random_state=None, ): super().__init__( C=C, diff --git a/src/qml_benchmarks/models/base.py b/src/qml_benchmarks/models/base.py new file mode 100644 index 00000000..6738ec53 --- /dev/null +++ b/src/qml_benchmarks/models/base.py @@ -0,0 +1,270 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base classes for models.""" + +import copy +from abc import abstractmethod + +import jax +import jax.numpy as jnp +import numpy as np +import optax +from qml_benchmarks.model_utils import train +from sklearn.base import BaseEstimator + + +class BaseGenerator(BaseEstimator): + """ + A base class for generative models. + + We use Scikit-learn's `BaseEstimator` so that we can take advantage of + Scikit-learn's hyperparameter search algorithms like `GridSearchCV` or + `RandomizedSearchCV` for hyperparameter tuning. + + Args: + dim (int): dimension of the data (i.e. the number of features) + """ + + def __init__(self, dim: int) -> None: + self.dim = dim + + @abstractmethod + def initialize(self, x: any = None): + """ + Initialize the model and create the model parameters. + + Args: + x: batch of data to use to initialize the model + """ + pass + + @abstractmethod + def sample(self, num_samples: int) -> any: + """ + Sample from the model. + + Args: + num_samples: The number of samples to generate. + """ + pass + + +class EnergyBasedModel(BaseGenerator): + r""" + A base class for energy-based generative models with common functionalities. + + The class implements MCMC sampling via the energy function. This is used to sample from the model and to train + the model via k-contrastive divergence (see eqn (3) of arXiv:2101.03288). + + We use Scikit-learn's `BaseEstimator` so that we can take advantage of + Scikit-learn's hyperparameter search algorithms like `GridSearchCV` or + `RandomizedSearchCV` for hyperparameter tuning. + + The parameters of the model are stored in the `params_` attribute. The + model hyperparameters are explicitly passed to the constructor. See the + `BaseEstimator` documentation for more details. + + References: + Yang Song, Diederik P. Kingma + "How to Train Your Energy-Based Models" + arXiv:2101.03288 + + Args: + dim (int): dimension of the data (i.e. number of features) + cdiv_steps (int): number of mcmc steps to perform to estimate the constrastive divergence loss (default 1) + convergence_interval (int): The number of loss values to consider to decide convergence. + max_steps (int): Maximum number of training steps. A warning will be raised if training did not converge. + learning_rate (float): Initial learning rate for training. + batch_size (int): Size of batches used for computing parameter updates. + jit (bool): Whether to use just in time compilation. + random_state (int): Seed used for pseudorandom number generation. + """ + + def __init__( + self, + dim: int = None, + learning_rate=0.001, + batch_size=32, + max_steps=10000, + cdiv_steps=1, + convergence_interval=None, + random_state=42, + jit=True, + ) -> None: + self.learning_rate = learning_rate + self.jit = jit + self.batch_size = batch_size + self.max_steps = max_steps + self.cdiv_steps = cdiv_steps + self.convergence_interval = convergence_interval + self.vmap = True + self.max_vmap = None + self.random_state = random_state + self.rng = np.random.default_rng(random_state) + + # data depended attributes + self.params_: dict[str : jnp.array] = None + self.dim = dim # initialized to None + + # Train dependent attributes that the function train in self.fit() sets. + self.history_: list[float] = None + self.training_time_: float = None + + # jax transformations of class functions + self.mcmc_step = jax.jit(self.mcmc_step) if self.jit else self.mcmc_step + self.batched_mcmc_sample = jax.vmap( + self.mcmc_sample, in_axes=(None, 0, None, 0) + ) + + def generate_key(self): + return jax.random.PRNGKey(self.rng.integers(1000000)) + + @abstractmethod + def energy(self, params: dict, x: any) -> float: + """ + The energy function for the model for a batch of configurations x. + This function should be implemented by the subclass. + + Args: + params: model parameters that determine the energy function. + x: batch of configurations of shape (n_batch, dim) for which to calculate the energy + Returns: + energy (Array): Array of energies of shape (n_batch,) + """ + pass + + def mcmc_step(self, args, i): + """ + Perform one metropolis hastings steps. + The format is such that it can be used with jax.lax.scan for fast compilation. + """ + params, key, x = args + key1, key2 = jax.random.split(key, 2) + + # flip a random bit + flip_idx = jax.random.choice(key1, jnp.arange(self.dim)) + flip_config = jnp.zeros(self.dim, dtype="int32") + flip_config = flip_config.at[flip_idx].set(1) + x_flip = jnp.array((x + flip_config) % 2) + + en = self.energy(params, jnp.expand_dims(x, 0))[0] + en_flip = self.energy(params, jnp.expand_dims(x_flip, 0))[0] + + accept_ratio = jnp.exp(-en_flip) / jnp.exp(-en) + accept = jnp.array(jax.random.bernoulli(key2, accept_ratio), dtype=int)[0] + x_new = accept * x_flip + (1 - accept) * x + return [params, key2, x_new], x_new + + def mcmc_sample(self, params, x_init, num_mcmc_steps, key): + """ + Sample a chain of configurations from a starting configuration x_init + """ + carry = [params, key, x_init] + carry, configs = jax.lax.scan(self.mcmc_step, carry, jnp.arange(num_mcmc_steps)) + return configs + + def sample(self, num_samples, num_steps=1000, max_chunk_size=100): + """ + Sample configurations starting from a random configuration. + Each sample is generated by sampling a random configuration and perforning a number of mcmc updates. + Args: + num_samples (int): number of samples to draw + num_steps (int): number of mcmc steps before drawing a sample + max_chunk_size (int): maximum number of samples to vmap the sampling for at a time (large values + use significant memory) + """ + if self.params_ is None: + raise ValueError( + "Model not initialized. Call model.initialize first with" + "example data sample." + ) + keys = jax.random.split(self.generate_key(), num_samples) + + x_init = jnp.array( + jax.random.bernoulli( + self.generate_key(), p=0.5, shape=(num_samples, self.dim) + ), + dtype=int, + ) + + # chunk the sampling, otherwise the vmap can blow the memory + num_chunks = num_steps // max_chunk_size + 1 + x_init = jnp.array_split(x_init, num_chunks) + keys = jnp.array_split(keys, num_chunks) + configs = [] + for elem in zip(x_init, keys): + new_configs = self.batched_mcmc_sample( + self.params_, elem[0], num_steps, elem[1] + ) + configs.append(new_configs[:, -1]) + + configs = jnp.concatenate(configs) + return configs + + def contrastive_divergence_loss(self, params, X, y, key): + """ + Implementation of the standard contrastive divergence loss function (see eqn 3 of arXiv:2101.03288). + Args: + X (array): batch of training examples + y (array): not used; should be set to None when training + key: JAX PRNG key used for MCMC sampling + """ + keys = jax.random.split(key, X.shape[0]) + + # we do not take the gradient wrt the sampling, so decouple the param dict here + params_copy = copy.deepcopy(params) + for key in params_copy.keys(): + params_copy[key] = jax.lax.stop_gradient(params_copy[key]) + + configs = self.batched_mcmc_sample(params_copy, X, self.cdiv_steps, keys) + x0 = X + x1 = configs[:, -1] + + # taking the gradient of this loss is equivalent to the CD-k update + loss = self.energy(params, x0) - self.energy(params, x1) + + return jnp.mean(loss) + + def fit(self, X: jnp.array, y: any = None) -> None: + """ + Fit the parameters and update self.params_. + """ + self.initialize(X) + c_div_loss = ( + jax.jit(self.contrastive_divergence_loss) + if self.jit + else self.contrastive_divergence_loss + ) + + self.params_ = train( + self, + c_div_loss, + optax.adam, + X, + None, + self.generate_key, + convergence_interval=self.convergence_interval, + ) + + @abstractmethod + def score(self, X, y=None) -> any: + """ + Score function to be used with hyperparameter optimization (larger score => better) + + Args: + X: Dataset to calculate score for + y: labels (set to None for generative models to interface with sklearn functionality) + """ + pass diff --git a/src/qml_benchmarks/models/energy_based_model.py b/src/qml_benchmarks/models/energy_based_model.py new file mode 100644 index 00000000..94a29759 --- /dev/null +++ b/src/qml_benchmarks/models/energy_based_model.py @@ -0,0 +1,241 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import flax.linen as nn +from qml_benchmarks.models.base import EnergyBasedModel, BaseGenerator +from sklearn.neural_network import BernoulliRBM +from qml_benchmarks.model_utils import mmd_loss, median_heuristic +import numpy as np + + +class MLP(nn.Module): + "Multilayer perceptron implemented in flax" + # Create a MLP with hidden layers and neurons specfied as a list of integers. + hidden_layers: list[int] + + @nn.compact + def __call__(self, x): + for dim in self.hidden_layers: + x = nn.Dense(dim)(x) + x = nn.tanh(x) + x = nn.Dense(1)(x) + return x + + +class DeepEBM(EnergyBasedModel): + """ + Energy-based model which uses a fully connected multi-layer perceptron neural network as its energy function. + The model is trained via k-contrastive divergence. + The score function corresponds to the (negative of the) maximum mean discrepancy distance. + + Args: + learning_rate (float): Initial learning rate for training. + batch_size (int): Size of batches used for computing parameter updates. + max_steps (int): Maximum number of training steps. A warning will be raised if training did not converge. + cdiv_steps (int): number of mcmc steps to perform to estimate the constrastive divergence loss (default 1) + convergence_interval (int): The number of loss values to consider to decide convergence. + jit (bool): Whether to use just in time compilation. + random_state (int): Seed used for pseudorandom number generation. + hidden_layers (list[int]): + The number of hidden layers and neurons in the MLP layers. e.g. [8,4] uses a three layer network where the + first layers maps to 8 neurons, the second to 4, and the last layer to 1 neuron. + mmd_kwargs (dict): arguments used for the maximum mean discrepancy score. n_samples and n_steps are the args + sent to self.sample when sampling configurations used for evaluation. sigma is the bandwidth of the + maximum mean discrepancy. + """ + + def __init__( + self, + learning_rate=0.001, + batch_size=32, + max_steps=10000, + cdiv_steps=1, + convergence_interval=None, + random_state=42, + jit=True, + hidden_layers=[8, 4], + mmd_kwargs={"n_samples": 1000, "n_steps": 1000, "sigma": 1.0}, + ): + super().__init__( + dim=None, + learning_rate=learning_rate, + batch_size=batch_size, + max_steps=max_steps, + cdiv_steps=cdiv_steps, + convergence_interval=convergence_interval, + random_state=random_state, + jit=jit, + ) + self.hidden_layers = hidden_layers + self.mmd_kwargs = mmd_kwargs + + def initialize(self, x): + dim = x.shape[1] + if not isinstance(dim, int): + raise NotImplementedError( + "The model is not yet implemented for data" + "with arbitrary dimensions. `dim` must be an integer." + ) + + self.dim = dim + self.model = MLP(hidden_layers=self.hidden_layers) + self.params_ = self.model.init(self.generate_key(), x) + + def energy(self, params, x): + """ + energy function + Args: + params: parameters of the neural network to be passed to flax + x: batch of configurations of shape (n_batch, dim) + returns: + batch of energy values + """ + return self.model.apply(params, x) + + def score(self, X: np.ndarray, y=None) -> float: + """ + Maximum mean discrepancy score function + Args: + X (Array): batch of test samples to evalute the model against. + """ + sigma = self.mmd_kwargs["sigma"] + sigmas = [sigma] if isinstance(sigma, (int, float)) else sigma + score = np.mean( + [ + mmd_loss( + X, + self.sample( + self.mmd_kwargs["n_samples"], self.mmd_kwargs["n_steps"] + ), + sigma, + ) + for sigma in sigmas + ] + ) + return float(-score) + + +class RestrictedBoltzmannMachine(BernoulliRBM, BaseGenerator): + """ + Implementation of a restricted Boltzmann machine. The model wraps the scikit-learn BernoulliRBM class and is + trained via constrastive divergence. + Args: + n_components (int): Number of hidden units in the RBM + learning_rate (float): learning rate for training + batch_size (int): batch size for training + n_iter (int): number of epochs of training + verbose (int): verbosity level + random_state (int): random seed used for reproducibility + score_fn (str): determinies the score function used in hyperparameter optimization. If 'pseudolikelihood' + sklearn's pseudolikelihood function is used, if 'mmd' the (negative of) the maximum mean discrepancy is used. + mmd_kwargs (dict): arguments used for the maximum mean discrepancy score. n_samples and n_steps are the args + sent to self.sample when sampling configurations used for evaluation. sigma is the bandwidth of the + maximum mean discrepancy. + + """ + + def __init__( + self, + n_components=256, + learning_rate=0.0001, + batch_size=10, + n_iter=10, + verbose=0, + random_state=42, + score_fn="pseudolikelihood", + mmd_kwargs={"n_samples": 1000, "n_steps": 1000, "sigma": 1.0}, + ): + super().__init__( + n_components=n_components, + learning_rate=learning_rate, + batch_size=batch_size, + n_iter=n_iter, + verbose=verbose, + random_state=random_state, + ) + self.score_fn = score_fn + self.mmd_kwargs = mmd_kwargs + self.rng = np.random.default_rng(random_state) + + def initialize(self, X: any = None): + if len(X.shape) > 2: + raise ValueError("Input data must be 2D") + self.dim = X.shape[1] + + def fit(self, X, y=None): + """ + fit the model using k-contrastive divergence. simply wraps the sklearn fit function. + Args: + X (np.array): training data + y: not used; set to None to interface with sklearn correctly. + """ + self.initialize(X) + super().fit(X, y) + + # Gibbs sampling: + def _sample(self, init_configs, num_steps=1000): + """ + Sample the model for given number of steps via the .gibbs method of sklean's RBM. The initial configuration + is sampled randomly. + + Args: + num_steps (int): Number of Gibbs sample steps + + Returns: + np.array: The sampled configurations + """ + if self.dim is None: + raise ValueError("Model must be initialized before sampling") + v = init_configs + for _ in range(num_steps): + v = self.gibbs(v) # Assuming `gibbs` is an instance method + return v + + def sample(self, num_samples: int, num_steps: int = 1000) -> np.ndarray: + """ + Sample the model. Each sample is generated by sampling a random configuration and performing a number of + Gibbs sampling steps. We use joblib to parallelize the sampling. + Args: + num_samples (int): number of samples to return + num_steps (int): number of Gibbs sampling steps for each sample + """ + init_configs = self.rng.choice([0, 1], size=(num_samples, self.dim,)) + samples_t = self._sample(init_configs, num_steps=num_steps) + samples_t = np.array(samples_t, dtype=int) + return samples_t + + def score(self, X: np.ndarray, y: np.ndarray = None) -> float: + """ + Score function for hyperparameter optimization. + Args: + X (Array): batch of test samples to evalute the model against. + """ + if self.score_fn == "pseudolikelihood": + return float(np.mean(super().score_samples(X))) + elif self.score_fn == "mmd": + sigma = self.mmd_kwargs["sigma"] + sigmas = [sigma] if isinstance(sigma, (int, float)) else sigma + score = np.mean( + [ + mmd_loss( + X, + self.sample( + self.mmd_kwargs["n_samples"], self.mmd_kwargs["n_steps"] + ), + sigma, + ) + for sigma in sigmas + ] + ) + return float(-score)