diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index fd8797d084a..d22806f2b7a 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -507,6 +507,9 @@ same information.
Bug fixes 🐛
+* `qml.ControlledQubitUnitary` has consistent behaviour with program capture enabled.
+ [(#6719)](https://github.com/PennyLaneAI/pennylane/pull/6719)
+
* The `Wires` object throws a `TypeError` if `wires=None`.
[(#6713)](https://github.com/PennyLaneAI/pennylane/pull/6713)
[(#6720)](https://github.com/PennyLaneAI/pennylane/pull/6720)
diff --git a/pennylane/ops/op_math/controlled_ops.py b/pennylane/ops/op_math/controlled_ops.py
index 93370bcb37d..7f3e2258f64 100644
--- a/pennylane/ops/op_math/controlled_ops.py
+++ b/pennylane/ops/op_math/controlled_ops.py
@@ -25,6 +25,7 @@
import pennylane as qml
from pennylane.operation import AnyWires, Wires
from pennylane.ops.qubit.parametric_ops_single_qubit import stack_last
+from pennylane.wires import WiresLike
from .controlled import ControlledOp
from .controlled_decompositions import decompose_mcx
@@ -119,22 +120,61 @@ def _unflatten(cls, data, metadata):
data[0], control_wires=metadata[0], control_values=metadata[1], work_wires=metadata[2]
)
+ # pylint: disable=arguments-differ, too-many-arguments, unused-argument, too-many-positional-arguments
+ @classmethod
+ def _primitive_bind_call(
+ cls,
+ base,
+ control_wires: WiresLike,
+ wires: WiresLike = (),
+ control_values=None,
+ unitary_check=False,
+ work_wires: WiresLike = (),
+ ):
+ wires = Wires(() if wires is None else wires)
+ work_wires = Wires(() if work_wires is None else work_wires)
+
+ if hasattr(base, "wires") and len(wires) != 0:
+ warnings.warn(
+ "base operator already has wires; values specified through wires kwarg will be ignored."
+ )
+ wires = Wires(())
+
+ all_wires = control_wires + wires
+ return cls._primitive.bind(
+ base, control_wires=all_wires, control_values=control_values, work_wires=work_wires
+ )
+
# pylint: disable=too-many-arguments,too-many-positional-arguments
def __init__(
self,
base,
- control_wires,
- wires=None,
+ control_wires: WiresLike,
+ wires: WiresLike = (),
control_values=None,
unitary_check=False,
- work_wires=None,
+ work_wires: WiresLike = (),
):
- if getattr(base, "wires", False) and wires is not None:
+ wires = Wires(() if wires is None else wires)
+ work_wires = Wires(() if work_wires is None else work_wires)
+ control_wires = Wires(control_wires)
+
+ if hasattr(base, "wires") and len(wires) != 0:
warnings.warn(
"base operator already has wires; values specified through wires kwarg will be ignored."
)
+ wires = Wires(())
if isinstance(base, Iterable):
+ if len(wires) == 0:
+ if len(control_wires) > 1:
+ num_base_wires = int(qml.math.log2(qml.math.shape(base)[-1]))
+ wires = control_wires[-num_base_wires:]
+ control_wires = control_wires[:-num_base_wires]
+ else:
+ raise TypeError(
+ "Must specify a set of wires. None is not a valid `wires` label."
+ )
# We use type.__call__ instead of calling the class directly so that we don't bind the
# operator primitive when new program capture is enabled
base = type.__call__(qml.QubitUnitary, base, wires=wires, unitary_check=unitary_check)
diff --git a/tests/ops/op_math/test_controlled_ops.py b/tests/ops/op_math/test_controlled_ops.py
index 7a002cccda4..4ae0268f678 100644
--- a/tests/ops/op_math/test_controlled_ops.py
+++ b/tests/ops/op_math/test_controlled_ops.py
@@ -49,9 +49,63 @@
X_broadcasted = np.array([X] * 3)
+# pylint: disable=too-many-public-methods
class TestControlledQubitUnitary:
"""Tests specific to the ControlledQubitUnitary operation"""
+ def test_wires_is_none(self):
+ """Test that an error is raised if the user provides no target wires for an iterable base operator"""
+ base_op = [[0, 1], [1, 0]]
+
+ with pytest.raises(TypeError, match="Must specify a set of wires"):
+ qml.ControlledQubitUnitary(base_op, control_wires=1, wires=None)
+
+ @pytest.mark.jax
+ @pytest.mark.usefixtures("enable_disable_plxpr")
+ def test_wires_specified_twice_with_capture(self):
+ """Test that a UserWarning is raised for providing redundant wires with capture enabled"""
+ base = qml.QubitUnitary(X, wires=0)
+ with pytest.warns(
+ UserWarning,
+ match="base operator already has wires; values specified through wires kwarg will be ignored.",
+ ):
+ qml.ControlledQubitUnitary(base, control_wires=[1, 2], wires=3)
+
+ @pytest.mark.jax
+ @pytest.mark.usefixtures("enable_disable_plxpr")
+ @pytest.mark.parametrize(
+ "control_wires, wires",
+ [(0, 1), ([0, 1], [2])],
+ )
+ def test_consistency_with_capture(self, control_wires, wires):
+ """Test that the operator wires are as expected with capture enabled"""
+ base_op = [[0, 1], [1, 0]]
+
+ op_kwarg = qml.ControlledQubitUnitary(base_op, control_wires=control_wires, wires=wires)
+ assert op_kwarg.base.wires == Wires(wires)
+ assert op_kwarg.control_wires == Wires(control_wires)
+ op = qml.ControlledQubitUnitary(base_op, control_wires, wires)
+ assert op.base.wires == Wires(wires)
+ assert op.control_wires == Wires(control_wires)
+
+ @pytest.mark.jax
+ @pytest.mark.usefixtures("enable_disable_plxpr")
+ def test_pairwise_consistency_with_capture(self):
+ """Test that both combinations of control and target wires lead to the same operator"""
+ base_op = [[0, 1], [1, 0]]
+
+ control_wires_1, wires_1 = [0, 1], [2]
+ op_1 = qml.ControlledQubitUnitary(base_op, control_wires=control_wires_1, wires=wires_1)
+
+ assert op_1.base.wires == Wires(2)
+ assert op_1.control_wires == Wires([0, 1])
+
+ control_wires_2, wires_2 = [0, 1, 2], ()
+ op_2 = qml.ControlledQubitUnitary(base_op, control_wires=control_wires_2, wires=wires_2)
+
+ assert op_2.base.wires == Wires(2)
+ assert op_2.control_wires == Wires([0, 1])
+
def test_initialization_from_matrix_and_operator(self):
base_op = QubitUnitary(X, wires=1)
@@ -88,7 +142,7 @@ def test_wrong_shape(self):
with pytest.raises(ValueError, match=r"Input unitary must be of shape \(2, 2\)"):
qml.ControlledQubitUnitary(np.eye(4), control_wires=[0, 1], wires=2).matrix()
- @pytest.mark.parametrize("target_wire", range(3))
+ @pytest.mark.parametrize("target_wire", list(range(3)))
def test_toffoli(self, target_wire):
"""Test if ControlledQubitUnitary acts like a Toffoli gate when the input unitary is a
single-qubit X. This test allows the target wire to be any of the three wires."""