Skip to content

Commit

Permalink
Fix zero-jvps with shot vectors (#6219)
Browse files Browse the repository at this point in the history
**Context:**
If a tape has no trainable parameters, generic zero-valued JVPs are
created for it in JVP calculations.
However, these generic calculations are not taking shot vectors
correctly into account, also because they do not use
`MeasurementProcess.shape` correctly.

**Description of the Change:**
Change the generic zero-valued JVPs so they are compatible with shot
vectors.

**Benefits:**
One bug less.

**Possible Drawbacks:**
N/A

**Related GitHub Issues:**
Fixes #6220 

[sc-73033]

---------

Co-authored-by: Christina Lee <[email protected]>
  • Loading branch information
2 people authored and mudit2812 committed Sep 10, 2024
1 parent 2e15022 commit 065d240
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 20 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@

<h3>Bug fixes 🐛</h3>

* Fix a bug where zero-valued JVPs were calculated wrongly in the presence of shot vectors.
[(#6219)](https://github.com/PennyLaneAI/pennylane/pull/6219)

* Fix `qml.PrepSelPrep` template to work with `torch`:
[(#6191)](https://github.com/PennyLaneAI/pennylane/pull/6191)

Expand Down
13 changes: 10 additions & 3 deletions pennylane/gradients/jvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,18 @@ def jvp(tape, tangent, gradient_fn, gradient_kwargs=None):
if len(tape.trainable_params) == 0:
# The tape has no trainable parameters; the JVP
# is simply none.
def zero_vjp(_):
res = tuple(np.zeros(mp.shape(None, tape.shots)) for mp in tape.measurements)
def zero_jvp_for_single_shots(s):
res = tuple(
np.zeros(mp.shape(shots=s), dtype=mp.numeric_type) for mp in tape.measurements
)
return res[0] if len(tape.measurements) == 1 else res

return tuple(), zero_vjp
def zero_jvp(_):
if tape.shots.has_partitioned_shots:
return tuple(zero_jvp_for_single_shots(s) for s in tape.shots)
return zero_jvp_for_single_shots(tape.shots.total_shots)

return tuple(), zero_jvp

multi_m = len(tape.measurements) > 1

Expand Down
22 changes: 12 additions & 10 deletions pennylane/workflow/jacobian_products.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ def _compute_vjps(jacs, dys, tapes):
return tuple(vjps)


def _zero_jvp_single_shots(shots, tape):
jvp = tuple(np.zeros(mp.shape(shots=shots), dtype=mp.numeric_type) for mp in tape.measurements)
return jvp[0] if len(tape.measurements) == 1 else jvp


def _zero_jvp(tape):
if tape.shots.has_partitioned_shots:
return tuple(_zero_jvp_single_shots(s, tape) for s in tape.shots)
return _zero_jvp_single_shots(tape.shots.total_shots, tape)


def _compute_jvps(jacs, tangents, tapes):
"""Compute the jvps of multiple tapes, directly for a Jacobian and tangents."""
f = {True: qml.gradients.compute_jvp_multi, False: qml.gradients.compute_jvp_single}
Expand All @@ -54,16 +65,7 @@ def _compute_jvps(jacs, tangents, tapes):
for jac, dx, t in zip(jacs, tangents, tapes):
multi = len(t.measurements) > 1
if len(t.trainable_params) == 0:
empty_shots = qml.measurements.Shots(None)
zeros_jvp = tuple(
np.zeros(mp.shape(None, empty_shots), dtype=mp.numeric_type)
for mp in t.measurements
)
zeros_jvp = zeros_jvp[0] if len(t.measurements) == 1 else zeros_jvp
if t.shots.has_partitioned_shots:
jvps.append(tuple(zeros_jvp for _ in range(t.shots.num_copies)))
else:
jvps.append(zeros_jvp)
jvps.append(_zero_jvp(t))
elif t.shots.has_partitioned_shots:
jvps.append(tuple(f[multi](dx, j) for j in jac))
else:
Expand Down
17 changes: 10 additions & 7 deletions tests/gradients/core/test_jvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pennylane as qml
from pennylane import numpy as np
from pennylane.gradients import param_shift
from pennylane.measurements.shots import Shots

_x = np.arange(12).reshape((2, 3, 2))

Expand Down Expand Up @@ -799,7 +800,8 @@ def cost_fn(params, tangent):
class TestBatchJVP:
"""Tests for the batch JVP function"""

def test_one_tape_no_trainable_parameters(self):
@pytest.mark.parametrize("shots", [Shots(None), Shots(10), Shots([20, 10])])
def test_one_tape_no_trainable_parameters(self, shots):
"""A tape with no trainable parameters will simply return None"""
dev = qml.device("default.qubit", wires=2)

Expand All @@ -808,46 +810,47 @@ def test_one_tape_no_trainable_parameters(self):
qml.CNOT(wires=[0, 1])
qml.expval(qml.PauliZ(0))

tape1 = qml.tape.QuantumScript.from_queue(q1)
tape1 = qml.tape.QuantumScript.from_queue(q1, shots=shots)
with qml.queuing.AnnotatedQueue() as q2:
qml.RX(0.4, wires=0)
qml.RX(0.6, wires=0)
qml.CNOT(wires=[0, 1])
qml.expval(qml.PauliZ(0))

tape2 = qml.tape.QuantumScript.from_queue(q2)
tape2 = qml.tape.QuantumScript.from_queue(q2, shots=shots)
tape1.trainable_params = {}
tape2.trainable_params = {0, 1}

tapes = [tape1, tape2]
tangents = [np.array([1.0, 1.0]), np.array([1.0, 1.0])]

v_tapes, fn = qml.gradients.batch_jvp(tapes, tangents, param_shift)
assert len(v_tapes) == 4

# Even though there are 3 parameters, only two contribute
# to the JVP, so only 2*2=4 quantum evals
assert len(v_tapes) == 4
res = fn(dev.execute(v_tapes))

assert qml.math.allclose(res[0], np.array(0.0))
assert res[1] is not None

def test_all_tapes_no_trainable_parameters(self):
@pytest.mark.parametrize("shots", [Shots(None), Shots(10), Shots([20, 10])])
def test_all_tapes_no_trainable_parameters(self, shots):
"""If all tapes have no trainable parameters all outputs will be None"""

with qml.queuing.AnnotatedQueue() as q1:
qml.RX(0.4, wires=0)
qml.CNOT(wires=[0, 1])
qml.expval(qml.PauliZ(0))

tape1 = qml.tape.QuantumScript.from_queue(q1)
tape1 = qml.tape.QuantumScript.from_queue(q1, shots=shots)
with qml.queuing.AnnotatedQueue() as q2:
qml.RX(0.4, wires=0)
qml.RX(0.6, wires=0)
qml.CNOT(wires=[0, 1])
qml.expval(qml.PauliZ(0))

tape2 = qml.tape.QuantumScript.from_queue(q2)
tape2 = qml.tape.QuantumScript.from_queue(q2, shots=shots)
tape1.trainable_params = set()
tape2.trainable_params = set()

Expand Down

0 comments on commit 065d240

Please sign in to comment.