Skip to content

Commit

Permalink
fix: further clean-up in controlled.py
Browse files Browse the repository at this point in the history
  • Loading branch information
andrijapau committed Dec 19, 2024
1 parent 5fe2e84 commit b6e306a
Showing 1 changed file with 32 additions and 14 deletions.
46 changes: 32 additions & 14 deletions pennylane/ops/op_math/controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,20 @@ def ctrl(
op: Operator,
control: Any,
control_values: Optional[Sequence[bool]] = None,
work_wires: Optional[Any] = None,
work_wires: Optional[Any] = (),
) -> Operator: ...


@overload
def ctrl(
op: Callable,
control: Any,
control_values: Optional[Sequence[bool]] = None,
work_wires: Optional[Any] = None,
work_wires: Optional[Any] = (),
) -> Callable: ...
def ctrl(op, control: Any, control_values=None, work_wires=None):


def ctrl(op, control: WiresLike, control_values=None, work_wires: WiresLike = ()):
"""Create a method that applies a controlled version of the provided op.
:func:`~.qjit` compatible.
Expand Down Expand Up @@ -148,7 +152,7 @@ def cond_fn():
return create_controlled_op(op, control, control_values=control_values, work_wires=work_wires)


def create_controlled_op(op, control, control_values=None, work_wires: WiresLike = ()):
def create_controlled_op(op, control: WiresLike, control_values=None, work_wires: WiresLike = ()):
"""Default ``qml.ctrl`` implementation, allowing other implementations to call it when needed."""

control = Wires(control)
Expand Down Expand Up @@ -201,7 +205,7 @@ def create_controlled_op(op, control, control_values=None, work_wires: WiresLike
return _ctrl_transform(op, control, control_values, work_wires)


def _ctrl_transform(op, control, control_values, work_wires):
def _ctrl_transform(op, control: WiresLike, control_values, work_wires: WiresLike):
@wraps(op)
def wrapper(*args, **kwargs):
qscript = qml.tape.make_qscript(op)(*args, **kwargs)
Expand Down Expand Up @@ -261,7 +265,9 @@ def _(*_, **__):
return ctrl_prim


def _capture_ctrl_transform(qfunc: Callable, control, control_values, work_wires) -> Callable:
def _capture_ctrl_transform(
qfunc: Callable, control: WiresLike, control_values, work_wires: WiresLike
) -> Callable:
"""Capture compatible way of performing an ctrl transform."""
# note that this logic is tested in `tests/capture/test_nested_plxpr.py`
import jax # pylint: disable=import-outside-toplevel
Expand Down Expand Up @@ -341,7 +347,7 @@ def _try_wrap_in_custom_ctrl_op(


def _handle_pauli_x_based_controlled_ops(
op, control: WiresLike, control_values, work_wires: WiresLike = None
op, control: WiresLike, control_values, work_wires: WiresLike = ()
):
"""Handles PauliX-based controlled operations."""

Expand Down Expand Up @@ -486,7 +492,12 @@ def __new__(cls, base, *_, **__):
# pylint: disable=arguments-differ

Check notice on line 492 in pennylane/ops/op_math/controlled.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/ops/op_math/controlled.py#L492

Too many positional arguments (6/5) (too-many-positional-arguments)
@classmethod
def _primitive_bind_call(
cls, base, control_wires, control_values=None, work_wires=None, id=None
cls,
base,
control_wires: WiresLike,
control_values=None,
work_wires: WiresLike = (),
id=None,
):
control_wires = Wires(control_wires)
return cls._primitive.bind(
Expand All @@ -499,7 +510,7 @@ def __init__(
base,
control_wires: WiresLike,
control_values=None,
work_wires: WiresLike = None,
work_wires: WiresLike = (),
id=None,
):
control_wires = Wires(control_wires)
Expand Down Expand Up @@ -663,17 +674,17 @@ def _compute_matrix_from_base(self):

return qmlmath.block_diag([left_pad, base_matrix, right_pad])

def matrix(self, wire_order=None):
def matrix(self, wire_order: WiresLike = None):
if self.compute_matrix is not Operator.compute_matrix:
canonical_matrix = self.compute_matrix(*self.data)
else:
canonical_matrix = self._compute_matrix_from_base()

wire_order = wire_order or self.wires
wire_order = self.wires if wire_order is None else wire_order
return qml.math.expand_matrix(canonical_matrix, wires=self.wires, wire_order=wire_order)

# pylint: disable=arguments-differ
def sparse_matrix(self, wire_order=None, format="csr"):
def sparse_matrix(self, wire_order: WiresLike = None, format="csr"):
if wire_order is not None:
raise NotImplementedError("wire_order argument is not yet implemented.")

Expand Down Expand Up @@ -925,7 +936,14 @@ def __new__(cls, *_, **__):
return object.__new__(cls)

# pylint: disable=too-many-function-args
def __init__(self, base, control_wires, control_values=None, work_wires=None, id=None):
def __init__(

Check notice on line 939 in pennylane/ops/op_math/controlled.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/ops/op_math/controlled.py#L939

Too many positional arguments (6/5) (too-many-positional-arguments)
self,
base,
control_wires: WiresLike,
control_values=None,
work_wires: WiresLike = None,
id=None,
):
super().__init__(base, control_wires, control_values, work_wires, id)
# check the grad_recipe validity
if self.grad_recipe is None:
Expand Down Expand Up @@ -975,7 +993,7 @@ def parameter_frequencies(self):
if Controlled._primitive is not None: # pylint: disable=protected-access

@Controlled._primitive.def_impl # pylint: disable=protected-access
def _(base, *control_wires, control_values=None, work_wires=None, id=None):
def _(base, *control_wires, control_values=None, work_wires: WiresLike = None, id=None):
return type.__call__(
Controlled,
base,
Expand Down

0 comments on commit b6e306a

Please sign in to comment.