Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructures preprocess to include building block, extensible transforms #4659

Merged
merged 25 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
0a17255
refactor preprocess transforms
albi3ro Oct 6, 2023
9ba52eb
Apply suggestions from code review
albi3ro Oct 6, 2023
57ce47b
cleaning and tests
albi3ro Oct 10, 2023
18ffd25
lots of moving tests around and fixing them
albi3ro Oct 10, 2023
bf4c569
try and fix conflict
albi3ro Oct 10, 2023
9026cf6
fix commit problem
albi3ro Oct 10, 2023
c07a926
Merge branch 'master' into cleaning-preproces
albi3ro Oct 10, 2023
bfa0c2c
fix defer measurements
albi3ro Oct 10, 2023
5feaf6e
fixing tests
albi3ro Oct 11, 2023
562739e
fix adjoint metric tensor
albi3ro Oct 11, 2023
7c6a891
move static method out of class
albi3ro Oct 11, 2023
aad6b74
fix tests and documentation
albi3ro Oct 11, 2023
fbd4662
more test fixes
albi3ro Oct 11, 2023
c861001
tests and changelog
albi3ro Oct 11, 2023
e85e94d
Merge branch 'master' into cleaning-preproces
albi3ro Oct 11, 2023
b18c48c
fix supports_derivatives
albi3ro Oct 11, 2023
29d6ee6
Update pennylane/devices/preprocess.py
albi3ro Oct 12, 2023
9ebe948
improved docstrings
albi3ro Oct 12, 2023
6780f81
Merge branch 'cleaning-preproces' of https://github.com/PennyLaneAI/p…
albi3ro Oct 12, 2023
f032726
Apply suggestions from code review
albi3ro Oct 13, 2023
6bb5cc6
Update tests/devices/default_qubit/test_default_qubit_preprocessing.py
albi3ro Oct 13, 2023
22c6756
Merge branch 'master' into cleaning-preproces
albi3ro Oct 13, 2023
93692cc
Merge branch 'master' into cleaning-preproces
vincentmr Oct 16, 2023
4182387
Test python 3.9
vincentmr Oct 16, 2023
75d9699
Revert CI constraints.
vincentmr Oct 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@

<h3>Improvements 🛠</h3>

* `pennylane.devices.preprocess` now offers the transforms `decompose`, `validate_observables`, `validate_measurements`,
`validate_device_wires`, `validate_multiprocessing_workers`, `warn_about_trainable_observables`,
and `no_sampling` to assist in the construction of devices under the new `devices.Device` API.
[(#4659)](https://github.com/PennyLaneAI/pennylane/pull/4659)

* `pennylane.defer_measurements` will now exit early if the input does not contain mid circuit measurements.
[(#4659)](https://github.com/PennyLaneAI/pennylane/pull/4659)

* `default.qubit` now tracks the number of equivalent qpu executions and total shots
when the device is sampling. Note that `"simulations"` denotes the number of simulation passes, where as
`"executions"` denotes how many different computational bases need to be sampled in. Additionally, the
Expand Down
30 changes: 30 additions & 0 deletions pennylane/devices/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,36 @@
Device
DefaultQubit

Preprocessing Transforms
------------------------

The ``preprocess`` module offers several transforms that can be used in constructing the :meth:`~.devices.Device.preprocess`
method for devices.

.. currentmodule:: pennylane.devices.preprocess
.. autosummary::
:toctree: api

decompose
validate_observables
validate_measurements
validate_device_wires
validate_multiprocessing_workers
warn_about_trainable_observables
no_sampling

Other transforms that may be relevant to device preprocessing include:

.. currentmodule:: pennylane
.. autosummary::
:toctree: api

defer_measurements
transforms.broadcast_expand
transforms.sum_expand
transforms.split_non_commuting
transforms.hamiltonian_expand

Qubit Simulation Tools
----------------------

Expand Down
180 changes: 164 additions & 16 deletions pennylane/devices/default_qubit.py
trbromley marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
This module contains the next generation successor to default qubit
"""

from dataclasses import replace
from functools import partial
from numbers import Number
from typing import Union, Callable, Tuple, Optional, Sequence
Expand All @@ -28,15 +29,18 @@
from pennylane.transforms.core import TransformProgram

from . import Device
from .execution_config import ExecutionConfig, DefaultExecutionConfig
from .qubit.simulate import simulate, get_final_state, measure_final_state
from .qubit.sampling import get_num_shots_and_executions
from .qubit.preprocess import (
preprocess,
validate_and_expand_adjoint,
from .preprocess import (
decompose,
validate_observables,
validate_measurements,
validate_multiprocessing_workers,
validate_device_wires,
warn_about_trainable_observables,
no_sampling,
)
from .execution_config import ExecutionConfig, DefaultExecutionConfig
from .qubit.simulate import simulate, get_final_state, measure_final_state
from .qubit.sampling import get_num_shots_and_executions
from .qubit.adjoint_jacobian import adjoint_jacobian, adjoint_vjp, adjoint_jvp

Result_or_ResultBatch = Union[Result, ResultBatch]
Expand All @@ -46,6 +50,95 @@
PostprocessingFn = Callable[[ResultBatch], Result_or_ResultBatch]


observables = {
"PauliX",
"PauliY",
"PauliZ",
"Hadamard",
"Hermitian",
"Identity",
"Projector",
"SparseHamiltonian",
"Hamiltonian",
"Sum",
"SProd",
"Prod",
"Exp",
"Evolution",
}


def observable_stopping_condition(obs: qml.operation.Operator) -> bool:
"""Specifies whether or not an observable is accepted by DefaultQubit."""
return obs.name in observables


def stopping_condition(op: qml.operation.Operator) -> bool:
"""Specify whether or not an Operator object is supported by the device."""
if op.name == "QFT" and len(op.wires) >= 6:
return False
if op.name == "GroverOperator" and len(op.wires) >= 13:
return False
if op.name == "Snapshot":
return True
if op.__class__.__name__ == "Pow" and qml.operation.is_trainable(op):
return False

return op.has_matrix
AmintorDusko marked this conversation as resolved.
Show resolved Hide resolved


def accepted_sample_measurement(m: qml.measurements.MeasurementProcess) -> bool:
"""Specifies whether or not a measurement is accepted when sampling."""
return isinstance(
m,
(
qml.measurements.SampleMeasurement,
qml.measurements.ClassicalShadowMP,
qml.measurements.ShadowExpvalMP,
),
)


def _add_adjoint_transforms(program: TransformProgram) -> None:
"""Private helper function for ``preprocess`` that adds the transforms specific
for adjoint differentiation.

Args:
program (TransformProgram): where we will add the adjoint differentiation transforms

Side Effects:
Adds transforms to the input program.

"""

def adjoint_ops(op: qml.operation.Operator) -> bool:
"""Specify whether or not an Operator is supported by adjoint differentiation."""
return op.num_params == 0 or op.num_params == 1 and op.has_generator

def adjoint_observables(obs: qml.operation.Operator) -> bool:
"""Specifies whether or not an observable is compatible with adjoint differentiation on DefaultQubit."""
return obs.has_matrix

def accepted_adjoint_measurement(m: qml.measurements.MeasurementProcess) -> bool:
return isinstance(m, qml.measurements.ExpectationMP)

name = "adjoint + default.qubit"
program.add_transform(no_sampling, name=name)
program.add_transform(
decompose,
stopping_condition=adjoint_ops,
name=name,
)
program.add_transform(validate_observables, adjoint_observables, name=name)
program.add_transform(
validate_measurements,
analytic_measurements=accepted_adjoint_measurement,
name=name,
)
program.add_transform(qml.transforms.broadcast_expand)
program.add_transform(warn_about_trainable_observables)


class DefaultQubit(Device):
"""A PennyLane device written in Python and capable of backpropagation derivatives.

Expand Down Expand Up @@ -246,18 +339,28 @@ def supports_derivatives(
):
return True

if execution_config.gradient_method == "adjoint" and execution_config.use_device_gradient:
if (
execution_config.gradient_method == "adjoint"
and execution_config.use_device_gradient is not False
AmintorDusko marked this conversation as resolved.
Show resolved Hide resolved
):
if circuit is None:
return True

return isinstance(validate_and_expand_adjoint(circuit)[0][0], QuantumScript)
prog = TransformProgram()
_add_adjoint_transforms(prog)

try:
prog((circuit,))
except (qml.operation.DecompositionUndefinedError, qml.DeviceError):
return False
return True

return False

def preprocess(
self,
execution_config: ExecutionConfig = DefaultExecutionConfig,
) -> Tuple[QuantumTapeBatch, PostprocessingFn, ExecutionConfig]:
) -> Tuple[TransformProgram, ExecutionConfig]:
"""This function defines the device transform program to be applied and an updated device configuration.

Args:
Expand All @@ -276,19 +379,64 @@ def preprocess(
* Currently does not intrinsically support parameter broadcasting

"""
config = self._setup_execution_config(execution_config)
transform_program = TransformProgram()
# Validate device wires
transform_program.add_transform(validate_device_wires, self)

transform_program.add_transform(qml.defer_measurements)
transform_program.add_transform(validate_device_wires, self.wires, name=self.name)
transform_program.add_transform(
decompose, stopping_condition=stopping_condition, name=self.name
)
transform_program.add_transform(
validate_measurements, sample_measurements=accepted_sample_measurement, name=self.name
)
transform_program.add_transform(
validate_observables, stopping_condition=observable_stopping_condition, name=self.name
)

# Validate multi processing
max_workers = execution_config.device_options.get("max_workers", self._max_workers)
transform_program.add_transform(validate_multiprocessing_workers, max_workers, self)
max_workers = config.device_options.get("max_workers", self._max_workers)
if max_workers:
transform_program.add_transform(validate_multiprocessing_workers, max_workers, self)

if config.gradient_method == "backprop":
transform_program.add_transform(no_sampling, name="backprop + default.qubit")

if config.gradient_method == "adjoint":
_add_adjoint_transforms(transform_program)

# General preprocessing (Validate measurement, expand, adjoint expand, broadcast expand)
transform_program_preprocess, config = preprocess(execution_config=execution_config)
transform_program = transform_program + transform_program_preprocess
return transform_program, config

def _setup_execution_config(self, execution_config: ExecutionConfig) -> ExecutionConfig:
"""This is a private helper for ``preprocess`` that sets up the execution config.

Args:
execution_config (ExecutionConfig)

Returns:
ExecutionConfig: a preprocessed execution config

"""
updated_values = {}
if execution_config.gradient_method == "best":
updated_values["gradient_method"] = "backprop"
if execution_config.use_device_gradient is None:
updated_values["use_device_gradient"] = execution_config.gradient_method in {
"best",
"adjoint",
"backprop",
}
if execution_config.grad_on_execution is None:
updated_values["grad_on_execution"] = execution_config.gradient_method == "adjoint"
updated_values["device_options"] = dict(execution_config.device_options) # copy
if "max_workers" not in updated_values["device_options"]:
updated_values["device_options"]["max_workers"] = self._max_workers
if "rng" not in updated_values["device_options"]:
updated_values["device_options"]["rng"] = self._rng
if "prng_key" not in updated_values["device_options"]:
updated_values["device_options"]["prng_key"] = self._prng_key
return replace(execution_config, **updated_values)

def execute(
self,
circuits: QuantumTape_or_Batch,
Expand Down
Loading
Loading