diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index ee1ec88f389..d1f17ae1dae 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -291,6 +291,9 @@

Bug fixes 🐛

+* The `QNode` interface now resets if an error occurs during execution. + [(#5449)](https://github.com/PennyLaneAI/pennylane/pull/5449) + * Fix failing tests due to changes with Lightning's adjoint diff pipeline. [(#5450)](https://github.com/PennyLaneAI/pennylane/pull/5450) diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 8e57af1ff62..1e446de482d 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -83,6 +83,21 @@ def _make_execution_config( ) +def _to_qfunc_output_type( + results: qml.typing.Result, qfunc_output, has_partitioned_shots +) -> qml.typing.Result: + # Special case of single Measurement in a list + if isinstance(qfunc_output, list) and len(qfunc_output) == 1: + results = [results] + + # If the return type is not tuple (list or ndarray) (Autograd and TF backprop removed) + if isinstance(qfunc_output, (tuple, qml.measurements.MeasurementProcess)): + return results + if has_partitioned_shots: + return tuple(type(qfunc_output)(r) for r in results) + return type(qfunc_output)(results) + + class QNode: """Represents a quantum node in the hybrid computational graph. @@ -972,25 +987,18 @@ def construct(self, args, kwargs): # pylint: disable=too-many-branches if old_interface == "auto": self.interface = "auto" - def __call__(self, *args, **kwargs) -> qml.typing.Result: + def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml.typing.Result: + """Construct the transform program and execute the tapes. Helper function for ``__call__`` - old_interface = self.interface - if old_interface == "auto": - interface = qml.math.get_interface(*args, *list(kwargs.values())) - self._interface = INTERFACE_MAP[interface] - - if self._qfunc_uses_shots_arg: - override_shots = False - else: - if "shots" not in kwargs: - kwargs["shots"] = _get_device_shots(self._original_device) - override_shots = kwargs["shots"] + Args: + args (tuple): the arguments the QNode is called with + kwargs (dict): the keyword arguments the QNode is called with + override_shots : the shots to use for the execution. - # construct the tape - self.construct(args, kwargs) + Returns: + Result - original_grad_fn = [self.gradient_fn, self.gradient_kwargs, self.device] - self._update_gradient_fn(shots=override_shots, tape=self._tape) + """ cache = self.execute_kwargs.get("cache", False) using_custom_cache = ( @@ -1049,24 +1057,39 @@ def __call__(self, *args, **kwargs) -> qml.typing.Result: ): res = _convert_to_interface(res, self.interface) - # Special case of single Measurement in a list - if isinstance(self._qfunc_output, list) and len(self._qfunc_output) == 1: - res = [res] + return _to_qfunc_output_type( + res, self._qfunc_output, self._tape.shots.has_partitioned_shots + ) - # If the return type is not tuple (list or ndarray) (Autograd and TF backprop removed) - if not isinstance(self._qfunc_output, (tuple, qml.measurements.MeasurementProcess)): - has_partitioned_shots = self.tape.shots.has_partitioned_shots - if has_partitioned_shots: - res = tuple(type(self._qfunc_output)(r) for r in res) - else: - res = type(self._qfunc_output)(res) + def __call__(self, *args, **kwargs) -> qml.typing.Result: + old_interface = self.interface if old_interface == "auto": - self._interface = "auto" + interface = qml.math.get_interface(*args, *list(kwargs.values())) + self._interface = INTERFACE_MAP[interface] + + if self._qfunc_uses_shots_arg: + override_shots = False + else: + if "shots" not in kwargs: + kwargs["shots"] = _get_device_shots(self._original_device) + override_shots = kwargs["shots"] + + # construct the tape + self.construct(args, kwargs) + + original_grad_fn = [self.gradient_fn, self.gradient_kwargs, self.device] + self._update_gradient_fn(shots=override_shots, tape=self._tape) + + try: + res = self._execution_component(args, kwargs, override_shots=override_shots) + finally: + if old_interface == "auto": + self._interface = "auto" - self._update_original_device() + self._update_original_device() - _, self.gradient_kwargs, self.device = original_grad_fn + _, self.gradient_kwargs, self.device = original_grad_fn return res diff --git a/tests/test_qnode.py b/tests/test_qnode.py index d027a3ae763..5d18fc79a8e 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -49,6 +49,8 @@ def execute(self, circuits, execution_config=None): class CustomDeviceWithDiffMethod(qml.devices.Device): + """A device that defines a derivative.""" + def execute(self, circuits, execution_config=None): return 0 @@ -79,6 +81,8 @@ def test_copy(): class TestInitialization: + """Testing the initialization of the qnode.""" + def test_cache_initialization_maxdiff_1(self): """Test that when max_diff = 1, the cache initializes to false.""" @@ -929,6 +933,9 @@ def conditional_ry_qnode(x, y): @pytest.mark.parametrize("dev_name", ["default.qubit.legacy", "default.mixed"]) def test_dynamic_one_shot_if_mcm_unsupported(self, dev_name): + """Test an error is raised if the dynamic one shot transform is a applied to a qnode with a device that + does not support mid circuit measurements. + """ dev = qml.device(dev_name, wires=2, shots=100) with pytest.raises( @@ -1363,6 +1370,8 @@ def qn2(x, y): class TestTransformProgramIntegration: + """Tests for the integration of the transform program with the qnode.""" + def test_transform_program_modifies_circuit(self): """Test qnode integration with a transform that turns the circuit into just a pauli x.""" dev = qml.device("default.qubit", wires=1) @@ -1638,6 +1647,8 @@ def test_device_with_custom_diff_method_name(self): """Test a device that has its own custom diff method.""" class CustomDeviceWithDiffMethod2(qml.devices.DefaultQubit): + """A device with a custom derivative named hello.""" + def supports_derivatives(self, execution_config=None, circuit=None): return getattr(execution_config, "gradient_method", None) == "hello" @@ -1860,3 +1871,21 @@ def circuit(x): ): x = pnp.array([0.5, 0.4, 0.3], requires_grad=True) circuit.construct([x], {}) + + +def test_resets_after_execution_error(): + """Test that the interface is reset to ``"auto"`` if an error occurs during execution.""" + + # pylint: disable=too-few-public-methods + class BadOp(qml.operation.Operator): + """An operator that will cause an error during execution.""" + + @qml.qnode(qml.device("default.qubit")) + def circuit(x): + BadOp(x, wires=0) + return qml.state() + + with pytest.raises(qml.DeviceError): + circuit(qml.numpy.array(0.1)) + + assert circuit.interface == "auto"