diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 3680c3dd6ea..010f57c77fe 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -340,6 +340,9 @@ such as `shots`, `rng` and `prng_key`.
Other Improvements
+* `qml.transforms.cancel_inverses` now has improved handling of `Adjoint` operator cancellation.
+ [(#6752)](https://github.com/PennyLaneAI/pennylane/pull/6752)
+
* `qml.math.grad` and `qml.math.jacobian` added to differentiate a function with inputs of any
interface in a jax-like manner.
[(#6741)](https://github.com/PennyLaneAI/pennylane/pull/6741)
diff --git a/pennylane/transforms/optimization/cancel_inverses.py b/pennylane/transforms/optimization/cancel_inverses.py
index 418e68941a6..fb8bb458557 100644
--- a/pennylane/transforms/optimization/cancel_inverses.py
+++ b/pennylane/transforms/optimization/cancel_inverses.py
@@ -29,18 +29,8 @@
from .optimization_utils import find_next_gate
-def _ops_equal(op1, op2):
- """Checks if two operators are equal up to class, data, hyperparameters, and wires"""
- return (
- op1.__class__ is op2.__class__
- and (op1.data == op2.data)
- and (op1.hyperparameters == op2.hyperparameters)
- and (op1.wires == op2.wires)
- )
-
-
-def _are_inverses(op1, op2):
- """Checks if two operators are inverses of each other
+def _can_cancel_ops(op1, op2):
+ """Checks if two operators can be cancelled
Args:
op1 (~.Operator)
@@ -49,17 +39,39 @@ def _are_inverses(op1, op2):
Returns:
Bool
"""
- # op1 is self-inverse and the next gate is also op1
- if op1 in self_inverses and op1.name == op2.name:
- return True
-
- # op1 is an `Adjoint` class and its base is equal to op2
- if isinstance(op1, Adjoint) and _ops_equal(op1.base, op2):
- return True
-
- # op2 is an `Adjoint` class and its base is equal to op1
- if isinstance(op2, Adjoint) and _ops_equal(op2.base, op1):
- return True
+ # Make sure that if one of the ops is Adjoint, it is always op2 by swapping
+ # the ops if op1 is Adjoint
+ if isinstance(op1, Adjoint):
+ op1, op2 = op2, op1
+
+ are_self_inverses_without_wires = op1 in self_inverses and op1.name == op2.name
+ are_inverses_without_wires = (
+ isinstance(op2, Adjoint)
+ and op1.__class__ == op2.base.__class__
+ and op1.data == op2.base.data
+ and op1.hyperparameters == op2.base.hyperparameters
+ )
+ if are_self_inverses_without_wires or are_inverses_without_wires:
+
+ # If the wires are the same, then we can safely cancel both
+ if op1.wires == op2.wires:
+ return True
+ # If wires are not equal, there are two things that can happen.
+ # 1. There is not full overlap in the wires; we cannot cancel
+ if len(Wires.shared_wires([op1.wires, op2.wires])) != len(op1.wires):
+ return False
+
+ # 2. There is full overlap, but the wires are in a different order.
+ # If the wires are in a different order, gates that are "symmetric"
+ # over all wires (e.g., CZ), can be cancelled.
+ if op1 in symmetric_over_all_wires:
+ return True
+ # For other gates, as long as the control wires are the same, we can still
+ # cancel (e.g., the Toffoli gate).
+ if op1 in symmetric_over_control_wires:
+ # TODO[David Wierichs]: This assumes single-qubit targets of controlled gates
+ if len(Wires.shared_wires([op1.wires[:-1], op2.wires[:-1]])) == len(op1.wires) - 1:
+ return True
return False
@@ -123,24 +135,7 @@ def interpret_operation(self, op: Operator):
self.previous_ops[w] = op
return []
- cancel = False
- if _are_inverses(op, prev_op):
- # Same wires, cancel
- if op.wires == prev_op.wires:
- cancel = True
- # Full overlap over wires
- elif len(Wires.shared_wires([op.wires, prev_op.wires])) == len(op.wires):
- # symmetric op + full wire overlap; cancel
- if op in symmetric_over_all_wires:
- cancel = True
- # symmetric over control wires, full overlap over control wires; cancel
- elif op in symmetric_over_control_wires and (
- len(Wires.shared_wires([op.wires[:-1], prev_op.wires[:-1]]))
- == len(op.wires) - 1
- ):
- cancel = True
- # No or partial overlap over wires; can't cancel
-
+ cancel = _can_cancel_ops(op, prev_op)
if cancel:
for w in op.wires:
self.previous_ops.pop(w)
@@ -275,7 +270,7 @@ def cancel_inverses(tape: QuantumScript) -> tuple[QuantumScriptBatch, Postproces
.. code-block:: python
- @cancel_inverses
+ @qml.transforms.cancel_inverses
@qml.qnode(device=dev)
def circuit(x, y, z):
qml.Hadamard(wires=0)
@@ -291,7 +286,7 @@ def circuit(x, y, z):
return qml.expval(qml.Z(0))
>>> circuit(0.1, 0.2, 0.3)
- 0.999999999999999
+ 1.0
.. details::
:title: Usage Details
@@ -326,7 +321,7 @@ def qfunc(x, y, z):
second qubit that should cancel. We can obtain a simplified circuit by running
the ``cancel_inverses`` transform:
- >>> optimized_qfunc = cancel_inverses(qfunc)
+ >>> optimized_qfunc = qml.transforms.cancel_inverses(qfunc)
>>> optimized_qnode = qml.QNode(optimized_qfunc, dev)
>>> print(qml.draw(optimized_qnode)(1, 2, 3))
0: ──RZ(3.00)───────────╭●─┤
@@ -353,42 +348,15 @@ def qfunc(x, y, z):
# Otherwise, get the next gate
next_gate = list_copy[next_gate_idx]
- # If either of the two flags is true, we can potentially cancel the gates
- if _are_inverses(current_gate, next_gate):
- # If the wires are the same, then we can safely remove both
- if current_gate.wires == next_gate.wires:
- list_copy.pop(next_gate_idx)
- continue
- # If wires are not equal, there are two things that can happen.
- # 1. There is not full overlap in the wires; we cannot cancel
- if len(Wires.shared_wires([current_gate.wires, next_gate.wires])) != len(
- current_gate.wires
- ):
- operations.append(current_gate)
- continue
-
- # 2. There is full overlap, but the wires are in a different order.
- # If the wires are in a different order, gates that are "symmetric"
- # over all wires (e.g., CZ), can be cancelled.
- if current_gate in symmetric_over_all_wires:
- list_copy.pop(next_gate_idx)
- continue
- # For other gates, as long as the control wires are the same, we can still
- # cancel (e.g., the Toffoli gate).
- if current_gate in symmetric_over_control_wires:
- # TODO[David Wierichs]: This assumes single-qubit targets of controlled gates
- if (
- len(Wires.shared_wires([current_gate.wires[:-1], next_gate.wires[:-1]]))
- == len(current_gate.wires) - 1
- ):
- list_copy.pop(next_gate_idx)
- continue
+ # If operators are inverses, cancel
+ if _can_cancel_ops(current_gate, next_gate):
+ list_copy.pop(next_gate_idx)
+ continue
# Apply gate any cases where
# - there is no wire symmetry
# - the control wire symmetry does not apply because the control wires are not the same
# - neither of the flags are_self_inverses and are_inverses are true
operations.append(current_gate)
- continue
new_tape = tape.copy(operations=operations)
diff --git a/tests/capture/transforms/test_capture_cancel_inverses.py b/tests/capture/transforms/test_capture_cancel_inverses.py
index ccc761b86f1..00e8b78da5e 100644
--- a/tests/capture/transforms/test_capture_cancel_inverses.py
+++ b/tests/capture/transforms/test_capture_cancel_inverses.py
@@ -90,15 +90,21 @@ def f():
jaxpr = jax.make_jaxpr(f)()
assert len(jaxpr.eqns) == 0
- def test_cancel_inverses_symmetric_wires(self):
+ @pytest.mark.parametrize("adjoint_first", [True, False])
+ def test_cancel_inverses_symmetric_wires(self, adjoint_first):
"""Test that operations that are inverses regardless of wire order are cancelled."""
@CancelInversesInterpreter()
- def f():
- qml.CCZ([0, 1, 2])
- qml.CCZ([2, 0, 1])
-
- jaxpr = jax.make_jaxpr(f)()
+ def f(x):
+ if adjoint_first:
+ qml.adjoint(qml.MultiRZ(x, [2, 0, 1]))
+ qml.MultiRZ(x, [0, 1, 2])
+ else:
+ qml.MultiRZ(x, [2, 0, 1])
+ qml.adjoint(qml.MultiRZ(x, [0, 1, 2]))
+
+ args = (1.5,)
+ jaxpr = jax.make_jaxpr(f)(*args)
assert len(jaxpr.eqns) == 0
def test_cancel_inverses_symmetric_control_wires(self):
diff --git a/tests/transforms/test_optimization/test_cancel_inverses.py b/tests/transforms/test_optimization/test_cancel_inverses.py
index d571a2a1f43..2d874583f5c 100644
--- a/tests/transforms/test_optimization/test_cancel_inverses.py
+++ b/tests/transforms/test_optimization/test_cancel_inverses.py
@@ -194,6 +194,26 @@ def qfunc():
assert len(ops) == 0
+ @pytest.mark.parametrize("adjoint_first", [True, False])
+ def test_symmetric_over_all_wires(self, adjoint_first):
+ """Test that adjacent adjoint ops are cancelled due to wire symmetry."""
+
+ def qfunc(x):
+ if adjoint_first:
+ qml.adjoint(qml.MultiRZ(x, [2, 0, 1]))
+ qml.MultiRZ(x, [0, 1, 2])
+ else:
+ qml.MultiRZ(x, [2, 0, 1])
+ qml.adjoint(qml.MultiRZ(x, [0, 1, 2]))
+
+ transformed_qfunc = cancel_inverses(qfunc)
+
+ ops = qml.tape.make_qscript(transformed_qfunc)(1.5).operations
+
+ names_expected = []
+ wires_expected = []
+ compare_operation_lists(ops, names_expected, wires_expected)
+
# Example QNode and device for interface testing
dev = qml.device("default.qubit", wires=3)