diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index b0779a3db18..82522d9d509 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -164,6 +164,9 @@

Other improvements

+* Faster `qml.probs` measurements due to an optimization in `_samples_to_counts`. + [(#5145)](https://github.com/PennyLaneAI/pennylane/pull/5145) + * Cuts down on performance bottlenecks in converting a `PauliSentence` to a `Sum`. [(#5141)](https://github.com/PennyLaneAI/pennylane/pull/5141) [(#5150)](https://github.com/PennyLaneAI/pennylane/pull/5150) diff --git a/pennylane/measurements/counts.py b/pennylane/measurements/counts.py index 66720bff760..066cd7ba30b 100644 --- a/pennylane/measurements/counts.py +++ b/pennylane/measurements/counts.py @@ -16,6 +16,7 @@ """ import warnings from typing import Sequence, Tuple, Optional +import numpy as np import pennylane as qml from pennylane.operation import Operator @@ -25,11 +26,6 @@ from .mid_measure import MeasurementValue -def _sample_to_str(sample): - """Converts a bit-array to a string. For example, ``[0, 1]`` would become '01'.""" - return "".join(map(str, sample)) - - def counts(op=None, wires=None, all_outcomes=False) -> "CountsMP": r"""Sample from the supplied observable, with the number of shots determined from the ``dev.shots`` attribute of the corresponding device, @@ -301,17 +297,30 @@ def circuit(x): if self.obs is None and not isinstance(self.mv, MeasurementValue): # convert samples and outcomes (if using) from arrays to str for dict keys - samples = qml.math.array( - [sample for sample in samples if not qml.math.any(qml.math.isnan(sample))] - ) + + # remove nans + mask = qml.math.isnan(samples) + num_wires = shape[-1] + if np.any(mask): + mask = np.logical_not(np.any(mask, axis=tuple(range(1, samples.ndim)))) + samples = samples[mask, ...] + + # convert to string + def convert(x): + return f"{x:0{num_wires}b}" + + exp2 = 2 ** np.arange(num_wires - 1, -1, -1) + samples = np.einsum("...i,i", samples, exp2) + new_shape = samples.shape samples = qml.math.cast_like(samples, qml.math.int8(0)) - samples = qml.math.apply_along_axis(_sample_to_str, -1, samples) + samples = list(map(convert, samples.ravel())) + samples = np.array(samples).reshape(new_shape) + batched_ndims = 3 # no observable was provided, batched samples will have shape (batch_size, shots, len(wires)) if self.all_outcomes: num_wires = len(self.wires) if len(self.wires) > 0 else shape[-1] - outcomes = list( - map(_sample_to_str, qml.QubitDevice.generate_basis_states(num_wires)) - ) + outcomes = list(map(convert, range(2**num_wires))) + elif self.all_outcomes: # This also covers statistics for mid-circuit measurements manipulated using # arithmetic operators