Skip to content

Commit

Permalink
Accelerate sample to counts conversion. (#5145)
Browse files Browse the repository at this point in the history
### Before submitting

Please complete the following checklist when submitting a PR:

- [x] All new features must include a unit test.
If you've fixed a bug or added code that should be tested, add a test to
the
      test directory!

- [x] All new functions and code must be clearly commented and
documented.
If you do make documentation changes, make sure that the docs build and
      render correctly by running `make docs`.

- [x] Ensure that the test suite passes, by running `make test`.

- [x] Add a new entry to the `doc/releases/changelog-dev.md` file,
summarizing the
      change, and including a link back to the PR.

- [x] The PennyLane source code conforms to
      [PEP8 standards](https://www.python.org/dev/peps/pep-0008/).
We check all of our code against [Pylint](https://www.pylint.org/).
      To lint modified files, simply `pip install pylint`, and then
      run `pylint pennylane/path/to/file.py`.

When all the above are checked, delete everything above the dashed
line and fill in the pull request template.


------------------------------------------------------------------------------------------------------------

**Context:**
While doing work on mid-circuit measurements #5088 #5120 , I found
`_samples_to_counts` to be a bottleneck in small circuits with a lot of
shots (>1e5). For example, the circuit below spends:

- 1.9 sec. in `_samples_to_counts` on `master`
- 0.09 sec. in `_samples_to_counts` on `optim/_samples_to_counts`
 
```
import pennylane as qml
import pennylane.numpy as np
dev = qml.device("default.qubit", shots=100000)
@qml.qnode(dev)
@qml.defer_measurements
def func2(x, y):
    qml.RX(x, wires=0)
    m0 = qml.measure(0)
    qml.RX(y, wires=1)
    m1 = qml.measure(1)
    return qml.counts(wires=[0, 1])
results2 = func2(np.pi / 4, np.pi / 4)
```

**Description of the Change:**
Accelerate NaN pruning and int to str conversion in
`_samples_to_counts`.

**Benefits:**

**Possible Drawbacks:**

**Related GitHub Issues:**

---------

Co-authored-by: Matthew Silverman <[email protected]>
  • Loading branch information
vincentmr and timmysilv authored Feb 6, 2024
1 parent e73213e commit c1a813d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@

<h4>Other improvements</h4>

* Faster `qml.probs` measurements due to an optimization in `_samples_to_counts`.
[(#5145)](https://github.com/PennyLaneAI/pennylane/pull/5145)

* Cuts down on performance bottlenecks in converting a `PauliSentence` to a `Sum`.
[(#5141)](https://github.com/PennyLaneAI/pennylane/pull/5141)
[(#5150)](https://github.com/PennyLaneAI/pennylane/pull/5150)
Expand Down
33 changes: 21 additions & 12 deletions pennylane/measurements/counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""
import warnings
from typing import Sequence, Tuple, Optional
import numpy as np

import pennylane as qml
from pennylane.operation import Operator
Expand All @@ -25,11 +26,6 @@
from .mid_measure import MeasurementValue


def _sample_to_str(sample):
"""Converts a bit-array to a string. For example, ``[0, 1]`` would become '01'."""
return "".join(map(str, sample))


def counts(op=None, wires=None, all_outcomes=False) -> "CountsMP":
r"""Sample from the supplied observable, with the number of shots
determined from the ``dev.shots`` attribute of the corresponding device,
Expand Down Expand Up @@ -301,17 +297,30 @@ def circuit(x):

if self.obs is None and not isinstance(self.mv, MeasurementValue):
# convert samples and outcomes (if using) from arrays to str for dict keys
samples = qml.math.array(
[sample for sample in samples if not qml.math.any(qml.math.isnan(sample))]
)

# remove nans
mask = qml.math.isnan(samples)
num_wires = shape[-1]
if np.any(mask):
mask = np.logical_not(np.any(mask, axis=tuple(range(1, samples.ndim))))
samples = samples[mask, ...]

# convert to string
def convert(x):
return f"{x:0{num_wires}b}"

exp2 = 2 ** np.arange(num_wires - 1, -1, -1)
samples = np.einsum("...i,i", samples, exp2)
new_shape = samples.shape
samples = qml.math.cast_like(samples, qml.math.int8(0))
samples = qml.math.apply_along_axis(_sample_to_str, -1, samples)
samples = list(map(convert, samples.ravel()))
samples = np.array(samples).reshape(new_shape)

batched_ndims = 3 # no observable was provided, batched samples will have shape (batch_size, shots, len(wires))
if self.all_outcomes:
num_wires = len(self.wires) if len(self.wires) > 0 else shape[-1]
outcomes = list(
map(_sample_to_str, qml.QubitDevice.generate_basis_states(num_wires))
)
outcomes = list(map(convert, range(2**num_wires)))

elif self.all_outcomes:
# This also covers statistics for mid-circuit measurements manipulated using
# arithmetic operators
Expand Down

0 comments on commit c1a813d

Please sign in to comment.