Skip to content

Commit

Permalink
[Compiler] Add Catalyst adjoint (#4725)
Browse files Browse the repository at this point in the history
**Description of the Change:**

This PR adds support for QJIT compatible qml.adjoint.

---------

Co-authored-by: Ali Asadi <[email protected]>
Co-authored-by: Josh Izaac <[email protected]>
Co-authored-by: David Ittah <[email protected]>
Co-authored-by: Matthew Silverman <[email protected]>
  • Loading branch information
5 people authored Nov 29, 2023
1 parent 669c86c commit b1f68fb
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 1 deletion.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@
[-4.20735506e-01, 4.20735506e-01]])
```

* `qml.adjoint` can be used with the `qml.qjit` decorator.
[(#4725)](https://github.com/PennyLaneAI/pennylane/pull/4725)

* `qml.ctrl` can be used with the `qml.qjit` decorator.
[(#4726)](https://github.com/PennyLaneAI/pennylane/pull/4726)

Expand Down
2 changes: 2 additions & 0 deletions pennylane/compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
.. autosummary::
:toctree: api
~adjoint
~ctrl
~grad
~jacobian
Expand Down
50 changes: 49 additions & 1 deletion pennylane/ops/op_math/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@
from pennylane.operation import Observable, Operation, Operator
from pennylane.queuing import QueuingManager
from pennylane.tape import make_qscript
from pennylane.compiler import compiler
from pennylane.compiler.compiler import CompileError

from .symbolicop import SymbolicOp


# pylint: disable=no-member
def adjoint(fn, lazy=True):
"""Create the adjoint of an Operator or a function that applies the adjoint of the provided function.
:func:`~.qjit` compatible.
Args:
fn (function or :class:`~.operation.Operator`): A single operator or a quantum function that
Expand All @@ -36,6 +39,7 @@ def adjoint(fn, lazy=True):
Keyword Args:
lazy=True (bool): If the transform is behaving lazily, all operations are wrapped in a ``Adjoint`` class
and handled later. If ``lazy=False``, operation-specific adjoint decompositions are first attempted.
Setting ``lazy=False`` is not supported when used with :func:`~.qjit`.
Returns:
(function or :class:`~.operation.Operator`): If an Operator is provided, returns an Operator that is the adjoint.
Expand All @@ -44,7 +48,17 @@ def adjoint(fn, lazy=True):
.. note::
The adjoint and inverse are identical for unitary gates, but not in general. For example, quantum channels and observables may have different adjoint and inverse operators.
The adjoint and inverse are identical for unitary gates, but not in general. For example, quantum channels and
observables may have different adjoint and inverse operators.
.. note::
When used with :func:`~.qjit`, this function only supports the Catalyst compiler.
See :func:`catalyst.adjoint` for more details.
Please see the Catalyst :doc:`quickstart guide <catalyst:dev/quick_start>`,
as well as the :doc:`sharp bits and debugging tips <catalyst:dev/sharp_bits>`
page for an overview of the differences between Catalyst and PennyLane.
.. note::
Expand Down Expand Up @@ -101,6 +115,34 @@ def circuit(a):
>>> print(qml.draw(circuit)(0.2))
0: ──RX(0.20)──SX──SX†──RX(0.20)†─┤ <Z>
**Example with compiler**
The adjoint used in a compilation context can be applied on control flow.
.. code-block:: python
dev = qml.device("lightning.qubit", wires=1)
@qml.qjit
@qml.qnode(dev)
def workflow(theta, n, wires):
def func():
@qml.for_loop(0, n, 1)
def loop_fn(i):
qml.RX(theta, wires=wires)
loop_fn()
qml.adjoint(func)()
return qml.probs()
>>> workflow(jnp.pi/2, 3, 0)
[1.00000000e+00 7.39557099e-32]
.. warning::
The Catalyst adjoint function does not support performing the adjoint
of quantum functions that contain mid-circuit measurements.
.. details::
:title: Lazy Evaluation
Expand All @@ -117,6 +159,12 @@ def circuit(a):
Adjoint(S)(wires=[0])
"""
if active_jit := compiler.active_compiler():
if lazy is False:
raise CompileError("Setting lazy=False is not supported with qjit.")
available_eps = compiler.AvailableCompilers.names_entrypoints
ops_loader = available_eps[active_jit]["ops"].load()
return ops_loader.adjoint(fn)
if isinstance(fn, Operator):
return Adjoint(fn) if lazy else _single_op_eager(fn, update_queue=True)
if not callable(fn):
Expand Down
40 changes: 40 additions & 0 deletions tests/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=import-outside-toplevel
import pytest
import pennylane as qml
from pennylane.compiler.compiler import CompileError

from pennylane import numpy as np

Expand Down Expand Up @@ -209,6 +210,45 @@ def circuit(x: float):
result_header = "func.func private @circuit(%arg0: tensor<f64>) -> tensor<f64>"
assert result_header in mlir_str

def test_qjit_adjoint(self):
"""Test JIT compilation with adjoint support"""
dev = qml.device("lightning.qubit", wires=2)

@qml.qjit
@qml.qnode(device=dev)
def workflow_cl(theta, wires):
def func():
qml.RX(theta, wires=wires)

qml.adjoint(func)()
return qml.probs()

@qml.qnode(device=dev)
def workflow_pl(theta, wires):
def func():
qml.RX(theta, wires=wires)

qml.adjoint(func)()
return qml.probs()

assert jnp.allclose(workflow_cl(0.1, [1]), workflow_pl(0.1, [1]))

def test_qjit_adjoint_lazy(self):
"""Test that Lazy kwarg is not supported."""
dev = qml.device("lightning.qubit", wires=2)

@qml.qjit
@qml.qnode(device=dev)
def workflow(theta, wires):
def func():
qml.RX(theta, wires=wires)

qml.adjoint(func, lazy=False)()
return qml.probs()

with pytest.raises(CompileError, match="Setting lazy=False is not supported with qjit."):
workflow(0.1, [1])

def test_control(self):
"""Test that control works with qjit."""
dev = qml.device("lightning.qubit", wires=2)
Expand Down

0 comments on commit b1f68fb

Please sign in to comment.