diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 0064f239c96..051f9d4cdf7 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -193,6 +193,9 @@ same information.

Bug fixes 🐛

+* `QNode` return behaviour is now consistent for lists and tuples. + [(#6568)](https://github.com/PennyLaneAI/pennylane/pull/6568) + * `qml.QNode` now accepts arguments with types defined in libraries that are not necessarily in the list of supported interfaces, such as the `Graph` class defined in `networkx`. [(#6600)](https://github.com/PennyLaneAI/pennylane/pull/6600) diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 267df623385..c9d0158ff8b 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -208,8 +208,8 @@ def _to_qfunc_output_type( return tuple(_to_qfunc_output_type(r, qfunc_output, False) for r in results) # Special case of single Measurement in a list - if isinstance(qfunc_output, list) and len(qfunc_output) == 1: - results = [results] + if isinstance(qfunc_output, Sequence) and len(qfunc_output) == 1: + results = (results,) # If the return type is not tuple (list or ndarray) (Autograd and TF backprop removed) if isinstance(qfunc_output, (tuple, qml.measurements.MeasurementProcess)): diff --git a/tests/test_qnode.py b/tests/test_qnode.py index aa0d7152b7b..db332ade2ab 100644 --- a/tests/test_qnode.py +++ b/tests/test_qnode.py @@ -148,6 +148,16 @@ def f(): class TestValidation: """Tests for QNode creation and validation""" + @pytest.mark.parametrize("return_type", (tuple, list)) + def test_return_behaviour_consistency(self, return_type): + """Test that the QNode return typing stays consistent""" + + @qml.qnode(qml.device("default.qubit")) + def circuit(return_type): + return return_type([qml.expval(qml.Z(0))]) + + assert isinstance(circuit(return_type), return_type) + def test_expansion_strategy_error(self): """Test that an error is raised if expansion_strategy is passed to the qnode.""" diff --git a/tests/transforms/test_broadcast_expand.py b/tests/transforms/test_broadcast_expand.py index 0e0d71bb6b8..c783f5d8794 100644 --- a/tests/transforms/test_broadcast_expand.py +++ b/tests/transforms/test_broadcast_expand.py @@ -330,7 +330,7 @@ def cost(*params): assert all(qml.math.allclose(_jac, e_jac) for _jac, e_jac in zip(jac[0], exp_jac[0])) assert all(qml.math.allclose(_jac, e_jac) for _jac, e_jac in zip(jac[1], exp_jac[1])) else: - assert all(qml.math.allclose(_jac, e_jac) for _jac, e_jac in zip(jac, exp_jac)) + assert all(qml.math.allclose(_jac, e_jac) for _jac, e_jac in zip(jac[0], exp_jac)) @pytest.mark.slow @pytest.mark.tf @@ -393,6 +393,7 @@ def cost(*params): qml.math.stack([jac[i][j] for i in range(len(obs))]) for j in range(len(params)) ) else: - assert qml.math.allclose(res, exp_fn(*params)) + assert qml.math.allclose(res[0], exp_fn(*params)) + jac = jac[0] assert all(qml.math.allclose(_jac, e_jac) for _jac, e_jac in zip(jac, exp_jac))