diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index ee1ec88f389..d1f17ae1dae 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -291,6 +291,9 @@
Bug fixes 🐛
+* The `QNode` interface now resets if an error occurs during execution.
+ [(#5449)](https://github.com/PennyLaneAI/pennylane/pull/5449)
+
* Fix failing tests due to changes with Lightning's adjoint diff pipeline.
[(#5450)](https://github.com/PennyLaneAI/pennylane/pull/5450)
diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py
index 8e57af1ff62..1e446de482d 100644
--- a/pennylane/workflow/qnode.py
+++ b/pennylane/workflow/qnode.py
@@ -83,6 +83,21 @@ def _make_execution_config(
)
+def _to_qfunc_output_type(
+ results: qml.typing.Result, qfunc_output, has_partitioned_shots
+) -> qml.typing.Result:
+ # Special case of single Measurement in a list
+ if isinstance(qfunc_output, list) and len(qfunc_output) == 1:
+ results = [results]
+
+ # If the return type is not tuple (list or ndarray) (Autograd and TF backprop removed)
+ if isinstance(qfunc_output, (tuple, qml.measurements.MeasurementProcess)):
+ return results
+ if has_partitioned_shots:
+ return tuple(type(qfunc_output)(r) for r in results)
+ return type(qfunc_output)(results)
+
+
class QNode:
"""Represents a quantum node in the hybrid computational graph.
@@ -972,25 +987,18 @@ def construct(self, args, kwargs): # pylint: disable=too-many-branches
if old_interface == "auto":
self.interface = "auto"
- def __call__(self, *args, **kwargs) -> qml.typing.Result:
+ def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml.typing.Result:
+ """Construct the transform program and execute the tapes. Helper function for ``__call__``
- old_interface = self.interface
- if old_interface == "auto":
- interface = qml.math.get_interface(*args, *list(kwargs.values()))
- self._interface = INTERFACE_MAP[interface]
-
- if self._qfunc_uses_shots_arg:
- override_shots = False
- else:
- if "shots" not in kwargs:
- kwargs["shots"] = _get_device_shots(self._original_device)
- override_shots = kwargs["shots"]
+ Args:
+ args (tuple): the arguments the QNode is called with
+ kwargs (dict): the keyword arguments the QNode is called with
+ override_shots : the shots to use for the execution.
- # construct the tape
- self.construct(args, kwargs)
+ Returns:
+ Result
- original_grad_fn = [self.gradient_fn, self.gradient_kwargs, self.device]
- self._update_gradient_fn(shots=override_shots, tape=self._tape)
+ """
cache = self.execute_kwargs.get("cache", False)
using_custom_cache = (
@@ -1049,24 +1057,39 @@ def __call__(self, *args, **kwargs) -> qml.typing.Result:
):
res = _convert_to_interface(res, self.interface)
- # Special case of single Measurement in a list
- if isinstance(self._qfunc_output, list) and len(self._qfunc_output) == 1:
- res = [res]
+ return _to_qfunc_output_type(
+ res, self._qfunc_output, self._tape.shots.has_partitioned_shots
+ )
- # If the return type is not tuple (list or ndarray) (Autograd and TF backprop removed)
- if not isinstance(self._qfunc_output, (tuple, qml.measurements.MeasurementProcess)):
- has_partitioned_shots = self.tape.shots.has_partitioned_shots
- if has_partitioned_shots:
- res = tuple(type(self._qfunc_output)(r) for r in res)
- else:
- res = type(self._qfunc_output)(res)
+ def __call__(self, *args, **kwargs) -> qml.typing.Result:
+ old_interface = self.interface
if old_interface == "auto":
- self._interface = "auto"
+ interface = qml.math.get_interface(*args, *list(kwargs.values()))
+ self._interface = INTERFACE_MAP[interface]
+
+ if self._qfunc_uses_shots_arg:
+ override_shots = False
+ else:
+ if "shots" not in kwargs:
+ kwargs["shots"] = _get_device_shots(self._original_device)
+ override_shots = kwargs["shots"]
+
+ # construct the tape
+ self.construct(args, kwargs)
+
+ original_grad_fn = [self.gradient_fn, self.gradient_kwargs, self.device]
+ self._update_gradient_fn(shots=override_shots, tape=self._tape)
+
+ try:
+ res = self._execution_component(args, kwargs, override_shots=override_shots)
+ finally:
+ if old_interface == "auto":
+ self._interface = "auto"
- self._update_original_device()
+ self._update_original_device()
- _, self.gradient_kwargs, self.device = original_grad_fn
+ _, self.gradient_kwargs, self.device = original_grad_fn
return res
diff --git a/tests/test_qnode.py b/tests/test_qnode.py
index d027a3ae763..5d18fc79a8e 100644
--- a/tests/test_qnode.py
+++ b/tests/test_qnode.py
@@ -49,6 +49,8 @@ def execute(self, circuits, execution_config=None):
class CustomDeviceWithDiffMethod(qml.devices.Device):
+ """A device that defines a derivative."""
+
def execute(self, circuits, execution_config=None):
return 0
@@ -79,6 +81,8 @@ def test_copy():
class TestInitialization:
+ """Testing the initialization of the qnode."""
+
def test_cache_initialization_maxdiff_1(self):
"""Test that when max_diff = 1, the cache initializes to false."""
@@ -929,6 +933,9 @@ def conditional_ry_qnode(x, y):
@pytest.mark.parametrize("dev_name", ["default.qubit.legacy", "default.mixed"])
def test_dynamic_one_shot_if_mcm_unsupported(self, dev_name):
+ """Test an error is raised if the dynamic one shot transform is a applied to a qnode with a device that
+ does not support mid circuit measurements.
+ """
dev = qml.device(dev_name, wires=2, shots=100)
with pytest.raises(
@@ -1363,6 +1370,8 @@ def qn2(x, y):
class TestTransformProgramIntegration:
+ """Tests for the integration of the transform program with the qnode."""
+
def test_transform_program_modifies_circuit(self):
"""Test qnode integration with a transform that turns the circuit into just a pauli x."""
dev = qml.device("default.qubit", wires=1)
@@ -1638,6 +1647,8 @@ def test_device_with_custom_diff_method_name(self):
"""Test a device that has its own custom diff method."""
class CustomDeviceWithDiffMethod2(qml.devices.DefaultQubit):
+ """A device with a custom derivative named hello."""
+
def supports_derivatives(self, execution_config=None, circuit=None):
return getattr(execution_config, "gradient_method", None) == "hello"
@@ -1860,3 +1871,21 @@ def circuit(x):
):
x = pnp.array([0.5, 0.4, 0.3], requires_grad=True)
circuit.construct([x], {})
+
+
+def test_resets_after_execution_error():
+ """Test that the interface is reset to ``"auto"`` if an error occurs during execution."""
+
+ # pylint: disable=too-few-public-methods
+ class BadOp(qml.operation.Operator):
+ """An operator that will cause an error during execution."""
+
+ @qml.qnode(qml.device("default.qubit"))
+ def circuit(x):
+ BadOp(x, wires=0)
+ return qml.state()
+
+ with pytest.raises(qml.DeviceError):
+ circuit(qml.numpy.array(0.1))
+
+ assert circuit.interface == "auto"