Skip to content

Commit

Permalink
Default preprocessing transforms handles diagonalization of measureme…
Browse files Browse the repository at this point in the history
…nts (#6653)

The original intention was to include it in
#6632 but we decided to
exclude it for now due to the following complications (and potential
solutions):

- The diagonalize_measurements transform assumes that
split_non_commuting has been applied (wouldn't make sense otherwise).
What if the device requires diagonalization of some observables but
supports commuting measurements?
- We consider this enough of an edge case that a device like that should
just implement its own preprocessing transform program.
- The diagonalize_measurements transform raises an error when it sees
anything that is not one of PauliX, PauliY, PauliZ, Hadamard, or a
linear combination of the four. For example, if a device supports
Hermitian but not Hadamard, it would expect that Hermitian is allowed
but Hadamard is diagonalized.
- We might consider changing diagonalize_measurements to simply leave
unrecognized observables as is instead of raising an error. The
unsupported observables that are not diagonalized will remain in the
circuit past the diagonalize_measurements transform but caught later in
the transform program by validate_observables.
- Diagonalization produces additional gates that the device may not
support.
- decompose should be applied after diagonalization. Hopefully with
restructured decompositions, diagonalizing gates can always be mapped to
the device native gate set.

[sc-79422]

---------

Co-authored-by: Christina Lee <[email protected]>
  • Loading branch information
astralcai and albi3ro authored Dec 5, 2024
1 parent 349dc75 commit 646ea39
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 42 deletions.
6 changes: 6 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
are added to the device API for devices that provides a TOML configuration file and thus have
a `capabilities` property.
[(#6632)](https://github.com/PennyLaneAI/pennylane/pull/6632)
[(#6653)](https://github.com/PennyLaneAI/pennylane/pull/6653)

<h4>New `labs` module `dla` for handling dynamical Lie algebras (DLAs)</h4>

Expand Down Expand Up @@ -154,6 +155,11 @@ featuring a `simulate` function for simulating mixed states in analytic mode.
* Added functions and dunder methods to add and multiply Resources objects in series and in parallel.
[(#6567)](https://github.com/PennyLaneAI/pennylane/pull/6567)

* The `diagonalize_measurements` transform no longer raises an error for unknown observables. Instead,
they are left undiagonalized, with the expectation that observable validation will catch any undiagonalized
observables that are also unsupported by the device.
[(#6653)](https://github.com/PennyLaneAI/pennylane/pull/6653)

<h4>Capturing and representing hybrid programs</h4>

* PennyLane transforms can now be captured as primitives with experimental program capture enabled.
Expand Down
42 changes: 34 additions & 8 deletions pennylane/devices/device_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,12 +514,30 @@ def execute_fn(tapes):
capabilities_analytic = self.capabilities.filter(finite_shots=False)
capabilities_shots = self.capabilities.filter(finite_shots=True)

program.add_transform(
decompose,
stopping_condition=capabilities_analytic.supports_operation,
stopping_condition_shots=capabilities_shots.supports_operation,
name=self.name,
)
needs_diagonalization = False
base_obs = {"PauliZ": qml.Z, "PauliX": qml.X, "PauliY": qml.Y, "Hadamard": qml.H}
if (
not all(obs in self.capabilities.observables for obs in base_obs)
# This check is to confirm that `split_non_commuting` has been applied, since
# `diagonalize_measurements` does not work with non-commuting measurements. If
# a device is flexible enough to support non-commuting observables but for some
# reason does not support all of `PauliZ`, `PauliX`, `PauliY`, and `Hadamard`,
# we consider it enough of an edge case that the device should just implement
# its own preprocessing transform.
and not self.capabilities.non_commuting_observables
):
needs_diagonalization = True
else:
# If the circuit does not need diagonalization, we decompose the circuit before
# potentially applying `split_non_commuting` that produces multiple tapes with
# duplicated operations. Otherwise, `decompose` has to be applied last because
# `diagonalize_measurements` may add additional gates that are not supported.
program.add_transform(
decompose,
stopping_condition=capabilities_analytic.supports_operation,
stopping_condition_shots=capabilities_shots.supports_operation,
name=self.name,
)

if not self.capabilities.overlapping_observables:
program.add_transform(qml.transforms.split_non_commuting, grouping_strategy="wires")
Expand All @@ -528,8 +546,16 @@ def execute_fn(tapes):
elif not self.capabilities.supports_observable("Sum"):
program.add_transform(qml.transforms.split_to_single_terms)

# TODO: diagonalization should be part of the default transform program, but we decided
# not to include it in this PR due to complications. See sc-79422
if needs_diagonalization:
obs_names = base_obs.keys() & self.capabilities.observables.keys()
obs = {base_obs[obs] for obs in obs_names}
program.add_transform(qml.transforms.diagonalize_measurements, supported_base_obs=obs)
program.add_transform(
decompose,
stopping_condition=lambda o: capabilities_analytic.supports_operation(o.name),
stopping_condition_shots=lambda o: capabilities_shots.supports_operation(o.name),
name=self.name,
)

program.add_transform(qml.transforms.broadcast_expand)

Expand Down
70 changes: 44 additions & 26 deletions pennylane/transforms/diagonalize_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from pennylane.transforms.core import transform

# pylint: disable=protected-access
# pylint: disable=protected-access,unused-argument

_default_supported_obs = (qml.Z, qml.Identity)

Expand Down Expand Up @@ -56,10 +56,17 @@ def diagonalize_measurements(tape, supported_base_obs=_default_supported_obs, to
qnode (QNode) or tuple[List[QuantumScript], function]: The transformed circuit as described in :func:`qml.transform <pennylane.transform>`.
.. note::
This transform will raise an error if it encounters non-commuting terms. To avoid non-commuting terms in
circuit measurements, the :func:`split_non_commuting <pennylane.transforms.split_non_commuting>` transform
can be applied.
An error will be raised if non-commuting terms are encountered. To avoid non-commuting
terms in circuit measurements, the :func:`split_non_commuting <pennylane.transforms.split_non_commuting>`
transform can be applied.
This transform will diagonalize what it can, i.e., ``qml.X``, ``qml.Y``, ``qml.Z``,
``qml.Hadamard``, ``qml.Identity``, or a linear combination of them. Any unrecognized
observable will not raise an error, deferring to the device's validation for supported
measurements later on. Lastly, if ``diagonalize_measurements`` produces additional gates
that the device does not support, the :func:`~pennylane.devices.preprocess.decompose`
transform should be applied to ensure that the additional gates are decomposed to those
that the device supports.
**Examples:**
Expand Down Expand Up @@ -100,8 +107,21 @@ def circuit(x):
.. details::
:title: Usage Details
The transform diagonalizes observables from the local Pauli basis only, i.e. it diagonalizes X, Y, Z,
and Hadamard.
The transform diagonalizes observables from the local Pauli basis only, i.e. it diagonalizes
X, Y, Z, and Hadamard. Any other observable will be unaffected:
.. code-block:: python3
measurements = [
qml.expval(qml.X(0) + qml.Hermitian([[1, 0], [0, 1]], wires=[1]))
]
tape = qml.tape.QuantumScript(measurements=measurements)
tapes, processsing_fn = diagnalize_measurements(tape)
>>> tapes[0].operations
[H(0)]
>>> tapes[0].measurements
[expval(Z(0) + Hermitian(array([[1, 0], [0, 1]]), wires=[1]))]
The transform can also diagonalize only a subset of these operators. By default, the only
supported base observable is Z. What if a backend device can handle
Expand Down Expand Up @@ -280,9 +300,7 @@ def _change_obs_to_Z(observable):

@_change_obs_to_Z.register
def _change_symbolic_op(observable: SymbolicOp):
diagonalizing_gates, [new_base] = diagonalize_qwc_pauli_words(
[observable.base],
)
diagonalizing_gates, [new_base] = diagonalize_qwc_pauli_words([observable.base])

params, hyperparams = observable.parameters, observable.hyperparameters
hyperparams = copy(hyperparams)
Expand All @@ -297,9 +315,7 @@ def _change_symbolic_op(observable: SymbolicOp):
def _change_linear_combination(observable: LinearCombination):
coeffs, obs = observable.terms()

diagonalizing_gates, new_operands = diagonalize_qwc_pauli_words(
obs,
)
diagonalizing_gates, new_operands = diagonalize_qwc_pauli_words(obs)

new_observable = LinearCombination(coeffs, new_operands)

Expand All @@ -308,9 +324,7 @@ def _change_linear_combination(observable: LinearCombination):

@_change_obs_to_Z.register
def _change_composite_op(observable: CompositeOp):
diagonalizing_gates, new_operands = diagonalize_qwc_pauli_words(
observable.operands,
)
diagonalizing_gates, new_operands = diagonalize_qwc_pauli_words(observable.operands)

new_observable = observable.__class__(*new_operands)

Expand Down Expand Up @@ -385,7 +399,7 @@ def _diagonalize_observable(
_visited_obs = (set(), set())

if not isinstance(observable, (qml.X, qml.Y, qml.Z, qml.Hadamard, qml.Identity)):
return _diagonalize_compound_observable(
return _diagonalize_non_basic_observable(
observable, _visited_obs, supported_base_obs=supported_base_obs
)

Expand Down Expand Up @@ -419,19 +433,23 @@ def _get_obs_and_gates(obs_list, _visited_obs, supported_base_obs=_default_suppo


@singledispatch
def _diagonalize_compound_observable(
def _diagonalize_non_basic_observable(
observable, _visited_obs, supported_base_obs=_default_supported_obs
):
"""Takes an observable consisting of multiple other observables, and changes all
"""Takes an observable other than X, Y, Z, H, and I, and diagonalize it.
For composite observables consisting of multiple other observables, it changes all
unsupported obs to the measurement basis. Applies diagonalizing gates if changing
the basis of an observable whose diagonalizing gates have not already been applied."""
the basis of an observable whose diagonalizing gates have not already been applied.
For other observables, simply skips and returns the observable as is.
raise NotImplementedError(
f"Unable to convert observable of type {type(observable)} to the measurement basis"
)
"""
_visited_obs[0].add(observable)
_visited_obs[1].add(observable.wires[0])
return [], observable, _visited_obs


@_diagonalize_compound_observable.register
@_diagonalize_non_basic_observable.register
def _diagonalize_symbolic_op(
observable: SymbolicOp, _visited_obs, supported_base_obs=_default_supported_obs
):
Expand All @@ -448,7 +466,7 @@ def _diagonalize_symbolic_op(
return diagonalizing_gates, new_observable, _visited_obs


@_diagonalize_compound_observable.register
@_diagonalize_non_basic_observable.register
def _diagonalize_linear_combination(
observable: LinearCombination, _visited_obs, supported_base_obs=_default_supported_obs
):
Expand All @@ -464,7 +482,7 @@ def _diagonalize_linear_combination(
return diagonalizing_gates, new_observable, _visited_obs


@_diagonalize_compound_observable.register
@_diagonalize_non_basic_observable.register
def _diagonalize_composite_op(
observable: CompositeOp, _visited_obs, supported_base_obs=_default_supported_obs
):
Expand Down
58 changes: 56 additions & 2 deletions tests/devices/test_device_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,10 @@ def execute(
with pytest.raises(qml.DeviceError, match=r"Measurement var\(Z\(0\)\) not accepted"):
_, __ = program((invalid_tape,))

invalid_tape = QuantumScript([], [qml.expval(qml.PauliX(0))], shots=shots)
with pytest.raises(qml.DeviceError, match=r"Observable X\(0\) not supported"):
invalid_tape = QuantumScript(
[], [qml.expval(qml.Hermitian([[1.0, 0], [0, 1.0]], 0))], shots=shots
)
with pytest.raises(qml.DeviceError, match=r"Observable Hermitian"):
_, __ = program((invalid_tape,))

shots_only_meas_tape = QuantumScript([], [qml.counts()], shots=shots)
Expand Down Expand Up @@ -555,6 +557,58 @@ def execute(self, circuits, execution_config=DefaultExecutionConfig):
assert qml.transforms.split_to_single_terms not in program
assert qml.transforms.split_to_single_terms not in program

@pytest.mark.usefixtures("create_temporary_toml_file")
@pytest.mark.parametrize("create_temporary_toml_file", [EXAMPLE_TOML_FILE], indirect=True)
@pytest.mark.parametrize("non_commuting_obs", [True, False])
@pytest.mark.parametrize("all_obs_support", [True, False])
def test_diagonalize_measurements(self, request, non_commuting_obs, all_obs_support):
"""Tests that the diagonalize_measurements transform is applied correctly."""

class CustomDevice(Device):

config_filepath = request.node.toml_file

def __init__(self):
super().__init__()
self.capabilities.non_commuting_observables = non_commuting_obs
if all_obs_support:
self.capabilities.observables.update(
{
"PauliX": OperatorProperties(),
"PauliY": OperatorProperties(),
"PauliZ": OperatorProperties(),
"Hadamard": OperatorProperties(),
}
)
else:
self.capabilities.observables.update(
{
"PauliZ": OperatorProperties(),
"PauliX": OperatorProperties(),
"PauliY": OperatorProperties(),
"Hermitian": OperatorProperties(),
}
)

def execute(self, circuits, execution_config=DefaultExecutionConfig):
return (0,)

dev = CustomDevice()
program = dev.preprocess_transforms()
if non_commuting_obs is True:
assert qml.transforms.diagonalize_measurements not in program
elif all_obs_support is True:
assert qml.transforms.diagonalize_measurements not in program
else:
assert qml.transforms.diagonalize_measurements in program
for transform_container in program:
if transform_container._transform is qml.transforms.diagonalize_measurements:
assert transform_container._kwargs["supported_base_obs"] == {
"PauliZ",
"PauliX",
"PauliY",
}


class TestMinimalDevice:
"""Tests for a device with only a minimal execute provided."""
Expand Down
4 changes: 2 additions & 2 deletions tests/measurements/test_probs.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,11 +405,11 @@ def circuit(phi):
@pytest.mark.parametrize("shots", [None, 1111, [1111, 1111]])
@pytest.mark.parametrize("phi", [0.0, np.pi / 3, np.pi])
def test_observable_is_measurement_value_list(
self, shots, phi, tol, tol_stochastic
self, shots, phi, tol, tol_stochastic, seed
): # pylint: disable=too-many-arguments
"""Test that probs for mid-circuit measurement values
are correct for a measurement value list."""
dev = qml.device("default.qubit")
dev = qml.device("default.qubit", seed=seed)

@qml.qnode(dev)
def circuit(phi):
Expand Down
10 changes: 6 additions & 4 deletions tests/transforms/test_diagonalize_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ def test_non_commuting_measurements_with_supported_obs(self, obs):
with pytest.raises(ValueError, match="Expected measurements on the same wire to commute"):
_ = _diagonalize_observable(obs, supported_base_obs=device_supported_obs)

def test_diagonalizing_unknown_observable_raises_error(self):
"""Test that an unknown observable raises an error when diagonalizing"""
def test_diagonalizing_unknown_observable(self):
"""Test that an unknown observable is left undiagonalized"""

# pylint: disable=too-few-public-methods
class MyObs(qml.operation.Observable):
Expand All @@ -225,8 +225,10 @@ class MyObs(qml.operation.Observable):
def name(self):
return f"MyObservable[{self.wires}]"

with pytest.raises(NotImplementedError, match="Unable to convert observable"):
_ = _diagonalize_observable(MyObs(wires=[2]))
initial_tape = qml.tape.QuantumScript([], [qml.expval(MyObs(wires=[2]))])
tapes, _ = diagonalize_measurements([initial_tape])
assert tapes[0].operations == []
assert tapes[0].measurements == [ExpectationMP(MyObs(wires=[2]))]

@pytest.mark.parametrize(
"obs, input_visited_obs, switch_basis, expected_res",
Expand Down

0 comments on commit 646ea39

Please sign in to comment.