Skip to content

Commit

Permalink
Add construct_batch and transform_program functions to workflow m…
Browse files Browse the repository at this point in the history
…odule (#5084)

See #5058  for a prototype of the end state.

This PR adds two functions `qml.workflow.construct_batch` and
`qml.workflow.transform_program`.

`transform_program(qnode, level)` takes a qnode and a "level" and
returns a transform program. The level keyword argument will be used
through `construct_batch`, `draw`, `draw_mpl`, and `specs` and indicates
a selection from the full transform program.

`construct_batch(qnode, level)(*args, **kwargs)` takes a qnode and a
level and returns a callable with the same signature as the qnode. It
then applies the transform program corresponding to `level`.


Additional minor changes and helpers:

* `qml.transforms.core.expand_fn_transform`: In order to place
`qml.Device.expand_fn` into the transform program easily, I added a
quick function to convert from a tape->tape function to a transform.

* `TransformContainer.__repr__`: This just makes my life easier when
working with transforms.

* Default the `is_informative` keyword argument to `False` instead of
`None` when creating a transform. This way a transform created with
`TransformContainer` has the same default value of `is_informative`.
This was causing me headaches when comparing transforms in testing.

* Slicing into a `TransformProgram` with a slice object returns another
`TransformProgram` instead of a list. Ex `prog[0:4]`

* Defined `__contains__` for `TransformProgram`. This just made my life
easier when playing around and testing things out. example: `qml.compile
in my_program`.

* Added some documentation on defined dunder methods in
`TransformProgram`

---------

Co-authored-by: Thomas R. Bromley <[email protected]>
Co-authored-by: Matthew Silverman <[email protected]>
  • Loading branch information
3 people authored Jan 31, 2024
1 parent f51e695 commit b5f302c
Show file tree
Hide file tree
Showing 18 changed files with 1,245 additions and 132 deletions.
Binary file added doc/_static/transforms_order.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
251 changes: 251 additions & 0 deletions doc/_static/transforms_order.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@
* Raise a more informative error when calling `adjoint_jacobian` with trainable state-prep operations.
[(#5026)](https://github.com/PennyLaneAI/pennylane/pull/5026)

* Adds `qml.workflow.get_transform_program` and `qml.workflow.construct_batch` to inspect the transform program and batch of tapes
at different stages.
[(#5084)](https://github.com/PennyLaneAI/pennylane/pull/5084)

* `CRX`, `CRY`, `CRZ`, `CROT`, and `ControlledPhaseShift` (i.e. `CPhaseShift`) now inherit from `ControlledOp`, giving them additional properties such as `control_wire` and `control_values`. Calling `qml.ctrl` on `RX`, `RY`, `RZ`, `Rot`, and `PhaseShift` with a single control wire will return gates of types `CRX`, `CRY`, etc. as opposed to a general `Controlled` operator.
[(#5069)](https://github.com/PennyLaneAI/pennylane/pull/5069)

Expand Down
3 changes: 1 addition & 2 deletions pennylane/_qubit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

import pennylane as qml
from pennylane import Device, DeviceError
from pennylane.workflow import set_shots
from pennylane.math import multiply as qmlmul
from pennylane.math import sum as qmlsum
from pennylane.measurements import (
Expand Down Expand Up @@ -1036,7 +1035,7 @@ def classical_shadow(self, obs, circuit):
n_snapshots = self.shots
seed = obs.seed

with set_shots(self, shots=1):
with qml.workflow.set_shots(self, shots=1):
# slow implementation but works for all devices
n_qubits = len(wires)
mapped_wires = np.array(self.map_wires(wires))
Expand Down
18 changes: 10 additions & 8 deletions pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,16 @@ def adjoint_state_measurements(
)


def adjoint_ops(op: qml.operation.Operator) -> bool:
"""Specify whether or not an Operator is supported by adjoint differentiation."""
return op.num_params == 0 or (op.num_params == 1 and op.has_generator)


def adjoint_observables(obs: qml.operation.Operator) -> bool:
"""Specifies whether or not an observable is compatible with adjoint differentiation on DefaultQubit."""
return obs.has_matrix


def _add_adjoint_transforms(program: TransformProgram, device_vjp=False) -> None:
"""Private helper function for ``preprocess`` that adds the transforms specific
for adjoint differentiation.
Expand All @@ -171,14 +181,6 @@ def _add_adjoint_transforms(program: TransformProgram, device_vjp=False) -> None
"""

def adjoint_ops(op: qml.operation.Operator) -> bool:
"""Specify whether or not an Operator is supported by adjoint differentiation."""
return op.num_params == 0 or op.num_params == 1 and op.has_generator

def adjoint_observables(obs: qml.operation.Operator) -> bool:
"""Specifies whether or not an observable is compatible with adjoint differentiation on DefaultQubit."""
return obs.has_matrix

name = "adjoint + default.qubit"
program.add_transform(no_sampling, name=name)
program.add_transform(
Expand Down
3 changes: 3 additions & 0 deletions pennylane/templates/subroutines/permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ def circuit():
"""

def __repr__(self):
return f"Permute({self.hyperparameters['permutation']}, wires={self.wires.tolist()})"

num_wires = AnyWires
grad_method = None

Expand Down
19 changes: 9 additions & 10 deletions pennylane/transforms/core/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def transform(
quantum_transform,
expand_transform=None,
classical_cotransform=None,
is_informative=None,
is_informative=False,
final_transform=False,
):
"""Generalizes a function that transforms tapes to work with additional circuit-like objects such as a
Expand All @@ -45,14 +45,15 @@ def transform(
* The transform must have the following structure (type hinting is optional): ``my_quantum_transform(tape:
qml.tape.QuantumTape, ...) -> ( Sequence[qml.tape.QuantumTape], Callable)``
expand_transform (Callable): An optional expand transform is applied directly before the input
Keyword Args:
expand_transform=None (Optional[Callable]): An optional expand transform is applied directly before the input
quantum transform. It must be a function that satisfies the same requirements as
``quantum_transform``.
classical_cotransform (Callable): A classical co-transform is a function to post-process the classical
classical_cotransform=None (Optional[Callable]): A classical co-transform is a function to post-process the classical
jacobian and the quantum jacobian and has the signature: ``my_cotransform(qjac, cjac, tape) -> tensor_like``
is_informative (bool): Whether or not a transform is informative. If true the transform is queued at the end
is_informative=False (bool): Whether or not a transform is informative. If true the transform is queued at the end
of the transform program and the tapes or qnode aren't executed.
final_transform (bool): Whether or not the transform is terminal. If true the transform is queued at the end
final_transform=False (bool): Whether or not the transform is terminal. If true the transform is queued at the end
of the transform program. ``is_informative`` supersedes ``final_transform``.
Returns:
Expand Down Expand Up @@ -177,15 +178,13 @@ def qnode_circuit(a):
)

# 3: CHeck the classical co-transform
if classical_cotransform is not None:
if not callable(classical_cotransform):
raise TransformError("The classical co-transform must be a valid Python function.")
if classical_cotransform is not None and not callable(classical_cotransform):
raise TransformError("The classical co-transform must be a valid Python function.")

dispatcher = TransformDispatcher(
return TransformDispatcher(
quantum_transform,
expand_transform=expand_transform,
classical_cotransform=classical_cotransform,
is_informative=is_informative,
final_transform=final_transform,
)
return dispatcher
3 changes: 3 additions & 0 deletions pennylane/transforms/core/transform_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,9 @@ def __init__(
self._is_informative = is_informative
self._final_transform = is_informative or final_transform

def __repr__(self):
return f"<{self._transform.__name__}({self._args}, {self._kwargs})>"

def __iter__(self):
return iter(
(
Expand Down
Loading

0 comments on commit b5f302c

Please sign in to comment.