-
Notifications
You must be signed in to change notification settings - Fork 616
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create
expand_plxpr_transforms
function for unwrapping transforms n…
…atively in plxpr (#6722) [sc-80562] **Note to reviewers:** Please don't be put off by the number of lines 😅 . Half of the new lines are just moving the `DecomposeInterpreter`, `MapWiresInterpreter`, and `CancelInversesInterpreter` from the `capture.transforms` to the `transforms` module. **Context:** This PR adds a function `qml.capture.expand_plxpr_transforms` that can be used to actually apply transform primitives to plxpr. An example is shown in the comment below of how it "unravels" transforms. This PR also reorganized the transforms program capture. See before for details. **Description of the Change:** * Add a `qml.capture.expand_plxpr_transforms` function which returns a new function that applies all transforms that are present as primitives in the original function. * Added an `ExpandTransformsInterpreter`. This interpreter does not do anything special, but all PL transforms' primitives get automatically registered so that the interpreter has a custom handler for evaluating transform primitives. This custom implementation "unravels" transforms as they are getting evaluated. * Update the `plxpr_transform` argument of `qml.transform`. This argument should now be a function with the following signature: ```python def dummy_plxpr_transform(jaxpr: jax.core.Jaxpr, consts: list, targs: list, tkwargs: dict, *args) -> jax.core.ClosedJaxpr: ... ``` By doing so, we can essentially generalize how plxpr transforms can be implemented. We chose to use interpreters because of their ease of use, but as long as the `plxpr_transform` function follows the "jaxpr to jaxpr" signature, the actual implementation of the transform can be whatever we want it to be 😄 * Reorganized code: * Remove `capture.transforms` submodule. * Move all interpreters to the same files as their respective transforms. * These interpreters are placed inside a function so that we can work around whether or not jax is installed. This is consistent with how we handle the creation of certain classes/functions/primitives in the `qml.capture` submodule. **Benefits:** Plxpr transforms are integrated with the existing transforms API **Possible Drawbacks:** **Related GitHub Issues:** --------- Co-authored-by: Pietropaolo Frisoni <[email protected]>
- Loading branch information
1 parent
e1566e5
commit 5d11c5c
Showing
17 changed files
with
1,117 additions
and
461 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# Copyright 2024 Xanadu Quantum Technologies Inc. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
Helper function for expanding transforms with program capture | ||
""" | ||
from functools import wraps | ||
from typing import Callable | ||
|
||
from .base_interpreter import PlxprInterpreter | ||
|
||
|
||
class ExpandTransformsInterpreter(PlxprInterpreter): | ||
"""Interpreter for expanding transform primitives that are applied to plxpr. | ||
This interpreter does not do anything special by itself. Instead, it is used | ||
by the PennyLane transforms to expand transform primitives in plxpr by | ||
applying the respective transform to the inner plxpr. When a transform is created | ||
using :func:`~pennylane.transform`, a custom primitive interpretation rule for | ||
that transform is automatically registered for ``ExpandTransformsInterpreter``. | ||
""" | ||
|
||
|
||
def expand_plxpr_transforms(f: Callable) -> Callable: | ||
"""Function for applying transforms to plxpr. | ||
Currently, when program capture is enabled, transforms are used as higher-order primitives. | ||
These primitives are present in the program, but their respective transform is not applied | ||
when a transformed function is called. ``expand_plxpr_transforms`` further "transforms" the | ||
input function to apply any transform primitives that are present in the program being run. | ||
**Example** | ||
In the below example, we can see that the ``qml.transforms.cancel_inverses`` transform has been | ||
applied to a function. However, the resulting program representation leaves the | ||
``cancel_inverses`` transform as a primitive without actually transforming the program. | ||
.. code-block:: python | ||
qml.capture.enable() | ||
@qml.transforms.cancel_inverses | ||
def circuit(): | ||
qml.X(0) | ||
qml.S(1) | ||
qml.X(0) | ||
qml.adjoint(qml.S(1)) | ||
return qml.expval(qml.Z(1)) | ||
>>> qml.capture.make_plxpr(circuit)() | ||
{ lambda ; . let | ||
a:AbstractMeasurement(n_wires=None) = cancel_inverses_transform[ | ||
args_slice=slice(0, 0, None) | ||
consts_slice=slice(0, 0, None) | ||
inner_jaxpr={ lambda ; . let | ||
_:AbstractOperator() = PauliX[n_wires=1] 0 | ||
_:AbstractOperator() = S[n_wires=1] 1 | ||
_:AbstractOperator() = PauliX[n_wires=1] 0 | ||
b:AbstractOperator() = S[n_wires=1] 1 | ||
_:AbstractOperator() = Adjoint b | ||
c:AbstractOperator() = PauliZ[n_wires=1] 1 | ||
d:AbstractMeasurement(n_wires=None) = expval_obs c | ||
in (d,) } | ||
targs_slice=slice(0, None, None) | ||
tkwargs={} | ||
] | ||
in (a,) } | ||
To apply the transform, we can use ``expand_plxpr_transforms`` as follows: | ||
>>> transformed_circuit = qml.capture.expand_plxpr_transforms(circuit) | ||
>>> qml.capture.make_plxpr(transformed_circuit)() | ||
{ lambda ; . let | ||
a:AbstractOperator() = PauliZ[n_wires=1] 1 | ||
b:AbstractMeasurement(n_wires=None) = expval_obs a | ||
in (b,) } | ||
As seen, the transform primitive is no longer present, but it has been applied | ||
to the original program, indicated by the inverse operators being cancelled. | ||
Args: | ||
f (Callable): The callable to which any present transforms should be applied. | ||
Returns: | ||
Callable: Callable with transforms applied. | ||
""" | ||
|
||
@wraps(f) | ||
def wrapper(*args, **kwargs): | ||
transformed_f = ExpandTransformsInterpreter()(f) | ||
return transformed_f(*args, **kwargs) | ||
|
||
return wrapper |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.