diff --git a/pennylane/ops/op_math/controlled.py b/pennylane/ops/op_math/controlled.py index 9997d86f1b1..a88f0eba301 100644 --- a/pennylane/ops/op_math/controlled.py +++ b/pennylane/ops/op_math/controlled.py @@ -42,16 +42,20 @@ def ctrl( op: Operator, control: Any, control_values: Optional[Sequence[bool]] = None, - work_wires: Optional[Any] = None, + work_wires: Optional[Any] = (), ) -> Operator: ... + + @overload def ctrl( op: Callable, control: Any, control_values: Optional[Sequence[bool]] = None, - work_wires: Optional[Any] = None, + work_wires: Optional[Any] = (), ) -> Callable: ... -def ctrl(op, control: Any, control_values=None, work_wires=None): + + +def ctrl(op, control: WiresLike, control_values=None, work_wires: WiresLike = ()): """Create a method that applies a controlled version of the provided op. :func:`~.qjit` compatible. @@ -148,7 +152,7 @@ def cond_fn(): return create_controlled_op(op, control, control_values=control_values, work_wires=work_wires) -def create_controlled_op(op, control, control_values=None, work_wires: WiresLike = ()): +def create_controlled_op(op, control: WiresLike, control_values=None, work_wires: WiresLike = ()): """Default ``qml.ctrl`` implementation, allowing other implementations to call it when needed.""" control = Wires(control) @@ -201,7 +205,7 @@ def create_controlled_op(op, control, control_values=None, work_wires: WiresLike return _ctrl_transform(op, control, control_values, work_wires) -def _ctrl_transform(op, control, control_values, work_wires): +def _ctrl_transform(op, control: WiresLike, control_values, work_wires: WiresLike): @wraps(op) def wrapper(*args, **kwargs): qscript = qml.tape.make_qscript(op)(*args, **kwargs) @@ -261,7 +265,9 @@ def _(*_, **__): return ctrl_prim -def _capture_ctrl_transform(qfunc: Callable, control, control_values, work_wires) -> Callable: +def _capture_ctrl_transform( + qfunc: Callable, control: WiresLike, control_values, work_wires: WiresLike +) -> Callable: """Capture compatible way of performing an ctrl transform.""" # note that this logic is tested in `tests/capture/test_nested_plxpr.py` import jax # pylint: disable=import-outside-toplevel @@ -341,7 +347,7 @@ def _try_wrap_in_custom_ctrl_op( def _handle_pauli_x_based_controlled_ops( - op, control: WiresLike, control_values, work_wires: WiresLike = None + op, control: WiresLike, control_values, work_wires: WiresLike = () ): """Handles PauliX-based controlled operations.""" @@ -486,7 +492,12 @@ def __new__(cls, base, *_, **__): # pylint: disable=arguments-differ @classmethod def _primitive_bind_call( - cls, base, control_wires, control_values=None, work_wires=None, id=None + cls, + base, + control_wires: WiresLike, + control_values=None, + work_wires: WiresLike = (), + id=None, ): control_wires = Wires(control_wires) return cls._primitive.bind( @@ -499,7 +510,7 @@ def __init__( base, control_wires: WiresLike, control_values=None, - work_wires: WiresLike = None, + work_wires: WiresLike = (), id=None, ): control_wires = Wires(control_wires) @@ -663,17 +674,17 @@ def _compute_matrix_from_base(self): return qmlmath.block_diag([left_pad, base_matrix, right_pad]) - def matrix(self, wire_order=None): + def matrix(self, wire_order: WiresLike = None): if self.compute_matrix is not Operator.compute_matrix: canonical_matrix = self.compute_matrix(*self.data) else: canonical_matrix = self._compute_matrix_from_base() - wire_order = wire_order or self.wires + wire_order = self.wires if wire_order is None else wire_order return qml.math.expand_matrix(canonical_matrix, wires=self.wires, wire_order=wire_order) # pylint: disable=arguments-differ - def sparse_matrix(self, wire_order=None, format="csr"): + def sparse_matrix(self, wire_order: WiresLike = None, format="csr"): if wire_order is not None: raise NotImplementedError("wire_order argument is not yet implemented.") @@ -925,7 +936,14 @@ def __new__(cls, *_, **__): return object.__new__(cls) # pylint: disable=too-many-function-args - def __init__(self, base, control_wires, control_values=None, work_wires=None, id=None): + def __init__( + self, + base, + control_wires: WiresLike, + control_values=None, + work_wires: WiresLike = None, + id=None, + ): super().__init__(base, control_wires, control_values, work_wires, id) # check the grad_recipe validity if self.grad_recipe is None: @@ -975,7 +993,7 @@ def parameter_frequencies(self): if Controlled._primitive is not None: # pylint: disable=protected-access @Controlled._primitive.def_impl # pylint: disable=protected-access - def _(base, *control_wires, control_values=None, work_wires=None, id=None): + def _(base, *control_wires, control_values=None, work_wires: WiresLike = None, id=None): return type.__call__( Controlled, base,