Skip to content

Commit

Permalink
deprecate set_shots (#6250)
Browse files Browse the repository at this point in the history
[sc-71546]

We no longer interact with the legacy device interface during out
workflow. Therefore, we should always set shots via the new methods and
not use the `set_shots` context manager.
  • Loading branch information
albi3ro authored Sep 13, 2024
1 parent 87c1bc3 commit e484ba2
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 12 deletions.
10 changes: 9 additions & 1 deletion pennylane/devices/_qubit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,11 @@ def classical_shadow(self, obs, circuit):
n_snapshots = self.shots
seed = obs.seed

with qml.workflow.set_shots(self, shots=1):
original_shots = self.shots
original_shot_vector = self._shot_vector

try:
self.shots = 1
# slow implementation but works for all devices
n_qubits = len(wires)
mapped_wires = np.array(self.map_wires(wires))
Expand Down Expand Up @@ -1139,6 +1143,10 @@ def classical_shadow(self, obs, circuit):
)

outcomes[t] = self.generate_samples()[0][mapped_wires]
finally:
self.shots = original_shots
# pylint: disable=attribute-defined-outside-init
self._shot_vector = original_shot_vector

return self._cast(self._stack([outcomes, recipes]), dtype=np.int8)

Expand Down
9 changes: 8 additions & 1 deletion pennylane/measurements/classical_shadow.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,11 @@ def process(self, tape, device):
n_snapshots = device.shots
seed = self.seed

with qml.workflow.set_shots(device, shots=1):
original_shots = device.shots
original_shot_vector = device._shot_vector # pylint: disable=protected-access

try:
device.shots = 1
# slow implementation but works for all devices
n_qubits = len(wires)
mapped_wires = np.array(device.map_wires(wires))
Expand All @@ -311,6 +315,9 @@ def process(self, tape, device):
device.apply(tape.operations, rotations=tape.diagonalizing_gates + rotations)

outcomes[t] = device.generate_samples()[0][mapped_wires]
finally:
device.shots = original_shots
device._shot_vector = original_shot_vector # pylint: disable=protected-access

return qml.math.cast(qml.math.stack([outcomes, recipes]), dtype=np.int8)

Expand Down
21 changes: 21 additions & 0 deletions pennylane/workflow/set_shots.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""
# pylint: disable=protected-access
import contextlib
import warnings

import pennylane as qml
from pennylane.measurements import Shots
Expand All @@ -26,6 +27,20 @@
def set_shots(device, shots):
r"""Context manager to temporarily change the shots of a device.
.. warning::
``set_shots`` is deprecated and will be removed in PennyLane version v0.40.
To dynamically update the shots on the workflow, shots can be manually set on a ``QNode`` call:
>>> circuit(shots=my_new_shots)
When working with the internal tapes, shots should be set on each tape.
>>> tape = qml.tape.QuantumScript([], [qml.sample()], shots=50)
This context manager can be used in two ways.
As a standard context manager:
Expand All @@ -47,6 +62,12 @@ def set_shots(device, shots):
"The new device interface is not compatible with `set_shots`. "
"Set shots when calling the qnode or put the shots on the QuantumTape."
)
warnings.warn(
"set_shots is deprecated.\n"
"Please dyanmically update shots via keyword argument when calling a QNode "
" or set shots on the tape.",
qml.PennyLaneDeprecationWarning,
)
if isinstance(shots, Shots):
shots = shots.shot_vector if shots.has_partitioned_shots else shots.total_shots
if shots == device.shots:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""
Tests for workflow.set_shots
"""

import pytest

import pennylane as qml
from pennylane.measurements import Shots
Expand All @@ -24,16 +24,18 @@
def test_set_with_shots_class():
"""Test that shots can be set on the old device interface with a Shots class."""

dev = qml.devices.DefaultQubitLegacy(wires=1)
with set_shots(dev, Shots(10)):
assert dev.shots == 10
dev = qml.devices.DefaultMixed(wires=1)
with pytest.warns(qml.PennyLaneDeprecationWarning):
with set_shots(dev, Shots(10)):
assert dev.shots == 10

assert dev.shots is None

shot_tuples = Shots((10, 10))
with set_shots(dev, shot_tuples):
assert dev.shots == 20
assert dev.shot_vector == list(shot_tuples.shot_vector)
with pytest.warns(qml.PennyLaneDeprecationWarning):
with set_shots(dev, shot_tuples):
assert dev.shots == 20
assert dev.shot_vector == list(shot_tuples.shot_vector)

assert dev.shots is None

Expand All @@ -42,6 +44,7 @@ def test_shots_not_altered_if_False():
"""Test a value of False can be passed to shots, indicating to not override
shots on the device."""

dev = qml.devices.DefaultQubitLegacy(wires=1)
with set_shots(dev, False):
assert dev.shots is None
dev = qml.devices.DefaultMixed(wires=1)
with pytest.warns(qml.PennyLaneDeprecationWarning):
with set_shots(dev, False):
assert dev.shots is None

0 comments on commit e484ba2

Please sign in to comment.