From 45558334cf3c3ba8b40bc0e5c8b0bfc09300a54a Mon Sep 17 00:00:00 2001 From: lillian542 <38584660+lillian542@users.noreply.github.com> Date: Mon, 16 Dec 2024 15:16:25 -0500 Subject: [PATCH] Fix bug in autograph implementation for ctrl and adjoint (#6685) **Context:** The checks in the existing implementation of autograph have handling for conversion if a _function_ is passed to `qml.ctrl` or `qml.adjoint`. Unfortunately, these checks were breaking the case where a fully initialized operator was passed to `qml.ctrl` or `qml.adjoint`, because they expect the check to always receive a callable. **Description of the Change:** We have autograph track the `_capture_adjoint_transform` and `_capture_ctrl_transform` functions instead, since at that point the function really does always receive a callable. [sc-80037] --------- Co-authored-by: andrijapau Co-authored-by: Christina Lee Co-authored-by: Mudit Pandey Co-authored-by: Pietropaolo Frisoni Co-authored-by: Yushao Chen (Jerry) --- doc/development/autograph.rst | 1 + doc/releases/changelog-dev.md | 1 + pennylane/capture/autograph/ag_primitives.py | 9 ++++-- pennylane/capture/make_plxpr.py | 5 +++ tests/capture/autograph/test_autograph.py | 33 ++++++++++++++++++-- 5 files changed, 45 insertions(+), 4 deletions(-) diff --git a/doc/development/autograph.rst b/doc/development/autograph.rst index ddbf8e5d7a0..9fc376aff0f 100644 --- a/doc/development/autograph.rst +++ b/doc/development/autograph.rst @@ -470,6 +470,7 @@ AutoGraph: >>> eval_jaxpr(plxpr.jaxpr, plxpr.consts, 0.1) [Array(9., dtype=float64, weak_type=True)] + Indexing within a loop ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 2ba70266f8c..1e90b2392b1 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -59,6 +59,7 @@ [(#6413)](https://github.com/PennyLaneAI/pennylane/pull/6413) [(#6426)](https://github.com/PennyLaneAI/pennylane/pull/6426) [(#6645)](https://github.com/PennyLaneAI/pennylane/pull/6645) + [(#6685)](https://github.com/PennyLaneAI/pennylane/pull/6685) * New `qml.GQSP` template has been added to perform Generalized Quantum Signal Processing (GQSP). The functionality `qml.poly_to_angles` has been also extended to support GQSP. diff --git a/pennylane/capture/autograph/ag_primitives.py b/pennylane/capture/autograph/ag_primitives.py index 215231d6442..a5b00266c0b 100644 --- a/pennylane/capture/autograph/ag_primitives.py +++ b/pennylane/capture/autograph/ag_primitives.py @@ -380,10 +380,15 @@ def converted_call(fn, args, kwargs, caller_fn_scope=None, options=None): (ag_config, "CONVERSION_RULES", module_allowlist), (ag_py_builtins, "BUILTIN_FUNCTIONS_MAP", py_builtins_map), ): + # Using qml.ops.op_math.adjoint points to the adjoint function + # and importing this at the top of the file creates circular imports + # pylint: disable=import-outside-toplevel, protected-access + from pennylane.ops.op_math.adjoint import _capture_adjoint_transform + # HOTFIX: pass through calls of known PennyLane wrapper functions if fn in ( - qml.adjoint, - qml.ctrl, + _capture_adjoint_transform, + qml.ops.op_math.controlled._capture_ctrl_transform, qml.grad, qml.jacobian, qml.vjp, diff --git a/pennylane/capture/make_plxpr.py b/pennylane/capture/make_plxpr.py index f531c03a1a6..23aeae92a7a 100644 --- a/pennylane/capture/make_plxpr.py +++ b/pennylane/capture/make_plxpr.py @@ -98,6 +98,11 @@ def circ(x): module in TensorFlow (`official documentation `_ ). + .. note:: + + There are some limitations and sharp bits regarding AutoGraph; to better understand + supported behaviour and limitations, see https://docs.pennylane.ai/en/stable/development/autograph.html + On its own, capture of standard Python control flow is not supported: .. code-block:: python diff --git a/tests/capture/autograph/test_autograph.py b/tests/capture/autograph/test_autograph.py index 6e6f9116209..1665eb3fec0 100644 --- a/tests/capture/autograph/test_autograph.py +++ b/tests/capture/autograph/test_autograph.py @@ -258,7 +258,33 @@ def fn(x: float): assert check_cache(inner1.func) assert check_cache(inner2.func) - @pytest.mark.xfail(raises=NotImplementedError) + def test_adjoint_op(self): + """Test that the adjoint of an operator successfully passes through autograph""" + + @qml.qnode(qml.device("default.qubit", wires=2)) + def circ(): + qml.adjoint(qml.X(0)) + return qml.expval(qml.Z(0)) + + plxpr = qml.capture.make_plxpr(circ, autograph=True)() + assert jax.core.eval_jaxpr(plxpr.jaxpr, plxpr.consts)[0] == -1 + + def test_ctrl_op(self): + """Test that the adjoint of an operator successfully passes through autograph without raising an error""" + + @qml.qnode(qml.device("default.qubit", wires=2)) + def circ(): + qml.X(1) + qml.ctrl(qml.X(0), 1) + return qml.expval(qml.Z(0)) + + plxpr = qml.capture.make_plxpr(circ, autograph=True)() + assert jax.core.eval_jaxpr(plxpr.jaxpr, plxpr.consts)[0] == -1 + + @pytest.mark.xfail( + raises=NotImplementedError, + reason="adjoint_transform_prim not implemented on DefaultQubitInterpreter", + ) def test_adjoint_wrapper(self): """Test conversion is happening successfully on functions wrapped with 'adjoint'.""" @@ -279,7 +305,10 @@ def circ(x: float): assert check_cache(circ.func) assert check_cache(inner) - @pytest.mark.xfail(raises=NotImplementedError) + @pytest.mark.xfail( + raises=NotImplementedError, + reason="ctrl_transform_prim not implemented on DefaultQubitInterpreter", + ) def test_ctrl_wrapper(self): """Test conversion is happening successfully on functions wrapped with 'ctrl'."""