From e484ba21c7a46b50b020d2026ba824a96b418d83 Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Fri, 13 Sep 2024 10:36:10 -0400 Subject: [PATCH 1/2] deprecate set_shots (#6250) [sc-71546] We no longer interact with the legacy device interface during out workflow. Therefore, we should always set shots via the new methods and not use the `set_shots` context manager. --- pennylane/devices/_qubit_device.py | 10 +++++++- pennylane/measurements/classical_shadow.py | 9 +++++++- pennylane/workflow/set_shots.py | 21 +++++++++++++++++ .../test_set_shots_legacy.py | 23 +++++++++++-------- 4 files changed, 51 insertions(+), 12 deletions(-) diff --git a/pennylane/devices/_qubit_device.py b/pennylane/devices/_qubit_device.py index a54c2985290..ef9f41f0378 100644 --- a/pennylane/devices/_qubit_device.py +++ b/pennylane/devices/_qubit_device.py @@ -1110,7 +1110,11 @@ def classical_shadow(self, obs, circuit): n_snapshots = self.shots seed = obs.seed - with qml.workflow.set_shots(self, shots=1): + original_shots = self.shots + original_shot_vector = self._shot_vector + + try: + self.shots = 1 # slow implementation but works for all devices n_qubits = len(wires) mapped_wires = np.array(self.map_wires(wires)) @@ -1139,6 +1143,10 @@ def classical_shadow(self, obs, circuit): ) outcomes[t] = self.generate_samples()[0][mapped_wires] + finally: + self.shots = original_shots + # pylint: disable=attribute-defined-outside-init + self._shot_vector = original_shot_vector return self._cast(self._stack([outcomes, recipes]), dtype=np.int8) diff --git a/pennylane/measurements/classical_shadow.py b/pennylane/measurements/classical_shadow.py index 6c502858cf8..200a48a25a5 100644 --- a/pennylane/measurements/classical_shadow.py +++ b/pennylane/measurements/classical_shadow.py @@ -284,7 +284,11 @@ def process(self, tape, device): n_snapshots = device.shots seed = self.seed - with qml.workflow.set_shots(device, shots=1): + original_shots = device.shots + original_shot_vector = device._shot_vector # pylint: disable=protected-access + + try: + device.shots = 1 # slow implementation but works for all devices n_qubits = len(wires) mapped_wires = np.array(device.map_wires(wires)) @@ -311,6 +315,9 @@ def process(self, tape, device): device.apply(tape.operations, rotations=tape.diagonalizing_gates + rotations) outcomes[t] = device.generate_samples()[0][mapped_wires] + finally: + device.shots = original_shots + device._shot_vector = original_shot_vector # pylint: disable=protected-access return qml.math.cast(qml.math.stack([outcomes, recipes]), dtype=np.int8) diff --git a/pennylane/workflow/set_shots.py b/pennylane/workflow/set_shots.py index 1bc7dff7f33..aaaed1a1a44 100644 --- a/pennylane/workflow/set_shots.py +++ b/pennylane/workflow/set_shots.py @@ -17,6 +17,7 @@ """ # pylint: disable=protected-access import contextlib +import warnings import pennylane as qml from pennylane.measurements import Shots @@ -26,6 +27,20 @@ def set_shots(device, shots): r"""Context manager to temporarily change the shots of a device. + + .. warning:: + + ``set_shots`` is deprecated and will be removed in PennyLane version v0.40. + + To dynamically update the shots on the workflow, shots can be manually set on a ``QNode`` call: + + >>> circuit(shots=my_new_shots) + + When working with the internal tapes, shots should be set on each tape. + + >>> tape = qml.tape.QuantumScript([], [qml.sample()], shots=50) + + This context manager can be used in two ways. As a standard context manager: @@ -47,6 +62,12 @@ def set_shots(device, shots): "The new device interface is not compatible with `set_shots`. " "Set shots when calling the qnode or put the shots on the QuantumTape." ) + warnings.warn( + "set_shots is deprecated.\n" + "Please dyanmically update shots via keyword argument when calling a QNode " + " or set shots on the tape.", + qml.PennyLaneDeprecationWarning, + ) if isinstance(shots, Shots): shots = shots.shot_vector if shots.has_partitioned_shots else shots.total_shots if shots == device.shots: diff --git a/tests/interfaces/legacy_devices_integration/test_set_shots_legacy.py b/tests/interfaces/legacy_devices_integration/test_set_shots_legacy.py index 6e9739a631f..619bd74a066 100644 --- a/tests/interfaces/legacy_devices_integration/test_set_shots_legacy.py +++ b/tests/interfaces/legacy_devices_integration/test_set_shots_legacy.py @@ -14,7 +14,7 @@ """ Tests for workflow.set_shots """ - +import pytest import pennylane as qml from pennylane.measurements import Shots @@ -24,16 +24,18 @@ def test_set_with_shots_class(): """Test that shots can be set on the old device interface with a Shots class.""" - dev = qml.devices.DefaultQubitLegacy(wires=1) - with set_shots(dev, Shots(10)): - assert dev.shots == 10 + dev = qml.devices.DefaultMixed(wires=1) + with pytest.warns(qml.PennyLaneDeprecationWarning): + with set_shots(dev, Shots(10)): + assert dev.shots == 10 assert dev.shots is None shot_tuples = Shots((10, 10)) - with set_shots(dev, shot_tuples): - assert dev.shots == 20 - assert dev.shot_vector == list(shot_tuples.shot_vector) + with pytest.warns(qml.PennyLaneDeprecationWarning): + with set_shots(dev, shot_tuples): + assert dev.shots == 20 + assert dev.shot_vector == list(shot_tuples.shot_vector) assert dev.shots is None @@ -42,6 +44,7 @@ def test_shots_not_altered_if_False(): """Test a value of False can be passed to shots, indicating to not override shots on the device.""" - dev = qml.devices.DefaultQubitLegacy(wires=1) - with set_shots(dev, False): - assert dev.shots is None + dev = qml.devices.DefaultMixed(wires=1) + with pytest.warns(qml.PennyLaneDeprecationWarning): + with set_shots(dev, False): + assert dev.shots is None From 060bb9a7479dcf6735671d809da19d500c4b7a89 Mon Sep 17 00:00:00 2001 From: Will Date: Fri, 13 Sep 2024 11:53:28 -0400 Subject: [PATCH 2/2] Fix FABLE template to return the correct result in JIT mode (#6263) This PR fixes bug #6262 --- doc/releases/changelog-dev.md | 3 +++ pennylane/templates/subroutines/fable.py | 20 ++++++++++--------- .../templates/test_subroutines/test_fable.py | 11 ++++++---- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 73c266f2833..90dbe86a773 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -105,6 +105,9 @@ * The ``qml.Qubitization`` template now orders the ``control`` wires first and the ``hamiltonian`` wires second, which is the expected according to other templates. [(#6229)](https://github.com/PennyLaneAI/pennylane/pull/6229) +* The ``qml.FABLE`` template now returns the correct value when JIT is enabled. + [(#6263)](https://github.com/PennyLaneAI/pennylane/pull/6263) + *

Contributors ✍️

This release contains contributions from (in alphabetical order): diff --git a/pennylane/templates/subroutines/fable.py b/pennylane/templates/subroutines/fable.py index 16f1160ddd0..d9637676738 100644 --- a/pennylane/templates/subroutines/fable.py +++ b/pennylane/templates/subroutines/fable.py @@ -166,17 +166,19 @@ def compute_decomposition(input_matrix, wires, tol=0): # pylint:disable=argumen for c_wire in nots: op_list.append(qml.CNOT(wires=[c_wire] + ancilla)) op_list.append(qml.RY(2 * theta, wires=ancilla)) + nots = {} nots[wire_map[control_index]] = 1 + continue + + if qml.math.abs(2 * theta) > tol: + for c_wire in nots: + op_list.append(qml.CNOT(wires=[c_wire] + ancilla)) + op_list.append(qml.RY(2 * theta, wires=ancilla)) + nots = {} + if wire_map[control_index] in nots: + del nots[wire_map[control_index]] else: - if abs(2 * theta) > tol: - for c_wire in nots: - op_list.append(qml.CNOT(wires=[c_wire] + ancilla)) - op_list.append(qml.RY(2 * theta, wires=ancilla)) - nots = {} - if wire_map[control_index] in nots: - del nots[wire_map[control_index]] - else: - nots[wire_map[control_index]] = 1 + nots[wire_map[control_index]] = 1 for c_wire in nots: op_list.append(qml.CNOT([c_wire] + ancilla)) diff --git a/tests/templates/test_subroutines/test_fable.py b/tests/templates/test_subroutines/test_fable.py index 8649fe71748..d2ba5f2496a 100644 --- a/tests/templates/test_subroutines/test_fable.py +++ b/tests/templates/test_subroutines/test_fable.py @@ -235,7 +235,7 @@ def circuit_jax(input_matrix): assert np.allclose(gradient_numeric, gradient_jax[0, 0], rtol=0.001) @pytest.mark.jax - def test_fable_grad_jax_jit(self, input_matrix): + def test_fable_jax_jit(self, input_matrix): """Test that FABLE is differentiable when using jax.""" import jax import jax.numpy as jnp @@ -272,18 +272,21 @@ def test_fable_grad_jax_jit(self, input_matrix): input_jax_negative_delta = jnp.array(input_negative_delta) input_matrix_jax = jnp.array(input_matrix) - @jax.jit @qml.qnode(dev, diff_method="backprop") def circuit_jax(input_matrix): qml.FABLE(input_matrix, wires=range(5), tol=0) return qml.expval(qml.PauliZ(wires=0)) - grad_fn = jax.grad(circuit_jax) + jitted_fn = jax.jit(circuit_jax) + + grad_fn = jax.grad(jitted_fn) gradient_numeric = ( circuit_jax(input_jax_positive_delta) - circuit_jax(input_jax_negative_delta) ) / (2 * delta) gradient_jax = grad_fn(input_matrix_jax) - assert np.allclose(gradient_numeric, gradient_jax[0, 0], rtol=0.001) + + assert qml.math.allclose(gradient_numeric, gradient_jax[0, 0], rtol=0.001) + assert qml.math.allclose(jitted_fn(input_matrix), circuit_jax(input_matrix)) @pytest.mark.jax def test_fable_grad_jax_jit_error(self, input_matrix):