From 7e3e3500964d71bb269ffb90c9f7eb410a6edaee Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 9 Dec 2024 17:40:57 -0500 Subject: [PATCH 01/10] Add CancelInversesInterpreter --- .../optimization/cancel_inverses.py | 75 +++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/pennylane/transforms/optimization/cancel_inverses.py b/pennylane/transforms/optimization/cancel_inverses.py index 7a6e35e7c8e..6443682d75e 100644 --- a/pennylane/transforms/optimization/cancel_inverses.py +++ b/pennylane/transforms/optimization/cancel_inverses.py @@ -14,6 +14,7 @@ """Transform for cancelling adjacent inverse gates in quantum circuits.""" # pylint: disable=too-many-branches +from pennylane.capture import PlxprInterpreter from pennylane.ops.op_math import Adjoint from pennylane.ops.qubit.attributes import ( self_inverses, @@ -63,6 +64,80 @@ def _are_inverses(op1, op2): return False +class CancelInversesInterpreter(PlxprInterpreter): + """Plxpr Interpreter for applying the ``cancel_inverses`` transform to callables or jaxpr + when program capture is enabled. + """ + + def __init__(self): + super().__init__() + self.previous_ops = {} + + def cleanup(self) -> None: + """Perform any final steps after iterating through all equations.""" + self.previous_ops = {} + + def interpret_operation(self, op): + """Interpret a PennyLane operation instance. + + Args: + op (Operator): a pennylane operator instance + + Returns: + Any + + This method is only called when the operator's output is a dropped variable, + so the output will not affect later equations in the circuit. + + See also: :meth:`~.interpret_operation_eqn`. + + """ + if len(op.wires) == 0: + return super().interpret_operation(op) + + prev_op = self.previous_ops.get(op.wires[0], None) + if prev_op is None: + for w in op.wires: + self.previous_ops[w] = op + return [] + + cancel = False + if _are_inverses(op, prev_op): + # Same wires, cancel + if op.wires == prev_op.wires: + cancel = True + # Full overlap over wires + elif len(Wires.shared_wires([op.wires, prev_op.wires])) == len(op.wires): + # symmetric op + full wire overlap; cancel + if op in symmetric_over_all_wires: + cancel = True + # symmetric over control wires, full overlap over control wires; cancel + elif op in symmetric_over_control_wires and ( + len(Wires.shared_wires([op.wires[:-1], prev_op.wires[:-1]])) + == len(op.wires) - 1 + ): + cancel = True + # No or partial overlap over wires; can't cancel + + if cancel: + for w in op.wires: + self.previous_ops.pop(w) + return [] + + previous_ops_on_wires = set(self.previous_ops.get(w) for w in op.wires) + for o in previous_ops_on_wires: + if o is not None: + for w in o.wires: + self.previous_ops.pop(w) + for w in op.wires: + self.previous_ops[w] = op + + res = [] + for o in previous_ops_on_wires: + res.append(super().interpret_operation(o)) + return res + + @transform def cancel_inverses(tape: QuantumScript) -> tuple[QuantumScriptBatch, PostprocessingFn]: """Quantum function transform to remove any operations that are applied next to their From 888c80ee8b9b47dd624db89f9eab51b1cccbb26d Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Wed, 11 Dec 2024 13:49:40 -0500 Subject: [PATCH 02/10] Move interpreter to capture module --- doc/releases/changelog-dev.md | 5 + pennylane/capture/transforms/__init__.py | 19 ++++ .../transforms/capture_cancel_inverses.py | 93 +++++++++++++++++++ .../optimization/cancel_inverses.py | 75 --------------- 4 files changed, 117 insertions(+), 75 deletions(-) create mode 100644 pennylane/capture/transforms/__init__.py create mode 100644 pennylane/capture/transforms/capture_cancel_inverses.py diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 2899afeee90..2ae8c82bb8c 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -215,6 +215,11 @@ such as `shots`, `rng` and `prng_key`.

Capturing and representing hybrid programs

+* Functions and plxpr can now be natively transformed using the new `qml.capture.transforms.CancelInterpreter` + when program capture is enabled. This class cancels operators appearing consecutively that are adjoints of each + other, and follows the same API as `qml.transforms.cancel_inverses`. + [(#6692)](https://github.com/PennyLaneAI/pennylane/pull/6692) + * Execution with capture enabled now follows a new execution pipeline and natively passes the captured jaxpr to the device. Since it no longer falls back to the old pipeline, execution only works with a reduced feature set. diff --git a/pennylane/capture/transforms/__init__.py b/pennylane/capture/transforms/__init__.py new file mode 100644 index 00000000000..6e5093a36a7 --- /dev/null +++ b/pennylane/capture/transforms/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Plxpr compatible transforms +""" +from .capture_cancel_inverses import CancelInversesInterpreter + +__all__ = ("CancelInversesInterpreter",) diff --git a/pennylane/capture/transforms/capture_cancel_inverses.py b/pennylane/capture/transforms/capture_cancel_inverses.py new file mode 100644 index 00000000000..a4393e2dbde --- /dev/null +++ b/pennylane/capture/transforms/capture_cancel_inverses.py @@ -0,0 +1,93 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Transform for cancelling adjacent inverse gates in quantum circuits.""" +# pylint: disable=protected-access +import pennylane as qml +from pennylane.ops.qubit.attributes import symmetric_over_all_wires, symmetric_over_control_wires +from pennylane.transforms.optimization.cancel_inverses import _are_inverses +from pennylane.wires import Wires + + +class CancelInversesInterpreter(qml.capture.PlxprInterpreter): + """Plxpr Interpreter for applying the ``cancel_inverses`` transform to callables or jaxpr + when program capture is enabled. + """ + + def __init__(self): + super().__init__() + self.previous_ops = {} + + def cleanup(self) -> None: + """Perform any final steps after iterating through all equations.""" + self.previous_ops = {} + + def interpret_operation(self, op): + """Interpret a PennyLane operation instance. + + Args: + op (Operator): a pennylane operator instance + + Returns: + Any + + This method is only called when the operator's output is a dropped variable, + so the output will not affect later equations in the circuit. + + See also: :meth:`~.interpret_operation_eqn`. + + """ + if len(op.wires) == 0: + return super().interpret_operation(op) + + prev_op = self.previous_ops.get(op.wires[0], None) + if prev_op is None: + for w in op.wires: + self.previous_ops[w] = op + return [] + + cancel = False + if _are_inverses(op, prev_op): + # Same wires, cancel + if op.wires == prev_op.wires: + cancel = True + # Full overlap over wires + elif len(Wires.shared_wires([op.wires, prev_op.wires])) == len(op.wires): + # symmetric op + full wire overlap; cancel + if op in symmetric_over_all_wires: + cancel = True + # symmetric over control wires, full overlap over control wires; cancel + elif op in symmetric_over_control_wires and ( + len(Wires.shared_wires([op.wires[:-1], prev_op.wires[:-1]])) + == len(op.wires) - 1 + ): + cancel = True + # No or partial overlap over wires; can't cancel + + if cancel: + for w in op.wires: + self.previous_ops.pop(w) + return [] + + previous_ops_on_wires = set(self.previous_ops.get(w) for w in op.wires) + for o in previous_ops_on_wires: + if o is not None: + for w in o.wires: + self.previous_ops.pop(w) + for w in op.wires: + self.previous_ops[w] = op + + res = [] + for o in previous_ops_on_wires: + res.append(super().interpret_operation(o)) + return res diff --git a/pennylane/transforms/optimization/cancel_inverses.py b/pennylane/transforms/optimization/cancel_inverses.py index 6443682d75e..7a6e35e7c8e 100644 --- a/pennylane/transforms/optimization/cancel_inverses.py +++ b/pennylane/transforms/optimization/cancel_inverses.py @@ -14,7 +14,6 @@ """Transform for cancelling adjacent inverse gates in quantum circuits.""" # pylint: disable=too-many-branches -from pennylane.capture import PlxprInterpreter from pennylane.ops.op_math import Adjoint from pennylane.ops.qubit.attributes import ( self_inverses, @@ -64,80 +63,6 @@ def _are_inverses(op1, op2): return False -class CancelInversesInterpreter(PlxprInterpreter): - """Plxpr Interpreter for applying the ``cancel_inverses`` transform to callables or jaxpr - when program capture is enabled. - """ - - def __init__(self): - super().__init__() - self.previous_ops = {} - - def cleanup(self) -> None: - """Perform any final steps after iterating through all equations.""" - self.previous_ops = {} - - def interpret_operation(self, op): - """Interpret a PennyLane operation instance. - - Args: - op (Operator): a pennylane operator instance - - Returns: - Any - - This method is only called when the operator's output is a dropped variable, - so the output will not affect later equations in the circuit. - - See also: :meth:`~.interpret_operation_eqn`. - - """ - if len(op.wires) == 0: - return super().interpret_operation(op) - - prev_op = self.previous_ops.get(op.wires[0], None) - if prev_op is None: - for w in op.wires: - self.previous_ops[w] = op - return [] - - cancel = False - if _are_inverses(op, prev_op): - # Same wires, cancel - if op.wires == prev_op.wires: - cancel = True - # Full overlap over wires - elif len(Wires.shared_wires([op.wires, prev_op.wires])) == len(op.wires): - # symmetric op + full wire overlap; cancel - if op in symmetric_over_all_wires: - cancel = True - # symmetric over control wires, full overlap over control wires; cancel - elif op in symmetric_over_control_wires and ( - len(Wires.shared_wires([op.wires[:-1], prev_op.wires[:-1]])) - == len(op.wires) - 1 - ): - cancel = True - # No or partial overlap over wires; can't cancel - - if cancel: - for w in op.wires: - self.previous_ops.pop(w) - return [] - - previous_ops_on_wires = set(self.previous_ops.get(w) for w in op.wires) - for o in previous_ops_on_wires: - if o is not None: - for w in o.wires: - self.previous_ops.pop(w) - for w in op.wires: - self.previous_ops[w] = op - - res = [] - for o in previous_ops_on_wires: - res.append(super().interpret_operation(o)) - return res - - @transform def cancel_inverses(tape: QuantumScript) -> tuple[QuantumScriptBatch, PostprocessingFn]: """Quantum function transform to remove any operations that are applied next to their From 16d823c1f88a7e42d1305824203f92de5c5b263c Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Thu, 12 Dec 2024 16:22:53 -0500 Subject: [PATCH 03/10] Add tests --- .../transforms/capture_cancel_inverses.py | 71 ++++++++ .../test_capture_cancel_inverses.py | 154 ++++++++++++++++++ 2 files changed, 225 insertions(+) create mode 100644 tests/capture/transforms/test_capture_cancel_inverses.py diff --git a/pennylane/capture/transforms/capture_cancel_inverses.py b/pennylane/capture/transforms/capture_cancel_inverses.py index a4393e2dbde..ab7e794ac8f 100644 --- a/pennylane/capture/transforms/capture_cancel_inverses.py +++ b/pennylane/capture/transforms/capture_cancel_inverses.py @@ -22,6 +22,11 @@ class CancelInversesInterpreter(qml.capture.PlxprInterpreter): """Plxpr Interpreter for applying the ``cancel_inverses`` transform to callables or jaxpr when program capture is enabled. + + .. note:: + + In the process of transforming plxpr, this interpreter may reorder operations that do + not share any wires. This will not impact the correctness of the circuit. """ def __init__(self): @@ -35,6 +40,9 @@ def cleanup(self) -> None: def interpret_operation(self, op): """Interpret a PennyLane operation instance. + This method cancels operations that are the adjoint of the previous + operation on the same wires, and otherwise, applies it. + Args: op (Operator): a pennylane operator instance @@ -47,6 +55,7 @@ def interpret_operation(self, op): See also: :meth:`~.interpret_operation_eqn`. """ + # pylint: disable=too-many-branches if len(op.wires) == 0: return super().interpret_operation(op) @@ -79,6 +88,10 @@ def interpret_operation(self, op): self.previous_ops.pop(w) return [] + # Putting the operations in a set to avoid applying the same op multiple times + # Using a set causes order to no longer be guaranteed, so the new order of the + # operations might differ from the original order. However, this only impacts + # operators without any shared wires, so correctness will not be impacted. previous_ops_on_wires = set(self.previous_ops.get(w) for w in op.wires) for o in previous_ops_on_wires: if o is not None: @@ -91,3 +104,61 @@ def interpret_operation(self, op): for o in previous_ops_on_wires: res.append(super().interpret_operation(o)) return res + + def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list: + """Evaluate a jaxpr. + + Args: + jaxpr (jax.core.Jaxpr): the jaxpr to evaluate + consts (list[TensorLike]): the constant variables for the jaxpr + *args (tuple[TensorLike]): The arguments for the jaxpr. + + Returns: + list[TensorLike]: the results of the execution. + + """ + self._env = {} + self.setup() + + for arg, invar in zip(args, jaxpr.invars, strict=True): + self._env[invar] = arg + for const, constvar in zip(consts, jaxpr.constvars, strict=True): + self._env[constvar] = const + + for eqn in jaxpr.eqns: + + custom_handler = self._primitive_registrations.get(eqn.primitive, None) + if custom_handler: + invals = [self.read(invar) for invar in eqn.invars] + outvals = custom_handler(self, *invals, **eqn.params) + elif isinstance(eqn.outvars[0].aval, qml.capture.AbstractOperator): + outvals = self.interpret_operation_eqn(eqn) + elif isinstance(eqn.outvars[0].aval, qml.capture.AbstractMeasurement): + outvals = self.interpret_measurement_eqn(eqn) + else: + invals = [self.read(invar) for invar in eqn.invars] + outvals = eqn.primitive.bind(*invals, **eqn.params) + + if not eqn.primitive.multiple_results: + outvals = [outvals] + for outvar, outval in zip(eqn.outvars, outvals, strict=True): + self._env[outvar] = outval + + # The following is needed because any operations inside self.previous_ops have not yet + # been applied. At this point, we **know** that any operations that should be cancelled + # have been cancelled, and operations left inside self.previous_ops should be applied + ops_remaining = set(self.previous_ops.values()) + for op in ops_remaining: + super().interpret_operation(op) + + # Read the final result of the Jaxpr from the environment + outvals = [] + for var in jaxpr.outvars: + outval = self.read(var) + if isinstance(outval, qml.operation.Operator): + outvals.append(super().interpret_operation(outval)) + else: + outvals.append(outval) + self.cleanup() + self._env = {} + return outvals diff --git a/tests/capture/transforms/test_capture_cancel_inverses.py b/tests/capture/transforms/test_capture_cancel_inverses.py new file mode 100644 index 00000000000..9921618296c --- /dev/null +++ b/tests/capture/transforms/test_capture_cancel_inverses.py @@ -0,0 +1,154 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the ``DecomposeInterpreter`` class""" +from functools import partial + +# pylint:disable=wrong-import-position +import pytest + +import pennylane as qml + +jax = pytest.importorskip("jax") + +from pennylane.capture.primitives import ( + cond_prim, + for_loop_prim, + grad_prim, + jacobian_prim, + qnode_prim, + while_loop_prim, +) +from pennylane.capture.transforms import CancelInversesInterpreter + +pytestmark = [pytest.mark.jax, pytest.mark.usefixtures("enable_disable_plxpr")] + + +@pytest.mark.parametrize("lazy_adjoint", [True, False]) +class TestCancelInversesInterpreter: + """Unit tests for the CancelInversesInterpreter for canceling adjacent inverse + operations in plxpr.""" + + def test_cancel_inverses_simple(self, lazy_adjoint): + """Test that inverse ops in a simple circuit are cancelled.""" + + @CancelInversesInterpreter() + def f(): + qml.X(0) + qml.X(0) + qml.S(1) + qml.adjoint(qml.S(1), lazy=lazy_adjoint) + qml.adjoint(qml.T(2), lazy=lazy_adjoint) + qml.Hadamard(1) # Applied + qml.T(2) + qml.Z(0) # Applied + qml.IsingXX(1.5, [2, 3]) # Applied + qml.IsingXX(2.5, [0, 1]) # Applied + qml.SWAP([2, 0]) + qml.SWAP([0, 2]) + qml.CNOT([2, 0]) # Applied + qml.Z(1) # Applied + + jaxpr = jax.make_jaxpr(f)() + assert len(jaxpr.eqns) == 6 + + # Each of the pairs of primitives being compared below have disjoint wires, so their + # order is not relevant to their correctness + expected_primitives_first_second = {qml.Hadamard._primitive, qml.PauliZ._primitive} + actual_primitives_first_second = {jaxpr.eqns[0].primitive, jaxpr.eqns[1].primitive} + assert actual_primitives_first_second == expected_primitives_first_second + + expected_primitives_third_fourth = {qml.IsingXX._primitive, qml.IsingXX._primitive} + actual_primitives_third_fourth = {jaxpr.eqns[2].primitive, jaxpr.eqns[3].primitive} + assert actual_primitives_third_fourth == expected_primitives_third_fourth + + expected_primitives_fifth_sixth = {qml.CNOT._primitive, qml.PauliZ._primitive} + actual_primitives_fifth_sixth = {jaxpr.eqns[4].primitive, jaxpr.eqns[5].primitive} + assert actual_primitives_fifth_sixth == expected_primitives_fifth_sixth + + def test_cancel_inverses_true_inverses(self, lazy_adjoint): + """Test that operations that are inverses with the same wires are cancelled.""" + + @CancelInversesInterpreter() + def f(): + qml.CRX(1.5, [0, 1]) + qml.adjoint(qml.CRX(1.5, [0, 1]), lazy=lazy_adjoint) + + jaxpr = jax.make_jaxpr(f)() + assert len(jaxpr.eqns) == 0 + + def test_cancel_inverses_symmetric_wires(self, lazy_adjoint): # pylint: disable=unused-argument + """Test that operations that are inverses regardless of wire order are cancelled.""" + + @CancelInversesInterpreter() + def f(): + qml.CCZ([0, 1, 2]) + qml.CCZ([2, 0, 1]) + + jaxpr = jax.make_jaxpr(f)() + assert len(jaxpr.eqns) == 0 + + def test_cancel_inverses_symmetric_control_wires( + self, lazy_adjoint + ): # pylint: disable=unused-argument + """Test that operations that are inverses regardless of control_wire order are cancelled.""" + + @CancelInversesInterpreter() + def f(): + qml.Toffoli([0, 1, 2]) + qml.Toffoli([1, 0, 2]) + + jaxpr = jax.make_jaxpr(f)() + assert len(jaxpr.eqns) == 0 + + def test_cancel_inverese_nested_ops_on_same_wires(self, lazy_adjoint): + """Test that only the innermost adjacent adjoint ops are cancelled when multiple + cancellable operators are present.""" + + @CancelInversesInterpreter() + def f(): + qml.S(0) + qml.adjoint(qml.T(0), lazy=lazy_adjoint) + qml.T(0) + qml.adjoint(qml.S(0), lazy=lazy_adjoint) + + jaxpr = jax.make_jaxpr(f)() + assert len(jaxpr.eqns) == 3 + assert jaxpr.eqns[0].primitive == qml.S._primitive + assert jaxpr.eqns[1].primitive == qml.S._primitive + assert jaxpr.eqns[2].primitive == qml.ops.Adjoint._primitive + + def test_returned_op_is_not_cancelled(self, lazy_adjoint): + """Test that ops that are returned by the function being transformed are not cancelled.""" + + def test_ctrl_higher_order_primitive(self, lazy_adjoint): + """Test that ctrl higher order primitives are transformed correctly.""" + + def test_adjoint_higher_order_primitive(self, lazy_adjoint): + """Test that adjoint higher order primitives are transformed correctly.""" + + def test_cond_higher_order_primitive(self, lazy_adjoint): + """Test that cond higher order primitives are transformed correctly.""" + + def test_for_loop_higher_order_primitive(self, lazy_adjoint): + """Test that for_loop higher order primitives are transformed correctly.""" + + def test_while_loop_higher_order_primitive(self, lazy_adjoint): + """Test that while_loop higher order primitives are transformed correctly.""" + + def test_qnode_higher_order_primitive(self, lazy_adjoint): + """Test that qnode higher order primitives are transformed correctly.""" + + @pytest.mark.parametrize("grad_fn", [qml.grad, qml.jacobian]) + def test_grad_and_jac_higher_order_primitives(self, grad_fn, lazy_adjoint): + """Test that grad and jacobian higher order primitives are transformed correctly.""" From ffb6a6cd54f0a52981429e0fd20340360191bca0 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Thu, 12 Dec 2024 17:31:21 -0500 Subject: [PATCH 04/10] Add tests --- .../transforms/capture_cancel_inverses.py | 30 +++++-- .../test_capture_cancel_inverses.py | 88 +++++++++++++++++++ 2 files changed, 112 insertions(+), 6 deletions(-) diff --git a/pennylane/capture/transforms/capture_cancel_inverses.py b/pennylane/capture/transforms/capture_cancel_inverses.py index ab7e794ac8f..90766317ddc 100644 --- a/pennylane/capture/transforms/capture_cancel_inverses.py +++ b/pennylane/capture/transforms/capture_cancel_inverses.py @@ -33,11 +33,11 @@ def __init__(self): super().__init__() self.previous_ops = {} - def cleanup(self) -> None: - """Perform any final steps after iterating through all equations.""" + def setup(self) -> None: + """Initialize the instance before interpreting equations.""" self.previous_ops = {} - def interpret_operation(self, op): + def interpret_operation(self, op: qml.operation.Operator): """Interpret a PennyLane operation instance. This method cancels operations that are the adjoint of the previous @@ -105,6 +105,18 @@ def interpret_operation(self, op): res.append(super().interpret_operation(o)) return res + def interpret_all_previous_ops(self) -> None: + """Interpret all ops in ``previous_ops``. This is done whenever any + operators that haven't been interpreted that are saved to be cancelled + no longer need to be saved.""" + ops_remaining = set(self.previous_ops.values()) + for op in ops_remaining: + super().interpret_operation(op) + + all_wires = tuple(self.previous_ops.keys()) + for w in all_wires: + self.previous_ops.pop(w) + def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list: """Evaluate a jaxpr. @@ -129,13 +141,21 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list: custom_handler = self._primitive_registrations.get(eqn.primitive, None) if custom_handler: + # Interpret any stored ops so that they are applied before the custom + # primitive is handled + self.interpret_all_previous_ops() invals = [self.read(invar) for invar in eqn.invars] outvals = custom_handler(self, *invals, **eqn.params) elif isinstance(eqn.outvars[0].aval, qml.capture.AbstractOperator): outvals = self.interpret_operation_eqn(eqn) elif isinstance(eqn.outvars[0].aval, qml.capture.AbstractMeasurement): + self.interpret_all_previous_ops() outvals = self.interpret_measurement_eqn(eqn) else: + # Transform primitives don't have custom handlers, so we check for them here + # to purge the stored ops in self.previous_ops + if eqn.primitive.name.endswith("_transform"): + self.interpret_all_previous_ops invals = [self.read(invar) for invar in eqn.invars] outvals = eqn.primitive.bind(*invals, **eqn.params) @@ -147,9 +167,7 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list: # The following is needed because any operations inside self.previous_ops have not yet # been applied. At this point, we **know** that any operations that should be cancelled # have been cancelled, and operations left inside self.previous_ops should be applied - ops_remaining = set(self.previous_ops.values()) - for op in ops_remaining: - super().interpret_operation(op) + self.interpret_all_previous_ops() # Read the final result of the Jaxpr from the environment outvals = [] diff --git a/tests/capture/transforms/test_capture_cancel_inverses.py b/tests/capture/transforms/test_capture_cancel_inverses.py index 9921618296c..4826a197f4f 100644 --- a/tests/capture/transforms/test_capture_cancel_inverses.py +++ b/tests/capture/transforms/test_capture_cancel_inverses.py @@ -22,7 +22,9 @@ jax = pytest.importorskip("jax") from pennylane.capture.primitives import ( + adjoint_transform_prim, cond_prim, + ctrl_transform_prim, for_loop_prim, grad_prim, jacobian_prim, @@ -131,12 +133,70 @@ def f(): def test_returned_op_is_not_cancelled(self, lazy_adjoint): """Test that ops that are returned by the function being transformed are not cancelled.""" + @CancelInversesInterpreter() + def f(): + qml.PauliX(0) + return qml.PauliX(0) + + jaxpr = jax.make_jaxpr(f)() + assert len(jaxpr.eqns) == 2 + assert jaxpr.eqns[0].primitive == qml.PauliX._primitive + assert jaxpr.eqns[1].primitive == qml.PauliX._primitive + assert jaxpr.jaxpr.outvars[0] == jaxpr.eqns[1].outvars[0] + def test_ctrl_higher_order_primitive(self, lazy_adjoint): """Test that ctrl higher order primitives are transformed correctly.""" + def ctrl_fn(y): + qml.S(0) + qml.Hadamard(1) + qml.Hadamard(1) + qml.adjoint(qml.S(0), lazy=lazy_adjoint) + qml.RX(y, 0) + + @CancelInversesInterpreter() + def f(x): + qml.RX(x, 0) + qml.ctrl(ctrl_fn, [2, 3])(x) + qml.RY(x, 1) + + jaxpr = jax.make_jaxpr(f)(1.5) + assert len(jaxpr.eqns) == 3 + assert jaxpr.eqns[0].primitive == qml.RX._primitive + assert jaxpr.eqns[1].primitive == ctrl_transform_prim + assert jaxpr.eqns[2].primitive == qml.RY._primitive + + inner_jaxpr = jaxpr.eqns[1].params["jaxpr"] + assert len(inner_jaxpr.eqns) == 1 + assert inner_jaxpr.eqns[0].primitive == qml.RX._primitive + def test_adjoint_higher_order_primitive(self, lazy_adjoint): """Test that adjoint higher order primitives are transformed correctly.""" + def adjoint_fn(y): + qml.S(0) + qml.Hadamard(1) + qml.Hadamard(1) + qml.adjoint(qml.S(0), lazy=lazy_adjoint) + qml.RX(y, 0) + + @CancelInversesInterpreter() + def f(x): + qml.RX(x, 0) + qml.adjoint(adjoint_fn, lazy=lazy_adjoint)(x) + qml.RY(x, 1) + + jaxpr = jax.make_jaxpr(f)(1.5) + assert len(jaxpr.eqns) == 3 + assert jaxpr.eqns[0].primitive == qml.RX._primitive + assert jaxpr.eqns[1].primitive == adjoint_transform_prim + assert jaxpr.eqns[1].params["lazy"] == lazy_adjoint + assert jaxpr.eqns[2].primitive == qml.RY._primitive + + inner_jaxpr = jaxpr.eqns[1].params["jaxpr"] + assert len(inner_jaxpr.eqns) == 1 + assert inner_jaxpr.eqns[0].primitive == qml.RX._primitive + def test_cond_higher_order_primitive(self, lazy_adjoint): """Test that cond higher order primitives are transformed correctly.""" @@ -148,6 +208,34 @@ def test_while_loop_higher_order_primitive(self, lazy_adjoint): def test_qnode_higher_order_primitive(self, lazy_adjoint): """Test that qnode higher order primitives are transformed correctly.""" + dev = qml.device("default.qubit", wires=4) + + @qml.qnode(dev) + def circuit(y): + qml.S(0) + qml.Hadamard(1) + qml.Hadamard(1) + qml.adjoint(qml.S(0), lazy=lazy_adjoint) + qml.RX(y, 0) + return qml.expval(qml.PauliZ(0)) + + @CancelInversesInterpreter() + def f(x): + qml.RX(x, 0) + circuit(x) + qml.RY(x, 1) + + jaxpr = jax.make_jaxpr(f)(1.5) + assert len(jaxpr.eqns) == 3 + assert jaxpr.eqns[0].primitive == qml.RX._primitive + assert jaxpr.eqns[1].primitive == qnode_prim + assert jaxpr.eqns[2].primitive == qml.RY._primitive + + inner_jaxpr = jaxpr.eqns[1].params["qfunc_jaxpr"] + assert len(inner_jaxpr.eqns) == 3 + assert inner_jaxpr.eqns[0].primitive == qml.RX._primitive + assert inner_jaxpr.eqns[1].primitive == qml.PauliZ._primitive + assert inner_jaxpr.eqns[2].primitive == qml.measurements.ExpectationMP._obs_primitive @pytest.mark.parametrize("grad_fn", [qml.grad, qml.jacobian]) def test_grad_and_jac_higher_order_primitives(self, grad_fn, lazy_adjoint): From 34cbc97e10482b391b551bfb7eece1c827d486fa Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Thu, 12 Dec 2024 18:02:34 -0500 Subject: [PATCH 05/10] Finish adding tests --- .../test_capture_cancel_inverses.py | 143 +++++++++++++++++- 1 file changed, 140 insertions(+), 3 deletions(-) diff --git a/tests/capture/transforms/test_capture_cancel_inverses.py b/tests/capture/transforms/test_capture_cancel_inverses.py index 4826a197f4f..9f977077b2c 100644 --- a/tests/capture/transforms/test_capture_cancel_inverses.py +++ b/tests/capture/transforms/test_capture_cancel_inverses.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Unit tests for the ``DecomposeInterpreter`` class""" -from functools import partial -# pylint:disable=wrong-import-position +# pylint:disable=wrong-import-position,protected-access import pytest import pennylane as qml @@ -130,7 +129,7 @@ def f(): assert jaxpr.eqns[1].primitive == qml.S._primitive assert jaxpr.eqns[2].primitive == qml.ops.Adjoint._primitive - def test_returned_op_is_not_cancelled(self, lazy_adjoint): + def test_returned_op_is_not_cancelled(self, lazy_adjoint): # pylint: disable=unused-argument """Test that ops that are returned by the function being transformed are not cancelled.""" @CancelInversesInterpreter() @@ -200,12 +199,120 @@ def f(x): def test_cond_higher_order_primitive(self, lazy_adjoint): """Test that cond higher order primitives are transformed correctly.""" + @CancelInversesInterpreter() + def f(x): + qml.RX(x, 0) + + @qml.cond(x > 2) + def cond_fn(): + qml.Hadamard(0) + qml.Hadamard(0) + return qml.S(0) + + @cond_fn.else_if(x > 1) + def _(): + qml.S(0) + qml.adjoint(qml.S(0), lazy=lazy_adjoint) + return qml.T(0) + + @cond_fn.otherwise + def _(): + qml.adjoint(qml.T(0), lazy=lazy_adjoint) + qml.T(0) + return qml.Hadamard(0) + + return cond_fn() + + jaxpr = jax.make_jaxpr(f)(1.5) + # 2 primitives for true and elif branch conditions of the conditional + assert len(jaxpr.eqns) == 4 + assert jaxpr.eqns[2].primitive == qml.RX._primitive + assert jaxpr.eqns[3].primitive == cond_prim + + # true branch + branch = jaxpr.eqns[3].params["jaxpr_branches"][0] + assert len(branch.eqns) == 1 + assert branch.eqns[0].primitive == qml.S._primitive + assert branch.outvars[0] == branch.eqns[0].outvars[0] + + # elif branch + branch = jaxpr.eqns[3].params["jaxpr_branches"][1] + assert len(branch.eqns) == 1 + assert branch.eqns[0].primitive == qml.T._primitive + assert branch.outvars[0] == branch.eqns[0].outvars[0] + + # true branch + branch = jaxpr.eqns[3].params["jaxpr_branches"][2] + assert len(branch.eqns) == 1 + assert branch.eqns[0].primitive == qml.Hadamard._primitive + assert branch.outvars[0] == branch.eqns[0].outvars[0] + def test_for_loop_higher_order_primitive(self, lazy_adjoint): """Test that for_loop higher order primitives are transformed correctly.""" + @CancelInversesInterpreter() + def f(x, n): + qml.RX(x, 0) + + @qml.for_loop(n) + def loop_fn(i): # pylint: disable=unused-argument + qml.S(0) + qml.Hadamard(1) + qml.Hadamard(1) + qml.adjoint(qml.S(0), lazy=lazy_adjoint) + qml.RX(x, 0) + + loop_fn() + qml.RY(x, 1) + + jaxpr = jax.make_jaxpr(f)(1.5, 4) + assert len(jaxpr.eqns) == 3 + assert jaxpr.eqns[0].primitive == qml.RX._primitive + assert jaxpr.eqns[1].primitive == for_loop_prim + assert jaxpr.eqns[2].primitive == qml.RY._primitive + + inner_jaxpr = jaxpr.eqns[1].params["jaxpr_body_fn"] + assert len(inner_jaxpr.eqns) == 1 + assert inner_jaxpr.eqns[0].primitive == qml.RX._primitive + def test_while_loop_higher_order_primitive(self, lazy_adjoint): """Test that while_loop higher order primitives are transformed correctly.""" + @CancelInversesInterpreter() + def f(x, n): + qml.RX(x, 0) + + @qml.while_loop(lambda i: i < 2 * n) + def loop_fn(i): + qml.S(0) + qml.Hadamard(1) + qml.Hadamard(1) + qml.adjoint(qml.S(0), lazy=lazy_adjoint) + qml.RX(x, 0) + return i + 1 + + loop_fn(x) + qml.RY(x, 1) + + jaxpr = jax.make_jaxpr(f)(1.5, 4) + assert len(jaxpr.eqns) == 3 + assert jaxpr.eqns[0].primitive == qml.RX._primitive + assert jaxpr.eqns[1].primitive == while_loop_prim + assert jaxpr.eqns[2].primitive == qml.RY._primitive + + inner_jaxpr = jaxpr.eqns[1].params["jaxpr_body_fn"] + assert len(inner_jaxpr.eqns) == 2 + # The i + 1 primitive and the RX may get reordered, but the outcome will not be impacted + assert any(eqn.primitive == qml.RX._primitive for eqn in inner_jaxpr.eqns) + + # Check that the output of the i + 1 is returned + if inner_jaxpr.eqns[0].primitive == qml.RX._primitive: + add_eqn = inner_jaxpr.eqns[1] + else: + add_eqn = inner_jaxpr.eqns[0] + assert add_eqn.primitive.name == "add" + assert inner_jaxpr.outvars[0] == add_eqn.outvars[0] + def test_qnode_higher_order_primitive(self, lazy_adjoint): """Test that qnode higher order primitives are transformed correctly.""" dev = qml.device("default.qubit", wires=4) @@ -240,3 +347,33 @@ def f(x): @pytest.mark.parametrize("grad_fn", [qml.grad, qml.jacobian]) def test_grad_and_jac_higher_order_primitives(self, grad_fn, lazy_adjoint): """Test that grad and jacobian higher order primitives are transformed correctly.""" + dev = qml.device("default.qubit", wires=4) + + @qml.qnode(dev) + def circuit(y): + qml.S(0) + qml.Hadamard(1) + qml.Hadamard(1) + qml.adjoint(qml.S(0), lazy=lazy_adjoint) + qml.RX(y, 0) + return qml.expval(qml.PauliZ(0)) + + @CancelInversesInterpreter() + def f(x): + qml.RX(x, 0) + out = grad_fn(circuit)(x) + qml.RY(x, 1) + return out + + jaxpr = jax.make_jaxpr(f)(1.5) + assert len(jaxpr.eqns) == 3 + assert jaxpr.eqns[0].primitive == qml.RX._primitive + assert jaxpr.eqns[1].primitive == grad_prim if grad_fn == qml.grad else jacobian_prim + assert jaxpr.eqns[2].primitive == qml.RY._primitive + + inner_jaxpr = jaxpr.eqns[1].params["jaxpr"] + assert len(inner_jaxpr.eqns) == 1 + qfunc_jaxpr = inner_jaxpr.eqns[0].params["qfunc_jaxpr"] + assert qfunc_jaxpr.eqns[0].primitive == qml.RX._primitive + assert qfunc_jaxpr.eqns[1].primitive == qml.PauliZ._primitive + assert qfunc_jaxpr.eqns[2].primitive == qml.measurements.ExpectationMP._obs_primitive From 9d2d8907fad4e1f851e07c1cbd6b87b57c813dcb Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Fri, 13 Dec 2024 13:32:31 -0500 Subject: [PATCH 06/10] Linting, adding coverage --- .../transforms/capture_cancel_inverses.py | 11 +- .../test_capture_cancel_inverses.py | 107 +++++++++++++----- 2 files changed, 84 insertions(+), 34 deletions(-) diff --git a/pennylane/capture/transforms/capture_cancel_inverses.py b/pennylane/capture/transforms/capture_cancel_inverses.py index 90766317ddc..a88699ef7bc 100644 --- a/pennylane/capture/transforms/capture_cancel_inverses.py +++ b/pennylane/capture/transforms/capture_cancel_inverses.py @@ -129,6 +129,7 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list: list[TensorLike]: the results of the execution. """ + # pylint: disable=too-many-branches,attribute-defined-outside-init self._env = {} self.setup() @@ -146,16 +147,20 @@ def eval(self, jaxpr: "jax.core.Jaxpr", consts: list, *args) -> list: self.interpret_all_previous_ops() invals = [self.read(invar) for invar in eqn.invars] outvals = custom_handler(self, *invals, **eqn.params) - elif isinstance(eqn.outvars[0].aval, qml.capture.AbstractOperator): + elif len(eqn.outvars) > 0 and isinstance( + eqn.outvars[0].aval, qml.capture.AbstractOperator + ): outvals = self.interpret_operation_eqn(eqn) - elif isinstance(eqn.outvars[0].aval, qml.capture.AbstractMeasurement): + elif len(eqn.outvars) > 0 and isinstance( + eqn.outvars[0].aval, qml.capture.AbstractMeasurement + ): self.interpret_all_previous_ops() outvals = self.interpret_measurement_eqn(eqn) else: # Transform primitives don't have custom handlers, so we check for them here # to purge the stored ops in self.previous_ops if eqn.primitive.name.endswith("_transform"): - self.interpret_all_previous_ops + self.interpret_all_previous_ops() invals = [self.read(invar) for invar in eqn.invars] outvals = eqn.primitive.bind(*invals, **eqn.params) diff --git a/tests/capture/transforms/test_capture_cancel_inverses.py b/tests/capture/transforms/test_capture_cancel_inverses.py index 9f977077b2c..8a6e076a76b 100644 --- a/tests/capture/transforms/test_capture_cancel_inverses.py +++ b/tests/capture/transforms/test_capture_cancel_inverses.py @@ -35,12 +35,11 @@ pytestmark = [pytest.mark.jax, pytest.mark.usefixtures("enable_disable_plxpr")] -@pytest.mark.parametrize("lazy_adjoint", [True, False]) class TestCancelInversesInterpreter: """Unit tests for the CancelInversesInterpreter for canceling adjacent inverse operations in plxpr.""" - def test_cancel_inverses_simple(self, lazy_adjoint): + def test_cancel_inverses_simple(self): """Test that inverse ops in a simple circuit are cancelled.""" @CancelInversesInterpreter() @@ -48,8 +47,8 @@ def f(): qml.X(0) qml.X(0) qml.S(1) - qml.adjoint(qml.S(1), lazy=lazy_adjoint) - qml.adjoint(qml.T(2), lazy=lazy_adjoint) + qml.adjoint(qml.S(1)) + qml.adjoint(qml.T(2)) qml.Hadamard(1) # Applied qml.T(2) qml.Z(0) # Applied @@ -77,18 +76,18 @@ def f(): actual_primitives_fifth_sixth = {jaxpr.eqns[4].primitive, jaxpr.eqns[5].primitive} assert actual_primitives_fifth_sixth == expected_primitives_fifth_sixth - def test_cancel_inverses_true_inverses(self, lazy_adjoint): + def test_cancel_inverses_true_inverses(self): """Test that operations that are inverses with the same wires are cancelled.""" @CancelInversesInterpreter() def f(): qml.CRX(1.5, [0, 1]) - qml.adjoint(qml.CRX(1.5, [0, 1]), lazy=lazy_adjoint) + qml.adjoint(qml.CRX(1.5, [0, 1])) jaxpr = jax.make_jaxpr(f)() assert len(jaxpr.eqns) == 0 - def test_cancel_inverses_symmetric_wires(self, lazy_adjoint): # pylint: disable=unused-argument + def test_cancel_inverses_symmetric_wires(self): """Test that operations that are inverses regardless of wire order are cancelled.""" @CancelInversesInterpreter() @@ -99,9 +98,7 @@ def f(): jaxpr = jax.make_jaxpr(f)() assert len(jaxpr.eqns) == 0 - def test_cancel_inverses_symmetric_control_wires( - self, lazy_adjoint - ): # pylint: disable=unused-argument + def test_cancel_inverses_symmetric_control_wires(self): """Test that operations that are inverses regardless of control_wire order are cancelled.""" @CancelInversesInterpreter() @@ -112,16 +109,16 @@ def f(): jaxpr = jax.make_jaxpr(f)() assert len(jaxpr.eqns) == 0 - def test_cancel_inverese_nested_ops_on_same_wires(self, lazy_adjoint): + def test_cancel_inverese_nested_ops_on_same_wires(self): """Test that only the innermost adjacent adjoint ops are cancelled when multiple cancellable operators are present.""" @CancelInversesInterpreter() def f(): qml.S(0) - qml.adjoint(qml.T(0), lazy=lazy_adjoint) + qml.adjoint(qml.T(0)) qml.T(0) - qml.adjoint(qml.S(0), lazy=lazy_adjoint) + qml.adjoint(qml.S(0)) jaxpr = jax.make_jaxpr(f)() assert len(jaxpr.eqns) == 3 @@ -129,7 +126,7 @@ def f(): assert jaxpr.eqns[1].primitive == qml.S._primitive assert jaxpr.eqns[2].primitive == qml.ops.Adjoint._primitive - def test_returned_op_is_not_cancelled(self, lazy_adjoint): # pylint: disable=unused-argument + def test_returned_op_is_not_cancelled(self): """Test that ops that are returned by the function being transformed are not cancelled.""" @CancelInversesInterpreter() @@ -143,14 +140,61 @@ def f(): assert jaxpr.eqns[1].primitive == qml.PauliX._primitive assert jaxpr.jaxpr.outvars[0] == jaxpr.eqns[1].outvars[0] - def test_ctrl_higher_order_primitive(self, lazy_adjoint): + def test_no_wire_ops_not_cancelled(self): + """Test that inverse operations with no wires do not get cancelled.""" + + @CancelInversesInterpreter() + def f(): + qml.Identity() + qml.Identity() + return qml.expval(qml.Z(0)) + + jaxpr = jax.make_jaxpr(f)() + assert len(jaxpr.eqns) == 4 + assert jaxpr.eqns[0].primitive == qml.Identity._primitive + assert jaxpr.eqns[1].primitive == qml.Identity._primitive + assert jaxpr.eqns[2].primitive == qml.PauliZ._primitive + assert jaxpr.eqns[3].primitive == qml.measurements.ExpectationMP._obs_primitive + + def test_transform_higher_order_primitive(self): + """Test that the inner_jaxpr of transform primitives is not transformed.""" + + @qml.transform + def dummy_transform(tape): + """Dummy transform""" + return [tape], lambda res: res[0] + + @CancelInversesInterpreter() + def f(x): + @dummy_transform + def g(): + qml.S(0) + qml.adjoint(qml.S(0)) + + qml.RX(x, 0) + g() + qml.RY(x, 0) + + jaxpr = jax.make_jaxpr(f)(1.5) + assert len(jaxpr.eqns) == 3 + assert jaxpr.eqns[0].primitive == qml.RX._primitive + assert jaxpr.eqns[1].primitive == dummy_transform._primitive + assert jaxpr.eqns[2].primitive == qml.RY._primitive + + inner_jaxpr = jaxpr.eqns[1].params["inner_jaxpr"] + assert len(inner_jaxpr.eqns) == 3 + assert inner_jaxpr.eqns[0].primitive == qml.S._primitive + assert inner_jaxpr.eqns[1].primitive == qml.S._primitive + assert inner_jaxpr.eqns[2].primitive == qml.ops.Adjoint._primitive + + def test_ctrl_higher_order_primitive(self): """Test that ctrl higher order primitives are transformed correctly.""" def ctrl_fn(y): qml.S(0) qml.Hadamard(1) qml.Hadamard(1) - qml.adjoint(qml.S(0), lazy=lazy_adjoint) + qml.adjoint(qml.S(0)) qml.RX(y, 0) @CancelInversesInterpreter() @@ -169,34 +213,35 @@ def f(x): assert len(inner_jaxpr.eqns) == 1 assert inner_jaxpr.eqns[0].primitive == qml.RX._primitive - def test_adjoint_higher_order_primitive(self, lazy_adjoint): + @pytest.mark.parametrize("lazy", [True, False]) + def test_adjoint_higher_order_primitive(self, lazy): """Test that adjoint higher order primitives are transformed correctly.""" def adjoint_fn(y): qml.S(0) qml.Hadamard(1) qml.Hadamard(1) - qml.adjoint(qml.S(0), lazy=lazy_adjoint) + qml.adjoint(qml.S(0)) qml.RX(y, 0) @CancelInversesInterpreter() def f(x): qml.RX(x, 0) - qml.adjoint(adjoint_fn, lazy=lazy_adjoint)(x) + qml.adjoint(adjoint_fn, lazy=lazy)(x) qml.RY(x, 1) jaxpr = jax.make_jaxpr(f)(1.5) assert len(jaxpr.eqns) == 3 assert jaxpr.eqns[0].primitive == qml.RX._primitive assert jaxpr.eqns[1].primitive == adjoint_transform_prim - assert jaxpr.eqns[1].params["lazy"] == lazy_adjoint + assert jaxpr.eqns[1].params["lazy"] == lazy assert jaxpr.eqns[2].primitive == qml.RY._primitive inner_jaxpr = jaxpr.eqns[1].params["jaxpr"] assert len(inner_jaxpr.eqns) == 1 assert inner_jaxpr.eqns[0].primitive == qml.RX._primitive - def test_cond_higher_order_primitive(self, lazy_adjoint): + def test_cond_higher_order_primitive(self): """Test that cond higher order primitives are transformed correctly.""" @CancelInversesInterpreter() @@ -212,12 +257,12 @@ def cond_fn(): @cond_fn.else_if(x > 1) def _(): qml.S(0) - qml.adjoint(qml.S(0), lazy=lazy_adjoint) + qml.adjoint(qml.S(0)) return qml.T(0) @cond_fn.otherwise def _(): - qml.adjoint(qml.T(0), lazy=lazy_adjoint) + qml.adjoint(qml.T(0)) qml.T(0) return qml.Hadamard(0) @@ -247,7 +292,7 @@ def _(): assert branch.eqns[0].primitive == qml.Hadamard._primitive assert branch.outvars[0] == branch.eqns[0].outvars[0] - def test_for_loop_higher_order_primitive(self, lazy_adjoint): + def test_for_loop_higher_order_primitive(self): """Test that for_loop higher order primitives are transformed correctly.""" @CancelInversesInterpreter() @@ -259,7 +304,7 @@ def loop_fn(i): # pylint: disable=unused-argument qml.S(0) qml.Hadamard(1) qml.Hadamard(1) - qml.adjoint(qml.S(0), lazy=lazy_adjoint) + qml.adjoint(qml.S(0)) qml.RX(x, 0) loop_fn() @@ -275,7 +320,7 @@ def loop_fn(i): # pylint: disable=unused-argument assert len(inner_jaxpr.eqns) == 1 assert inner_jaxpr.eqns[0].primitive == qml.RX._primitive - def test_while_loop_higher_order_primitive(self, lazy_adjoint): + def test_while_loop_higher_order_primitive(self): """Test that while_loop higher order primitives are transformed correctly.""" @CancelInversesInterpreter() @@ -287,7 +332,7 @@ def loop_fn(i): qml.S(0) qml.Hadamard(1) qml.Hadamard(1) - qml.adjoint(qml.S(0), lazy=lazy_adjoint) + qml.adjoint(qml.S(0)) qml.RX(x, 0) return i + 1 @@ -313,7 +358,7 @@ def loop_fn(i): assert add_eqn.primitive.name == "add" assert inner_jaxpr.outvars[0] == add_eqn.outvars[0] - def test_qnode_higher_order_primitive(self, lazy_adjoint): + def test_qnode_higher_order_primitive(self): """Test that qnode higher order primitives are transformed correctly.""" dev = qml.device("default.qubit", wires=4) @@ -322,7 +367,7 @@ def circuit(y): qml.S(0) qml.Hadamard(1) qml.Hadamard(1) - qml.adjoint(qml.S(0), lazy=lazy_adjoint) + qml.adjoint(qml.S(0)) qml.RX(y, 0) return qml.expval(qml.PauliZ(0)) @@ -345,7 +390,7 @@ def f(x): assert inner_jaxpr.eqns[2].primitive == qml.measurements.ExpectationMP._obs_primitive @pytest.mark.parametrize("grad_fn", [qml.grad, qml.jacobian]) - def test_grad_and_jac_higher_order_primitives(self, grad_fn, lazy_adjoint): + def test_grad_and_jac_higher_order_primitives(self, grad_fn): """Test that grad and jacobian higher order primitives are transformed correctly.""" dev = qml.device("default.qubit", wires=4) @@ -354,7 +399,7 @@ def circuit(y): qml.S(0) qml.Hadamard(1) qml.Hadamard(1) - qml.adjoint(qml.S(0), lazy=lazy_adjoint) + qml.adjoint(qml.S(0)) qml.RX(y, 0) return qml.expval(qml.PauliZ(0)) From 99de54215cf276688c90d1c83bb7d8f9f93d362b Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Fri, 13 Dec 2024 13:36:40 -0500 Subject: [PATCH 07/10] Fix init --- pennylane/capture/transforms/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pennylane/capture/transforms/__init__.py b/pennylane/capture/transforms/__init__.py index 651ff7da8e2..bc69da61763 100644 --- a/pennylane/capture/transforms/__init__.py +++ b/pennylane/capture/transforms/__init__.py @@ -15,11 +15,9 @@ Public/internal API for the pennylane.capture.transforms module. """ from .capture_cancel_inverses import CancelInversesInterpreter -from .capture_decompose import DecomposeInterpreter from .map_wires import MapWiresInterpreter __all__ = ( "CancelInversesInterpreter", - "DecomposeInterpreter", "MapWiresInterpreter", -) \ No newline at end of file +) From b71337c7f2f45e4cc214a1d5b81208dfb5575053 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Fri, 13 Dec 2024 17:49:28 -0500 Subject: [PATCH 08/10] Update tests/capture/transforms/test_capture_cancel_inverses.py Co-authored-by: Pietropaolo Frisoni --- tests/capture/transforms/test_capture_cancel_inverses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/capture/transforms/test_capture_cancel_inverses.py b/tests/capture/transforms/test_capture_cancel_inverses.py index 8a6e076a76b..9f18a8f9453 100644 --- a/tests/capture/transforms/test_capture_cancel_inverses.py +++ b/tests/capture/transforms/test_capture_cancel_inverses.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for the ``DecomposeInterpreter`` class""" +"""Unit tests for the ``CancelInversesInterpreter`` class""" # pylint:disable=wrong-import-position,protected-access import pytest From 3e3514a3c8b08051f488edeee1e8468cf138f9fa Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 17 Dec 2024 10:51:02 -0500 Subject: [PATCH 09/10] [skip ci] Skip CI From 4e75f9531af3c17ba5975c4ecd87b63cf31cbbf0 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 17 Dec 2024 14:48:33 -0500 Subject: [PATCH 10/10] Update docstring --- pennylane/capture/transforms/capture_cancel_inverses.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pennylane/capture/transforms/capture_cancel_inverses.py b/pennylane/capture/transforms/capture_cancel_inverses.py index a88699ef7bc..8a1f35ff8af 100644 --- a/pennylane/capture/transforms/capture_cancel_inverses.py +++ b/pennylane/capture/transforms/capture_cancel_inverses.py @@ -106,9 +106,8 @@ def interpret_operation(self, op: qml.operation.Operator): return res def interpret_all_previous_ops(self) -> None: - """Interpret all ops in ``previous_ops``. This is done whenever any - operators that haven't been interpreted that are saved to be cancelled - no longer need to be saved.""" + """Interpret all operators in ``previous_ops``. This is done when any previously + uninterpreted operators, saved for cancellation, no longer need to be stored.""" ops_remaining = set(self.previous_ops.values()) for op in ops_remaining: super().interpret_operation(op)