diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index fd8797d084a..d22806f2b7a 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -507,6 +507,9 @@ same information.

Bug fixes 🐛

+* `qml.ControlledQubitUnitary` has consistent behaviour with program capture enabled. + [(#6719)](https://github.com/PennyLaneAI/pennylane/pull/6719) + * The `Wires` object throws a `TypeError` if `wires=None`. [(#6713)](https://github.com/PennyLaneAI/pennylane/pull/6713) [(#6720)](https://github.com/PennyLaneAI/pennylane/pull/6720) diff --git a/pennylane/ops/op_math/controlled_ops.py b/pennylane/ops/op_math/controlled_ops.py index 93370bcb37d..7f3e2258f64 100644 --- a/pennylane/ops/op_math/controlled_ops.py +++ b/pennylane/ops/op_math/controlled_ops.py @@ -25,6 +25,7 @@ import pennylane as qml from pennylane.operation import AnyWires, Wires from pennylane.ops.qubit.parametric_ops_single_qubit import stack_last +from pennylane.wires import WiresLike from .controlled import ControlledOp from .controlled_decompositions import decompose_mcx @@ -119,22 +120,61 @@ def _unflatten(cls, data, metadata): data[0], control_wires=metadata[0], control_values=metadata[1], work_wires=metadata[2] ) + # pylint: disable=arguments-differ, too-many-arguments, unused-argument, too-many-positional-arguments + @classmethod + def _primitive_bind_call( + cls, + base, + control_wires: WiresLike, + wires: WiresLike = (), + control_values=None, + unitary_check=False, + work_wires: WiresLike = (), + ): + wires = Wires(() if wires is None else wires) + work_wires = Wires(() if work_wires is None else work_wires) + + if hasattr(base, "wires") and len(wires) != 0: + warnings.warn( + "base operator already has wires; values specified through wires kwarg will be ignored." + ) + wires = Wires(()) + + all_wires = control_wires + wires + return cls._primitive.bind( + base, control_wires=all_wires, control_values=control_values, work_wires=work_wires + ) + # pylint: disable=too-many-arguments,too-many-positional-arguments def __init__( self, base, - control_wires, - wires=None, + control_wires: WiresLike, + wires: WiresLike = (), control_values=None, unitary_check=False, - work_wires=None, + work_wires: WiresLike = (), ): - if getattr(base, "wires", False) and wires is not None: + wires = Wires(() if wires is None else wires) + work_wires = Wires(() if work_wires is None else work_wires) + control_wires = Wires(control_wires) + + if hasattr(base, "wires") and len(wires) != 0: warnings.warn( "base operator already has wires; values specified through wires kwarg will be ignored." ) + wires = Wires(()) if isinstance(base, Iterable): + if len(wires) == 0: + if len(control_wires) > 1: + num_base_wires = int(qml.math.log2(qml.math.shape(base)[-1])) + wires = control_wires[-num_base_wires:] + control_wires = control_wires[:-num_base_wires] + else: + raise TypeError( + "Must specify a set of wires. None is not a valid `wires` label." + ) # We use type.__call__ instead of calling the class directly so that we don't bind the # operator primitive when new program capture is enabled base = type.__call__(qml.QubitUnitary, base, wires=wires, unitary_check=unitary_check) diff --git a/tests/ops/op_math/test_controlled_ops.py b/tests/ops/op_math/test_controlled_ops.py index 7a002cccda4..4ae0268f678 100644 --- a/tests/ops/op_math/test_controlled_ops.py +++ b/tests/ops/op_math/test_controlled_ops.py @@ -49,9 +49,63 @@ X_broadcasted = np.array([X] * 3) +# pylint: disable=too-many-public-methods class TestControlledQubitUnitary: """Tests specific to the ControlledQubitUnitary operation""" + def test_wires_is_none(self): + """Test that an error is raised if the user provides no target wires for an iterable base operator""" + base_op = [[0, 1], [1, 0]] + + with pytest.raises(TypeError, match="Must specify a set of wires"): + qml.ControlledQubitUnitary(base_op, control_wires=1, wires=None) + + @pytest.mark.jax + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_wires_specified_twice_with_capture(self): + """Test that a UserWarning is raised for providing redundant wires with capture enabled""" + base = qml.QubitUnitary(X, wires=0) + with pytest.warns( + UserWarning, + match="base operator already has wires; values specified through wires kwarg will be ignored.", + ): + qml.ControlledQubitUnitary(base, control_wires=[1, 2], wires=3) + + @pytest.mark.jax + @pytest.mark.usefixtures("enable_disable_plxpr") + @pytest.mark.parametrize( + "control_wires, wires", + [(0, 1), ([0, 1], [2])], + ) + def test_consistency_with_capture(self, control_wires, wires): + """Test that the operator wires are as expected with capture enabled""" + base_op = [[0, 1], [1, 0]] + + op_kwarg = qml.ControlledQubitUnitary(base_op, control_wires=control_wires, wires=wires) + assert op_kwarg.base.wires == Wires(wires) + assert op_kwarg.control_wires == Wires(control_wires) + op = qml.ControlledQubitUnitary(base_op, control_wires, wires) + assert op.base.wires == Wires(wires) + assert op.control_wires == Wires(control_wires) + + @pytest.mark.jax + @pytest.mark.usefixtures("enable_disable_plxpr") + def test_pairwise_consistency_with_capture(self): + """Test that both combinations of control and target wires lead to the same operator""" + base_op = [[0, 1], [1, 0]] + + control_wires_1, wires_1 = [0, 1], [2] + op_1 = qml.ControlledQubitUnitary(base_op, control_wires=control_wires_1, wires=wires_1) + + assert op_1.base.wires == Wires(2) + assert op_1.control_wires == Wires([0, 1]) + + control_wires_2, wires_2 = [0, 1, 2], () + op_2 = qml.ControlledQubitUnitary(base_op, control_wires=control_wires_2, wires=wires_2) + + assert op_2.base.wires == Wires(2) + assert op_2.control_wires == Wires([0, 1]) + def test_initialization_from_matrix_and_operator(self): base_op = QubitUnitary(X, wires=1) @@ -88,7 +142,7 @@ def test_wrong_shape(self): with pytest.raises(ValueError, match=r"Input unitary must be of shape \(2, 2\)"): qml.ControlledQubitUnitary(np.eye(4), control_wires=[0, 1], wires=2).matrix() - @pytest.mark.parametrize("target_wire", range(3)) + @pytest.mark.parametrize("target_wire", list(range(3))) def test_toffoli(self, target_wire): """Test if ControlledQubitUnitary acts like a Toffoli gate when the input unitary is a single-qubit X. This test allows the target wire to be any of the three wires."""