From c1a813d3df83479527ad6751e5431a08294664a8 Mon Sep 17 00:00:00 2001 From: Vincent Michaud-Rioux Date: Tue, 6 Feb 2024 11:48:21 -0500 Subject: [PATCH] Accelerate sample to counts conversion. (#5145) ### Before submitting Please complete the following checklist when submitting a PR: - [x] All new features must include a unit test. If you've fixed a bug or added code that should be tested, add a test to the test directory! - [x] All new functions and code must be clearly commented and documented. If you do make documentation changes, make sure that the docs build and render correctly by running `make docs`. - [x] Ensure that the test suite passes, by running `make test`. - [x] Add a new entry to the `doc/releases/changelog-dev.md` file, summarizing the change, and including a link back to the PR. - [x] The PennyLane source code conforms to [PEP8 standards](https://www.python.org/dev/peps/pep-0008/). We check all of our code against [Pylint](https://www.pylint.org/). To lint modified files, simply `pip install pylint`, and then run `pylint pennylane/path/to/file.py`. When all the above are checked, delete everything above the dashed line and fill in the pull request template. ------------------------------------------------------------------------------------------------------------ **Context:** While doing work on mid-circuit measurements #5088 #5120 , I found `_samples_to_counts` to be a bottleneck in small circuits with a lot of shots (>1e5). For example, the circuit below spends: - 1.9 sec. in `_samples_to_counts` on `master` - 0.09 sec. in `_samples_to_counts` on `optim/_samples_to_counts` ``` import pennylane as qml import pennylane.numpy as np dev = qml.device("default.qubit", shots=100000) @qml.qnode(dev) @qml.defer_measurements def func2(x, y): qml.RX(x, wires=0) m0 = qml.measure(0) qml.RX(y, wires=1) m1 = qml.measure(1) return qml.counts(wires=[0, 1]) results2 = func2(np.pi / 4, np.pi / 4) ``` **Description of the Change:** Accelerate NaN pruning and int to str conversion in `_samples_to_counts`. **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** --------- Co-authored-by: Matthew Silverman --- doc/releases/changelog-dev.md | 3 +++ pennylane/measurements/counts.py | 33 ++++++++++++++++++++------------ 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index b0779a3db18..82522d9d509 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -164,6 +164,9 @@

Other improvements

+* Faster `qml.probs` measurements due to an optimization in `_samples_to_counts`. + [(#5145)](https://github.com/PennyLaneAI/pennylane/pull/5145) + * Cuts down on performance bottlenecks in converting a `PauliSentence` to a `Sum`. [(#5141)](https://github.com/PennyLaneAI/pennylane/pull/5141) [(#5150)](https://github.com/PennyLaneAI/pennylane/pull/5150) diff --git a/pennylane/measurements/counts.py b/pennylane/measurements/counts.py index 66720bff760..066cd7ba30b 100644 --- a/pennylane/measurements/counts.py +++ b/pennylane/measurements/counts.py @@ -16,6 +16,7 @@ """ import warnings from typing import Sequence, Tuple, Optional +import numpy as np import pennylane as qml from pennylane.operation import Operator @@ -25,11 +26,6 @@ from .mid_measure import MeasurementValue -def _sample_to_str(sample): - """Converts a bit-array to a string. For example, ``[0, 1]`` would become '01'.""" - return "".join(map(str, sample)) - - def counts(op=None, wires=None, all_outcomes=False) -> "CountsMP": r"""Sample from the supplied observable, with the number of shots determined from the ``dev.shots`` attribute of the corresponding device, @@ -301,17 +297,30 @@ def circuit(x): if self.obs is None and not isinstance(self.mv, MeasurementValue): # convert samples and outcomes (if using) from arrays to str for dict keys - samples = qml.math.array( - [sample for sample in samples if not qml.math.any(qml.math.isnan(sample))] - ) + + # remove nans + mask = qml.math.isnan(samples) + num_wires = shape[-1] + if np.any(mask): + mask = np.logical_not(np.any(mask, axis=tuple(range(1, samples.ndim)))) + samples = samples[mask, ...] + + # convert to string + def convert(x): + return f"{x:0{num_wires}b}" + + exp2 = 2 ** np.arange(num_wires - 1, -1, -1) + samples = np.einsum("...i,i", samples, exp2) + new_shape = samples.shape samples = qml.math.cast_like(samples, qml.math.int8(0)) - samples = qml.math.apply_along_axis(_sample_to_str, -1, samples) + samples = list(map(convert, samples.ravel())) + samples = np.array(samples).reshape(new_shape) + batched_ndims = 3 # no observable was provided, batched samples will have shape (batch_size, shots, len(wires)) if self.all_outcomes: num_wires = len(self.wires) if len(self.wires) > 0 else shape[-1] - outcomes = list( - map(_sample_to_str, qml.QubitDevice.generate_basis_states(num_wires)) - ) + outcomes = list(map(convert, range(2**num_wires))) + elif self.all_outcomes: # This also covers statistics for mid-circuit measurements manipulated using # arithmetic operators