Skip to content

Commit

Permalink
[sc-69429] SPSA keyword argument bug fix (#6027)
Browse files Browse the repository at this point in the history
**Context:** `SPSAOptimizer`'s `step_and_cost` method was ignoring
keyword arguments in the objective function.

**Description of the Change:** `SPSAOptimizer.step_and_cost` no longer
ignores kwargs.

**Benefits:** Your keyword arguments are heard <3 

**Possible Drawbacks:** None

**Related github issue:**
#6028
  • Loading branch information
isaacdevlugt authored Jul 23, 2024
1 parent ce3b4b6 commit f454f59
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 3 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@

<h3>Bug fixes 🐛</h3>

* Fixed a bug in `qml.SPSAOptimizer` that ignored keyword arguments in the objective function.
[(#6027)](https://github.com/PennyLaneAI/pennylane/pull/6027)

* `dynamic_one_shot` was broken for old-API devices since `override_shots` was deprecated.
[(#6024)](https://github.com/PennyLaneAI/pennylane/pull/6024)

Expand Down
4 changes: 3 additions & 1 deletion pennylane/optimize/spsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def step_and_cost(self, objective_fn, *args, **kwargs):
objective function output prior to the step.
"""
g = self.compute_grad(objective_fn, args, kwargs)

new_args = self.apply_grad(g, args)

self.k += 1
Expand Down Expand Up @@ -270,7 +271,8 @@ def compute_grad(self, objective_fn, args, kwargs):
shots = Shots(objective_fn.device._raw_shot_sequence) # pragma: no cover
else:
shots = Shots(None)
if np.prod(objective_fn.func(*args).shape(objective_fn.device, shots)) > 1:

if np.prod(objective_fn.func(*args, **kwargs).shape(objective_fn.device, shots)) > 1:
raise ValueError(
"The objective function must be a scalar function for the gradient "
"to be computed."
Expand Down
7 changes: 5 additions & 2 deletions tests/optimize/test_spsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def cost(params):

@pytest.mark.usefixtures("use_legacy_opmath")
@pytest.mark.slow
def test_lighting_device_legacy_opmath(self):
def test_lightning_device_legacy_opmath(self):
"""Test SPSAOptimizer implementation with lightning.qubit device."""
coeffs = [0.2, -0.543, 0.4514]
obs = [
Expand Down Expand Up @@ -479,7 +479,7 @@ def cost_fun(params, num_qubits=1):
assert energy < init_energy

@pytest.mark.slow
def test_lighting_device(self):
def test_lightning_device(self):
"""Test SPSAOptimizer implementation with lightning.qubit device."""
coeffs = [0.2, -0.543, 0.4514]
obs = [
Expand All @@ -494,6 +494,9 @@ def test_lighting_device(self):
@qml.qnode(dev)
def cost_fun(params, num_qubits=1):
qml.BasisState([1, 1, 0, 0], wires=range(num_qubits))

assert num_qubits == 4

for i in range(num_qubits):
qml.Rot(*params[i], wires=0)
qml.CNOT(wires=[2, 3])
Expand Down

0 comments on commit f454f59

Please sign in to comment.