Skip to content

Commit

Permalink
[Program capture] qml.capture.qnode_call works with closure variabl…
Browse files Browse the repository at this point in the history
…es and consts (#6052)

**Context:**

When trying to get nested qfunc controls to work properly, I finally
figured out what the `jaxpr.consts` actually is, and that we need it to
properly handle closure variables. Unfortunately, that also means we
need to go back and update the qnode primtive.

**Description of the Change:**

Turns the captured consts into positional arguments and add a `n_consts`
keyword argument to the primitive.

**Benefits:**

We can handle closure variables.

**Possible Drawbacks:**

**Related GitHub Issues:**
  • Loading branch information
albi3ro authored Jul 30, 2024
1 parent d74f3fd commit f9adf90
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 6 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@

<h3>Improvements 🛠</h3>

* During experimental program capture, the qnode can now use closure variables.
[(#6052)](https://github.com/PennyLaneAI/pennylane/pull/6052)

* `GlobalPhase` now supports parameter broadcasting.
[(#5923)](https://github.com/PennyLaneAI/pennylane/pull/5923)

Expand Down
18 changes: 12 additions & 6 deletions pennylane/capture/capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _get_shapes_for(*measurements, shots=None, num_device_wires=0):

for s in shots:
for m in measurements:
shape, dtype = m.abstract_eval(shots=s, num_device_wires=num_device_wires)
shape, dtype = m.aval.abstract_eval(shots=s, num_device_wires=num_device_wires)
shapes.append(jax.core.ShapedArray(shape, dtype_map.get(dtype, dtype)))
return shapes

Expand All @@ -61,10 +61,14 @@ def _get_qnode_prim():
qnode_prim = jax.core.Primitive("qnode")
qnode_prim.multiple_results = True

# pylint: disable=too-many-arguments
@qnode_prim.def_impl
def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr):
def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts):
consts = args[:n_consts]
args = args[n_consts:]

def qfunc(*inner_args):
return jax.core.eval_jaxpr(qfunc_jaxpr.jaxpr, qfunc_jaxpr.consts, *inner_args)
return jax.core.eval_jaxpr(qfunc_jaxpr, consts, *inner_args)

with warnings.catch_warnings():
warnings.filterwarnings(
Expand All @@ -77,8 +81,8 @@ def qfunc(*inner_args):

# pylint: disable=unused-argument
@qnode_prim.def_abstract_eval
def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr):
mps = qfunc_jaxpr.out_avals
def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts):
mps = qfunc_jaxpr.outvars
return _get_shapes_for(*mps, shots=shots, num_device_wires=len(device.wires))

return qnode_prim
Expand Down Expand Up @@ -177,11 +181,13 @@ def f(x):
qnode_prim = _get_qnode_prim()

res = qnode_prim.bind(
*qfunc_jaxpr.consts,
*args,
shots=shots,
qnode=qnode,
device=qnode.device,
qnode_kwargs=qnode_kwargs,
qfunc_jaxpr=qfunc_jaxpr,
qfunc_jaxpr=qfunc_jaxpr.jaxpr,
n_consts=len(qfunc_jaxpr.consts),
)
return res[0] if len(res) == 1 else res
18 changes: 18 additions & 0 deletions tests/capture/test_capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,21 @@ def circuit():
"postselect_mode": None,
}
assert jaxpr.eqns[0].params["qnode_kwargs"] == expected


def test_qnode_closure_variables():
"""Test that qnode can capture closure variables and consts."""

a = jax.numpy.array(2.0)

@qml.qnode(qml.device("default.qubit", wires=2))
def circuit(w):
qml.RX(a, w)
return qml.expval(qml.Z(0))

jaxpr = jax.make_jaxpr(circuit)(1)
assert len(jaxpr.eqns[0].invars) == 2 # one closure variable, one arg
assert jaxpr.eqns[0].params["n_consts"] == 1

out = jax.core.eval_jaxpr(jaxpr.jaxpr, [jax.numpy.array(0.5)], 0)
assert qml.math.allclose(out, jax.numpy.cos(0.5))

0 comments on commit f9adf90

Please sign in to comment.