-
Notifications
You must be signed in to change notification settings - Fork 616
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a
combine_global_phases
transform (#6686)
**Description of the Change:** I implemented a `qml.transform.combine_global_phases` transform to combine all `qml.GlobalPhase` gates in a circuit into a single `qml.GlobalPhase` operation (applied at the end of the new circuit without specifying any wire) with the phase equal to the total algebraic sum of each original phase. **Benefits:** This transform can be useful for circuits that include a lot of `qml.GlobalPhase` gates, which can be introduced directly during circuit creation, decompositions that include `qml.GlobalPhase` gates, etc. **Related GitHub Issues:** #6644 --------- Co-authored-by: Astral Cai <[email protected]> Co-authored-by: Christina Lee <[email protected]> Co-authored-by: Mudit Pandey <[email protected]>
- Loading branch information
1 parent
a703355
commit b4f2c20
Showing
4 changed files
with
271 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Copyright 2018-2024 Xanadu Quantum Technologies Inc. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
""" | ||
Provides a transform to combine all ``qml.GlobalPhase`` gates in a circuit into a single one applied at the end. | ||
""" | ||
|
||
import pennylane as qml | ||
from pennylane.tape import QuantumScript, QuantumScriptBatch | ||
from pennylane.transforms import transform | ||
from pennylane.typing import PostprocessingFn | ||
|
||
|
||
@transform | ||
def combine_global_phases(tape: QuantumScript) -> tuple[QuantumScriptBatch, PostprocessingFn]: | ||
"""Combine all ``qml.GlobalPhase`` gates into a single ``qml.GlobalPhase`` operation. | ||
This transform returns a new circuit where all ``qml.GlobalPhase`` gates in the original circuit (if exists) | ||
are removed, and a new ``qml.GlobalPhase`` is added at the end of the list of operations with its phase | ||
being a total global phase computed as the algebraic sum of all global phases in the original circuit. | ||
Args: | ||
tape (QNode or QuantumScript or Callable): the input circuit to be transformed. | ||
Returns: | ||
qnode (QNode) or quantum function (Callable) or tuple[List[QuantumScript], function]: | ||
the transformed circuit as described in :func:`qml.transform <pennylane.transform>`. | ||
**Example** | ||
Suppose we want to combine all the global phase gates in a given quantum circuit. | ||
The ``combine_global_phases`` transform can be used to do this as follows: | ||
.. code-block:: python3 | ||
dev = qml.device("default.qubit", wires=3) | ||
@qml.transforms.combine_global_phases | ||
@qml.qnode(dev) | ||
def circuit(): | ||
qml.GlobalPhase(0.3, wires=0) | ||
qml.PauliY(wires=0) | ||
qml.Hadamard(wires=1) | ||
qml.CNOT(wires=(1,2)) | ||
qml.GlobalPhase(0.46, wires=2) | ||
return qml.expval(qml.X(0) @ qml.Z(1)) | ||
To check the result, let's print out the circuit: | ||
>>> print(qml.draw(circuit)()) | ||
0: ──Y─────GlobalPhase(0.76)─┤ ╭<X@Z> | ||
1: ──H─╭●──GlobalPhase(0.76)─┤ ╰<X@Z> | ||
2: ────╰X──GlobalPhase(0.76)─┤ | ||
""" | ||
|
||
has_global_phase = False | ||
phi = 0 | ||
operations = [] | ||
for op in tape.operations: | ||
if isinstance(op, qml.GlobalPhase): | ||
has_global_phase = True | ||
phi += op.parameters[0] | ||
else: | ||
operations.append(op) | ||
|
||
if has_global_phase: | ||
with qml.QueuingManager.stop_recording(): | ||
operations.append(qml.GlobalPhase(phi=phi)) | ||
|
||
new_tape = tape.copy(operations=operations) | ||
|
||
def null_postprocessing(results): | ||
"""A postprocesing function returned by a transform that only converts the batch of results | ||
into a result for a single ``QuantumScript``. | ||
""" | ||
return results[0] | ||
|
||
return (new_tape,), null_postprocessing |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
""" | ||
Tests for the combine_global_phases transform. | ||
""" | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
import pennylane as qml | ||
from pennylane.transforms import combine_global_phases | ||
|
||
|
||
def original_qfunc(phi1, phi2, return_state=False): | ||
qml.Hadamard(wires=1) | ||
qml.GlobalPhase(phi1, wires=[0, 1]) | ||
qml.PauliY(wires=0) | ||
qml.PauliX(wires=2) | ||
qml.CNOT(wires=[1, 2]) | ||
qml.GlobalPhase(phi2, wires=1) | ||
qml.CNOT(wires=[2, 0]) | ||
if return_state: | ||
return qml.state() | ||
return qml.expval(qml.Z(0) @ qml.X(1)) | ||
|
||
|
||
def expected_qfunc(phi1, phi2, return_state=False): | ||
qml.Hadamard(wires=1) | ||
qml.PauliY(wires=0) | ||
qml.PauliX(wires=2) | ||
qml.CNOT(wires=[1, 2]) | ||
qml.CNOT(wires=[2, 0]) | ||
qml.GlobalPhase(phi1 + phi2) | ||
if return_state: | ||
return qml.state() | ||
return qml.expval(qml.Z(0) @ qml.X(1)) | ||
|
||
|
||
def test_no_global_phase_gate(): | ||
"""Test that when the input ``QuantumScript`` has no ``qml.GlobalPhase`` gate, the returned output is exactly the same""" | ||
qscript = qml.tape.QuantumScript([qml.Hadamard(0), qml.RX(0, 0)]) | ||
|
||
expected_qscript = qml.tape.QuantumScript([qml.Hadamard(0), qml.RX(0, 0)]) | ||
(transformed_qscript,), _ = combine_global_phases(qscript) | ||
|
||
qml.assert_equal(expected_qscript, transformed_qscript) | ||
|
||
|
||
def test_single_global_phase_gate(): | ||
"""Test that when the input ``QuantumScript`` has a single ``qml.GlobalPhase`` gate, the returned output has an equivalent | ||
``qml.GlobalPhase`` operation appended at the end""" | ||
phi = 1.23 | ||
qscript = qml.tape.QuantumScript([qml.Hadamard(0), qml.GlobalPhase(phi, 0), qml.RX(0, 0)]) | ||
|
||
expected_qscript = qml.tape.QuantumScript([qml.Hadamard(0), qml.RX(0, 0), qml.GlobalPhase(phi)]) | ||
(transformed_qscript,), _ = combine_global_phases(qscript) | ||
|
||
qml.assert_equal(expected_qscript, transformed_qscript) | ||
|
||
|
||
def test_multiple_global_phase_gates(): | ||
"""Test that when the input ``QuantumScript`` has multiple ``qml.GlobalPhase`` gates, the returned output has an equivalent | ||
single ``qml.GlobalPhase`` operation appended at the end with a total phase being equal to the sum of each original global phase | ||
""" | ||
phi1 = 1.23 | ||
phi2 = 4.56 | ||
qscript = qml.tape.QuantumScript( | ||
[qml.GlobalPhase(phi1, 0), qml.Hadamard(0), qml.GlobalPhase(phi2, 0), qml.RX(0, 0)] | ||
) | ||
|
||
expected_qscript = qml.tape.QuantumScript( | ||
[qml.Hadamard(0), qml.RX(0, 0), qml.GlobalPhase(phi1 + phi2)] | ||
) | ||
(transformed_qscript,), _ = combine_global_phases(qscript) | ||
|
||
qml.assert_equal(expected_qscript, transformed_qscript) | ||
|
||
|
||
def test_combine_global_phases(): | ||
"""Test that the ``combine_global_phases`` function implements the expected transform on a | ||
QuantumScript and check the equivalence between statevectors before and after the transform.""" | ||
transformed_qfunc = combine_global_phases(original_qfunc) | ||
|
||
dev = qml.device("default.qubit", wires=3) | ||
original_qnode = qml.QNode(original_qfunc, device=dev) | ||
transformed_qnode = qml.QNode(transformed_qfunc, device=dev) | ||
|
||
phi1 = 1.23 | ||
phi2 = 4.56 | ||
expected_qscript = qml.tape.make_qscript(expected_qfunc)(phi1, phi2) | ||
transformed_qscript = qml.tape.make_qscript(transformed_qfunc)(phi1, phi2) | ||
|
||
original_state = original_qnode(phi1, phi2, return_state=True) | ||
transformed_state = transformed_qnode(phi1, phi2, return_state=True) | ||
|
||
# check the equivalence between expected and transformed quantum scripts | ||
qml.assert_equal(expected_qscript, transformed_qscript) | ||
|
||
# check the equivalence between statevectors before and after the transform | ||
assert np.allclose(original_state, transformed_state) | ||
|
||
|
||
@pytest.mark.autograd | ||
def test_differentiability_autograd(): | ||
"""Test that the output of the ``combine_global_phases`` transform is differentiable with autograd""" | ||
import pennylane.numpy as pnp | ||
|
||
dev = qml.device("default.qubit", wires=3) | ||
original_qnode = qml.QNode(original_qfunc, device=dev) | ||
transformed_qnode = combine_global_phases(original_qnode) | ||
|
||
phi1 = pnp.array(0.25) | ||
phi2 = pnp.array(-0.6) | ||
grad1, grad2 = qml.jacobian(transformed_qnode)(phi1, phi2) | ||
|
||
assert qml.math.isclose(grad1, 0.0) | ||
assert qml.math.isclose(grad2, 0.0) | ||
|
||
|
||
@pytest.mark.jax | ||
@pytest.mark.parametrize("use_jit", [False, True]) | ||
def test_differentiability_jax(use_jit): | ||
"""Test that the output of the ``combine_global_phases`` transform is differentiable with JAX""" | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
dev = qml.device("default.qubit", wires=3) | ||
original_qnode = qml.QNode(original_qfunc, device=dev) | ||
transformed_qnode = combine_global_phases(original_qnode) | ||
|
||
if use_jit: | ||
transformed_qnode = jax.jit(transformed_qnode) | ||
|
||
phi1 = jnp.array(0.25) | ||
phi2 = jnp.array(-0.6) | ||
grad1, grad2 = jax.jacobian(transformed_qnode, argnums=[0, 1])(phi1, phi2) | ||
|
||
assert qml.math.isclose(grad1, 0.0) | ||
assert qml.math.isclose(grad2, 0.0) | ||
|
||
|
||
@pytest.mark.torch | ||
def test_differentiability_torch(): | ||
"""Test that the output of the ``combine_global_phases`` transform is differentiable with Torch""" | ||
import torch | ||
from torch.autograd.functional import jacobian | ||
|
||
dev = qml.device("default.qubit", wires=3) | ||
original_qnode = qml.QNode(original_qfunc, device=dev) | ||
transformed_qnode = combine_global_phases(original_qnode) | ||
|
||
phi1 = torch.tensor(0.25) | ||
phi2 = torch.tensor(-0.6) | ||
grad1, grad2 = jacobian(transformed_qnode, (phi1, phi2)) | ||
|
||
zero = torch.tensor(0.0) | ||
assert qml.math.isclose(grad1, zero) | ||
assert qml.math.isclose(grad2, zero) | ||
|
||
|
||
@pytest.mark.tf | ||
def test_differentiability_tensorflow(): | ||
"""Test that the output of the ``combine_global_phases`` transform is differentiable with TensorFlow""" | ||
import tensorflow as tf | ||
|
||
dev = qml.device("default.qubit", wires=3) | ||
original_qnode = qml.QNode(original_qfunc, device=dev) | ||
|
||
phi1 = tf.Variable(0.25) | ||
phi2 = tf.Variable(-0.6) | ||
with tf.GradientTape() as tape: | ||
transformed_qnode = combine_global_phases(original_qnode)(phi1, phi2) | ||
grad1, grad2 = tape.jacobian(transformed_qnode, (phi1, phi2)) | ||
|
||
assert qml.math.isclose(grad1, 0.0) | ||
assert qml.math.isclose(grad2, 0.0) |