diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index edd92247fb4..8df2567e397 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -341,6 +341,9 @@ such as `shots`, `rng` and `prng_key`.
Other Improvements
+* `Wires` object usage across Pennylane source code has been tidied up.
+ [(#6689)](https://github.com/PennyLaneAI/pennylane/pull/6689)
+
* `qml.equal` now supports `PauliWord` and `PauliSentence` instances.
[(#6703)](https://github.com/PennyLaneAI/pennylane/pull/6703)
@@ -375,6 +378,7 @@ such as `shots`, `rng` and `prng_key`.
* PennyLane is compatible with `quimb 1.10.0`.
[(#6630)](https://github.com/PennyLaneAI/pennylane/pull/6630)
+ [(#6736)](https://github.com/PennyLaneAI/pennylane/pull/6736)
* Add developer focused `run` function to `qml.workflow` module.
[(#6657)](https://github.com/PennyLaneAI/pennylane/pull/6657)
diff --git a/pennylane/devices/default_tensor.py b/pennylane/devices/default_tensor.py
index 3f68a41b646..cb19f625d9c 100644
--- a/pennylane/devices/default_tensor.py
+++ b/pennylane/devices/default_tensor.py
@@ -226,8 +226,8 @@ class DefaultTensor(Device):
`quimb's tensor_contract documentation `_.
Default is ``"auto-hq"``.
local_simplify (str): The simplification sequence to apply to the tensor network for computing local expectation values.
- For a complete list of available simplification options, see the
- `quimb's full_simplify documentation `_.
+ At present, this argument can only be provided when the TN method is used. For a complete list of available simplification options,
+ see the `quimb's full_simplify documentation `_.
Default is ``"ADCRS"``.
@@ -400,8 +400,10 @@ def __init__(
self._max_bond_dim = kwargs.get("max_bond_dim", None)
self._cutoff = kwargs.get("cutoff", None)
- # options both for MPS and TN
+ # options for TN
self._local_simplify = kwargs.get("local_simplify", "ADCRS")
+
+ # options both for MPS and TN
self._contraction_optimizer = kwargs.get("contraction_optimizer", "auto-hq")
self._contract = None
@@ -810,14 +812,22 @@ def _local_expectation(self, matrix, wires) -> float:
# after the execution, we could avoid copying the circuit.
qc = self._quimb_circuit.copy()
- exp_val = qc.local_expectation(
- matrix,
- wires,
- dtype=self._c_dtype.__name__,
- optimize=self._contraction_optimizer,
- simplify_sequence=self._local_simplify,
- simplify_atol=0.0,
- )
+ if self.method == "mps":
+ exp_val = qc.local_expectation(
+ matrix,
+ wires,
+ dtype=self._c_dtype.__name__,
+ optimize=self._contraction_optimizer,
+ )
+ else:
+ exp_val = qc.local_expectation(
+ matrix,
+ wires,
+ dtype=self._c_dtype.__name__,
+ optimize=self._contraction_optimizer,
+ simplify_sequence=self._local_simplify,
+ simplify_atol=0.0,
+ )
return float(np.real(exp_val))
diff --git a/pennylane/measurements/classical_shadow.py b/pennylane/measurements/classical_shadow.py
index 7af5b7f523f..7e9dc5285c8 100644
--- a/pennylane/measurements/classical_shadow.py
+++ b/pennylane/measurements/classical_shadow.py
@@ -23,7 +23,7 @@
import pennylane as qml
from pennylane.operation import Operator
-from pennylane.wires import Wires
+from pennylane.wires import Wires, WiresLike
from .measurements import MeasurementShapeError, MeasurementTransform, Shadow, ShadowExpval
@@ -89,7 +89,7 @@ def circuit(x, obs):
return ShadowExpvalMP(H=H, seed=seed, k=k)
-def classical_shadow(wires, seed=None):
+def classical_shadow(wires: WiresLike, seed=None):
"""
The classical shadow measurement protocol.
@@ -227,7 +227,10 @@ class ClassicalShadowMP(MeasurementTransform):
"""
def __init__(
- self, wires: Optional[Wires] = None, seed: Optional[int] = None, id: Optional[str] = None
+ self,
+ wires: Optional[WiresLike] = None,
+ seed: Optional[int] = None,
+ id: Optional[str] = None,
):
self.seed = seed
super().__init__(wires=wires, id=id)
diff --git a/pennylane/ops/channel.py b/pennylane/ops/channel.py
index f28166c039c..837ee72a83a 100644
--- a/pennylane/ops/channel.py
+++ b/pennylane/ops/channel.py
@@ -20,6 +20,7 @@
from pennylane import math as np
from pennylane.operation import AnyWires, Channel
+from pennylane.wires import Wires, WiresLike
class AmplitudeDamping(Channel):
@@ -58,7 +59,8 @@ class AmplitudeDamping(Channel):
num_wires = 1
grad_method = "F"
- def __init__(self, gamma, wires, id=None):
+ def __init__(self, gamma, wires: WiresLike, id=None):
+ wires = Wires(wires)
super().__init__(gamma, wires=wires, id=id)
@staticmethod
@@ -563,7 +565,8 @@ class PauliError(Channel):
num_params = 2
"""int: Number of trainable parameters that the operator depends on."""
- def __init__(self, operators, p, wires=None, id=None):
+ def __init__(self, operators, p, wires: WiresLike, id=None):
+ wires = Wires(wires)
super().__init__(operators, p, wires=wires, id=id)
# check if the specified operators are legal
@@ -713,7 +716,8 @@ class QubitChannel(Channel):
num_wires = AnyWires
grad_method = None
- def __init__(self, K_list, wires=None, id=None):
+ def __init__(self, K_list, wires: WiresLike, id=None):
+ wires = Wires(wires)
super().__init__(*K_list, wires=wires, id=id)
# check all Kraus matrices are square matrices
@@ -744,7 +748,8 @@ def _flatten(self):
# pylint: disable=arguments-differ, unused-argument
@classmethod
- def _primitive_bind_call(cls, K_list, wires=None, id=None):
+ def _primitive_bind_call(cls, K_list, wires: WiresLike, id=None):
+ wires = Wires(wires)
return super()._primitive_bind_call(*K_list, wires=wires)
@staticmethod
diff --git a/pennylane/ops/identity.py b/pennylane/ops/identity.py
index dc6a401af25..5667e849aab 100644
--- a/pennylane/ops/identity.py
+++ b/pennylane/ops/identity.py
@@ -27,6 +27,7 @@
Operation,
SparseMatrixUndefinedError,
)
+from pennylane.wires import WiresLike
class Identity(CVObservable, Operation):
@@ -61,15 +62,16 @@ class Identity(CVObservable, Operation):
ev_order = 1
@classmethod
- def _primitive_bind_call(cls, wires=None, **kwargs): # pylint: disable=arguments-differ
- wires = [] if wires is None else wires
+ def _primitive_bind_call(
+ cls, wires: WiresLike = (), **kwargs
+ ): # pylint: disable=arguments-differ
return super()._primitive_bind_call(wires=wires, **kwargs)
def _flatten(self):
return tuple(), (self.wires, tuple())
- def __init__(self, wires=None, id=None):
- super().__init__(wires=[] if wires is None else wires, id=id)
+ def __init__(self, wires: WiresLike = (), id=None):
+ super().__init__(wires=wires, id=id)
self._hyperparameters = {"n_wires": len(self.wires)}
self._pauli_rep = qml.pauli.PauliSentence({qml.pauli.PauliWord({}): 1.0})
@@ -308,12 +310,13 @@ def circuit():
grad_method = None
@classmethod
- def _primitive_bind_call(cls, phi, wires=None, **kwargs): # pylint: disable=arguments-differ
- wires = [] if wires is None else wires
+ def _primitive_bind_call(
+ cls, phi, wires: WiresLike = (), **kwargs
+ ): # pylint: disable=arguments-differ
return super()._primitive_bind_call(phi, wires=wires, **kwargs)
- def __init__(self, phi, wires=None, id=None):
- super().__init__(phi, wires=[] if wires is None else wires, id=id)
+ def __init__(self, phi, wires: WiresLike = (), id=None):
+ super().__init__(phi, wires=wires, id=id)
@staticmethod
def compute_eigvals(phi, n_wires=1): # pylint: disable=arguments-differ
@@ -412,7 +415,9 @@ def compute_diagonalizing_gates(
return []
@staticmethod
- def compute_decomposition(phi, wires=None): # pylint:disable=arguments-differ,unused-argument
+ def compute_decomposition(
+ phi, wires: WiresLike = ()
+ ): # pylint:disable=arguments-differ,unused-argument
r"""Representation of the operator as a product of other operators (static method).
.. note::
diff --git a/pennylane/ops/meta.py b/pennylane/ops/meta.py
index 6d63d820609..39405982ec8 100644
--- a/pennylane/ops/meta.py
+++ b/pennylane/ops/meta.py
@@ -23,7 +23,7 @@
import pennylane as qml
from pennylane.operation import AnyWires, Operation
-from pennylane.wires import Wires # pylint: disable=unused-import
+from pennylane.wires import Wires, WiresLike
class Barrier(Operation):
@@ -46,7 +46,8 @@ class Barrier(Operation):
num_wires = AnyWires
par_domain = None
- def __init__(self, wires=Wires([]), only_visual=False, id=None):
+ def __init__(self, wires: WiresLike = (), only_visual=False, id=None):
+ wires = Wires(wires)
self.only_visual = only_visual
self.hyperparameters["only_visual"] = only_visual
super().__init__(wires=wires, id=id)
@@ -119,16 +120,12 @@ class WireCut(Operation):
num_wires = AnyWires
grad_method = None
- def __init__(self, *params, wires=None, id=None):
- if wires == []:
- raise ValueError(
- f"{self.__class__.__name__}: wrong number of wires. "
- f"At least one wire has to be given."
- )
- super().__init__(*params, wires=wires, id=id)
+ def __init__(self, wires: WiresLike = (), id=None):
+ wires = Wires(wires)
+ super().__init__(wires=wires, id=id)
@staticmethod
- def compute_decomposition(wires): # pylint: disable=unused-argument
+ def compute_decomposition(wires: WiresLike): # pylint: disable=unused-argument
r"""Representation of the operator as a product of other operators (static method).
Since this operator is a placeholder inside a circuit, it decomposes into an empty list.
diff --git a/pennylane/ops/op_math/controlled.py b/pennylane/ops/op_math/controlled.py
index 06279a04a39..d49209660c6 100644
--- a/pennylane/ops/op_math/controlled.py
+++ b/pennylane/ops/op_math/controlled.py
@@ -31,7 +31,7 @@
from pennylane.capture.capture_diff import create_non_interpreted_prim
from pennylane.compiler import compiler
from pennylane.operation import Operator
-from pennylane.wires import Wires
+from pennylane.wires import Wires, WiresLike
from .controlled_decompositions import ctrl_decomp_bisect, ctrl_decomp_zyz
from .symbolicop import SymbolicOp
@@ -175,7 +175,7 @@ def create_controlled_op(op, control, control_values=None, work_wires=None):
# Flatten nested controlled operations to a multi-controlled operation for better
# decomposition algorithms. This includes special cases like CRX, CRot, etc.
if isinstance(op, Controlled):
- work_wires = work_wires or []
+ work_wires = () if work_wires is None else work_wires
return ctrl(
op.base,
control=control + op.control_wires,
@@ -489,9 +489,16 @@ def _primitive_bind_call(
)
# pylint: disable=too-many-function-args
- def __init__(self, base, control_wires, control_values=None, work_wires=None, id=None):
+ def __init__(
+ self,
+ base,
+ control_wires: WiresLike,
+ control_values=None,
+ work_wires: WiresLike = None,
+ id=None,
+ ):
control_wires = Wires(control_wires)
- work_wires = Wires([]) if work_wires is None else Wires(work_wires)
+ work_wires = Wires(() if work_wires is None else work_wires)
if control_values is None:
control_values = [True] * len(control_wires)
diff --git a/pennylane/ops/op_math/controlled_ops.py b/pennylane/ops/op_math/controlled_ops.py
index 7f3e2258f64..5a4ecdaf8b1 100644
--- a/pennylane/ops/op_math/controlled_ops.py
+++ b/pennylane/ops/op_math/controlled_ops.py
@@ -1155,7 +1155,16 @@ def _primitive_bind_call(cls, wires, control_values=None, work_wires=None, id=No
)
# pylint: disable=too-many-arguments
- def __init__(self, control_wires=None, wires=None, control_values=None, work_wires=None):
+ def __init__(
+ self,
+ control_wires: WiresLike = (),
+ wires: WiresLike = (),
+ control_values=None,
+ work_wires: WiresLike = (),
+ ):
+ control_wires = Wires(() if control_wires is None else control_wires)
+ wires = Wires(() if wires is None else wires)
+ work_wires = Wires(() if work_wires is None else work_wires)
# First raise deprecation warnings regardless of the validity of other arguments
if isinstance(control_values, str):
@@ -1164,22 +1173,19 @@ def __init__(self, control_wires=None, wires=None, control_values=None, work_wir
"supported in future releases, Use a list of booleans or integers instead.",
qml.PennyLaneDeprecationWarning,
)
- if control_wires is not None:
+ if len(control_wires) > 0:
warnings.warn(
"The control_wires keyword for MultiControlledX is deprecated, and will "
"be removed soon. Use wires = (*control_wires, target_wire) instead.",
UserWarning,
)
- if wires is None:
+ if len(wires) == 0:
raise ValueError("Must specify the wires where the operation acts on")
- wires = wires if isinstance(wires, Wires) else Wires(wires)
-
- if control_wires is not None:
+ if len(control_wires) > 0:
if len(wires) != 1:
raise ValueError("MultiControlledX accepts a single target wire.")
- control_wires = Wires(control_wires)
else:
if len(wires) < 2:
raise ValueError(
@@ -1212,7 +1218,7 @@ def wires(self):
# pylint: disable=unused-argument, arguments-differ
@staticmethod
- def compute_matrix(control_wires, control_values=None, **kwargs):
+ def compute_matrix(control_wires: WiresLike, control_values=None, **kwargs):
r"""Representation of the operator as a canonical matrix in the computational basis (static method).
The canonical matrix is the textbook matrix representation that does not consider wires.
@@ -1255,7 +1261,9 @@ def matrix(self, wire_order=None):
# pylint: disable=unused-argument, arguments-differ
@staticmethod
- def compute_decomposition(wires=None, work_wires=None, control_values=None, **kwargs):
+ def compute_decomposition(
+ wires: WiresLike = None, work_wires: WiresLike = None, control_values=None, **kwargs
+ ):
r"""Representation of the operator as a product of other operators (static method).
.. math:: O = O_1 O_2 \dots O_n.
@@ -1282,6 +1290,7 @@ def compute_decomposition(wires=None, work_wires=None, control_values=None, **kw
Toffoli(wires=[0, 1, 'aux'])]
"""
+ wires = Wires(() if wires is None else wires)
if len(wires) < 2:
raise ValueError(f"Wrong number of wires. {len(wires)} given. Need at least 2.")
@@ -1357,7 +1366,7 @@ class CRX(ControlledOp):
name = "CRX"
parameter_frequencies = [(0.5, 1.0)]
- def __init__(self, phi, wires, id=None):
+ def __init__(self, phi, wires: WiresLike, id=None):
# We use type.__call__ instead of calling the class directly so that we don't bind the
# operator primitive when new program capture is enabled
base = type.__call__(qml.RX, phi, wires=wires[1:])
@@ -1374,7 +1383,7 @@ def _unflatten(cls, data, metadata):
return cls(*data, wires=metadata[0])
@classmethod
- def _primitive_bind_call(cls, phi, wires, id=None):
+ def _primitive_bind_call(cls, phi, wires: WiresLike, id=None):
return cls._primitive.bind(phi, *wires, n_wires=len(wires))
@staticmethod
@@ -1425,7 +1434,7 @@ def compute_matrix(theta): # pylint: disable=arguments-differ
return qml.math.stack([stack_last(row) for row in matrix], axis=-2)
@staticmethod
- def compute_decomposition(phi, wires): # pylint: disable=arguments-differ
+ def compute_decomposition(phi, wires: WiresLike): # pylint: disable=arguments-differ
r"""Representation of the operator as a product of other operators (static method). :
.. math:: O = O_1 O_2 \dots O_n.
diff --git a/pennylane/ops/qubit/non_parametric_ops.py b/pennylane/ops/qubit/non_parametric_ops.py
index a1e287f0933..e5f1328bf12 100644
--- a/pennylane/ops/qubit/non_parametric_ops.py
+++ b/pennylane/ops/qubit/non_parametric_ops.py
@@ -252,7 +252,7 @@ def pauli_rep(self):
)
return self._pauli_rep
- def __init__(self, wires: Optional[WiresLike] = None, id: Optional[str] = None):
+ def __init__(self, wires: WiresLike, id: Optional[str] = None):
super().__init__(wires=wires, id=id)
def label(
diff --git a/pennylane/ops/qubit/parametric_ops_multi_qubit.py b/pennylane/ops/qubit/parametric_ops_multi_qubit.py
index 1064315965c..e49bb9d3542 100644
--- a/pennylane/ops/qubit/parametric_ops_multi_qubit.py
+++ b/pennylane/ops/qubit/parametric_ops_multi_qubit.py
@@ -661,7 +661,7 @@ def compute_eigvals(*params: TensorLike, **hyperparams) -> TensorLike:
@staticmethod
def compute_decomposition(
- *params: TensorLike, wires: Optional[WiresLike] = None, **hyperparams
+ *params: TensorLike, wires: WiresLike, **hyperparams
) -> list["qml.operation.Operator"]:
r"""Representation of the operator as a product of other operators (static method).
diff --git a/pennylane/pauli/pauli_arithmetic.py b/pennylane/pauli/pauli_arithmetic.py
index e0c1a131a19..70fcee0f0de 100644
--- a/pennylane/pauli/pauli_arithmetic.py
+++ b/pennylane/pauli/pauli_arithmetic.py
@@ -23,7 +23,7 @@
from pennylane import math
from pennylane.ops import Identity, PauliX, PauliY, PauliZ, Prod, SProd, Sum
from pennylane.typing import TensorLike
-from pennylane.wires import Wires
+from pennylane.wires import Wires, WiresLike
I = "I"
X = "X"
@@ -504,7 +504,7 @@ def _get_csr_indices(self, wire_order):
current_size *= 2
return indices
- def operation(self, wire_order=None):
+ def operation(self, wire_order: WiresLike = ()):
"""Returns a native PennyLane :class:`~pennylane.operation.Operation` representing the PauliWord."""
if len(self) == 0:
return Identity(wires=wire_order)
@@ -1001,7 +1001,7 @@ def _sum_different_structure_pws(indices, data):
matrix.eliminate_zeros()
return matrix
- def operation(self, wire_order=None):
+ def operation(self, wire_order: WiresLike = ()):
"""Returns a native PennyLane :class:`~pennylane.operation.Operation` representing the PauliSentence."""
if len(self) == 0:
return qml.s_prod(0, Identity(wires=wire_order))
diff --git a/pennylane/templates/subroutines/grover.py b/pennylane/templates/subroutines/grover.py
index 11924dc828c..1d051fd3fe5 100644
--- a/pennylane/templates/subroutines/grover.py
+++ b/pennylane/templates/subroutines/grover.py
@@ -18,7 +18,7 @@
from pennylane.operation import AnyWires, Operation
from pennylane.ops import GlobalPhase, Hadamard, MultiControlledX, PauliZ
-from pennylane.wires import Wires
+from pennylane.wires import Wires, WiresLike
class GroverOperator(Operation):
@@ -109,13 +109,16 @@ def _flatten(self):
hyperparameters = (("work_wires", self.hyperparameters["work_wires"]),)
return tuple(), (self.wires, hyperparameters)
- def __init__(self, wires=None, work_wires=None, id=None):
- if (not hasattr(wires, "__len__")) or (len(wires) < 2):
+ def __init__(self, wires: WiresLike, work_wires: WiresLike = (), id=None):
+ wires = Wires(wires)
+ work_wires = Wires(() if work_wires is None else work_wires)
+
+ if len(wires) < 2:
raise ValueError("GroverOperator must have at least two wires provided.")
self._hyperparameters = {
"n_wires": len(wires),
- "work_wires": Wires(work_wires) if work_wires is not None else Wires([]),
+ "work_wires": work_wires,
}
super().__init__(wires=wires, id=id)
@@ -126,7 +129,7 @@ def num_params(self):
@staticmethod
def compute_decomposition(
- wires, work_wires, **kwargs
+ wires: WiresLike, work_wires: WiresLike, **kwargs
): # pylint: disable=arguments-differ,unused-argument
r"""Representation of the operator as a product of other operators.
diff --git a/pennylane/templates/subroutines/phase_adder.py b/pennylane/templates/subroutines/phase_adder.py
index 722dfad9787..9460f0e3feb 100644
--- a/pennylane/templates/subroutines/phase_adder.py
+++ b/pennylane/templates/subroutines/phase_adder.py
@@ -19,7 +19,7 @@
import pennylane as qml
from pennylane.operation import Operation
-from pennylane.wires import WiresLike
+from pennylane.wires import Wires, WiresLike
def _add_k_fourier(k, wires: WiresLike):
@@ -127,8 +127,8 @@ def __init__(
self, k, x_wires: WiresLike, mod=None, work_wire: WiresLike = (), id=None
): # pylint: disable=too-many-arguments
- work_wire = qml.wires.Wires(() if work_wire is None else work_wire)
- x_wires = qml.wires.Wires(x_wires)
+ work_wire = Wires(() if work_wire is None else work_wire)
+ x_wires = Wires(x_wires)
num_work_wires = len(work_wire)
@@ -152,15 +152,11 @@ def __init__(
"None of the wires in work_wire should be included in x_wires."
)
- all_wires = (
- qml.wires.Wires(x_wires) + qml.wires.Wires(work_wire)
- if work_wire
- else qml.wires.Wires(x_wires)
- )
+ all_wires = x_wires + work_wire
self.hyperparameters["k"] = k % mod
self.hyperparameters["mod"] = mod
- self.hyperparameters["work_wire"] = qml.wires.Wires(work_wire)
+ self.hyperparameters["work_wire"] = work_wire
self.hyperparameters["x_wires"] = x_wires
super().__init__(wires=all_wires, id=id)
diff --git a/pennylane/templates/subroutines/qft.py b/pennylane/templates/subroutines/qft.py
index 4cb0bd758df..7be360f05d1 100644
--- a/pennylane/templates/subroutines/qft.py
+++ b/pennylane/templates/subroutines/qft.py
@@ -22,6 +22,7 @@
import pennylane as qml
from pennylane.operation import AnyWires, Operation
+from pennylane.wires import Wires, WiresLike
class QFT(Operation):
@@ -131,8 +132,8 @@ def scFT_add(m, k, n_wires):
num_wires = AnyWires
grad_method = None
- def __init__(self, wires=None, id=None):
- wires = qml.wires.Wires(wires)
+ def __init__(self, wires: WiresLike, id=None):
+ wires = Wires(wires)
self.hyperparameters["n_wires"] = len(wires)
super().__init__(wires=wires, id=id)
@@ -143,13 +144,16 @@ def _flatten(self):
def num_params(self):
return 0
+ def decomposition(self):
+ return self.compute_decomposition(wires=self.wires)
+
@staticmethod
@functools.lru_cache()
def compute_matrix(n_wires): # pylint: disable=arguments-differ
return np.fft.ifft(np.eye(2**n_wires), norm="ortho")
@staticmethod
- def compute_decomposition(wires, n_wires): # pylint: disable=arguments-differ,unused-argument
+ def compute_decomposition(wires: WiresLike): # pylint: disable=arguments-differ,unused-argument
r"""Representation of the operator as a product of other operators (static method).
.. math:: O = O_1 O_2 \dots O_n.
@@ -159,24 +163,25 @@ def compute_decomposition(wires, n_wires): # pylint: disable=arguments-differ,u
Args:
wires (Iterable, Wires): wires that the operator acts on
- n_wires (int): number of wires or ``len(wires)``
Returns:
list[Operator]: decomposition of the operator
**Example:**
- >>> qml.QFT.compute_decomposition(wires=(0,1,2), n_wires=3)
+ >>> qml.QFT.compute_decomposition(wires=(0,1,2))
[H(0),
ControlledPhaseShift(1.5707963267948966, wires=Wires([1, 0])),
ControlledPhaseShift(0.7853981633974483, wires=Wires([2, 0])),
H(1),
ControlledPhaseShift(1.5707963267948966, wires=Wires([2, 1])),
H(2),
- H(4),
- SWAP(wires=[0, 4])]
+ SWAP(wires=[0, 2])]
"""
+ wires = Wires(wires)
+ n_wires = len(wires)
+
shifts = [2 * np.pi * 2**-i for i in range(2, n_wires + 1)]
shift_len = len(shifts)
diff --git a/tests/templates/test_subroutines/test_qft.py b/tests/templates/test_subroutines/test_qft.py
index a3554b1d079..1716476fddd 100644
--- a/tests/templates/test_subroutines/test_qft.py
+++ b/tests/templates/test_subroutines/test_qft.py
@@ -37,6 +37,24 @@ def test_QFT(self):
exp = QFT
assert np.allclose(res, exp)
+ @pytest.mark.parametrize("n_qubits", range(2, 6))
+ def test_QFT_compute_decomposition(self, n_qubits):
+ """Test if the QFT operation is correctly decomposed"""
+ decomp = qml.QFT.compute_decomposition(wires=range(n_qubits))
+
+ dev = qml.device("default.qubit", wires=n_qubits)
+
+ out_states = []
+ for state in np.eye(2**n_qubits):
+ ops = [qml.StatePrep(state, wires=range(n_qubits))] + decomp
+ qs = qml.tape.QuantumScript(ops, [qml.state()])
+ out_states.append(dev.execute(qs))
+
+ reconstructed_unitary = np.array(out_states).T
+ expected_unitary = qml.QFT(wires=range(n_qubits)).matrix()
+
+ assert np.allclose(reconstructed_unitary, expected_unitary)
+
@pytest.mark.parametrize("n_qubits", range(2, 6))
def test_QFT_decomposition(self, n_qubits):
"""Test if the QFT operation is correctly decomposed"""