diff --git a/pennylane/ops/op_math/controlled.py b/pennylane/ops/op_math/controlled.py index a88f0eba301..9997d86f1b1 100644 --- a/pennylane/ops/op_math/controlled.py +++ b/pennylane/ops/op_math/controlled.py @@ -42,20 +42,16 @@ def ctrl( op: Operator, control: Any, control_values: Optional[Sequence[bool]] = None, - work_wires: Optional[Any] = (), + work_wires: Optional[Any] = None, ) -> Operator: ... - - @overload def ctrl( op: Callable, control: Any, control_values: Optional[Sequence[bool]] = None, - work_wires: Optional[Any] = (), + work_wires: Optional[Any] = None, ) -> Callable: ... - - -def ctrl(op, control: WiresLike, control_values=None, work_wires: WiresLike = ()): +def ctrl(op, control: Any, control_values=None, work_wires=None): """Create a method that applies a controlled version of the provided op. :func:`~.qjit` compatible. @@ -152,7 +148,7 @@ def cond_fn(): return create_controlled_op(op, control, control_values=control_values, work_wires=work_wires) -def create_controlled_op(op, control: WiresLike, control_values=None, work_wires: WiresLike = ()): +def create_controlled_op(op, control, control_values=None, work_wires: WiresLike = ()): """Default ``qml.ctrl`` implementation, allowing other implementations to call it when needed.""" control = Wires(control) @@ -205,7 +201,7 @@ def create_controlled_op(op, control: WiresLike, control_values=None, work_wires return _ctrl_transform(op, control, control_values, work_wires) -def _ctrl_transform(op, control: WiresLike, control_values, work_wires: WiresLike): +def _ctrl_transform(op, control, control_values, work_wires): @wraps(op) def wrapper(*args, **kwargs): qscript = qml.tape.make_qscript(op)(*args, **kwargs) @@ -265,9 +261,7 @@ def _(*_, **__): return ctrl_prim -def _capture_ctrl_transform( - qfunc: Callable, control: WiresLike, control_values, work_wires: WiresLike -) -> Callable: +def _capture_ctrl_transform(qfunc: Callable, control, control_values, work_wires) -> Callable: """Capture compatible way of performing an ctrl transform.""" # note that this logic is tested in `tests/capture/test_nested_plxpr.py` import jax # pylint: disable=import-outside-toplevel @@ -347,7 +341,7 @@ def _try_wrap_in_custom_ctrl_op( def _handle_pauli_x_based_controlled_ops( - op, control: WiresLike, control_values, work_wires: WiresLike = () + op, control: WiresLike, control_values, work_wires: WiresLike = None ): """Handles PauliX-based controlled operations.""" @@ -492,12 +486,7 @@ def __new__(cls, base, *_, **__): # pylint: disable=arguments-differ @classmethod def _primitive_bind_call( - cls, - base, - control_wires: WiresLike, - control_values=None, - work_wires: WiresLike = (), - id=None, + cls, base, control_wires, control_values=None, work_wires=None, id=None ): control_wires = Wires(control_wires) return cls._primitive.bind( @@ -510,7 +499,7 @@ def __init__( base, control_wires: WiresLike, control_values=None, - work_wires: WiresLike = (), + work_wires: WiresLike = None, id=None, ): control_wires = Wires(control_wires) @@ -674,17 +663,17 @@ def _compute_matrix_from_base(self): return qmlmath.block_diag([left_pad, base_matrix, right_pad]) - def matrix(self, wire_order: WiresLike = None): + def matrix(self, wire_order=None): if self.compute_matrix is not Operator.compute_matrix: canonical_matrix = self.compute_matrix(*self.data) else: canonical_matrix = self._compute_matrix_from_base() - wire_order = self.wires if wire_order is None else wire_order + wire_order = wire_order or self.wires return qml.math.expand_matrix(canonical_matrix, wires=self.wires, wire_order=wire_order) # pylint: disable=arguments-differ - def sparse_matrix(self, wire_order: WiresLike = None, format="csr"): + def sparse_matrix(self, wire_order=None, format="csr"): if wire_order is not None: raise NotImplementedError("wire_order argument is not yet implemented.") @@ -936,14 +925,7 @@ def __new__(cls, *_, **__): return object.__new__(cls) # pylint: disable=too-many-function-args - def __init__( - self, - base, - control_wires: WiresLike, - control_values=None, - work_wires: WiresLike = None, - id=None, - ): + def __init__(self, base, control_wires, control_values=None, work_wires=None, id=None): super().__init__(base, control_wires, control_values, work_wires, id) # check the grad_recipe validity if self.grad_recipe is None: @@ -993,7 +975,7 @@ def parameter_frequencies(self): if Controlled._primitive is not None: # pylint: disable=protected-access @Controlled._primitive.def_impl # pylint: disable=protected-access - def _(base, *control_wires, control_values=None, work_wires: WiresLike = None, id=None): + def _(base, *control_wires, control_values=None, work_wires=None, id=None): return type.__call__( Controlled, base,