diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index db92dcf79b3..df7004f00f9 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -4,6 +4,13 @@

New features since last release

+* Introduced `sample_probs` function for the `qml.devices.qubit` and `qml.devices.qutrit_mixed` modules: + - This function takes probability distributions as input and returns sampled outcomes. + - Simplifies the sampling process by separating it from other operations in the measurement chain. + - Improves modularity: The same code can be easily adapted for other devices (e.g., a potential `default_mixed` device). + - Enhances maintainability by isolating the sampling logic. + [(#6354)](https://github.com/PennyLaneAI/pennylane/pull/6354) + * `qml.transforms.decompose` is added for stepping through decompositions to a target gate set. [(#6334)](https://github.com/PennyLaneAI/pennylane/pull/6334) diff --git a/pennylane/devices/qubit/__init__.py b/pennylane/devices/qubit/__init__.py index 860d91953c4..b12ad79640a 100644 --- a/pennylane/devices/qubit/__init__.py +++ b/pennylane/devices/qubit/__init__.py @@ -26,15 +26,16 @@ measure measure_with_samples sample_state + sample_probs simulate adjoint_jacobian adjoint_jvp adjoint_vjp """ -from .apply_operation import apply_operation from .adjoint_jacobian import adjoint_jacobian, adjoint_jvp, adjoint_vjp +from .apply_operation import apply_operation from .initialize_state import create_initial_state from .measure import measure -from .sampling import sample_state, measure_with_samples -from .simulate import simulate, get_final_state, measure_final_state +from .sampling import measure_with_samples, sample_probs, sample_state +from .simulate import get_final_state, measure_final_state, simulate diff --git a/pennylane/devices/qubit/sampling.py b/pennylane/devices/qubit/sampling.py index be6090a0354..85c89de3701 100644 --- a/pennylane/devices/qubit/sampling.py +++ b/pennylane/devices/qubit/sampling.py @@ -471,51 +471,70 @@ def sample_state( Returns: ndarray[int]: Sample values of the shape (shots, num_wires) """ - if prng_key is not None or qml.math.get_interface(state) == "jax": - return _sample_state_jax( - state, shots, prng_key, is_state_batched=is_state_batched, wires=wires, seed=rng - ) - - rng = np.random.default_rng(rng) total_indices = len(state.shape) - is_state_batched state_wires = qml.wires.Wires(range(total_indices)) wires_to_sample = wires or state_wires num_wires = len(wires_to_sample) - basis_states = np.arange(2**num_wires) flat_state = flatten_state(state, total_indices) with qml.queuing.QueuingManager.stop_recording(): probs = qml.probs(wires=wires_to_sample).process_state(flat_state, state_wires) + # Keep same interface (e.g. jax) as in the device - # when using the torch interface with float32 as default dtype, - # probabilities must be renormalized as they may not sum to one - # see https://github.com/PennyLaneAI/pennylane/issues/5444 - norm = qml.math.sum(probs, axis=-1) - abs_diff = qml.math.abs(norm - 1.0) - cutoff = 1e-07 + return sample_probs(probs, shots, num_wires, is_state_batched, rng, prng_key) - if is_state_batched: - normalize_condition = False - for s in abs_diff: - if s != 0: - normalize_condition = True - if s > cutoff: - normalize_condition = False - break +def sample_probs(probs, shots, num_wires, is_state_batched, rng, prng_key=None): + """ + Sample from given probabilities, dispatching between JAX and NumPy implementations. + + Args: + probs (array): The probabilities to sample from + shots (int): The number of samples to take + num_wires (int): The number of wires to sample + is_state_batched (bool): whether the state is batched or not + rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]): + A seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``. + If no value is provided, a default RNG will be used + prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is + the key to the JAX pseudo random number generator. Only for simulation using JAX. + """ + if qml.math.get_interface(probs) == "jax" or prng_key is not None: + return _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key, seed=rng) + + return _sample_probs_numpy(probs, shots, num_wires, is_state_batched, rng) + + +def _sample_probs_numpy(probs, shots, num_wires, is_state_batched, rng): + """ + Sample from given probabilities using NumPy's random number generator. + + Args: + probs (array): The probabilities to sample from + shots (int): The number of samples to take + num_wires (int): The number of wires to sample + is_state_batched (bool): whether the state is batched or not + rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]): + A seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``. + If no value is provided, a default RNG will be used + """ + rng = np.random.default_rng(rng) + norm = qml.math.sum(probs, axis=-1) + norm_err = qml.math.abs(norm - 1.0) + cutoff = 1e-07 - if normalize_condition: - probs = probs / norm[:, np.newaxis] if norm.shape else probs / norm + norm_err = norm_err[..., np.newaxis] if not is_state_batched else norm_err + if qml.math.any(norm_err > cutoff): + raise ValueError("probabilities do not sum to 1") - # rng.choice doesn't support broadcasting + basis_states = np.arange(2**num_wires) + if is_state_batched: + probs = probs / norm[:, np.newaxis] if norm.shape else probs / norm samples = np.stack([rng.choice(basis_states, shots, p=p) for p in probs]) else: - if not 0 < abs_diff < cutoff: - norm = 1.0 probs = probs / norm - samples = rng.choice(basis_states, shots, p=probs) powers_of_two = 1 << np.arange(num_wires, dtype=np.int64)[::-1] @@ -523,26 +542,19 @@ def sample_state( return (states_sampled_base_ten > 0).astype(np.int64) -# pylint:disable = unused-argument -def _sample_state_jax( - state, - shots: int, - prng_key, - is_state_batched: bool = False, - wires=None, - seed=None, -) -> np.ndarray: +def _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key=None, seed=None): """ Returns a series of samples of a state for the JAX interface based on the PRNG. Args: - state (array[complex]): A state vector to be sampled + probs (array): The probabilities to sample from shots (int): The number of samples to take - prng_key (jax.random.PRNGKey): A``jax.random.PRNGKey``. This is - the key to the JAX pseudo random number generator. + num_wires (int): The number of wires to sample is_state_batched (bool): whether the state is batched or not - wires (Sequence[int]): The wires to sample - seed (numpy.random.Generator): seed to use to generate a key if a ``prng_key`` is not present. ``None`` by default. + prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is + the key to the JAX pseudo random number generator. Only for simulation using JAX. + seed (Optional[int]): A seed for the random number generator. This is only used if ``prng_key`` + is not provided. Returns: ndarray[int]: Sample values of the shape (shots, num_wires) @@ -554,19 +566,10 @@ def _sample_state_jax( if prng_key is None: prng_key = jax.random.PRNGKey(np.random.default_rng(seed).integers(100000)) - total_indices = len(state.shape) - is_state_batched - state_wires = qml.wires.Wires(range(total_indices)) - - wires_to_sample = wires or state_wires - num_wires = len(wires_to_sample) - basis_states = np.arange(2**num_wires) - - flat_state = flatten_state(state, total_indices) - with qml.queuing.QueuingManager.stop_recording(): - probs = qml.probs(wires=wires_to_sample).process_state(flat_state, state_wires) + basis_states = jnp.arange(2**num_wires) if is_state_batched: - keys = jax_random_split(prng_key, num=len(state)) + keys = jax_random_split(prng_key, num=probs.shape[0]) samples = jnp.array( [ jax.random.choice(_key, basis_states, shape=(shots,), p=prob) @@ -577,6 +580,6 @@ def _sample_state_jax( _, key = jax_random_split(prng_key) samples = jax.random.choice(key, basis_states, shape=(shots,), p=probs) - powers_of_two = 1 << np.arange(num_wires, dtype=int)[::-1] + powers_of_two = 1 << jnp.arange(num_wires, dtype=jnp.int64)[::-1] states_sampled_base_ten = samples[..., None] & powers_of_two - return (states_sampled_base_ten > 0).astype(int) + return (states_sampled_base_ten > 0).astype(jnp.int64) diff --git a/pennylane/devices/qutrit_mixed/__init__.py b/pennylane/devices/qutrit_mixed/__init__.py index 192b5a1b65a..fb7377287cc 100644 --- a/pennylane/devices/qutrit_mixed/__init__.py +++ b/pennylane/devices/qutrit_mixed/__init__.py @@ -33,5 +33,5 @@ from .apply_operation import apply_operation from .initialize_state import create_initial_state from .measure import measure -from .sampling import sample_state, measure_with_samples +from .sampling import sample_state, measure_with_samples, sample_probs from .simulate import simulate, get_final_state, measure_final_state diff --git a/pennylane/devices/qutrit_mixed/sampling.py b/pennylane/devices/qutrit_mixed/sampling.py index c5395e80d48..0e5e661f085 100644 --- a/pennylane/devices/qutrit_mixed/sampling.py +++ b/pennylane/devices/qutrit_mixed/sampling.py @@ -250,25 +250,74 @@ def _sample_state_jax( ndarray[int]: Sample values of the shape (shots, num_wires) """ # pylint: disable=import-outside-toplevel - import jax - import jax.numpy as jnp - - key = prng_key total_indices = get_num_wires(state, is_state_batched) state_wires = qml.wires.Wires(range(total_indices)) wires_to_sample = wires or state_wires num_wires = len(wires_to_sample) - basis_states = np.arange(QUDIT_DIM**num_wires) with qml.queuing.QueuingManager.stop_recording(): probs = measure(qml.probs(wires=wires_to_sample), state, is_state_batched, readout_errors) + state_len = len(state) + + return _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key, state_len) + + +def _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key, state_len): + """ + Sample from a probability distribution for a qutrit system using JAX. + + This function generates samples based on the given probability distribution + for a qutrit system with a specified number of wires. It can handle both + batched and non-batched probability distributions. This function uses JAX + for potential GPU acceleration and improved performance. + + Args: + probs (jnp.ndarray): Probability distribution to sample from. For non-batched + input, this should be a 1D array of length QUDIT_DIM**num_wires. For + batched input, this should be a 2D array where each row is a separate + probability distribution. + shots (int): Number of samples to generate. + num_wires (int): Number of wires in the qutrit system. + is_state_batched (bool): Whether the input probabilities are batched. + prng_key (jax.random.PRNGKey): JAX PRNG key for random number generation. + state_len (int): Length of the state (relevant for batched inputs). + + Returns: + jnp.ndarray: An array of samples. For non-batched input, the shape is + (shots, num_wires). For batched input, the shape is + (batch_size, shots, num_wires). + + Example: + >>> import jax + >>> import jax.numpy as jnp + >>> probs = jnp.array([0.2, 0.3, 0.5]) # For a single-wire qutrit system + >>> shots = 1000 + >>> num_wires = 1 + >>> is_state_batched = False + >>> prng_key = jax.random.PRNGKey(42) + >>> state_len = 1 + >>> samples = _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key, state_len) + >>> samples.shape + (1000, 1) + + Note: + This function requires JAX to be installed. It internally imports JAX + and its numpy module (jnp). + """ + # pylint: disable=import-outside-toplevel + import jax + import jax.numpy as jnp + + key = prng_key + + basis_states = np.arange(QUDIT_DIM**num_wires) if is_state_batched: # Produce separate keys for each of the probabilities along the broadcasted axis keys = [] - for _ in state: + for _ in range(state_len): key, subkey = jax.random.split(key) keys.append(subkey) samples = jnp.array( @@ -323,18 +372,54 @@ def sample_state( readout_errors=readout_errors, ) - rng = np.random.default_rng(rng) - total_indices = get_num_wires(state, is_state_batched) state_wires = qml.wires.Wires(range(total_indices)) wires_to_sample = wires or state_wires num_wires = len(wires_to_sample) - basis_states = np.arange(QUDIT_DIM**num_wires) with qml.queuing.QueuingManager.stop_recording(): probs = measure(qml.probs(wires=wires_to_sample), state, is_state_batched, readout_errors) + return sample_probs(probs, shots, num_wires, is_state_batched, rng) + + +def sample_probs(probs, shots, num_wires, is_state_batched, rng): + """ + Sample from a probability distribution for a qutrit system. + + This function generates samples based on the given probability distribution + for a qutrit system with a specified number of wires. It can handle both + batched and non-batched probability distributions. + + Args: + probs (ndarray): Probability distribution to sample from. For non-batched + input, this should be a 1D array of length QUDIT_DIM**num_wires. For + batched input, this should be a 2D array where each row is a separate + probability distribution. + shots (int): Number of samples to generate. + num_wires (int): Number of wires in the qutrit system. + is_state_batched (bool): Whether the input probabilities are batched. + rng (Optional[Generator]): Random number generator to use. If None, a new + generator will be created. + + Returns: + ndarray: An array of samples. For non-batched input, the shape is + (shots, num_wires). For batched input, the shape is + (batch_size, shots, num_wires). + + Example: + >>> probs = np.array([0.2, 0.3, 0.5]) # For a single-wire qutrit system + >>> shots = 1000 + >>> num_wires = 1 + >>> is_state_batched = False + >>> rng = np.random.default_rng(42) + >>> samples = sample_probs(probs, shots, num_wires, is_state_batched, rng) + >>> samples.shape + (1000, 1) + """ + rng = np.random.default_rng(rng) + basis_states = np.arange(QUDIT_DIM**num_wires) if is_state_batched: # rng.choice doesn't support broadcasting samples = np.stack([rng.choice(basis_states, shots, p=p) for p in probs]) diff --git a/tests/devices/qubit/test_sampling.py b/tests/devices/qubit/test_sampling.py index e36c69c26a3..f0ebc8aa3ac 100644 --- a/tests/devices/qubit/test_sampling.py +++ b/tests/devices/qubit/test_sampling.py @@ -20,7 +20,7 @@ import pennylane as qml from pennylane.devices.qubit import measure_with_samples, sample_state, simulate -from pennylane.devices.qubit.sampling import _sample_state_jax +from pennylane.devices.qubit.sampling import sample_probs from pennylane.devices.qubit.simulate import _FlexShots from pennylane.measurements import Shots @@ -84,27 +84,27 @@ def test_sample_state_basic(self, interface): @pytest.mark.jax def test_prng_key_as_seed_uses_sample_state_jax(self, mocker): - """Tests that sample_state calls _sample_state_jax if the seed is a JAX PRNG key""" + """Tests that sample_state calls _sample_probs_jax if the seed is a JAX PRNG key""" import jax jax.config.update("jax_enable_x64", True) - spy = mocker.spy(qml.devices.qubit.sampling, "_sample_state_jax") + spy = mocker.spy(qml.devices.qubit.sampling, "_sample_probs_jax") state = qml.math.array(two_qubit_state, like="jax") - # prng_key specified, should call _sample_state_jax + # prng_key specified, should call _sample_probs_jax _ = sample_state(state, 10, prng_key=jax.random.PRNGKey(15)) spy.assert_called_once() @pytest.mark.jax def test_sample_state_jax(self): - """Tests that the returned samples are as expected when explicitly calling _sample_state_jax.""" + """Tests that the returned samples are as expected when explicitly calling sample_state.""" import jax state = qml.math.array(two_qubit_state, like="jax") - samples = _sample_state_jax(state, 10, prng_key=jax.random.PRNGKey(84)) + samples = sample_state(state, 10, prng_key=jax.random.PRNGKey(84)) assert samples.shape == (10, 2) assert samples.dtype == np.int64 @@ -112,14 +112,14 @@ def test_sample_state_jax(self): @pytest.mark.jax def test_prng_key_determines_sample_state_jax_results(self): - """Test that setting the seed as a JAX PRNG key determines the results for _sample_state_jax""" + """Test that setting the seed as a JAX PRNG key determines the results for sample_state""" import jax state = qml.math.array(two_qubit_state, like="jax") - samples = _sample_state_jax(state, shots=10, prng_key=jax.random.PRNGKey(12)) - samples2 = _sample_state_jax(state, shots=10, prng_key=jax.random.PRNGKey(12)) - samples3 = _sample_state_jax(state, shots=10, prng_key=jax.random.PRNGKey(13)) + samples = sample_state(state, shots=10, prng_key=jax.random.PRNGKey(12)) + samples2 = sample_state(state, shots=10, prng_key=jax.random.PRNGKey(12)) + samples3 = sample_state(state, shots=10, prng_key=jax.random.PRNGKey(13)) assert np.all(samples == samples2) assert not np.allclose(samples, samples3) @@ -934,7 +934,7 @@ def test_nonsample_measure_shot_vector(self, shots, measurement, expected): @pytest.mark.jax class TestBroadcastingPRNG: - """Test that measurements work and use _sample_state_jax when the state has a batch dim + """Test that measurements work and use sample_state when the state has a batch dim and a PRNG key is provided""" def test_sample_measure(self, mocker): @@ -943,7 +943,7 @@ def test_sample_measure(self, mocker): jax.config.update("jax_enable_x64", True) - spy = mocker.spy(qml.devices.qubit.sampling, "_sample_state_jax") + spy = mocker.spy(qml.devices.qubit.sampling, "_sample_probs_jax") rng = np.random.default_rng(123) shots = qml.measurements.Shots(100) @@ -997,7 +997,7 @@ def test_nonsample_measure(self, mocker, measurement, expected): """Test that broadcasting works for the other sample measurements and single shots""" import jax - spy = mocker.spy(qml.devices.qubit.sampling, "_sample_state_jax") + spy = mocker.spy(qml.devices.qubit.sampling, "_sample_probs_jax") rng = np.random.default_rng(123) shots = qml.measurements.Shots(10000) @@ -1036,7 +1036,7 @@ def test_sample_measure_shot_vector(self, mocker, shots): import jax - spy = mocker.spy(qml.devices.qubit.sampling, "_sample_state_jax") + spy = mocker.spy(qml.devices.qubit.sampling, "_sample_probs_jax") rng = np.random.default_rng(123) shots = qml.measurements.Shots(shots) @@ -1112,7 +1112,7 @@ def test_nonsample_measure_shot_vector(self, mocker, shots, measurement, expecte import jax - spy = mocker.spy(qml.devices.qubit.sampling, "_sample_state_jax") + spy = mocker.spy(qml.devices.qubit.sampling, "_sample_probs_jax") rng = np.random.default_rng(123) shots = qml.measurements.Shots(shots) @@ -1321,3 +1321,53 @@ def test_complex_hamiltonian(self): expected = simulate(qs_exp) assert np.allclose(res, expected, atol=0.001) + + +class TestSampleProbs: + # pylint: disable=attribute-defined-outside-init + @pytest.fixture(autouse=True) + def setup(self): + self.rng = np.random.default_rng(42) # Fixed seed for reproducibility + + def test_basic_sampling(self): + """One Qubit, two outcomes""" + probs = np.array([0.3, 0.7]) + samples = sample_probs(probs, shots=1000, num_wires=1, is_state_batched=False, rng=self.rng) + assert samples.shape == (1000, 1) + # Check if the distribution is roughly correct (allowing for some variance) + zeros = np.sum(samples == 0) + assert 250 <= zeros <= 350 # Approx 30% of 1000, with some leeway + + def test_multi_qubit_sampling(self): + """Two Qubit, four outcomes""" + probs = np.array([0.1, 0.2, 0.3, 0.4]) + samples = sample_probs(probs, shots=1000, num_wires=2, is_state_batched=False, rng=self.rng) + assert samples.shape == (1000, 2) + # Check if all possible states are present + unique_samples = set(map(tuple, samples)) + assert len(unique_samples) == 4 + + def test_batched_sampling(self): + """A batch of two circuits, each with two outcomes""" + probs = np.array([[0.5, 0.5], [0.3, 0.7]]) + samples = sample_probs(probs, shots=1000, num_wires=1, is_state_batched=True, rng=self.rng) + assert samples.shape == (2, 1000, 1) + + def test_cutoff_edge_case_failure(self): + """Test sampling with probabilities just outside the cutoff.""" + cutoff = 1e-7 # Assuming this is the cutoff used in sample_probs + probs = np.array([0.5, 0.5 - 2 * cutoff]) + with pytest.raises(ValueError, match=r"(?i)probabilities do not sum to 1"): + sample_probs(probs, shots=1000, num_wires=1, is_state_batched=False, rng=self.rng) + + def test_batched_cutoff_edge_case_failure(self): + """Test sampling with probabilities just outside the cutoff.""" + cutoff = 1e-7 # Assuming this is the cutoff used in sample_probs + probs = np.array( + [ + [0.5, 0.5 - 2 * cutoff], + [0.5, 0.5 - 2 * cutoff], + ] + ) + with pytest.raises(ValueError, match=r"(?i)probabilities do not sum to 1"): + sample_probs(probs, shots=1000, num_wires=1, is_state_batched=True, rng=self.rng) diff --git a/tests/devices/qutrit_mixed/test_qutrit_mixed_sampling.py b/tests/devices/qutrit_mixed/test_qutrit_mixed_sampling.py index ecd2fbbcca8..8471ceb48fb 100644 --- a/tests/devices/qutrit_mixed/test_qutrit_mixed_sampling.py +++ b/tests/devices/qutrit_mixed/test_qutrit_mixed_sampling.py @@ -25,7 +25,11 @@ measure_with_samples, sample_state, ) -from pennylane.devices.qutrit_mixed.sampling import _sample_state_jax +from pennylane.devices.qutrit_mixed.sampling import ( + _sample_probs_jax, + _sample_state_jax, + sample_probs, +) from pennylane.measurements import Shots APPROX_ATOL = 0.05 @@ -34,6 +38,9 @@ TWO_QUTRITS = 2 THREE_QUTRITS = 3 +MISMATCH_ERROR = "a and p must have same size" +MISMATCH_ERROR_JAX = "p must be None or match the shape of a" + ml_frameworks_list = [ "numpy", pytest.param("autograd", marks=pytest.mark.autograd), @@ -606,7 +613,6 @@ def test_sample_measure_shot_vector(self, mocker, shots, batched_two_qutrit_pure ) def test_nonsample_measure_shot_vector(self, mocker, shots, measurement, expected): """Test that broadcasting works for the other sample measurements and shot vectors""" - import jax spy = mocker.spy(qml.devices.qutrit_mixed.sampling, "_sample_state_jax") @@ -693,3 +699,165 @@ def test_hamiltonian_expval_shot_vector(self, obs): assert isinstance(res, tuple) assert np.allclose(res[0], expected, atol=APPROX_ATOL) assert np.allclose(res[1], expected, atol=APPROX_ATOL) + + +class TestSampleProbs: + # pylint: disable=attribute-defined-outside-init + @pytest.fixture(autouse=True) + def setup(self): + self.rng = np.random.default_rng(42) + self.shots = 1000 + + def test_sample_probs_basic(self): + probs = np.array([0.2, 0.3, 0.5]) + num_wires = 1 + is_state_batched = False + + result = sample_probs(probs, self.shots, num_wires, is_state_batched, self.rng) + + assert result.shape == (self.shots, num_wires) + assert np.all(result >= 0) and np.all(result < QUDIT_DIM) + + _, counts = np.unique(result, return_counts=True) + observed_probs = counts / self.shots + np.testing.assert_allclose(observed_probs, probs, atol=0.05) + + def test_sample_probs_multi_wire(self): + probs = np.array( + [0.1, 0.2, 0.3, 0.15, 0.1, 0.05, 0.05, 0.03, 0.02] + ) # 3^2 = 9 probabilities for 2 wires + num_wires = 2 + is_state_batched = False + + result = sample_probs(probs, self.shots, num_wires, is_state_batched, self.rng) + + assert result.shape == (self.shots, num_wires) + assert np.all(result >= 0) and np.all(result < QUDIT_DIM) + + def test_sample_probs_batched(self): + probs = np.array([[0.2, 0.3, 0.5], [0.4, 0.1, 0.5]]) + num_wires = 1 + is_state_batched = True + + result = sample_probs(probs, self.shots, num_wires, is_state_batched, self.rng) + + assert result.shape == (2, self.shots, num_wires) + assert np.all(result >= 0) and np.all(result < QUDIT_DIM) + + @pytest.mark.parametrize( + "probs,num_wires,is_state_batched,expected_shape", + [ + (np.array([0.2, 0.3, 0.5]), 1, False, (1000, 1)), + (np.array([0.1, 0.2, 0.3, 0.15, 0.1, 0.05, 0.05, 0.03, 0.02]), 2, False, (1000, 2)), + (np.array([[0.2, 0.3, 0.5], [0.4, 0.1, 0.5]]), 1, True, (2, 1000, 1)), + ], + ) + def test_sample_probs_shapes(self, probs, num_wires, is_state_batched, expected_shape): + result = sample_probs(probs, self.shots, num_wires, is_state_batched, self.rng) + assert result.shape == expected_shape + + def test_invalid_probs(self): + probs = np.array( + [0.1, 0.2, 0.3, 0.4] + ) # 4 probabilities, which is invalid for qutrit system + num_wires = 2 + is_state_batched = False + + with pytest.raises(ValueError, match=MISMATCH_ERROR): + sample_probs(probs, self.shots, num_wires, is_state_batched, self.rng) + + +class TestSampleProbsJax: + # pylint: disable=attribute-defined-outside-init + @pytest.fixture(autouse=True) + def setup(self): + import jax + + self.jax_key = jax.random.PRNGKey(42) + self.shots = 1000 + + @pytest.mark.jax + def test_sample_probs_jax_basic(self): + probs = np.array([0.2, 0.3, 0.5]) + num_wires = 1 + is_state_batched = False + state_len = 1 + + result = _sample_probs_jax( + probs, self.shots, num_wires, is_state_batched, self.jax_key, state_len + ) + + assert result.shape == (self.shots, num_wires) + assert np.all(result >= 0) and qml.math.all(result < QUDIT_DIM) + + _, counts = qml.math.unique(result, return_counts=True) + observed_probs = counts / self.shots + np.testing.assert_allclose(observed_probs, probs, atol=0.05) + + @pytest.mark.jax + def test_sample_probs_jax_multi_wire(self): + probs = qml.math.array( + [0.1, 0.2, 0.3, 0.15, 0.1, 0.05, 0.05, 0.03, 0.02] + ) # 3^2 = 9 probabilities for 2 wires + num_wires = 2 + is_state_batched = False + state_len = 1 + + result = _sample_probs_jax( + probs, self.shots, num_wires, is_state_batched, self.jax_key, state_len + ) + + assert result.shape == (self.shots, num_wires) + assert qml.math.all(result >= 0) and qml.math.all(result < QUDIT_DIM) + + @pytest.mark.jax + def test_sample_probs_jax_batched(self): + probs = qml.math.array([[0.2, 0.3, 0.5], [0.4, 0.1, 0.5]]) + num_wires = 1 + is_state_batched = True + state_len = 2 + + result = _sample_probs_jax( + probs, self.shots, num_wires, is_state_batched, self.jax_key, state_len + ) + + assert result.shape == (2, self.shots, num_wires) + assert qml.math.all(result >= 0) and qml.math.all(result < QUDIT_DIM) + + # pylint: disable=too-many-arguments + @pytest.mark.parametrize( + "probs,num_wires,is_state_batched,expected_shape,state_len", + [ + (qml.math.array([0.2, 0.3, 0.5]), 1, False, (1000, 1), 1), + ( + qml.math.array([0.1, 0.2, 0.3, 0.15, 0.1, 0.05, 0.05, 0.03, 0.02]), + 2, + False, + (1000, 2), + 1, + ), + (qml.math.array([[0.2, 0.3, 0.5], [0.4, 0.1, 0.5]]), 1, True, (2, 1000, 1), 2), + ], + ) + @pytest.mark.jax + def test_sample_probs_jax_shapes( + self, probs, num_wires, is_state_batched, expected_shape, state_len + ): + result = _sample_probs_jax( + probs, self.shots, num_wires, is_state_batched, self.jax_key, state_len + ) + assert result.shape == expected_shape + + @pytest.mark.jax + def test_invalid_probs_jax(self): + probs = qml.math.array( + [0.1, 0.2, 0.3, 0.4] + ) # 4 probabilities, which is invalid for qutrit system + num_wires = 2 + is_state_batched = False + state_len = 1 + + with pytest.raises(ValueError, match=MISMATCH_ERROR_JAX): + _sample_probs_jax( + probs, self.shots, num_wires, is_state_batched, self.jax_key, state_len + )