Skip to content

Commit

Permalink
feat: Update ControlledQubitUnitary class
Browse files Browse the repository at this point in the history
  • Loading branch information
andrijapau committed Dec 13, 2024
1 parent a0ecd9a commit ba59255
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 5 deletions.
51 changes: 47 additions & 4 deletions pennylane/ops/op_math/controlled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pennylane as qml
from pennylane.operation import AnyWires, Wires
from pennylane.ops.qubit.parametric_ops_single_qubit import stack_last
from pennylane.wires import WiresLike

from .controlled import ControlledOp
from .controlled_decompositions import decompose_mcx
Expand Down Expand Up @@ -119,22 +120,64 @@ def _unflatten(cls, data, metadata):
data[0], control_wires=metadata[0], control_values=metadata[1], work_wires=metadata[2]
)

# pylint: disable=arguments-differ, too-many-arguments, unused-argument
@classmethod
def _primitive_bind_call(

Check notice on line 125 in pennylane/ops/op_math/controlled_ops.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/ops/op_math/controlled_ops.py#L125

Too many positional arguments (7/5) (too-many-positional-arguments)
cls,
base,
control_wires: WiresLike,
wires: WiresLike = (),
control_values=None,
unitary_check=False,
work_wires: WiresLike = (),
):
wires = () if wires is None else wires
work_wires = () if work_wires is None else work_wires
wires = Wires(wires)
work_wires = Wires(work_wires)

if getattr(base, "wires", False) and len(wires) != 0:
warnings.warn(
"base operator already has wires; values specified through wires kwarg will be ignored."
)
wires = Wires(())

all_wires = control_wires + wires
return cls._primitive.bind(
base, all_wires, control_values=control_values, work_wires=work_wires
)

# pylint: disable=too-many-arguments,too-many-positional-arguments
def __init__(
self,
base,
control_wires,
wires=None,
control_wires: WiresLike,
wires: WiresLike = (),
control_values=None,
unitary_check=False,
work_wires=None,
work_wires: WiresLike = (),
):
if getattr(base, "wires", False) and wires is not None:
wires = () if wires is None else wires
work_wires = () if work_wires is None else work_wires
wires = Wires(wires)
work_wires = Wires(work_wires)

if getattr(base, "wires", False) and len(wires) != 0:
warnings.warn(
"base operator already has wires; values specified through wires kwarg will be ignored."
)
wires = Wires(())

if isinstance(base, Iterable):
if len(wires) == 0:
if len(Wires(control_wires)) > 1:
num_base_wires = int(qml.math.log2(qml.math.shape(base)[-1]))
wires = Wires(control_wires)[-num_base_wires:]
control_wires = Wires(control_wires)[:-num_base_wires]
else:
raise TypeError(
"Must specify a set of wires. None is not a valid `wires` label."
)
# We use type.__call__ instead of calling the class directly so that we don't bind the
# operator primitive when new program capture is enabled
base = type.__call__(qml.QubitUnitary, base, wires=wires, unitary_check=unitary_check)
Expand Down
2 changes: 1 addition & 1 deletion tests/ops/op_math/test_controlled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_wrong_shape(self):
with pytest.raises(ValueError, match=r"Input unitary must be of shape \(2, 2\)"):
qml.ControlledQubitUnitary(np.eye(4), control_wires=[0, 1], wires=2).matrix()

@pytest.mark.parametrize("target_wire", range(3))
@pytest.mark.parametrize("target_wire", list(range(3)))
def test_toffoli(self, target_wire):
"""Test if ControlledQubitUnitary acts like a Toffoli gate when the input unitary is a
single-qubit X. This test allows the target wire to be any of the three wires."""
Expand Down

0 comments on commit ba59255

Please sign in to comment.