Skip to content

Commit

Permalink
Create expand_plxpr_transforms function for unwrapping transforms n…
Browse files Browse the repository at this point in the history
…atively in plxpr (#6722)

[sc-80562]

**Note to reviewers:** Please don't be put off by the number of lines 😅
. Half of the new lines are just moving the `DecomposeInterpreter`,
`MapWiresInterpreter`, and `CancelInversesInterpreter` from the
`capture.transforms` to the `transforms` module.

**Context:**
This PR adds a function `qml.capture.expand_plxpr_transforms` that can
be used to actually apply transform primitives to plxpr. An example is
shown in the comment below of how it "unravels" transforms. This PR also
reorganized the transforms program capture. See before for details.

**Description of the Change:**
* Add a `qml.capture.expand_plxpr_transforms` function which returns a
new function that applies all transforms that are present as primitives
in the original function.
* Added an `ExpandTransformsInterpreter`. This interpreter does not do
anything special, but all PL transforms' primitives get automatically
registered so that the interpreter has a custom handler for evaluating
transform primitives. This custom implementation "unravels" transforms
as they are getting evaluated.
* Update the `plxpr_transform` argument of `qml.transform`. This
argument should now be a function with the following signature:
  ```python
def dummy_plxpr_transform(jaxpr: jax.core.Jaxpr, consts: list, targs:
list, tkwargs: dict, *args) -> jax.core.ClosedJaxpr:
      ...
  ```
By doing so, we can essentially generalize how plxpr transforms can be
implemented. We chose to use interpreters because of their ease of use,
but as long as the `plxpr_transform` function follows the "jaxpr to
jaxpr" signature, the actual implementation of the transform can be
whatever we want it to be 😄
* Reorganized code:
  * Remove `capture.transforms` submodule.
* Move all interpreters to the same files as their respective
transforms.
* These interpreters are placed inside a function so that we can work
around whether or not jax is installed. This is consistent with how we
handle the creation of certain classes/functions/primitives in the
`qml.capture` submodule.

**Benefits:**
Plxpr transforms are integrated with the existing transforms API

**Possible Drawbacks:**

**Related GitHub Issues:**

---------

Co-authored-by: Pietropaolo Frisoni <[email protected]>
  • Loading branch information
mudit2812 and PietropaoloFrisoni authored Dec 19, 2024
1 parent e1566e5 commit 5d11c5c
Show file tree
Hide file tree
Showing 17 changed files with 1,117 additions and 461 deletions.
53 changes: 53 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,59 @@ such as `shots`, `rng` and `prng_key`.

<h4>Capturing and representing hybrid programs</h4>

* `qml.transform` now accepts a `plxpr_transform` argument. This argument must be a function that can transform plxpr.
Note that executing a transformed function will currently raise a `NotImplementedError`. To see more details, check
out the documentation of `qml.transform`.
[(#6633)](https://github.com/PennyLaneAI/pennylane/pull/6633)
[(#6722)](https://github.com/PennyLaneAI/pennylane/pull/6722)

* Users can now apply transforms with program capture enabled. Transformed functions cannot be executed by default. To apply
the transforms (and be able to execute the function), it must be decorated with the new `qml.capture.expand_plxpr_transforms`
function, which accepts a callable as input and returns a new function to which all present transforms have been applied.
[(#6722)](https://github.com/PennyLaneAI/pennylane/pull/6722)

```python
from functools import partial
import jax

qml.capture.enable()
wire_map = {0: 3, 1: 6, 2: 9}

@partial(qml.map_wires, wire_map=wire_map)
def circuit(x, y):
qml.RX(x, 0)
qml.CNOT([0, 1])
qml.CRY(y, [1, 2])
return qml.expval(qml.Z(2))
```
```pycon
>>> qml.capture.make_plxpr(circuit)(1.2, 3.4)
{ lambda ; a:f32[] b:f32[]. let
c:AbstractMeasurement(n_wires=None) = _map_wires_transform_transform[
args_slice=slice(0, 2, None)
consts_slice=slice(2, 2, None)
inner_jaxpr={ lambda ; d:f32[] e:f32[]. let
_:AbstractOperator() = RX[n_wires=1] d 0
_:AbstractOperator() = CNOT[n_wires=2] 0 1
_:AbstractOperator() = CRY[n_wires=2] e 1 2
f:AbstractOperator() = PauliZ[n_wires=1] 2
g:AbstractMeasurement(n_wires=None) = expval_obs f
in (g,) }
targs_slice=slice(2, None, None)
tkwargs={'wire_map': {0: 3, 1: 6, 2: 9}, 'queue': False}
] a b
in (c,) }
>>> transformed_circuit = qml.capture.expand_plxpr_transforms(circuit)
>>> jax.make_jaxpr(transformed_circuit)(1.2, 3.4)
{ lambda ; a:f32[] b:f32[]. let
_:AbstractOperator() = RX[n_wires=1] a 3
_:AbstractOperator() = CNOT[n_wires=2] 3 6
_:AbstractOperator() = CRY[n_wires=2] b 6 9
c:AbstractOperator() = PauliZ[n_wires=1] 9
d:AbstractMeasurement(n_wires=None) = expval_obs c
in (d,) }
```

* The `qml.iterative_qpe` function can now be compactly captured into jaxpr.
[(#6680)](https://github.com/PennyLaneAI/pennylane/pull/6680)

Expand Down
10 changes: 10 additions & 0 deletions pennylane/capture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
~create_measurement_obs_primitive
~create_measurement_wires_primitive
~create_measurement_mcm_primitive
~expand_plxpr_transforms
~run_autograph
~make_plxpr
~PlxprInterpreter
Expand Down Expand Up @@ -157,6 +158,8 @@ class MyCustomOp(qml.operation.Operator):
def _(*args, **kwargs):
return type.__call__(MyCustomOp, *args, **kwargs)
"""
from typing import Callable

from .switches import disable, enable, enabled
from .capture_meta import CaptureMeta, ABCCaptureMeta
from .capture_operators import create_operator_primitive
Expand All @@ -175,6 +178,7 @@ def _(*args, **kwargs):
AbstractMeasurement: type
qnode_prim: "jax.core.Primitive"
PlxprInterpreter: type # pylint: disable=redefined-outer-name
expand_plxpr_transforms: Callable[[Callable], Callable] # pylint: disable=redefined-outer-name


# pylint: disable=import-outside-toplevel, redefined-outer-name
Expand All @@ -199,6 +203,11 @@ def __getattr__(key):

return PlxprInterpreter

if key == "expand_plxpr_transforms":
from .expand_transforms import expand_plxpr_transforms

return expand_plxpr_transforms

raise AttributeError(f"module 'pennylane.capture' has no attribute '{key}'")


Expand All @@ -212,6 +221,7 @@ def __getattr__(key):
"create_measurement_obs_primitive",
"create_measurement_wires_primitive",
"create_measurement_mcm_primitive",
"expand_plxpr_transforms",
"AbstractOperator",
"AbstractMeasurement",
"qnode_prim",
Expand Down
103 changes: 103 additions & 0 deletions pennylane/capture/expand_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Helper function for expanding transforms with program capture
"""
from functools import wraps
from typing import Callable

from .base_interpreter import PlxprInterpreter


class ExpandTransformsInterpreter(PlxprInterpreter):
"""Interpreter for expanding transform primitives that are applied to plxpr.
This interpreter does not do anything special by itself. Instead, it is used
by the PennyLane transforms to expand transform primitives in plxpr by
applying the respective transform to the inner plxpr. When a transform is created
using :func:`~pennylane.transform`, a custom primitive interpretation rule for
that transform is automatically registered for ``ExpandTransformsInterpreter``.
"""


def expand_plxpr_transforms(f: Callable) -> Callable:
"""Function for applying transforms to plxpr.
Currently, when program capture is enabled, transforms are used as higher-order primitives.
These primitives are present in the program, but their respective transform is not applied
when a transformed function is called. ``expand_plxpr_transforms`` further "transforms" the
input function to apply any transform primitives that are present in the program being run.
**Example**
In the below example, we can see that the ``qml.transforms.cancel_inverses`` transform has been
applied to a function. However, the resulting program representation leaves the
``cancel_inverses`` transform as a primitive without actually transforming the program.
.. code-block:: python
qml.capture.enable()
@qml.transforms.cancel_inverses
def circuit():
qml.X(0)
qml.S(1)
qml.X(0)
qml.adjoint(qml.S(1))
return qml.expval(qml.Z(1))
>>> qml.capture.make_plxpr(circuit)()
{ lambda ; . let
a:AbstractMeasurement(n_wires=None) = cancel_inverses_transform[
args_slice=slice(0, 0, None)
consts_slice=slice(0, 0, None)
inner_jaxpr={ lambda ; . let
_:AbstractOperator() = PauliX[n_wires=1] 0
_:AbstractOperator() = S[n_wires=1] 1
_:AbstractOperator() = PauliX[n_wires=1] 0
b:AbstractOperator() = S[n_wires=1] 1
_:AbstractOperator() = Adjoint b
c:AbstractOperator() = PauliZ[n_wires=1] 1
d:AbstractMeasurement(n_wires=None) = expval_obs c
in (d,) }
targs_slice=slice(0, None, None)
tkwargs={}
]
in (a,) }
To apply the transform, we can use ``expand_plxpr_transforms`` as follows:
>>> transformed_circuit = qml.capture.expand_plxpr_transforms(circuit)
>>> qml.capture.make_plxpr(transformed_circuit)()
{ lambda ; . let
a:AbstractOperator() = PauliZ[n_wires=1] 1
b:AbstractMeasurement(n_wires=None) = expval_obs a
in (b,) }
As seen, the transform primitive is no longer present, but it has been applied
to the original program, indicated by the inverse operators being cancelled.
Args:
f (Callable): The callable to which any present transforms should be applied.
Returns:
Callable: Callable with transforms applied.
"""

@wraps(f)
def wrapper(*args, **kwargs):
transformed_f = ExpandTransformsInterpreter()(f)
return transformed_f(*args, **kwargs)

return wrapper
25 changes: 0 additions & 25 deletions pennylane/capture/transforms/__init__.py

This file was deleted.

Loading

0 comments on commit 5d11c5c

Please sign in to comment.