Skip to content

Commit

Permalink
Revert "fix: further clean-up in controlled.py"
Browse files Browse the repository at this point in the history
This reverts commit b6e306a.
  • Loading branch information
andrijapau committed Dec 19, 2024
1 parent b6e306a commit fe8d4a5
Showing 1 changed file with 14 additions and 32 deletions.
46 changes: 14 additions & 32 deletions pennylane/ops/op_math/controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,16 @@ def ctrl(
op: Operator,
control: Any,
control_values: Optional[Sequence[bool]] = None,
work_wires: Optional[Any] = (),
work_wires: Optional[Any] = None,
) -> Operator: ...


@overload
def ctrl(
op: Callable,
control: Any,
control_values: Optional[Sequence[bool]] = None,
work_wires: Optional[Any] = (),
work_wires: Optional[Any] = None,
) -> Callable: ...


def ctrl(op, control: WiresLike, control_values=None, work_wires: WiresLike = ()):
def ctrl(op, control: Any, control_values=None, work_wires=None):
"""Create a method that applies a controlled version of the provided op.
:func:`~.qjit` compatible.
Expand Down Expand Up @@ -152,7 +148,7 @@ def cond_fn():
return create_controlled_op(op, control, control_values=control_values, work_wires=work_wires)


def create_controlled_op(op, control: WiresLike, control_values=None, work_wires: WiresLike = ()):
def create_controlled_op(op, control, control_values=None, work_wires: WiresLike = ()):
"""Default ``qml.ctrl`` implementation, allowing other implementations to call it when needed."""

control = Wires(control)
Expand Down Expand Up @@ -205,7 +201,7 @@ def create_controlled_op(op, control: WiresLike, control_values=None, work_wires
return _ctrl_transform(op, control, control_values, work_wires)


def _ctrl_transform(op, control: WiresLike, control_values, work_wires: WiresLike):
def _ctrl_transform(op, control, control_values, work_wires):
@wraps(op)
def wrapper(*args, **kwargs):
qscript = qml.tape.make_qscript(op)(*args, **kwargs)
Expand Down Expand Up @@ -265,9 +261,7 @@ def _(*_, **__):
return ctrl_prim


def _capture_ctrl_transform(
qfunc: Callable, control: WiresLike, control_values, work_wires: WiresLike
) -> Callable:
def _capture_ctrl_transform(qfunc: Callable, control, control_values, work_wires) -> Callable:
"""Capture compatible way of performing an ctrl transform."""
# note that this logic is tested in `tests/capture/test_nested_plxpr.py`
import jax # pylint: disable=import-outside-toplevel
Expand Down Expand Up @@ -347,7 +341,7 @@ def _try_wrap_in_custom_ctrl_op(


def _handle_pauli_x_based_controlled_ops(
op, control: WiresLike, control_values, work_wires: WiresLike = ()
op, control: WiresLike, control_values, work_wires: WiresLike = None
):
"""Handles PauliX-based controlled operations."""

Expand Down Expand Up @@ -492,12 +486,7 @@ def __new__(cls, base, *_, **__):
# pylint: disable=arguments-differ
@classmethod
def _primitive_bind_call(
cls,
base,
control_wires: WiresLike,
control_values=None,
work_wires: WiresLike = (),
id=None,
cls, base, control_wires, control_values=None, work_wires=None, id=None
):
control_wires = Wires(control_wires)
return cls._primitive.bind(

Check notice on line 492 in pennylane/ops/op_math/controlled.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/ops/op_math/controlled.py#L492

Too many positional arguments (6/5) (too-many-positional-arguments)
Expand All @@ -510,7 +499,7 @@ def __init__(
base,
control_wires: WiresLike,
control_values=None,
work_wires: WiresLike = (),
work_wires: WiresLike = None,
id=None,
):
control_wires = Wires(control_wires)
Expand Down Expand Up @@ -674,17 +663,17 @@ def _compute_matrix_from_base(self):

return qmlmath.block_diag([left_pad, base_matrix, right_pad])

def matrix(self, wire_order: WiresLike = None):
def matrix(self, wire_order=None):
if self.compute_matrix is not Operator.compute_matrix:
canonical_matrix = self.compute_matrix(*self.data)
else:
canonical_matrix = self._compute_matrix_from_base()

wire_order = self.wires if wire_order is None else wire_order
wire_order = wire_order or self.wires
return qml.math.expand_matrix(canonical_matrix, wires=self.wires, wire_order=wire_order)

# pylint: disable=arguments-differ
def sparse_matrix(self, wire_order: WiresLike = None, format="csr"):
def sparse_matrix(self, wire_order=None, format="csr"):
if wire_order is not None:
raise NotImplementedError("wire_order argument is not yet implemented.")

Expand Down Expand Up @@ -936,14 +925,7 @@ def __new__(cls, *_, **__):
return object.__new__(cls)

# pylint: disable=too-many-function-args
def __init__(
self,
base,
control_wires: WiresLike,
control_values=None,
work_wires: WiresLike = None,
id=None,
):
def __init__(self, base, control_wires, control_values=None, work_wires=None, id=None):

Check notice on line 928 in pennylane/ops/op_math/controlled.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/ops/op_math/controlled.py#L928

Too many positional arguments (6/5) (too-many-positional-arguments)
super().__init__(base, control_wires, control_values, work_wires, id)
# check the grad_recipe validity
if self.grad_recipe is None:
Expand Down Expand Up @@ -993,7 +975,7 @@ def parameter_frequencies(self):
if Controlled._primitive is not None: # pylint: disable=protected-access

@Controlled._primitive.def_impl # pylint: disable=protected-access
def _(base, *control_wires, control_values=None, work_wires: WiresLike = None, id=None):
def _(base, *control_wires, control_values=None, work_wires=None, id=None):
return type.__call__(
Controlled,
base,
Expand Down

0 comments on commit fe8d4a5

Please sign in to comment.