Skip to content

Commit

Permalink
Reset qnode after execution error (#5449)
Browse files Browse the repository at this point in the history
**Context:**

While investigating Issue #5442, I was getting really confused by why my
tensorflow execution was giving a torch error. I realized that when the
execution fails, we do not properly reset `QNode.interface` back to
`"auto"`.

**Description of the Change:**

Pulls a block of the code to a helper method so that we can wrap it in a
`try-finally` block.

**Benefits:**

We can continue using a qnode after a failure has occured during
execution.

**Possible Drawbacks:**

**Related GitHub Issues:**

[sc-60047]

---------

Co-authored-by: Astral Cai <[email protected]>
  • Loading branch information
albi3ro and astralcai authored Apr 1, 2024
1 parent 08fad8c commit 05e266d
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 29 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,9 @@

<h3>Bug fixes 🐛</h3>

* The `QNode` interface now resets if an error occurs during execution.
[(#5449)](https://github.com/PennyLaneAI/pennylane/pull/5449)

* Fix failing tests due to changes with Lightning's adjoint diff pipeline.
[(#5450)](https://github.com/PennyLaneAI/pennylane/pull/5450)

Expand Down
81 changes: 52 additions & 29 deletions pennylane/workflow/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,21 @@ def _make_execution_config(
)


def _to_qfunc_output_type(
results: qml.typing.Result, qfunc_output, has_partitioned_shots
) -> qml.typing.Result:
# Special case of single Measurement in a list
if isinstance(qfunc_output, list) and len(qfunc_output) == 1:
results = [results]

# If the return type is not tuple (list or ndarray) (Autograd and TF backprop removed)
if isinstance(qfunc_output, (tuple, qml.measurements.MeasurementProcess)):
return results
if has_partitioned_shots:
return tuple(type(qfunc_output)(r) for r in results)
return type(qfunc_output)(results)


class QNode:
"""Represents a quantum node in the hybrid computational graph.
Expand Down Expand Up @@ -972,25 +987,18 @@ def construct(self, args, kwargs): # pylint: disable=too-many-branches
if old_interface == "auto":
self.interface = "auto"

def __call__(self, *args, **kwargs) -> qml.typing.Result:
def _execution_component(self, args: tuple, kwargs: dict, override_shots) -> qml.typing.Result:
"""Construct the transform program and execute the tapes. Helper function for ``__call__``
old_interface = self.interface
if old_interface == "auto":
interface = qml.math.get_interface(*args, *list(kwargs.values()))
self._interface = INTERFACE_MAP[interface]

if self._qfunc_uses_shots_arg:
override_shots = False
else:
if "shots" not in kwargs:
kwargs["shots"] = _get_device_shots(self._original_device)
override_shots = kwargs["shots"]
Args:
args (tuple): the arguments the QNode is called with
kwargs (dict): the keyword arguments the QNode is called with
override_shots : the shots to use for the execution.
# construct the tape
self.construct(args, kwargs)
Returns:
Result
original_grad_fn = [self.gradient_fn, self.gradient_kwargs, self.device]
self._update_gradient_fn(shots=override_shots, tape=self._tape)
"""

cache = self.execute_kwargs.get("cache", False)
using_custom_cache = (
Expand Down Expand Up @@ -1049,24 +1057,39 @@ def __call__(self, *args, **kwargs) -> qml.typing.Result:
):
res = _convert_to_interface(res, self.interface)

# Special case of single Measurement in a list
if isinstance(self._qfunc_output, list) and len(self._qfunc_output) == 1:
res = [res]
return _to_qfunc_output_type(
res, self._qfunc_output, self._tape.shots.has_partitioned_shots
)

# If the return type is not tuple (list or ndarray) (Autograd and TF backprop removed)
if not isinstance(self._qfunc_output, (tuple, qml.measurements.MeasurementProcess)):
has_partitioned_shots = self.tape.shots.has_partitioned_shots
if has_partitioned_shots:
res = tuple(type(self._qfunc_output)(r) for r in res)
else:
res = type(self._qfunc_output)(res)
def __call__(self, *args, **kwargs) -> qml.typing.Result:

old_interface = self.interface
if old_interface == "auto":
self._interface = "auto"
interface = qml.math.get_interface(*args, *list(kwargs.values()))
self._interface = INTERFACE_MAP[interface]

if self._qfunc_uses_shots_arg:
override_shots = False
else:
if "shots" not in kwargs:
kwargs["shots"] = _get_device_shots(self._original_device)
override_shots = kwargs["shots"]

# construct the tape
self.construct(args, kwargs)

original_grad_fn = [self.gradient_fn, self.gradient_kwargs, self.device]
self._update_gradient_fn(shots=override_shots, tape=self._tape)

try:
res = self._execution_component(args, kwargs, override_shots=override_shots)
finally:
if old_interface == "auto":
self._interface = "auto"

self._update_original_device()
self._update_original_device()

_, self.gradient_kwargs, self.device = original_grad_fn
_, self.gradient_kwargs, self.device = original_grad_fn

return res

Expand Down
29 changes: 29 additions & 0 deletions tests/test_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def execute(self, circuits, execution_config=None):


class CustomDeviceWithDiffMethod(qml.devices.Device):
"""A device that defines a derivative."""

def execute(self, circuits, execution_config=None):
return 0

Expand Down Expand Up @@ -79,6 +81,8 @@ def test_copy():


class TestInitialization:
"""Testing the initialization of the qnode."""

def test_cache_initialization_maxdiff_1(self):
"""Test that when max_diff = 1, the cache initializes to false."""

Expand Down Expand Up @@ -929,6 +933,9 @@ def conditional_ry_qnode(x, y):

@pytest.mark.parametrize("dev_name", ["default.qubit.legacy", "default.mixed"])
def test_dynamic_one_shot_if_mcm_unsupported(self, dev_name):
"""Test an error is raised if the dynamic one shot transform is a applied to a qnode with a device that
does not support mid circuit measurements.
"""
dev = qml.device(dev_name, wires=2, shots=100)

with pytest.raises(
Expand Down Expand Up @@ -1363,6 +1370,8 @@ def qn2(x, y):


class TestTransformProgramIntegration:
"""Tests for the integration of the transform program with the qnode."""

def test_transform_program_modifies_circuit(self):
"""Test qnode integration with a transform that turns the circuit into just a pauli x."""
dev = qml.device("default.qubit", wires=1)
Expand Down Expand Up @@ -1638,6 +1647,8 @@ def test_device_with_custom_diff_method_name(self):
"""Test a device that has its own custom diff method."""

class CustomDeviceWithDiffMethod2(qml.devices.DefaultQubit):
"""A device with a custom derivative named hello."""

def supports_derivatives(self, execution_config=None, circuit=None):
return getattr(execution_config, "gradient_method", None) == "hello"

Expand Down Expand Up @@ -1860,3 +1871,21 @@ def circuit(x):
):
x = pnp.array([0.5, 0.4, 0.3], requires_grad=True)
circuit.construct([x], {})


def test_resets_after_execution_error():
"""Test that the interface is reset to ``"auto"`` if an error occurs during execution."""

# pylint: disable=too-few-public-methods
class BadOp(qml.operation.Operator):
"""An operator that will cause an error during execution."""

@qml.qnode(qml.device("default.qubit"))
def circuit(x):
BadOp(x, wires=0)
return qml.state()

with pytest.raises(qml.DeviceError):
circuit(qml.numpy.array(0.1))

assert circuit.interface == "auto"

0 comments on commit 05e266d

Please sign in to comment.