diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index b0779a3db18..82522d9d509 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -164,6 +164,9 @@
Other improvements
+* Faster `qml.probs` measurements due to an optimization in `_samples_to_counts`.
+ [(#5145)](https://github.com/PennyLaneAI/pennylane/pull/5145)
+
* Cuts down on performance bottlenecks in converting a `PauliSentence` to a `Sum`.
[(#5141)](https://github.com/PennyLaneAI/pennylane/pull/5141)
[(#5150)](https://github.com/PennyLaneAI/pennylane/pull/5150)
diff --git a/pennylane/measurements/counts.py b/pennylane/measurements/counts.py
index 66720bff760..066cd7ba30b 100644
--- a/pennylane/measurements/counts.py
+++ b/pennylane/measurements/counts.py
@@ -16,6 +16,7 @@
"""
import warnings
from typing import Sequence, Tuple, Optional
+import numpy as np
import pennylane as qml
from pennylane.operation import Operator
@@ -25,11 +26,6 @@
from .mid_measure import MeasurementValue
-def _sample_to_str(sample):
- """Converts a bit-array to a string. For example, ``[0, 1]`` would become '01'."""
- return "".join(map(str, sample))
-
-
def counts(op=None, wires=None, all_outcomes=False) -> "CountsMP":
r"""Sample from the supplied observable, with the number of shots
determined from the ``dev.shots`` attribute of the corresponding device,
@@ -301,17 +297,30 @@ def circuit(x):
if self.obs is None and not isinstance(self.mv, MeasurementValue):
# convert samples and outcomes (if using) from arrays to str for dict keys
- samples = qml.math.array(
- [sample for sample in samples if not qml.math.any(qml.math.isnan(sample))]
- )
+
+ # remove nans
+ mask = qml.math.isnan(samples)
+ num_wires = shape[-1]
+ if np.any(mask):
+ mask = np.logical_not(np.any(mask, axis=tuple(range(1, samples.ndim))))
+ samples = samples[mask, ...]
+
+ # convert to string
+ def convert(x):
+ return f"{x:0{num_wires}b}"
+
+ exp2 = 2 ** np.arange(num_wires - 1, -1, -1)
+ samples = np.einsum("...i,i", samples, exp2)
+ new_shape = samples.shape
samples = qml.math.cast_like(samples, qml.math.int8(0))
- samples = qml.math.apply_along_axis(_sample_to_str, -1, samples)
+ samples = list(map(convert, samples.ravel()))
+ samples = np.array(samples).reshape(new_shape)
+
batched_ndims = 3 # no observable was provided, batched samples will have shape (batch_size, shots, len(wires))
if self.all_outcomes:
num_wires = len(self.wires) if len(self.wires) > 0 else shape[-1]
- outcomes = list(
- map(_sample_to_str, qml.QubitDevice.generate_basis_states(num_wires))
- )
+ outcomes = list(map(convert, range(2**num_wires)))
+
elif self.all_outcomes:
# This also covers statistics for mid-circuit measurements manipulated using
# arithmetic operators