Skip to content

Commit

Permalink
Upgrade fuse_rot_angles to preserve derivatives and global phases (#…
Browse files Browse the repository at this point in the history
…6031)

**Context:**
The current implementation of `fuse_rot_angles`, which is used by
`merge_rotations` and `single_qubit_fusion`, has the following issues:
1. It does not necessarily preserve global phases. As we move towards
global-phase aware standards, this becomes an issue where it wasn't one
before.
2. Its derivative is wrong, forming part (but not all) of the bug #5715.
In particular, the custom handling of special input values prevents the
calculation of correct derivatives, and `fuse_rot_angles` at singular
points leads to wrong derivatives, rather than NaN values, which are
mathematically well-motivated.
3. A minor technical issue is that the implementation requires nested
conditionals, leading to a good bit of code and separate handling of
traced JAX code, making it more complex.

**Description of the Change:**
The implementation of `fuse_rot_angles` is remade entirely. 

**Benefits:**
The remade code uses a comparably simple mathematical expression to
compute the fused rotation angles that
1. preserves global phases
2. has the correct derivative everywhere except for well-understandable,
predictable singular points. These singular points make sense because
rotation fusion is not a smooth map everywhere. The predictability
allowed us to write a dedicated test that confirms our understanding of
the singularities, at least within a large set of special test points.
3. does not require conditionals beyond those that are implemented in
`qml.math.arctan2` anyways, and thus available in all ML interfaces,
including JAX with JIT-compatibility.
4. As a bonus, the new implementation supports broadcasting/batching
with an arbitrary number of leading dimensions if all angles in each set
are broadcasted in the same way (because this is nice, easy to support,
and allows us to speed up tests a lot).

In summary, the global phases are fixed, Jacobians are only ever NaNs,
rather than wrong, and Jacobians are only NaNs when they were NaN or
wrong in the current implementation.

**Possible Drawbacks:**
N/A

**Related GitHub Issues:**
#5715 (not fixed entirely, just partially)
[sc-63642]

---------

Co-authored-by: Vincent Michaud-Rioux <[email protected]>
Co-authored-by: Korbinian Kottmann <[email protected]>
Co-authored-by: Thomas R. Bromley <[email protected]>
  • Loading branch information
4 people authored Aug 14, 2024
1 parent 4cd95c3 commit 98c8ed3
Show file tree
Hide file tree
Showing 5 changed files with 530 additions and 207 deletions.
6 changes: 6 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@

<h3>Improvements 🛠</h3>

* `fuse_rot_angles` now respects the global phase of the combined rotations.
[(#6031)](https://github.com/PennyLaneAI/pennylane/pull/6031)

* `QNGOptimizer` now supports cost functions with multiple arguments, updating each argument independently.
[(#5926)](https://github.com/PennyLaneAI/pennylane/pull/5926)

Expand Down Expand Up @@ -286,6 +289,9 @@

<h3>Bug fixes 🐛</h3>

* `fuse_rot_angles` no longer returns wrong derivatives at singular points but returns NaN.
[(#6031)](https://github.com/PennyLaneAI/pennylane/pull/6031)

* `qml.GlobalPhase` and `qml.I` can now be captured when acting on no wires.
[(#6060)](https://github.com/PennyLaneAI/pennylane/pull/6060)

Expand Down
23 changes: 21 additions & 2 deletions pennylane/transforms/optimization/merge_rotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,29 @@ def circuit(x, y, z):
>>> circuit(0.1, 0.2, 0.3)
0.9553364891256055
.. details::
:title: Details on merging ``Rot`` gates
:href: details-on-rot
When merging two :class:`~.pennylane.Rot` gates, there are a number of details to consider:
First, the output angles are not always defined uniquely, because Euler angles are not
unique for some rotations. ``merge_rotations`` makes a particular choice in
this case.
Second, ``merge_rotations`` is not differentiable everywhere when used on ``Rot``.
It has singularities for specific rotation angles where the derivative will be NaN.
Finally, this function can be numerically unstable near singular points.
It is therefore recommended to use it with 64-bit floating point precision angles.
For a mathematical derivation of the fusion of two ``Rot`` gates, see the documentation
of :func:`~.pennylane.transforms.single_qubit_fusion`.
.. details::
:title: Usage Details
You can also apply it on quantum function.
You can also apply ``merge_rotations`` to a quantum function.
.. code-block:: python
Expand Down Expand Up @@ -106,7 +125,7 @@ def qfunc(x, y, z):
2: ─╰X─────────H─╰●────────┤
It is also possible to explicitly specify which rotations ``merge_rotations`` should
be merged using the ``include_gates`` argument. For example, if in the above
merge using the ``include_gates`` argument. For example, if in the above
circuit we wanted only to merge the "RX" gates, we could do so as follows:
>>> optimized_qfunc = merge_rotations(include_gates=["RX"])(qfunc)
Expand Down
217 changes: 86 additions & 131 deletions pennylane/transforms/optimization/optimization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,7 @@
# limitations under the License.
"""Utility functions for circuit optimization."""
# pylint: disable=too-many-return-statements,import-outside-toplevel
from functools import partial

import numpy as np

from pennylane.math import abs as math_abs
from pennylane.math import allclose, arccos, arctan2, cos, get_interface, is_abstract, sin, stack
from pennylane.math import sum as math_sum
import pennylane as qml
from pennylane.ops.identity import GlobalPhase
from pennylane.wires import Wires

Expand Down Expand Up @@ -47,146 +41,107 @@ def find_next_gate(wires, op_list):
return next_gate_idx


def _zyz_to_quat(angles):
"""Converts a set of Euler angles in ZYZ format to a quaternion."""
qw = cos(angles[1] / 2) * cos(0.5 * (angles[0] + angles[2]))
qx = -sin(angles[1] / 2) * sin(0.5 * (angles[0] - angles[2]))
qy = sin(angles[1] / 2) * cos(0.5 * (angles[0] - angles[2]))
qz = cos(angles[1] / 2) * sin(0.5 * (angles[0] + angles[2]))

return stack([qw, qx, qy, qz])


def _quaternion_product(q1, q2):
"""Compute the product of two quaternions, q = q1 * q2."""
qw = q1[0] * q2[0] - q1[1] * q2[1] - q1[2] * q2[2] - q1[3] * q2[3]
qx = q1[0] * q2[1] + q1[1] * q2[0] + q1[2] * q2[3] - q1[3] * q2[2]
qy = q1[0] * q2[2] - q1[1] * q2[3] + q1[2] * q2[0] + q1[3] * q2[1]
qz = q1[0] * q2[3] + q1[1] * q2[2] - q1[2] * q2[1] + q1[3] * q2[0]

return stack([qw, qx, qy, qz])


def _singular_quat_to_zyz(qw, qx, qy, qz, y_arg, abstract_jax=False):
"""Compute the ZYZ angles for the singular case of qx = qy = 0"""
# pylint: disable=too-many-arguments
z1_arg1 = 2 * (qx * qy + qz * qw)
z1_arg2 = 1 - 2 * (qx**2 + qz**2)

if abstract_jax:
from jax.lax import cond

return cond(
y_arg > 0,
lambda z1_arg1, z1_arg2: stack([arctan2(z1_arg1, z1_arg2), 0.0, 0.0]),
lambda z1_arg1, z1_arg2: stack([-arctan2(z1_arg1, z1_arg2), np.pi, 0.0]),
z1_arg1,
z1_arg2,
)

if y_arg > 0:
z1 = arctan2(z1_arg1, z1_arg2)
y = z2 = 0.0
else:
z1 = -arctan2(z1_arg1, z1_arg2)
y = np.pi
z2 = 0.0
return stack([z1, y, z2])


def _regular_quat_to_zyz(qw, qx, qy, qz, y_arg):
"""Compute the ZYZ angles for the regular case (qx != 0 or qy != 0)"""
z1_arg1 = 2 * (qy * qz - qw * qx)
z1_arg2 = 2 * (qx * qz + qw * qy)
z1 = arctan2(z1_arg1, z1_arg2)

y = arccos(y_arg)

z2_arg1 = 2 * (qy * qz + qw * qx)
z2_arg2 = 2 * (qw * qy - qx * qz)
z2 = arctan2(z2_arg1, z2_arg2)

return stack([z1, y, z2])


def _fuse(angles_1, angles_2, abstract_jax=False):
"""Perform fusion of two angle sets. Separated out so we can do JIT with conditionals."""
# Compute the product of the quaternions
qw, qx, qy, qz = _quaternion_product(_zyz_to_quat(angles_1), _zyz_to_quat(angles_2))

# Convert the product back into the angles fed to Rot
y_arg = 1 - 2 * (qx**2 + qy**2)

# Require special treatment of the case qx = qy = 0. Note that we have to check
# for "greater than" as well, because of imprecisions
if abstract_jax:
from jax.lax import cond

return cond(
math_abs(y_arg) >= 1,
partial(_singular_quat_to_zyz, abstract_jax=True),
_regular_quat_to_zyz,
qw,
qx,
qy,
qz,
y_arg,
)

# Require special treatment of the case qx = qy = 0
if abs(y_arg) >= 1: # Have to check for "greater than" as well, because of imprecisions
return _singular_quat_to_zyz(qw, qx, qy, qz, y_arg)
return _regular_quat_to_zyz(qw, qx, qy, qz, y_arg)


def _no_fuse(angles_1, angles_2):
"""Special case: do not perform fusion when both Y angles are zero:
Rot(a, 0, b) Rot(c, 0, d) = Rot(a + b + c + d, 0, 0)
The quaternion math itself will fail in this case without a conditional.
"""
return stack([angles_1[0] + angles_1[2] + angles_2[0] + angles_2[2], 0.0, 0.0])
def _try_no_fuse(angles_1, angles_2):
"""Try to combine rotation angles without trigonometric identities
if some angles in the input angles vanish."""
# This sum is only computed to obtain a dtype-coerced object that respects
# TensorFlow's coercion rules between Python/NumPy objects and TF objects.
_sum = angles_1 + angles_2
# moveaxis required for batched inputs
phi1, theta1, omega1 = qml.math.moveaxis(qml.math.cast_like(angles_1, _sum), -1, 0)
phi2, theta2, omega2 = qml.math.moveaxis(qml.math.cast_like(angles_2, _sum), -1, 0)

if qml.math.allclose(omega1 + phi2, 0.0):
return qml.math.stack([phi1, theta1 + theta2, omega2])
if qml.math.allclose(theta1, 0.0):
# No Y rotation in first Rot
if qml.math.allclose(theta2, 0.0):
# Z rotations only
zero = qml.math.zeros_like(phi1) + qml.math.zeros_like(phi2)
return qml.math.stack([phi1 + omega1 + phi2 + omega2, zero, zero])
return qml.math.stack([phi1 + omega1 + phi2, theta2, omega2])
if qml.math.allclose(theta2, 0.0):
# No Y rotation in second Rot
return qml.math.stack([phi1, theta1, omega1 + phi2 + omega2])
return None


def fuse_rot_angles(angles_1, angles_2):
"""Computed the set of rotation angles that is obtained when composing
two ``qml.Rot`` operations.
r"""Compute the set of rotation angles that is equivalent to performing
two successive ``qml.Rot`` operations.
The ``qml.Rot`` operation represents the most general single-qubit operation.
Two such operations can be fused into a new operation, however the angular dependence
is non-trivial.
Args:
angles_1 (float): A set of three angles for the first ``qml.Rot`` operation.
angles_2 (float): A set of three angles for the second ``qml.Rot`` operation.
angles_1 (tensor_like): A set of three angles for the first ``qml.Rot`` operation.
angles_2 (tensor_like): A set of three angles for the second ``qml.Rot`` operation.
Returns:
array[float]: Rotation angles for a single ``qml.Rot`` operation that
tensor_like: Rotation angles for a single ``qml.Rot`` operation that
implements the same operation as the two sets of input angles.
"""
# Check if we are tracing; if so, use the special conditionals
if is_abstract(angles_1) or is_abstract(angles_2):
interface = get_interface(angles_1, angles_2)
This function supports broadcasting/batching as long as the two inputs are standard
broadcast-compatible.
.. note::
# TODO: implement something similar for torch and tensorflow interfaces
# If the interface is JAX, use jax.lax.cond so that we can jit even with conditionals
if interface == "jax":
from jax.lax import cond
The output angles are not always defined uniquely because Euler angles are not
unique for some rotations. ``fuse_rot_angles`` makes a particular
choice in this case.
return cond(
allclose(angles_1[1], 0.0) * allclose(angles_2[1], 0.0),
_no_fuse,
partial(_fuse, abstract_jax=True),
angles_1,
angles_2,
)
.. warning::
# For other interfaces where we would not be jitting or tracing, we can simply check
# if we are dealing with the special case of Rot(a, 0, b) Rot(c, 0, d).
if allclose(angles_1[1], 0.0) and allclose(angles_2[1], 0.0):
return _no_fuse(angles_1, angles_2)
This function is not differentiable everywhere. It has singularities for specific
input values where the derivative will be NaN.
return _fuse(angles_1, angles_2)
.. warning::
This function is numerically unstable at singular points. It is recommended to use
it with 64-bit floating point precision.
See the documentation of :func:`~.pennylane.transforms.single_qubit_fusion` for a
mathematical derivation of this function.
"""
angles_1 = qml.math.asarray(angles_1)
angles_2 = qml.math.asarray(angles_2)

if not (
qml.math.is_abstract(angles_1)
or qml.math.is_abstract(angles_2)
or qml.math.requires_grad(angles_1)
or qml.math.requires_grad(angles_2)
):
fused_angles = _try_no_fuse(angles_1, angles_2)
if fused_angles is not None:
return fused_angles

# moveaxis required for batched inputs
angles_1 = qml.math.moveaxis(angles_1, -1, 0)
angles_2 = qml.math.moveaxis(angles_2, -1, 0)
phi1, theta1, omega1 = angles_1[0], angles_1[1], angles_1[2]
phi2, theta2, omega2 = angles_2[0], angles_2[1], angles_2[2]
c1, c2 = qml.math.cos(theta1 / 2), qml.math.cos(theta2 / 2)
s1, s2 = qml.math.sin(theta1 / 2), qml.math.sin(theta2 / 2)

mag = qml.math.sqrt(
c1**2 * c2**2 + s1**2 * s2**2 - 2 * c1 * c2 * s1 * s2 * qml.math.cos(omega1 + phi2)
)
theta_f = 2 * qml.math.arccos(mag)

alpha1, beta1 = (phi1 + omega1) / 2, (phi1 - omega1) / 2
alpha2, beta2 = (phi2 + omega2) / 2, (phi2 - omega2) / 2

alpha_arg1 = -c1 * c2 * qml.math.sin(alpha1 + alpha2) - s1 * s2 * qml.math.sin(beta2 - beta1)
alpha_arg2 = c1 * c2 * qml.math.cos(alpha1 + alpha2) - s1 * s2 * qml.math.cos(beta2 - beta1)
alpha_f = -1 * qml.math.arctan2(alpha_arg1, alpha_arg2)

beta_arg1 = -c1 * s2 * qml.math.sin(alpha1 + beta2) + s1 * c2 * qml.math.sin(alpha2 - beta1)
beta_arg2 = c1 * s2 * qml.math.cos(alpha1 + beta2) + s1 * c2 * qml.math.cos(alpha2 - beta1)
beta_f = -1 * qml.math.arctan2(beta_arg1, beta_arg2)

return qml.math.stack([alpha_f + beta_f, theta_f, alpha_f - beta_f], axis=-1)


def _fuse_global_phases(operations):
Expand All @@ -206,5 +161,5 @@ def _fuse_global_phases(operations):
else:
fused_ops.append(op)

fused_ops.append(GlobalPhase(math_sum(op.data[0] for op in global_ops)))
fused_ops.append(GlobalPhase(sum(op.data[0] for op in global_ops)))
return fused_ops
Loading

0 comments on commit 98c8ed3

Please sign in to comment.