diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 0064f239c96..051f9d4cdf7 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -193,6 +193,9 @@ same information.
Bug fixes 🐛
+* `QNode` return behaviour is now consistent for lists and tuples.
+ [(#6568)](https://github.com/PennyLaneAI/pennylane/pull/6568)
+
* `qml.QNode` now accepts arguments with types defined in libraries that are not necessarily
in the list of supported interfaces, such as the `Graph` class defined in `networkx`.
[(#6600)](https://github.com/PennyLaneAI/pennylane/pull/6600)
diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py
index 267df623385..c9d0158ff8b 100644
--- a/pennylane/workflow/qnode.py
+++ b/pennylane/workflow/qnode.py
@@ -208,8 +208,8 @@ def _to_qfunc_output_type(
return tuple(_to_qfunc_output_type(r, qfunc_output, False) for r in results)
# Special case of single Measurement in a list
- if isinstance(qfunc_output, list) and len(qfunc_output) == 1:
- results = [results]
+ if isinstance(qfunc_output, Sequence) and len(qfunc_output) == 1:
+ results = (results,)
# If the return type is not tuple (list or ndarray) (Autograd and TF backprop removed)
if isinstance(qfunc_output, (tuple, qml.measurements.MeasurementProcess)):
diff --git a/tests/test_qnode.py b/tests/test_qnode.py
index aa0d7152b7b..db332ade2ab 100644
--- a/tests/test_qnode.py
+++ b/tests/test_qnode.py
@@ -148,6 +148,16 @@ def f():
class TestValidation:
"""Tests for QNode creation and validation"""
+ @pytest.mark.parametrize("return_type", (tuple, list))
+ def test_return_behaviour_consistency(self, return_type):
+ """Test that the QNode return typing stays consistent"""
+
+ @qml.qnode(qml.device("default.qubit"))
+ def circuit(return_type):
+ return return_type([qml.expval(qml.Z(0))])
+
+ assert isinstance(circuit(return_type), return_type)
+
def test_expansion_strategy_error(self):
"""Test that an error is raised if expansion_strategy is passed to the qnode."""
diff --git a/tests/transforms/test_broadcast_expand.py b/tests/transforms/test_broadcast_expand.py
index 0e0d71bb6b8..c783f5d8794 100644
--- a/tests/transforms/test_broadcast_expand.py
+++ b/tests/transforms/test_broadcast_expand.py
@@ -330,7 +330,7 @@ def cost(*params):
assert all(qml.math.allclose(_jac, e_jac) for _jac, e_jac in zip(jac[0], exp_jac[0]))
assert all(qml.math.allclose(_jac, e_jac) for _jac, e_jac in zip(jac[1], exp_jac[1]))
else:
- assert all(qml.math.allclose(_jac, e_jac) for _jac, e_jac in zip(jac, exp_jac))
+ assert all(qml.math.allclose(_jac, e_jac) for _jac, e_jac in zip(jac[0], exp_jac))
@pytest.mark.slow
@pytest.mark.tf
@@ -393,6 +393,7 @@ def cost(*params):
qml.math.stack([jac[i][j] for i in range(len(obs))]) for j in range(len(params))
)
else:
- assert qml.math.allclose(res, exp_fn(*params))
+ assert qml.math.allclose(res[0], exp_fn(*params))
+ jac = jac[0]
assert all(qml.math.allclose(_jac, e_jac) for _jac, e_jac in zip(jac, exp_jac))