Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix zero-jvps with shot vectors #6219

Merged
merged 9 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@

<h3>Bug fixes 🐛</h3>

* Fix a bug where zero-valued JVPs were calculated wrongly in the presence of shot vectors.
[(#6219)](https://github.com/PennyLaneAI/pennylane/pull/6219)

* Fix `qml.PrepSelPrep` template to work with `torch`:
[(#6191)](https://github.com/PennyLaneAI/pennylane/pull/6191)

Expand All @@ -71,4 +74,5 @@ Utkarsh Azad,
Lillian M. A. Frederiksen,
Christina Lee,
William Maxwell,
Lee J. O'Riordan,
Lee J. O'Riordan,
David Wierichs,
11 changes: 8 additions & 3 deletions pennylane/gradients/jvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,16 @@ def jvp(tape, tangent, gradient_fn, gradient_kwargs=None):
if len(tape.trainable_params) == 0:
# The tape has no trainable parameters; the JVP
# is simply none.
def zero_vjp(_):
res = tuple(np.zeros(mp.shape(None, tape.shots)) for mp in tape.measurements)
def zero_jvp_for_single_shots(s):
res = tuple(np.zeros(mp.shape(shots=s), dtype=mp.numeric_type) for mp in tape.measurements)
return res[0] if len(tape.measurements) == 1 else res

return tuple(), zero_vjp
def zero_jvp(_):
if tape.shots.has_partitioned_shots:
return tuple(zero_jvp_for_single_shots(s) for s in tape.shots)
return zero_jvp_for_single_shots(tape.shots.total_shots)

return tuple(), zero_jvp

multi_m = len(tape.measurements) > 1

Expand Down
22 changes: 12 additions & 10 deletions pennylane/workflow/jacobian_products.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ def _compute_vjps(jacs, dys, tapes):
return tuple(vjps)


def _zero_jvp_single_shots(shots, tape):
jvp = tuple(np.zeros(mp.shape(shots=shots), dtype=mp.numeric_type) for mp in tape.measurements)
return jvp[0] if len(tape.measurements) == 1 else jvp


def _zero_jvp(tape):
if tape.shots.has_partitioned_shots:
return tuple(_zero_jvp_single_shots(s, tape) for s in tape.shots)
return _zero_jvp_single_shots(tape.shots.total_shots, tape)


def _compute_jvps(jacs, tangents, tapes):
"""Compute the jvps of multiple tapes, directly for a Jacobian and tangents."""
f = {True: qml.gradients.compute_jvp_multi, False: qml.gradients.compute_jvp_single}
Expand All @@ -54,16 +65,7 @@ def _compute_jvps(jacs, tangents, tapes):
for jac, dx, t in zip(jacs, tangents, tapes):
multi = len(t.measurements) > 1
if len(t.trainable_params) == 0:
empty_shots = qml.measurements.Shots(None)
zeros_jvp = tuple(
np.zeros(mp.shape(None, empty_shots), dtype=mp.numeric_type)
for mp in t.measurements
)
zeros_jvp = zeros_jvp[0] if len(t.measurements) == 1 else zeros_jvp
if t.shots.has_partitioned_shots:
jvps.append(tuple(zeros_jvp for _ in range(t.shots.num_copies)))
else:
jvps.append(zeros_jvp)
jvps.append(_zero_jvp(t))
elif t.shots.has_partitioned_shots:
jvps.append(tuple(f[multi](dx, j) for j in jac))
else:
Expand Down
17 changes: 10 additions & 7 deletions tests/gradients/core/test_jvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pennylane as qml
from pennylane import numpy as np
from pennylane.gradients import param_shift
from pennylane.measurements.shots import Shots

_x = np.arange(12).reshape((2, 3, 2))

Expand Down Expand Up @@ -799,7 +800,8 @@ def cost_fn(params, tangent):
class TestBatchJVP:
"""Tests for the batch JVP function"""

def test_one_tape_no_trainable_parameters(self):
@pytest.mark.parametrize("shots", [Shots(None), Shots(10), Shots([20, 10])])
def test_one_tape_no_trainable_parameters(self, shots):
"""A tape with no trainable parameters will simply return None"""
dev = qml.device("default.qubit", wires=2)

Expand All @@ -808,46 +810,47 @@ def test_one_tape_no_trainable_parameters(self):
qml.CNOT(wires=[0, 1])
qml.expval(qml.PauliZ(0))

tape1 = qml.tape.QuantumScript.from_queue(q1)
tape1 = qml.tape.QuantumScript.from_queue(q1, shots=shots)
with qml.queuing.AnnotatedQueue() as q2:
qml.RX(0.4, wires=0)
qml.RX(0.6, wires=0)
qml.CNOT(wires=[0, 1])
qml.expval(qml.PauliZ(0))

tape2 = qml.tape.QuantumScript.from_queue(q2)
tape2 = qml.tape.QuantumScript.from_queue(q2, shots=shots)
tape1.trainable_params = {}
tape2.trainable_params = {0, 1}

tapes = [tape1, tape2]
tangents = [np.array([1.0, 1.0]), np.array([1.0, 1.0])]

v_tapes, fn = qml.gradients.batch_jvp(tapes, tangents, param_shift)
assert len(v_tapes) == 4

# Even though there are 3 parameters, only two contribute
# to the JVP, so only 2*2=4 quantum evals
assert len(v_tapes) == 4
res = fn(dev.execute(v_tapes))

assert qml.math.allclose(res[0], np.array(0.0))
assert res[1] is not None

def test_all_tapes_no_trainable_parameters(self):
@pytest.mark.parametrize("shots", [Shots(None), Shots(10), Shots([20, 10])])
def test_all_tapes_no_trainable_parameters(self, shots):
"""If all tapes have no trainable parameters all outputs will be None"""

with qml.queuing.AnnotatedQueue() as q1:
qml.RX(0.4, wires=0)
qml.CNOT(wires=[0, 1])
qml.expval(qml.PauliZ(0))

tape1 = qml.tape.QuantumScript.from_queue(q1)
tape1 = qml.tape.QuantumScript.from_queue(q1, shots=shots)
with qml.queuing.AnnotatedQueue() as q2:
qml.RX(0.4, wires=0)
qml.RX(0.6, wires=0)
qml.CNOT(wires=[0, 1])
qml.expval(qml.PauliZ(0))

tape2 = qml.tape.QuantumScript.from_queue(q2)
tape2 = qml.tape.QuantumScript.from_queue(q2, shots=shots)
tape1.trainable_params = set()
tape2.trainable_params = set()

Expand Down
Loading