diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 3680c3dd6ea..010f57c77fe 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -340,6 +340,9 @@ such as `shots`, `rng` and `prng_key`.

Other Improvements

+* `qml.transforms.cancel_inverses` now has improved handling of `Adjoint` operator cancellation. + [(#6752)](https://github.com/PennyLaneAI/pennylane/pull/6752) + * `qml.math.grad` and `qml.math.jacobian` added to differentiate a function with inputs of any interface in a jax-like manner. [(#6741)](https://github.com/PennyLaneAI/pennylane/pull/6741) diff --git a/pennylane/transforms/optimization/cancel_inverses.py b/pennylane/transforms/optimization/cancel_inverses.py index 418e68941a6..fb8bb458557 100644 --- a/pennylane/transforms/optimization/cancel_inverses.py +++ b/pennylane/transforms/optimization/cancel_inverses.py @@ -29,18 +29,8 @@ from .optimization_utils import find_next_gate -def _ops_equal(op1, op2): - """Checks if two operators are equal up to class, data, hyperparameters, and wires""" - return ( - op1.__class__ is op2.__class__ - and (op1.data == op2.data) - and (op1.hyperparameters == op2.hyperparameters) - and (op1.wires == op2.wires) - ) - - -def _are_inverses(op1, op2): - """Checks if two operators are inverses of each other +def _can_cancel_ops(op1, op2): + """Checks if two operators can be cancelled Args: op1 (~.Operator) @@ -49,17 +39,39 @@ def _are_inverses(op1, op2): Returns: Bool """ - # op1 is self-inverse and the next gate is also op1 - if op1 in self_inverses and op1.name == op2.name: - return True - - # op1 is an `Adjoint` class and its base is equal to op2 - if isinstance(op1, Adjoint) and _ops_equal(op1.base, op2): - return True - - # op2 is an `Adjoint` class and its base is equal to op1 - if isinstance(op2, Adjoint) and _ops_equal(op2.base, op1): - return True + # Make sure that if one of the ops is Adjoint, it is always op2 by swapping + # the ops if op1 is Adjoint + if isinstance(op1, Adjoint): + op1, op2 = op2, op1 + + are_self_inverses_without_wires = op1 in self_inverses and op1.name == op2.name + are_inverses_without_wires = ( + isinstance(op2, Adjoint) + and op1.__class__ == op2.base.__class__ + and op1.data == op2.base.data + and op1.hyperparameters == op2.base.hyperparameters + ) + if are_self_inverses_without_wires or are_inverses_without_wires: + + # If the wires are the same, then we can safely cancel both + if op1.wires == op2.wires: + return True + # If wires are not equal, there are two things that can happen. + # 1. There is not full overlap in the wires; we cannot cancel + if len(Wires.shared_wires([op1.wires, op2.wires])) != len(op1.wires): + return False + + # 2. There is full overlap, but the wires are in a different order. + # If the wires are in a different order, gates that are "symmetric" + # over all wires (e.g., CZ), can be cancelled. + if op1 in symmetric_over_all_wires: + return True + # For other gates, as long as the control wires are the same, we can still + # cancel (e.g., the Toffoli gate). + if op1 in symmetric_over_control_wires: + # TODO[David Wierichs]: This assumes single-qubit targets of controlled gates + if len(Wires.shared_wires([op1.wires[:-1], op2.wires[:-1]])) == len(op1.wires) - 1: + return True return False @@ -123,24 +135,7 @@ def interpret_operation(self, op: Operator): self.previous_ops[w] = op return [] - cancel = False - if _are_inverses(op, prev_op): - # Same wires, cancel - if op.wires == prev_op.wires: - cancel = True - # Full overlap over wires - elif len(Wires.shared_wires([op.wires, prev_op.wires])) == len(op.wires): - # symmetric op + full wire overlap; cancel - if op in symmetric_over_all_wires: - cancel = True - # symmetric over control wires, full overlap over control wires; cancel - elif op in symmetric_over_control_wires and ( - len(Wires.shared_wires([op.wires[:-1], prev_op.wires[:-1]])) - == len(op.wires) - 1 - ): - cancel = True - # No or partial overlap over wires; can't cancel - + cancel = _can_cancel_ops(op, prev_op) if cancel: for w in op.wires: self.previous_ops.pop(w) @@ -275,7 +270,7 @@ def cancel_inverses(tape: QuantumScript) -> tuple[QuantumScriptBatch, Postproces .. code-block:: python - @cancel_inverses + @qml.transforms.cancel_inverses @qml.qnode(device=dev) def circuit(x, y, z): qml.Hadamard(wires=0) @@ -291,7 +286,7 @@ def circuit(x, y, z): return qml.expval(qml.Z(0)) >>> circuit(0.1, 0.2, 0.3) - 0.999999999999999 + 1.0 .. details:: :title: Usage Details @@ -326,7 +321,7 @@ def qfunc(x, y, z): second qubit that should cancel. We can obtain a simplified circuit by running the ``cancel_inverses`` transform: - >>> optimized_qfunc = cancel_inverses(qfunc) + >>> optimized_qfunc = qml.transforms.cancel_inverses(qfunc) >>> optimized_qnode = qml.QNode(optimized_qfunc, dev) >>> print(qml.draw(optimized_qnode)(1, 2, 3)) 0: ──RZ(3.00)───────────╭●─┤ @@ -353,42 +348,15 @@ def qfunc(x, y, z): # Otherwise, get the next gate next_gate = list_copy[next_gate_idx] - # If either of the two flags is true, we can potentially cancel the gates - if _are_inverses(current_gate, next_gate): - # If the wires are the same, then we can safely remove both - if current_gate.wires == next_gate.wires: - list_copy.pop(next_gate_idx) - continue - # If wires are not equal, there are two things that can happen. - # 1. There is not full overlap in the wires; we cannot cancel - if len(Wires.shared_wires([current_gate.wires, next_gate.wires])) != len( - current_gate.wires - ): - operations.append(current_gate) - continue - - # 2. There is full overlap, but the wires are in a different order. - # If the wires are in a different order, gates that are "symmetric" - # over all wires (e.g., CZ), can be cancelled. - if current_gate in symmetric_over_all_wires: - list_copy.pop(next_gate_idx) - continue - # For other gates, as long as the control wires are the same, we can still - # cancel (e.g., the Toffoli gate). - if current_gate in symmetric_over_control_wires: - # TODO[David Wierichs]: This assumes single-qubit targets of controlled gates - if ( - len(Wires.shared_wires([current_gate.wires[:-1], next_gate.wires[:-1]])) - == len(current_gate.wires) - 1 - ): - list_copy.pop(next_gate_idx) - continue + # If operators are inverses, cancel + if _can_cancel_ops(current_gate, next_gate): + list_copy.pop(next_gate_idx) + continue # Apply gate any cases where # - there is no wire symmetry # - the control wire symmetry does not apply because the control wires are not the same # - neither of the flags are_self_inverses and are_inverses are true operations.append(current_gate) - continue new_tape = tape.copy(operations=operations) diff --git a/tests/capture/transforms/test_capture_cancel_inverses.py b/tests/capture/transforms/test_capture_cancel_inverses.py index ccc761b86f1..00e8b78da5e 100644 --- a/tests/capture/transforms/test_capture_cancel_inverses.py +++ b/tests/capture/transforms/test_capture_cancel_inverses.py @@ -90,15 +90,21 @@ def f(): jaxpr = jax.make_jaxpr(f)() assert len(jaxpr.eqns) == 0 - def test_cancel_inverses_symmetric_wires(self): + @pytest.mark.parametrize("adjoint_first", [True, False]) + def test_cancel_inverses_symmetric_wires(self, adjoint_first): """Test that operations that are inverses regardless of wire order are cancelled.""" @CancelInversesInterpreter() - def f(): - qml.CCZ([0, 1, 2]) - qml.CCZ([2, 0, 1]) - - jaxpr = jax.make_jaxpr(f)() + def f(x): + if adjoint_first: + qml.adjoint(qml.MultiRZ(x, [2, 0, 1])) + qml.MultiRZ(x, [0, 1, 2]) + else: + qml.MultiRZ(x, [2, 0, 1]) + qml.adjoint(qml.MultiRZ(x, [0, 1, 2])) + + args = (1.5,) + jaxpr = jax.make_jaxpr(f)(*args) assert len(jaxpr.eqns) == 0 def test_cancel_inverses_symmetric_control_wires(self): diff --git a/tests/transforms/test_optimization/test_cancel_inverses.py b/tests/transforms/test_optimization/test_cancel_inverses.py index d571a2a1f43..2d874583f5c 100644 --- a/tests/transforms/test_optimization/test_cancel_inverses.py +++ b/tests/transforms/test_optimization/test_cancel_inverses.py @@ -194,6 +194,26 @@ def qfunc(): assert len(ops) == 0 + @pytest.mark.parametrize("adjoint_first", [True, False]) + def test_symmetric_over_all_wires(self, adjoint_first): + """Test that adjacent adjoint ops are cancelled due to wire symmetry.""" + + def qfunc(x): + if adjoint_first: + qml.adjoint(qml.MultiRZ(x, [2, 0, 1])) + qml.MultiRZ(x, [0, 1, 2]) + else: + qml.MultiRZ(x, [2, 0, 1]) + qml.adjoint(qml.MultiRZ(x, [0, 1, 2])) + + transformed_qfunc = cancel_inverses(qfunc) + + ops = qml.tape.make_qscript(transformed_qfunc)(1.5).operations + + names_expected = [] + wires_expected = [] + compare_operation_lists(ops, names_expected, wires_expected) + # Example QNode and device for interface testing dev = qml.device("default.qubit", wires=3)