Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Capture] iterative_qpe uses qml control flow functions #6680

Merged
merged 9 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ such as `shots`, `rng` and `prng_key`.

<h4>Capturing and representing hybrid programs</h4>

* The `qml.iterative_qpe` function can now be compactly captured into jaxpr.
[(#6680)](https://github.com/PennyLaneAI/pennylane/pull/6680)

* Functions and plxpr can now be natively transformed using the new `qml.capture.transforms.DecomposeInterpreter`
when program capture is enabled. This class decomposes pennylane operators following the same API as
`qml.transforms.decompose`.
Expand Down
1 change: 1 addition & 0 deletions pennylane/devices/qubit/dq_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def interpret_operation(self, op):
self.state = apply_operation(op, self.state, is_state_batched=self.is_state_batched)
if op.batch_size:
self.is_state_batched = True
return op

def interpret_measurement_eqn(self, eqn: "jax.core.JaxprEqn"):
if "mcm" in eqn.primitive.name:
Expand Down
51 changes: 35 additions & 16 deletions pennylane/ops/functions/iterative_qpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ def iterative_qpe(base, aux_wire, iters):
estimation and returns a list of mid-circuit measurements with qubit reset.

Args:
base (Operator): the phase estimation unitary, specified as an :class:`~.Operator`
aux_wire (Union[Wires, int, str]): the wire to be used for the estimation
iters (int): the number of measurements to be performed
base (Operator): the phase estimation unitary, specified as an :class:`~.Operator`
aux_wire (Union[Wires, int, str]): the wire to be used for the estimation
iters (int): the number of measurements to be performed

Returns:
list[MidMeasureMP]: the list of measurements performed
list[MeasurementValue]: the abstract results of the mid-circuit measurements

.. seealso:: :class:`~.QuantumPhaseEstimation`, :func:`~.measure`

Expand All @@ -46,13 +46,13 @@ def iterative_qpe(base, aux_wire, iters):
@qml.qnode(dev)
def circuit():

# Initial state
qml.X(0)
# Initial state
qml.X(0)

# Iterative QPE
measurements = qml.iterative_qpe(qml.RZ(2.0, wires=[0]), aux_wire=1, iters=3)
# Iterative QPE
measurements = qml.iterative_qpe(qml.RZ(2.0, wires=[0]), aux_wire=1, iters=3)

return qml.sample(measurements)
return qml.sample(measurements)

.. code-block:: pycon

Expand All @@ -75,16 +75,35 @@ def circuit():
╚══════════════════════╩═════════════════════════║═══════╡ ├Sample[MCM]
╚═══════╡ ╰Sample[MCM]
"""
measurements = []
if qml.capture.enabled():
measurements = qml.math.zeros(iters, dtype=int, like="jax")
else:
measurements = [0] * iters

def measurement_loop(i, measurements, target):
# closure: aux_wire, iters, target

for i in range(iters):
qml.Hadamard(wires=aux_wire)
qml.ctrl(qml.pow(base, z=2 ** (iters - i - 1)), control=aux_wire)
qml.ctrl(qml.pow(target, z=2 ** (iters - i - 1)), control=aux_wire)

def conditional_loop(j):
# closure: measurements, iters, i, aux_wire
meas = measurements[iters - i + j]

def cond_func():
qml.PhaseShift(-2.0 * np.pi / (2 ** (j + 2)), wires=aux_wire)

for ind, meas in enumerate(measurements):
qml.cond(meas, qml.PhaseShift)(-2.0 * np.pi / 2 ** (ind + 2), wires=aux_wire)
qml.cond(meas, cond_func)()

qml.for_loop(i)(conditional_loop)()

qml.Hadamard(wires=aux_wire)
measurements.insert(0, qml.measure(wires=aux_wire, reset=True))
m = qml.measure(wires=aux_wire, reset=True)
if qml.capture.enabled():
measurements = measurements.at[iters - i - 1].set(m)
else:
measurements[iters - i - 1] = m

return measurements, target

return measurements
return qml.for_loop(iters)(measurement_loop)(measurements, base)[0]
48 changes: 48 additions & 0 deletions tests/ops/functions/test_iterative_qpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,51 @@ def circuit_iterative():
return [qml.expval(op=i) for i in measurements]

assert np.allclose(circuit_qpe(), circuit_iterative())


@pytest.mark.slow
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.jax
def test_capture_execution(seed):
"""Test that iterative qpe can be captured and executed.

While this is a rather bad test:
* the captured jaxpr has too many classical instructions for
easy verification of its contents
* The captured jaxpr cannot be used with CollectOpsandMeas as it converts mcm integers to
measurement values, which are incompatible with the scatter operation used in
`measurements = measurements.at[iters - i - 1].set(m)`
* Evaluating jaxpr currently uses single-branch-statistics, which gives incorrect results for a
a single execution.


"""

qml.capture.enable()
import jax

def f(x):
qml.X(0)
return qml.iterative_qpe(qml.RZ(x, wires=[0]), aux_wire=1, iters=3)

x = jax.numpy.array(2.0)

jaxpr = jax.make_jaxpr(f)(1.5)

dev = qml.device("default.qubit", wires=3, seed=seed)

# hack for single-branch statistics
samples = qml.math.vstack([dev.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, x) for _ in range(5000)])
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
probs_capture = qml.probs(wires=(0, 1, 2)).process_samples(
samples, wire_order=qml.wires.Wires((0, 1, 2))
)

qml.capture.disable()

@qml.qnode(dev)
def normal_qnode(x):
meas = f(x)
return qml.probs(op=meas)

probs_normal = normal_qnode(x)

assert qml.math.allclose(probs_capture, probs_normal, atol=0.02)
Loading