Skip to content

Commit

Permalink
Fix consistency of QNode results processing (#6568)
Browse files Browse the repository at this point in the history
**Context:**

The output of QNode execution was not consistent based on the type of
the qfunc output,
```python
@qml.qnode(qml.device('default.qubit'))
def circuit(t):
    return t([qml.expval(qml.Z(0))])
>>> circuit(tuple)
1.0
>>> circuit(list)
[1.0]
```

**Description of the Change:**

Updates how the results of the QNode execution gets "type-converted"
according to the `qfunc_output`.

After this fix we get consistency,
```python
@qml.qnode(qml.device('default.qubit'))
def circuit(t):
    return t([qml.expval(qml.Z(0))])
>>> circuit(tuple)
(1.0,)
>>> circuit(list)
[1.0]
```

**Benefits:** Consistency

**Possible Drawbacks:** Output of QNode execution may look different
than people are used to. A direct example is needed to change some tests
to squeeze out the new dimension that has been introduced.

**Related GitHub Issue:** #6540 

[sc-77682]
  • Loading branch information
andrijapau authored Nov 21, 2024
1 parent cb994be commit 76f9d7c
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 4 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ same information.

<h3>Bug fixes 🐛</h3>

* `QNode` return behaviour is now consistent for lists and tuples.
[(#6568)](https://github.com/PennyLaneAI/pennylane/pull/6568)

* `qml.QNode` now accepts arguments with types defined in libraries that are not necessarily
in the list of supported interfaces, such as the `Graph` class defined in `networkx`.
[(#6600)](https://github.com/PennyLaneAI/pennylane/pull/6600)
Expand Down
4 changes: 2 additions & 2 deletions pennylane/workflow/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ def _to_qfunc_output_type(
return tuple(_to_qfunc_output_type(r, qfunc_output, False) for r in results)

# Special case of single Measurement in a list
if isinstance(qfunc_output, list) and len(qfunc_output) == 1:
results = [results]
if isinstance(qfunc_output, Sequence) and len(qfunc_output) == 1:
results = (results,)

# If the return type is not tuple (list or ndarray) (Autograd and TF backprop removed)
if isinstance(qfunc_output, (tuple, qml.measurements.MeasurementProcess)):
Expand Down
10 changes: 10 additions & 0 deletions tests/test_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,16 @@ def f():
class TestValidation:
"""Tests for QNode creation and validation"""

@pytest.mark.parametrize("return_type", (tuple, list))
def test_return_behaviour_consistency(self, return_type):
"""Test that the QNode return typing stays consistent"""

@qml.qnode(qml.device("default.qubit"))
def circuit(return_type):
return return_type([qml.expval(qml.Z(0))])

assert isinstance(circuit(return_type), return_type)

def test_expansion_strategy_error(self):
"""Test that an error is raised if expansion_strategy is passed to the qnode."""

Expand Down
5 changes: 3 additions & 2 deletions tests/transforms/test_broadcast_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def cost(*params):
assert all(qml.math.allclose(_jac, e_jac) for _jac, e_jac in zip(jac[0], exp_jac[0]))
assert all(qml.math.allclose(_jac, e_jac) for _jac, e_jac in zip(jac[1], exp_jac[1]))
else:
assert all(qml.math.allclose(_jac, e_jac) for _jac, e_jac in zip(jac, exp_jac))
assert all(qml.math.allclose(_jac, e_jac) for _jac, e_jac in zip(jac[0], exp_jac))

@pytest.mark.slow
@pytest.mark.tf
Expand Down Expand Up @@ -393,6 +393,7 @@ def cost(*params):
qml.math.stack([jac[i][j] for i in range(len(obs))]) for j in range(len(params))
)
else:
assert qml.math.allclose(res, exp_fn(*params))
assert qml.math.allclose(res[0], exp_fn(*params))
jac = jac[0]

assert all(qml.math.allclose(_jac, e_jac) for _jac, e_jac in zip(jac, exp_jac))

0 comments on commit 76f9d7c

Please sign in to comment.