From 0f05af5d8fcf93732baf6865281e51d865acce7a Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Thu, 2 Jan 2025 17:10:05 -0500 Subject: [PATCH 1/5] Optimize cancel_inverses --- doc/releases/changelog-dev.md | 2 + .../optimization/cancel_inverses.py | 96 +++++++++---------- .../test_capture_cancel_inverses.py | 18 ++-- .../test_optimization/test_cancel_inverses.py | 20 ++++ 4 files changed, 79 insertions(+), 57 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 3680c3dd6ea..6570bc74d36 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -340,6 +340,8 @@ such as `shots`, `rng` and `prng_key`.

Other Improvements

+* `qml.transforms.cancel_inverses` is now better at handling cancellation of `Adjoint` operators. + * `qml.math.grad` and `qml.math.jacobian` added to differentiate a function with inputs of any interface in a jax-like manner. [(#6741)](https://github.com/PennyLaneAI/pennylane/pull/6741) diff --git a/pennylane/transforms/optimization/cancel_inverses.py b/pennylane/transforms/optimization/cancel_inverses.py index 418e68941a6..083bc48593e 100644 --- a/pennylane/transforms/optimization/cancel_inverses.py +++ b/pennylane/transforms/optimization/cancel_inverses.py @@ -30,12 +30,11 @@ def _ops_equal(op1, op2): - """Checks if two operators are equal up to class, data, hyperparameters, and wires""" + """Checks if two operators are equal up to class, data, and hyperparameters""" return ( op1.__class__ is op2.__class__ and (op1.data == op2.data) and (op1.hyperparameters == op2.hyperparameters) - and (op1.wires == op2.wires) ) @@ -64,6 +63,45 @@ def _are_inverses(op1, op2): return False +def _can_cancel_ops(op1, op2): + """Checks if two operators can be cancelled + + Args: + op1 (~.Operator) + op2 (~.Operator) + + Returns: + Bool + """ + if _are_inverses(op1, op2): + # Make sure that if one of the ops is Adjoint, it is always op2 by swapping + # the ops if op1 is Adjoint + if isinstance(op1, Adjoint): + op1, op2 = op2, op1 + + # If the wires are the same, then we can safely cancel both + if op1.wires == op2.wires: + return True + # If wires are not equal, there are two things that can happen. + # 1. There is not full overlap in the wires; we cannot cancel + if len(Wires.shared_wires([op1.wires, op2.wires])) != len(op1.wires): + return False + + # 2. There is full overlap, but the wires are in a different order. + # If the wires are in a different order, gates that are "symmetric" + # over all wires (e.g., CZ), can be cancelled. + if op1 in symmetric_over_all_wires: + return True + # For other gates, as long as the control wires are the same, we can still + # cancel (e.g., the Toffoli gate). + if op1 in symmetric_over_control_wires: + # TODO[David Wierichs]: This assumes single-qubit targets of controlled gates + if len(Wires.shared_wires([op1.wires[:-1], op2.wires[:-1]])) == len(op1.wires) - 1: + return True + + return False + + @lru_cache def _get_plxpr_cancel_inverses(): # pylint: disable=missing-function-docstring,too-many-statements try: @@ -123,24 +161,7 @@ def interpret_operation(self, op: Operator): self.previous_ops[w] = op return [] - cancel = False - if _are_inverses(op, prev_op): - # Same wires, cancel - if op.wires == prev_op.wires: - cancel = True - # Full overlap over wires - elif len(Wires.shared_wires([op.wires, prev_op.wires])) == len(op.wires): - # symmetric op + full wire overlap; cancel - if op in symmetric_over_all_wires: - cancel = True - # symmetric over control wires, full overlap over control wires; cancel - elif op in symmetric_over_control_wires and ( - len(Wires.shared_wires([op.wires[:-1], prev_op.wires[:-1]])) - == len(op.wires) - 1 - ): - cancel = True - # No or partial overlap over wires; can't cancel - + cancel = _can_cancel_ops(op, prev_op) if cancel: for w in op.wires: self.previous_ops.pop(w) @@ -353,42 +374,15 @@ def qfunc(x, y, z): # Otherwise, get the next gate next_gate = list_copy[next_gate_idx] - # If either of the two flags is true, we can potentially cancel the gates - if _are_inverses(current_gate, next_gate): - # If the wires are the same, then we can safely remove both - if current_gate.wires == next_gate.wires: - list_copy.pop(next_gate_idx) - continue - # If wires are not equal, there are two things that can happen. - # 1. There is not full overlap in the wires; we cannot cancel - if len(Wires.shared_wires([current_gate.wires, next_gate.wires])) != len( - current_gate.wires - ): - operations.append(current_gate) - continue - - # 2. There is full overlap, but the wires are in a different order. - # If the wires are in a different order, gates that are "symmetric" - # over all wires (e.g., CZ), can be cancelled. - if current_gate in symmetric_over_all_wires: - list_copy.pop(next_gate_idx) - continue - # For other gates, as long as the control wires are the same, we can still - # cancel (e.g., the Toffoli gate). - if current_gate in symmetric_over_control_wires: - # TODO[David Wierichs]: This assumes single-qubit targets of controlled gates - if ( - len(Wires.shared_wires([current_gate.wires[:-1], next_gate.wires[:-1]])) - == len(current_gate.wires) - 1 - ): - list_copy.pop(next_gate_idx) - continue + # If operators are inverses, cancel + if _can_cancel_ops(current_gate, next_gate): + list_copy.pop(next_gate_idx) + continue # Apply gate any cases where # - there is no wire symmetry # - the control wire symmetry does not apply because the control wires are not the same # - neither of the flags are_self_inverses and are_inverses are true operations.append(current_gate) - continue new_tape = tape.copy(operations=operations) diff --git a/tests/capture/transforms/test_capture_cancel_inverses.py b/tests/capture/transforms/test_capture_cancel_inverses.py index ccc761b86f1..00e8b78da5e 100644 --- a/tests/capture/transforms/test_capture_cancel_inverses.py +++ b/tests/capture/transforms/test_capture_cancel_inverses.py @@ -90,15 +90,21 @@ def f(): jaxpr = jax.make_jaxpr(f)() assert len(jaxpr.eqns) == 0 - def test_cancel_inverses_symmetric_wires(self): + @pytest.mark.parametrize("adjoint_first", [True, False]) + def test_cancel_inverses_symmetric_wires(self, adjoint_first): """Test that operations that are inverses regardless of wire order are cancelled.""" @CancelInversesInterpreter() - def f(): - qml.CCZ([0, 1, 2]) - qml.CCZ([2, 0, 1]) - - jaxpr = jax.make_jaxpr(f)() + def f(x): + if adjoint_first: + qml.adjoint(qml.MultiRZ(x, [2, 0, 1])) + qml.MultiRZ(x, [0, 1, 2]) + else: + qml.MultiRZ(x, [2, 0, 1]) + qml.adjoint(qml.MultiRZ(x, [0, 1, 2])) + + args = (1.5,) + jaxpr = jax.make_jaxpr(f)(*args) assert len(jaxpr.eqns) == 0 def test_cancel_inverses_symmetric_control_wires(self): diff --git a/tests/transforms/test_optimization/test_cancel_inverses.py b/tests/transforms/test_optimization/test_cancel_inverses.py index d571a2a1f43..2d874583f5c 100644 --- a/tests/transforms/test_optimization/test_cancel_inverses.py +++ b/tests/transforms/test_optimization/test_cancel_inverses.py @@ -194,6 +194,26 @@ def qfunc(): assert len(ops) == 0 + @pytest.mark.parametrize("adjoint_first", [True, False]) + def test_symmetric_over_all_wires(self, adjoint_first): + """Test that adjacent adjoint ops are cancelled due to wire symmetry.""" + + def qfunc(x): + if adjoint_first: + qml.adjoint(qml.MultiRZ(x, [2, 0, 1])) + qml.MultiRZ(x, [0, 1, 2]) + else: + qml.MultiRZ(x, [2, 0, 1]) + qml.adjoint(qml.MultiRZ(x, [0, 1, 2])) + + transformed_qfunc = cancel_inverses(qfunc) + + ops = qml.tape.make_qscript(transformed_qfunc)(1.5).operations + + names_expected = [] + wires_expected = [] + compare_operation_lists(ops, names_expected, wires_expected) + # Example QNode and device for interface testing dev = qml.device("default.qubit", wires=3) From b37a47808ee3aa592087e268127d7fdd16db22c4 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Fri, 3 Jan 2025 10:06:36 -0500 Subject: [PATCH 2/5] Update doc/releases/changelog-dev.md --- doc/releases/changelog-dev.md | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 6570bc74d36..a10f18dd985 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -341,6 +341,7 @@ such as `shots`, `rng` and `prng_key`.

Other Improvements

* `qml.transforms.cancel_inverses` is now better at handling cancellation of `Adjoint` operators. + [(#6752)](https://github.com/PennyLaneAI/pennylane/pull/6752) * `qml.math.grad` and `qml.math.jacobian` added to differentiate a function with inputs of any interface in a jax-like manner. From 1155ac40298014c08ce8c0837c8f738bc474ce0f Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Fri, 3 Jan 2025 11:03:00 -0500 Subject: [PATCH 3/5] Remove tiny helper functions --- .../optimization/cancel_inverses.py | 52 +++++-------------- 1 file changed, 13 insertions(+), 39 deletions(-) diff --git a/pennylane/transforms/optimization/cancel_inverses.py b/pennylane/transforms/optimization/cancel_inverses.py index 083bc48593e..df9904057ef 100644 --- a/pennylane/transforms/optimization/cancel_inverses.py +++ b/pennylane/transforms/optimization/cancel_inverses.py @@ -29,40 +29,6 @@ from .optimization_utils import find_next_gate -def _ops_equal(op1, op2): - """Checks if two operators are equal up to class, data, and hyperparameters""" - return ( - op1.__class__ is op2.__class__ - and (op1.data == op2.data) - and (op1.hyperparameters == op2.hyperparameters) - ) - - -def _are_inverses(op1, op2): - """Checks if two operators are inverses of each other - - Args: - op1 (~.Operator) - op2 (~.Operator) - - Returns: - Bool - """ - # op1 is self-inverse and the next gate is also op1 - if op1 in self_inverses and op1.name == op2.name: - return True - - # op1 is an `Adjoint` class and its base is equal to op2 - if isinstance(op1, Adjoint) and _ops_equal(op1.base, op2): - return True - - # op2 is an `Adjoint` class and its base is equal to op1 - if isinstance(op2, Adjoint) and _ops_equal(op2.base, op1): - return True - - return False - - def _can_cancel_ops(op1, op2): """Checks if two operators can be cancelled @@ -73,11 +39,19 @@ def _can_cancel_ops(op1, op2): Returns: Bool """ - if _are_inverses(op1, op2): - # Make sure that if one of the ops is Adjoint, it is always op2 by swapping - # the ops if op1 is Adjoint - if isinstance(op1, Adjoint): - op1, op2 = op2, op1 + # Make sure that if one of the ops is Adjoint, it is always op2 by swapping + # the ops if op1 is Adjoint + if isinstance(op1, Adjoint): + op1, op2 = op2, op1 + + are_self_inverses_without_wires = op1 in self_inverses and op1.name == op2.name + are_inverses_without_wires = ( + isinstance(op2, Adjoint) + and op1.__class__ == op2.base.__class__ + and op1.data == op2.base.data + and op1.hyperparameters == op2.base.hyperparameters + ) + if are_self_inverses_without_wires or are_inverses_without_wires: # If the wires are the same, then we can safely cancel both if op1.wires == op2.wires: From 2d91dddea27bdbaf7a6e7fb51fbf54d507eb0662 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Fri, 3 Jan 2025 13:51:56 -0500 Subject: [PATCH 4/5] Update doc/releases/changelog-dev.md Co-authored-by: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> --- doc/releases/changelog-dev.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index a10f18dd985..010f57c77fe 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -340,7 +340,7 @@ such as `shots`, `rng` and `prng_key`.

Other Improvements

-* `qml.transforms.cancel_inverses` is now better at handling cancellation of `Adjoint` operators. +* `qml.transforms.cancel_inverses` now has improved handling of `Adjoint` operator cancellation. [(#6752)](https://github.com/PennyLaneAI/pennylane/pull/6752) * `qml.math.grad` and `qml.math.jacobian` added to differentiate a function with inputs of any From 1b9d43a61dc439461a12f65ba37b70a7d05c2ddf Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Fri, 3 Jan 2025 13:54:23 -0500 Subject: [PATCH 5/5] Update docstring per review suggestion --- pennylane/transforms/optimization/cancel_inverses.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pennylane/transforms/optimization/cancel_inverses.py b/pennylane/transforms/optimization/cancel_inverses.py index df9904057ef..fb8bb458557 100644 --- a/pennylane/transforms/optimization/cancel_inverses.py +++ b/pennylane/transforms/optimization/cancel_inverses.py @@ -270,7 +270,7 @@ def cancel_inverses(tape: QuantumScript) -> tuple[QuantumScriptBatch, Postproces .. code-block:: python - @cancel_inverses + @qml.transforms.cancel_inverses @qml.qnode(device=dev) def circuit(x, y, z): qml.Hadamard(wires=0) @@ -286,7 +286,7 @@ def circuit(x, y, z): return qml.expval(qml.Z(0)) >>> circuit(0.1, 0.2, 0.3) - 0.999999999999999 + 1.0 .. details:: :title: Usage Details @@ -321,7 +321,7 @@ def qfunc(x, y, z): second qubit that should cancel. We can obtain a simplified circuit by running the ``cancel_inverses`` transform: - >>> optimized_qfunc = cancel_inverses(qfunc) + >>> optimized_qfunc = qml.transforms.cancel_inverses(qfunc) >>> optimized_qnode = qml.QNode(optimized_qfunc, dev) >>> print(qml.draw(optimized_qnode)(1, 2, 3)) 0: ──RZ(3.00)───────────╭●─┤