Skip to content

Commit

Permalink
Fix gradient transforms with default qubit and overridden shot vectors (
Browse files Browse the repository at this point in the history
#4795)

[sc-44395]

With this change, we can now do:
```
dev = qml.device("default.qubit")

@qml.qnode(dev)
def circuit(x):
    qml.RX(x, wires=0)
    return qml.expval(qml.PauliZ(0))

x = qml.numpy.array(0.543, requires_grad=True)

qml.gradients.param_shift(circuit)(x, shots=(10000, 10000, 1000) )
```
```
(array(-0.5209), array(-0.5258), array(-0.522))
```

---------

Co-authored-by: Matthew Silverman <[email protected]>
  • Loading branch information
albi3ro and timmysilv authored Nov 9, 2023
1 parent f845ae9 commit deaf387
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 2 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@

<h3>Bug fixes 🐛</h3>

* Gradient transforms now work with overridden shot vectors and default qubit.
[(#4795)](https://github.com/PennyLaneAI/pennylane/pull/4795)

* `qml.defer_measurements` now correctly transforms circuits when terminal measurements include wires
used in mid-circuit measurements.
[(#4787)](https://github.com/PennyLaneAI/pennylane/pull/4787)
Expand Down
2 changes: 2 additions & 0 deletions pennylane/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
This module contains the QNode class and qnode decorator.
"""
# pylint: disable=too-many-instance-attributes,too-many-arguments,protected-access,unnecessary-lambda-assignment, too-many-branches, too-many-statements
import copy
import functools
import inspect
import warnings
Expand Down Expand Up @@ -842,6 +843,7 @@ def tape(self) -> QuantumTape:

def construct(self, args, kwargs): # pylint: disable=too-many-branches
"""Call the quantum function with a tape context, ensuring the operations get queued."""
kwargs = copy.copy(kwargs)
old_interface = self.interface

if self._qfunc_uses_shots_arg:
Expand Down
4 changes: 2 additions & 2 deletions tests/gradients/core/test_gradient_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def test_setting_shots(self):
"""Test that setting the number of shots works correctly for
a gradient transform"""

dev = qml.device("default.qubit.legacy", wires=1, shots=1000)
dev = qml.device("default.qubit", wires=1, shots=1000)

@qml.qnode(dev)
def circuit(x):
Expand All @@ -617,7 +617,7 @@ def circuit(x):
# the gradient function can be called with different shot values
grad_fn = qml.gradients.param_shift(circuit)
assert grad_fn(x).shape == ()
assert grad_fn(x, shots=[(1, 1000)]).shape == (1000,)
assert len(grad_fn(x, shots=[(1, 1000)])) == 1000

# the original QNode is unaffected
assert circuit(x).shape == tuple()
Expand Down

0 comments on commit deaf387

Please sign in to comment.