Skip to content

Commit

Permalink
Fix bug in autograph implementation for ctrl and adjoint (#6685)
Browse files Browse the repository at this point in the history
**Context:**

The checks in the existing implementation of autograph have handling for
conversion if a _function_ is passed to `qml.ctrl` or `qml.adjoint`.

Unfortunately, these checks were breaking the case where a fully
initialized operator was passed to `qml.ctrl` or `qml.adjoint`, because
they expect the check to always receive a callable.

**Description of the Change:**
We have autograph track the `_capture_adjoint_transform` and
`_capture_ctrl_transform` functions instead, since at that point the
function really does always receive a callable.

[sc-80037]

---------

Co-authored-by: andrijapau <[email protected]>
Co-authored-by: Christina Lee <[email protected]>
Co-authored-by: Mudit Pandey <[email protected]>
Co-authored-by: Pietropaolo Frisoni <[email protected]>
Co-authored-by: Yushao Chen (Jerry) <[email protected]>
  • Loading branch information
6 people authored Dec 16, 2024
1 parent cffa08a commit 4555833
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 4 deletions.
1 change: 1 addition & 0 deletions doc/development/autograph.rst
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ AutoGraph:
>>> eval_jaxpr(plxpr.jaxpr, plxpr.consts, 0.1)
[Array(9., dtype=float64, weak_type=True)]


Indexing within a loop
~~~~~~~~~~~~~~~~~~~~~~

Expand Down
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
[(#6413)](https://github.com/PennyLaneAI/pennylane/pull/6413)
[(#6426)](https://github.com/PennyLaneAI/pennylane/pull/6426)
[(#6645)](https://github.com/PennyLaneAI/pennylane/pull/6645)
[(#6685)](https://github.com/PennyLaneAI/pennylane/pull/6685)

* New `qml.GQSP` template has been added to perform Generalized Quantum Signal Processing (GQSP).
The functionality `qml.poly_to_angles` has been also extended to support GQSP.
Expand Down
9 changes: 7 additions & 2 deletions pennylane/capture/autograph/ag_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,15 @@ def converted_call(fn, args, kwargs, caller_fn_scope=None, options=None):
(ag_config, "CONVERSION_RULES", module_allowlist),
(ag_py_builtins, "BUILTIN_FUNCTIONS_MAP", py_builtins_map),
):
# Using qml.ops.op_math.adjoint points to the adjoint function
# and importing this at the top of the file creates circular imports
# pylint: disable=import-outside-toplevel, protected-access
from pennylane.ops.op_math.adjoint import _capture_adjoint_transform

# HOTFIX: pass through calls of known PennyLane wrapper functions
if fn in (
qml.adjoint,
qml.ctrl,
_capture_adjoint_transform,
qml.ops.op_math.controlled._capture_ctrl_transform,
qml.grad,
qml.jacobian,
qml.vjp,
Expand Down
5 changes: 5 additions & 0 deletions pennylane/capture/make_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ def circ(x):
module in TensorFlow (`official documentation <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/index.md>`_
).
.. note::
There are some limitations and sharp bits regarding AutoGraph; to better understand
supported behaviour and limitations, see https://docs.pennylane.ai/en/stable/development/autograph.html
On its own, capture of standard Python control flow is not supported:
.. code-block:: python
Expand Down
33 changes: 31 additions & 2 deletions tests/capture/autograph/test_autograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,33 @@ def fn(x: float):
assert check_cache(inner1.func)
assert check_cache(inner2.func)

@pytest.mark.xfail(raises=NotImplementedError)
def test_adjoint_op(self):
"""Test that the adjoint of an operator successfully passes through autograph"""

@qml.qnode(qml.device("default.qubit", wires=2))
def circ():
qml.adjoint(qml.X(0))
return qml.expval(qml.Z(0))

plxpr = qml.capture.make_plxpr(circ, autograph=True)()
assert jax.core.eval_jaxpr(plxpr.jaxpr, plxpr.consts)[0] == -1

def test_ctrl_op(self):
"""Test that the adjoint of an operator successfully passes through autograph without raising an error"""

@qml.qnode(qml.device("default.qubit", wires=2))
def circ():
qml.X(1)
qml.ctrl(qml.X(0), 1)
return qml.expval(qml.Z(0))

plxpr = qml.capture.make_plxpr(circ, autograph=True)()
assert jax.core.eval_jaxpr(plxpr.jaxpr, plxpr.consts)[0] == -1

@pytest.mark.xfail(
raises=NotImplementedError,
reason="adjoint_transform_prim not implemented on DefaultQubitInterpreter",
)
def test_adjoint_wrapper(self):
"""Test conversion is happening successfully on functions wrapped with 'adjoint'."""

Expand All @@ -279,7 +305,10 @@ def circ(x: float):
assert check_cache(circ.func)
assert check_cache(inner)

@pytest.mark.xfail(raises=NotImplementedError)
@pytest.mark.xfail(
raises=NotImplementedError,
reason="ctrl_transform_prim not implemented on DefaultQubitInterpreter",
)
def test_ctrl_wrapper(self):
"""Test conversion is happening successfully on functions wrapped with 'ctrl'."""

Expand Down

0 comments on commit 4555833

Please sign in to comment.