diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index db92dcf79b3..df7004f00f9 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -4,6 +4,13 @@
New features since last release
+* Introduced `sample_probs` function for the `qml.devices.qubit` and `qml.devices.qutrit_mixed` modules:
+ - This function takes probability distributions as input and returns sampled outcomes.
+ - Simplifies the sampling process by separating it from other operations in the measurement chain.
+ - Improves modularity: The same code can be easily adapted for other devices (e.g., a potential `default_mixed` device).
+ - Enhances maintainability by isolating the sampling logic.
+ [(#6354)](https://github.com/PennyLaneAI/pennylane/pull/6354)
+
* `qml.transforms.decompose` is added for stepping through decompositions to a target gate set.
[(#6334)](https://github.com/PennyLaneAI/pennylane/pull/6334)
diff --git a/pennylane/devices/qubit/__init__.py b/pennylane/devices/qubit/__init__.py
index 860d91953c4..b12ad79640a 100644
--- a/pennylane/devices/qubit/__init__.py
+++ b/pennylane/devices/qubit/__init__.py
@@ -26,15 +26,16 @@
measure
measure_with_samples
sample_state
+ sample_probs
simulate
adjoint_jacobian
adjoint_jvp
adjoint_vjp
"""
-from .apply_operation import apply_operation
from .adjoint_jacobian import adjoint_jacobian, adjoint_jvp, adjoint_vjp
+from .apply_operation import apply_operation
from .initialize_state import create_initial_state
from .measure import measure
-from .sampling import sample_state, measure_with_samples
-from .simulate import simulate, get_final_state, measure_final_state
+from .sampling import measure_with_samples, sample_probs, sample_state
+from .simulate import get_final_state, measure_final_state, simulate
diff --git a/pennylane/devices/qubit/sampling.py b/pennylane/devices/qubit/sampling.py
index be6090a0354..85c89de3701 100644
--- a/pennylane/devices/qubit/sampling.py
+++ b/pennylane/devices/qubit/sampling.py
@@ -471,51 +471,70 @@ def sample_state(
Returns:
ndarray[int]: Sample values of the shape (shots, num_wires)
"""
- if prng_key is not None or qml.math.get_interface(state) == "jax":
- return _sample_state_jax(
- state, shots, prng_key, is_state_batched=is_state_batched, wires=wires, seed=rng
- )
-
- rng = np.random.default_rng(rng)
total_indices = len(state.shape) - is_state_batched
state_wires = qml.wires.Wires(range(total_indices))
wires_to_sample = wires or state_wires
num_wires = len(wires_to_sample)
- basis_states = np.arange(2**num_wires)
flat_state = flatten_state(state, total_indices)
with qml.queuing.QueuingManager.stop_recording():
probs = qml.probs(wires=wires_to_sample).process_state(flat_state, state_wires)
+ # Keep same interface (e.g. jax) as in the device
- # when using the torch interface with float32 as default dtype,
- # probabilities must be renormalized as they may not sum to one
- # see https://github.com/PennyLaneAI/pennylane/issues/5444
- norm = qml.math.sum(probs, axis=-1)
- abs_diff = qml.math.abs(norm - 1.0)
- cutoff = 1e-07
+ return sample_probs(probs, shots, num_wires, is_state_batched, rng, prng_key)
- if is_state_batched:
- normalize_condition = False
- for s in abs_diff:
- if s != 0:
- normalize_condition = True
- if s > cutoff:
- normalize_condition = False
- break
+def sample_probs(probs, shots, num_wires, is_state_batched, rng, prng_key=None):
+ """
+ Sample from given probabilities, dispatching between JAX and NumPy implementations.
+
+ Args:
+ probs (array): The probabilities to sample from
+ shots (int): The number of samples to take
+ num_wires (int): The number of wires to sample
+ is_state_batched (bool): whether the state is batched or not
+ rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]):
+ A seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``.
+ If no value is provided, a default RNG will be used
+ prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is
+ the key to the JAX pseudo random number generator. Only for simulation using JAX.
+ """
+ if qml.math.get_interface(probs) == "jax" or prng_key is not None:
+ return _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key, seed=rng)
+
+ return _sample_probs_numpy(probs, shots, num_wires, is_state_batched, rng)
+
+
+def _sample_probs_numpy(probs, shots, num_wires, is_state_batched, rng):
+ """
+ Sample from given probabilities using NumPy's random number generator.
+
+ Args:
+ probs (array): The probabilities to sample from
+ shots (int): The number of samples to take
+ num_wires (int): The number of wires to sample
+ is_state_batched (bool): whether the state is batched or not
+ rng (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]):
+ A seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``.
+ If no value is provided, a default RNG will be used
+ """
+ rng = np.random.default_rng(rng)
+ norm = qml.math.sum(probs, axis=-1)
+ norm_err = qml.math.abs(norm - 1.0)
+ cutoff = 1e-07
- if normalize_condition:
- probs = probs / norm[:, np.newaxis] if norm.shape else probs / norm
+ norm_err = norm_err[..., np.newaxis] if not is_state_batched else norm_err
+ if qml.math.any(norm_err > cutoff):
+ raise ValueError("probabilities do not sum to 1")
- # rng.choice doesn't support broadcasting
+ basis_states = np.arange(2**num_wires)
+ if is_state_batched:
+ probs = probs / norm[:, np.newaxis] if norm.shape else probs / norm
samples = np.stack([rng.choice(basis_states, shots, p=p) for p in probs])
else:
- if not 0 < abs_diff < cutoff:
- norm = 1.0
probs = probs / norm
-
samples = rng.choice(basis_states, shots, p=probs)
powers_of_two = 1 << np.arange(num_wires, dtype=np.int64)[::-1]
@@ -523,26 +542,19 @@ def sample_state(
return (states_sampled_base_ten > 0).astype(np.int64)
-# pylint:disable = unused-argument
-def _sample_state_jax(
- state,
- shots: int,
- prng_key,
- is_state_batched: bool = False,
- wires=None,
- seed=None,
-) -> np.ndarray:
+def _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key=None, seed=None):
"""
Returns a series of samples of a state for the JAX interface based on the PRNG.
Args:
- state (array[complex]): A state vector to be sampled
+ probs (array): The probabilities to sample from
shots (int): The number of samples to take
- prng_key (jax.random.PRNGKey): A``jax.random.PRNGKey``. This is
- the key to the JAX pseudo random number generator.
+ num_wires (int): The number of wires to sample
is_state_batched (bool): whether the state is batched or not
- wires (Sequence[int]): The wires to sample
- seed (numpy.random.Generator): seed to use to generate a key if a ``prng_key`` is not present. ``None`` by default.
+ prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is
+ the key to the JAX pseudo random number generator. Only for simulation using JAX.
+ seed (Optional[int]): A seed for the random number generator. This is only used if ``prng_key``
+ is not provided.
Returns:
ndarray[int]: Sample values of the shape (shots, num_wires)
@@ -554,19 +566,10 @@ def _sample_state_jax(
if prng_key is None:
prng_key = jax.random.PRNGKey(np.random.default_rng(seed).integers(100000))
- total_indices = len(state.shape) - is_state_batched
- state_wires = qml.wires.Wires(range(total_indices))
-
- wires_to_sample = wires or state_wires
- num_wires = len(wires_to_sample)
- basis_states = np.arange(2**num_wires)
-
- flat_state = flatten_state(state, total_indices)
- with qml.queuing.QueuingManager.stop_recording():
- probs = qml.probs(wires=wires_to_sample).process_state(flat_state, state_wires)
+ basis_states = jnp.arange(2**num_wires)
if is_state_batched:
- keys = jax_random_split(prng_key, num=len(state))
+ keys = jax_random_split(prng_key, num=probs.shape[0])
samples = jnp.array(
[
jax.random.choice(_key, basis_states, shape=(shots,), p=prob)
@@ -577,6 +580,6 @@ def _sample_state_jax(
_, key = jax_random_split(prng_key)
samples = jax.random.choice(key, basis_states, shape=(shots,), p=probs)
- powers_of_two = 1 << np.arange(num_wires, dtype=int)[::-1]
+ powers_of_two = 1 << jnp.arange(num_wires, dtype=jnp.int64)[::-1]
states_sampled_base_ten = samples[..., None] & powers_of_two
- return (states_sampled_base_ten > 0).astype(int)
+ return (states_sampled_base_ten > 0).astype(jnp.int64)
diff --git a/pennylane/devices/qutrit_mixed/__init__.py b/pennylane/devices/qutrit_mixed/__init__.py
index 192b5a1b65a..fb7377287cc 100644
--- a/pennylane/devices/qutrit_mixed/__init__.py
+++ b/pennylane/devices/qutrit_mixed/__init__.py
@@ -33,5 +33,5 @@
from .apply_operation import apply_operation
from .initialize_state import create_initial_state
from .measure import measure
-from .sampling import sample_state, measure_with_samples
+from .sampling import sample_state, measure_with_samples, sample_probs
from .simulate import simulate, get_final_state, measure_final_state
diff --git a/pennylane/devices/qutrit_mixed/sampling.py b/pennylane/devices/qutrit_mixed/sampling.py
index c5395e80d48..0e5e661f085 100644
--- a/pennylane/devices/qutrit_mixed/sampling.py
+++ b/pennylane/devices/qutrit_mixed/sampling.py
@@ -250,25 +250,74 @@ def _sample_state_jax(
ndarray[int]: Sample values of the shape (shots, num_wires)
"""
# pylint: disable=import-outside-toplevel
- import jax
- import jax.numpy as jnp
-
- key = prng_key
total_indices = get_num_wires(state, is_state_batched)
state_wires = qml.wires.Wires(range(total_indices))
wires_to_sample = wires or state_wires
num_wires = len(wires_to_sample)
- basis_states = np.arange(QUDIT_DIM**num_wires)
with qml.queuing.QueuingManager.stop_recording():
probs = measure(qml.probs(wires=wires_to_sample), state, is_state_batched, readout_errors)
+ state_len = len(state)
+
+ return _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key, state_len)
+
+
+def _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key, state_len):
+ """
+ Sample from a probability distribution for a qutrit system using JAX.
+
+ This function generates samples based on the given probability distribution
+ for a qutrit system with a specified number of wires. It can handle both
+ batched and non-batched probability distributions. This function uses JAX
+ for potential GPU acceleration and improved performance.
+
+ Args:
+ probs (jnp.ndarray): Probability distribution to sample from. For non-batched
+ input, this should be a 1D array of length QUDIT_DIM**num_wires. For
+ batched input, this should be a 2D array where each row is a separate
+ probability distribution.
+ shots (int): Number of samples to generate.
+ num_wires (int): Number of wires in the qutrit system.
+ is_state_batched (bool): Whether the input probabilities are batched.
+ prng_key (jax.random.PRNGKey): JAX PRNG key for random number generation.
+ state_len (int): Length of the state (relevant for batched inputs).
+
+ Returns:
+ jnp.ndarray: An array of samples. For non-batched input, the shape is
+ (shots, num_wires). For batched input, the shape is
+ (batch_size, shots, num_wires).
+
+ Example:
+ >>> import jax
+ >>> import jax.numpy as jnp
+ >>> probs = jnp.array([0.2, 0.3, 0.5]) # For a single-wire qutrit system
+ >>> shots = 1000
+ >>> num_wires = 1
+ >>> is_state_batched = False
+ >>> prng_key = jax.random.PRNGKey(42)
+ >>> state_len = 1
+ >>> samples = _sample_probs_jax(probs, shots, num_wires, is_state_batched, prng_key, state_len)
+ >>> samples.shape
+ (1000, 1)
+
+ Note:
+ This function requires JAX to be installed. It internally imports JAX
+ and its numpy module (jnp).
+ """
+ # pylint: disable=import-outside-toplevel
+ import jax
+ import jax.numpy as jnp
+
+ key = prng_key
+
+ basis_states = np.arange(QUDIT_DIM**num_wires)
if is_state_batched:
# Produce separate keys for each of the probabilities along the broadcasted axis
keys = []
- for _ in state:
+ for _ in range(state_len):
key, subkey = jax.random.split(key)
keys.append(subkey)
samples = jnp.array(
@@ -323,18 +372,54 @@ def sample_state(
readout_errors=readout_errors,
)
- rng = np.random.default_rng(rng)
-
total_indices = get_num_wires(state, is_state_batched)
state_wires = qml.wires.Wires(range(total_indices))
wires_to_sample = wires or state_wires
num_wires = len(wires_to_sample)
- basis_states = np.arange(QUDIT_DIM**num_wires)
with qml.queuing.QueuingManager.stop_recording():
probs = measure(qml.probs(wires=wires_to_sample), state, is_state_batched, readout_errors)
+ return sample_probs(probs, shots, num_wires, is_state_batched, rng)
+
+
+def sample_probs(probs, shots, num_wires, is_state_batched, rng):
+ """
+ Sample from a probability distribution for a qutrit system.
+
+ This function generates samples based on the given probability distribution
+ for a qutrit system with a specified number of wires. It can handle both
+ batched and non-batched probability distributions.
+
+ Args:
+ probs (ndarray): Probability distribution to sample from. For non-batched
+ input, this should be a 1D array of length QUDIT_DIM**num_wires. For
+ batched input, this should be a 2D array where each row is a separate
+ probability distribution.
+ shots (int): Number of samples to generate.
+ num_wires (int): Number of wires in the qutrit system.
+ is_state_batched (bool): Whether the input probabilities are batched.
+ rng (Optional[Generator]): Random number generator to use. If None, a new
+ generator will be created.
+
+ Returns:
+ ndarray: An array of samples. For non-batched input, the shape is
+ (shots, num_wires). For batched input, the shape is
+ (batch_size, shots, num_wires).
+
+ Example:
+ >>> probs = np.array([0.2, 0.3, 0.5]) # For a single-wire qutrit system
+ >>> shots = 1000
+ >>> num_wires = 1
+ >>> is_state_batched = False
+ >>> rng = np.random.default_rng(42)
+ >>> samples = sample_probs(probs, shots, num_wires, is_state_batched, rng)
+ >>> samples.shape
+ (1000, 1)
+ """
+ rng = np.random.default_rng(rng)
+ basis_states = np.arange(QUDIT_DIM**num_wires)
if is_state_batched:
# rng.choice doesn't support broadcasting
samples = np.stack([rng.choice(basis_states, shots, p=p) for p in probs])
diff --git a/tests/devices/qubit/test_sampling.py b/tests/devices/qubit/test_sampling.py
index e36c69c26a3..f0ebc8aa3ac 100644
--- a/tests/devices/qubit/test_sampling.py
+++ b/tests/devices/qubit/test_sampling.py
@@ -20,7 +20,7 @@
import pennylane as qml
from pennylane.devices.qubit import measure_with_samples, sample_state, simulate
-from pennylane.devices.qubit.sampling import _sample_state_jax
+from pennylane.devices.qubit.sampling import sample_probs
from pennylane.devices.qubit.simulate import _FlexShots
from pennylane.measurements import Shots
@@ -84,27 +84,27 @@ def test_sample_state_basic(self, interface):
@pytest.mark.jax
def test_prng_key_as_seed_uses_sample_state_jax(self, mocker):
- """Tests that sample_state calls _sample_state_jax if the seed is a JAX PRNG key"""
+ """Tests that sample_state calls _sample_probs_jax if the seed is a JAX PRNG key"""
import jax
jax.config.update("jax_enable_x64", True)
- spy = mocker.spy(qml.devices.qubit.sampling, "_sample_state_jax")
+ spy = mocker.spy(qml.devices.qubit.sampling, "_sample_probs_jax")
state = qml.math.array(two_qubit_state, like="jax")
- # prng_key specified, should call _sample_state_jax
+ # prng_key specified, should call _sample_probs_jax
_ = sample_state(state, 10, prng_key=jax.random.PRNGKey(15))
spy.assert_called_once()
@pytest.mark.jax
def test_sample_state_jax(self):
- """Tests that the returned samples are as expected when explicitly calling _sample_state_jax."""
+ """Tests that the returned samples are as expected when explicitly calling sample_state."""
import jax
state = qml.math.array(two_qubit_state, like="jax")
- samples = _sample_state_jax(state, 10, prng_key=jax.random.PRNGKey(84))
+ samples = sample_state(state, 10, prng_key=jax.random.PRNGKey(84))
assert samples.shape == (10, 2)
assert samples.dtype == np.int64
@@ -112,14 +112,14 @@ def test_sample_state_jax(self):
@pytest.mark.jax
def test_prng_key_determines_sample_state_jax_results(self):
- """Test that setting the seed as a JAX PRNG key determines the results for _sample_state_jax"""
+ """Test that setting the seed as a JAX PRNG key determines the results for sample_state"""
import jax
state = qml.math.array(two_qubit_state, like="jax")
- samples = _sample_state_jax(state, shots=10, prng_key=jax.random.PRNGKey(12))
- samples2 = _sample_state_jax(state, shots=10, prng_key=jax.random.PRNGKey(12))
- samples3 = _sample_state_jax(state, shots=10, prng_key=jax.random.PRNGKey(13))
+ samples = sample_state(state, shots=10, prng_key=jax.random.PRNGKey(12))
+ samples2 = sample_state(state, shots=10, prng_key=jax.random.PRNGKey(12))
+ samples3 = sample_state(state, shots=10, prng_key=jax.random.PRNGKey(13))
assert np.all(samples == samples2)
assert not np.allclose(samples, samples3)
@@ -934,7 +934,7 @@ def test_nonsample_measure_shot_vector(self, shots, measurement, expected):
@pytest.mark.jax
class TestBroadcastingPRNG:
- """Test that measurements work and use _sample_state_jax when the state has a batch dim
+ """Test that measurements work and use sample_state when the state has a batch dim
and a PRNG key is provided"""
def test_sample_measure(self, mocker):
@@ -943,7 +943,7 @@ def test_sample_measure(self, mocker):
jax.config.update("jax_enable_x64", True)
- spy = mocker.spy(qml.devices.qubit.sampling, "_sample_state_jax")
+ spy = mocker.spy(qml.devices.qubit.sampling, "_sample_probs_jax")
rng = np.random.default_rng(123)
shots = qml.measurements.Shots(100)
@@ -997,7 +997,7 @@ def test_nonsample_measure(self, mocker, measurement, expected):
"""Test that broadcasting works for the other sample measurements and single shots"""
import jax
- spy = mocker.spy(qml.devices.qubit.sampling, "_sample_state_jax")
+ spy = mocker.spy(qml.devices.qubit.sampling, "_sample_probs_jax")
rng = np.random.default_rng(123)
shots = qml.measurements.Shots(10000)
@@ -1036,7 +1036,7 @@ def test_sample_measure_shot_vector(self, mocker, shots):
import jax
- spy = mocker.spy(qml.devices.qubit.sampling, "_sample_state_jax")
+ spy = mocker.spy(qml.devices.qubit.sampling, "_sample_probs_jax")
rng = np.random.default_rng(123)
shots = qml.measurements.Shots(shots)
@@ -1112,7 +1112,7 @@ def test_nonsample_measure_shot_vector(self, mocker, shots, measurement, expecte
import jax
- spy = mocker.spy(qml.devices.qubit.sampling, "_sample_state_jax")
+ spy = mocker.spy(qml.devices.qubit.sampling, "_sample_probs_jax")
rng = np.random.default_rng(123)
shots = qml.measurements.Shots(shots)
@@ -1321,3 +1321,53 @@ def test_complex_hamiltonian(self):
expected = simulate(qs_exp)
assert np.allclose(res, expected, atol=0.001)
+
+
+class TestSampleProbs:
+ # pylint: disable=attribute-defined-outside-init
+ @pytest.fixture(autouse=True)
+ def setup(self):
+ self.rng = np.random.default_rng(42) # Fixed seed for reproducibility
+
+ def test_basic_sampling(self):
+ """One Qubit, two outcomes"""
+ probs = np.array([0.3, 0.7])
+ samples = sample_probs(probs, shots=1000, num_wires=1, is_state_batched=False, rng=self.rng)
+ assert samples.shape == (1000, 1)
+ # Check if the distribution is roughly correct (allowing for some variance)
+ zeros = np.sum(samples == 0)
+ assert 250 <= zeros <= 350 # Approx 30% of 1000, with some leeway
+
+ def test_multi_qubit_sampling(self):
+ """Two Qubit, four outcomes"""
+ probs = np.array([0.1, 0.2, 0.3, 0.4])
+ samples = sample_probs(probs, shots=1000, num_wires=2, is_state_batched=False, rng=self.rng)
+ assert samples.shape == (1000, 2)
+ # Check if all possible states are present
+ unique_samples = set(map(tuple, samples))
+ assert len(unique_samples) == 4
+
+ def test_batched_sampling(self):
+ """A batch of two circuits, each with two outcomes"""
+ probs = np.array([[0.5, 0.5], [0.3, 0.7]])
+ samples = sample_probs(probs, shots=1000, num_wires=1, is_state_batched=True, rng=self.rng)
+ assert samples.shape == (2, 1000, 1)
+
+ def test_cutoff_edge_case_failure(self):
+ """Test sampling with probabilities just outside the cutoff."""
+ cutoff = 1e-7 # Assuming this is the cutoff used in sample_probs
+ probs = np.array([0.5, 0.5 - 2 * cutoff])
+ with pytest.raises(ValueError, match=r"(?i)probabilities do not sum to 1"):
+ sample_probs(probs, shots=1000, num_wires=1, is_state_batched=False, rng=self.rng)
+
+ def test_batched_cutoff_edge_case_failure(self):
+ """Test sampling with probabilities just outside the cutoff."""
+ cutoff = 1e-7 # Assuming this is the cutoff used in sample_probs
+ probs = np.array(
+ [
+ [0.5, 0.5 - 2 * cutoff],
+ [0.5, 0.5 - 2 * cutoff],
+ ]
+ )
+ with pytest.raises(ValueError, match=r"(?i)probabilities do not sum to 1"):
+ sample_probs(probs, shots=1000, num_wires=1, is_state_batched=True, rng=self.rng)
diff --git a/tests/devices/qutrit_mixed/test_qutrit_mixed_sampling.py b/tests/devices/qutrit_mixed/test_qutrit_mixed_sampling.py
index ecd2fbbcca8..8471ceb48fb 100644
--- a/tests/devices/qutrit_mixed/test_qutrit_mixed_sampling.py
+++ b/tests/devices/qutrit_mixed/test_qutrit_mixed_sampling.py
@@ -25,7 +25,11 @@
measure_with_samples,
sample_state,
)
-from pennylane.devices.qutrit_mixed.sampling import _sample_state_jax
+from pennylane.devices.qutrit_mixed.sampling import (
+ _sample_probs_jax,
+ _sample_state_jax,
+ sample_probs,
+)
from pennylane.measurements import Shots
APPROX_ATOL = 0.05
@@ -34,6 +38,9 @@
TWO_QUTRITS = 2
THREE_QUTRITS = 3
+MISMATCH_ERROR = "a and p must have same size"
+MISMATCH_ERROR_JAX = "p must be None or match the shape of a"
+
ml_frameworks_list = [
"numpy",
pytest.param("autograd", marks=pytest.mark.autograd),
@@ -606,7 +613,6 @@ def test_sample_measure_shot_vector(self, mocker, shots, batched_two_qutrit_pure
)
def test_nonsample_measure_shot_vector(self, mocker, shots, measurement, expected):
"""Test that broadcasting works for the other sample measurements and shot vectors"""
-
import jax
spy = mocker.spy(qml.devices.qutrit_mixed.sampling, "_sample_state_jax")
@@ -693,3 +699,165 @@ def test_hamiltonian_expval_shot_vector(self, obs):
assert isinstance(res, tuple)
assert np.allclose(res[0], expected, atol=APPROX_ATOL)
assert np.allclose(res[1], expected, atol=APPROX_ATOL)
+
+
+class TestSampleProbs:
+ # pylint: disable=attribute-defined-outside-init
+ @pytest.fixture(autouse=True)
+ def setup(self):
+ self.rng = np.random.default_rng(42)
+ self.shots = 1000
+
+ def test_sample_probs_basic(self):
+ probs = np.array([0.2, 0.3, 0.5])
+ num_wires = 1
+ is_state_batched = False
+
+ result = sample_probs(probs, self.shots, num_wires, is_state_batched, self.rng)
+
+ assert result.shape == (self.shots, num_wires)
+ assert np.all(result >= 0) and np.all(result < QUDIT_DIM)
+
+ _, counts = np.unique(result, return_counts=True)
+ observed_probs = counts / self.shots
+ np.testing.assert_allclose(observed_probs, probs, atol=0.05)
+
+ def test_sample_probs_multi_wire(self):
+ probs = np.array(
+ [0.1, 0.2, 0.3, 0.15, 0.1, 0.05, 0.05, 0.03, 0.02]
+ ) # 3^2 = 9 probabilities for 2 wires
+ num_wires = 2
+ is_state_batched = False
+
+ result = sample_probs(probs, self.shots, num_wires, is_state_batched, self.rng)
+
+ assert result.shape == (self.shots, num_wires)
+ assert np.all(result >= 0) and np.all(result < QUDIT_DIM)
+
+ def test_sample_probs_batched(self):
+ probs = np.array([[0.2, 0.3, 0.5], [0.4, 0.1, 0.5]])
+ num_wires = 1
+ is_state_batched = True
+
+ result = sample_probs(probs, self.shots, num_wires, is_state_batched, self.rng)
+
+ assert result.shape == (2, self.shots, num_wires)
+ assert np.all(result >= 0) and np.all(result < QUDIT_DIM)
+
+ @pytest.mark.parametrize(
+ "probs,num_wires,is_state_batched,expected_shape",
+ [
+ (np.array([0.2, 0.3, 0.5]), 1, False, (1000, 1)),
+ (np.array([0.1, 0.2, 0.3, 0.15, 0.1, 0.05, 0.05, 0.03, 0.02]), 2, False, (1000, 2)),
+ (np.array([[0.2, 0.3, 0.5], [0.4, 0.1, 0.5]]), 1, True, (2, 1000, 1)),
+ ],
+ )
+ def test_sample_probs_shapes(self, probs, num_wires, is_state_batched, expected_shape):
+ result = sample_probs(probs, self.shots, num_wires, is_state_batched, self.rng)
+ assert result.shape == expected_shape
+
+ def test_invalid_probs(self):
+ probs = np.array(
+ [0.1, 0.2, 0.3, 0.4]
+ ) # 4 probabilities, which is invalid for qutrit system
+ num_wires = 2
+ is_state_batched = False
+
+ with pytest.raises(ValueError, match=MISMATCH_ERROR):
+ sample_probs(probs, self.shots, num_wires, is_state_batched, self.rng)
+
+
+class TestSampleProbsJax:
+ # pylint: disable=attribute-defined-outside-init
+ @pytest.fixture(autouse=True)
+ def setup(self):
+ import jax
+
+ self.jax_key = jax.random.PRNGKey(42)
+ self.shots = 1000
+
+ @pytest.mark.jax
+ def test_sample_probs_jax_basic(self):
+ probs = np.array([0.2, 0.3, 0.5])
+ num_wires = 1
+ is_state_batched = False
+ state_len = 1
+
+ result = _sample_probs_jax(
+ probs, self.shots, num_wires, is_state_batched, self.jax_key, state_len
+ )
+
+ assert result.shape == (self.shots, num_wires)
+ assert np.all(result >= 0) and qml.math.all(result < QUDIT_DIM)
+
+ _, counts = qml.math.unique(result, return_counts=True)
+ observed_probs = counts / self.shots
+ np.testing.assert_allclose(observed_probs, probs, atol=0.05)
+
+ @pytest.mark.jax
+ def test_sample_probs_jax_multi_wire(self):
+ probs = qml.math.array(
+ [0.1, 0.2, 0.3, 0.15, 0.1, 0.05, 0.05, 0.03, 0.02]
+ ) # 3^2 = 9 probabilities for 2 wires
+ num_wires = 2
+ is_state_batched = False
+ state_len = 1
+
+ result = _sample_probs_jax(
+ probs, self.shots, num_wires, is_state_batched, self.jax_key, state_len
+ )
+
+ assert result.shape == (self.shots, num_wires)
+ assert qml.math.all(result >= 0) and qml.math.all(result < QUDIT_DIM)
+
+ @pytest.mark.jax
+ def test_sample_probs_jax_batched(self):
+ probs = qml.math.array([[0.2, 0.3, 0.5], [0.4, 0.1, 0.5]])
+ num_wires = 1
+ is_state_batched = True
+ state_len = 2
+
+ result = _sample_probs_jax(
+ probs, self.shots, num_wires, is_state_batched, self.jax_key, state_len
+ )
+
+ assert result.shape == (2, self.shots, num_wires)
+ assert qml.math.all(result >= 0) and qml.math.all(result < QUDIT_DIM)
+
+ # pylint: disable=too-many-arguments
+ @pytest.mark.parametrize(
+ "probs,num_wires,is_state_batched,expected_shape,state_len",
+ [
+ (qml.math.array([0.2, 0.3, 0.5]), 1, False, (1000, 1), 1),
+ (
+ qml.math.array([0.1, 0.2, 0.3, 0.15, 0.1, 0.05, 0.05, 0.03, 0.02]),
+ 2,
+ False,
+ (1000, 2),
+ 1,
+ ),
+ (qml.math.array([[0.2, 0.3, 0.5], [0.4, 0.1, 0.5]]), 1, True, (2, 1000, 1), 2),
+ ],
+ )
+ @pytest.mark.jax
+ def test_sample_probs_jax_shapes(
+ self, probs, num_wires, is_state_batched, expected_shape, state_len
+ ):
+ result = _sample_probs_jax(
+ probs, self.shots, num_wires, is_state_batched, self.jax_key, state_len
+ )
+ assert result.shape == expected_shape
+
+ @pytest.mark.jax
+ def test_invalid_probs_jax(self):
+ probs = qml.math.array(
+ [0.1, 0.2, 0.3, 0.4]
+ ) # 4 probabilities, which is invalid for qutrit system
+ num_wires = 2
+ is_state_batched = False
+ state_len = 1
+
+ with pytest.raises(ValueError, match=MISMATCH_ERROR_JAX):
+ _sample_probs_jax(
+ probs, self.shots, num_wires, is_state_batched, self.jax_key, state_len
+ )