diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 1804494a8ae..e077b96306b 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -85,6 +85,9 @@
unprocessed diagonalizing gates.
[(#6290)](https://github.com/PennyLaneAI/pennylane/pull/6290)
+* A more sensible error message is raised from a `RecursionError` encountered when accessing properties and methods of a nested `CompositeOp` or `SProd`.
+ [(#6375)](https://github.com/PennyLaneAI/pennylane/pull/6375)
+
Capturing and representing hybrid programs
* `qml.wires.Wires` now accepts JAX arrays as input. Furthermore, a `FutureWarning` is no longer raised in `JAX 0.4.30+`
diff --git a/pennylane/ops/op_math/composite.py b/pennylane/ops/op_math/composite.py
index 4b0b75f905f..51dd94fb605 100644
--- a/pennylane/ops/op_math/composite.py
+++ b/pennylane/ops/op_math/composite.py
@@ -18,6 +18,7 @@
import abc
import copy
from collections.abc import Callable
+from functools import wraps
import pennylane as qml
from pennylane import math
@@ -27,6 +28,24 @@
# pylint: disable=too-many-instance-attributes
+def handle_recursion_error(func):
+ """Handles any recursion errors raised from too many levels of nesting."""
+
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ try:
+ return func(*args, **kwargs)
+ except RecursionError as e:
+ raise RuntimeError(
+ "Maximum recursion depth reached! This is likely due to nesting too many levels "
+ "of composite operators. Try setting lazy=False when calling qml.sum, qml.prod, "
+ "and qml.s_prod, or use the +, @, and * operators instead. Alternatively, you "
+ "can periodically call qml.simplify on your operators."
+ ) from e
+
+ return wrapper
+
+
class CompositeOp(Operator):
"""A base class for operators that are composed of other operators.
@@ -70,6 +89,7 @@ def __init__(
self.queue()
self._batch_size = _UNSET_BATCH_SIZE
+ @handle_recursion_error
def _check_batching(self):
batch_sizes = {op.batch_size for op in self if op.batch_size is not None}
if len(batch_sizes) > 1:
@@ -84,6 +104,7 @@ def __repr__(self):
[f"({op})" if op.arithmetic_depth > 0 else f"{op}" for op in self]
)
+ @handle_recursion_error
def __copy__(self):
cls = self.__class__
copied_op = cls.__new__(cls)
@@ -113,6 +134,7 @@ def _op_symbol(self) -> str:
"""The symbol used when visualizing the composite operator"""
@property
+ @handle_recursion_error
def data(self):
"""Create data property"""
return tuple(d for op in self for d in op.data)
@@ -132,6 +154,7 @@ def num_wires(self):
return len(self.wires)
@property
+ @handle_recursion_error
def num_params(self):
return sum(op.num_params for op in self)
@@ -152,9 +175,11 @@ def is_hermitian(self):
# pylint: disable=arguments-renamed, invalid-overridden-method
@property
+ @handle_recursion_error
def has_matrix(self):
return all(op.has_matrix or isinstance(op, qml.ops.Hamiltonian) for op in self)
+ @handle_recursion_error
def eigvals(self):
"""Return the eigenvalues of the specified operator.
@@ -290,6 +315,7 @@ def diagonalizing_gates(self):
)
return diag_gates
+ @handle_recursion_error
def label(self, decimals=None, base_label=None, cache=None):
r"""How the composite operator is represented in diagrams and drawings.
@@ -344,6 +370,7 @@ def _sort(cls, op_list, wire_map: dict = None) -> list[Operator]:
"""Sort composite operands by their wire indices."""
@property
+ @handle_recursion_error
def hash(self):
if self._hash is None:
self._hash = hash(
@@ -357,6 +384,7 @@ def basis(self):
return None
@property
+ @handle_recursion_error
def arithmetic_depth(self) -> int:
return 1 + max(op.arithmetic_depth for op in self)
@@ -365,6 +393,7 @@ def arithmetic_depth(self) -> int:
def _math_op(self) -> Callable:
"""The function used when combining the operands of the composite operator"""
+ @handle_recursion_error
def map_wires(self, wire_map: dict):
# pylint:disable=protected-access
cls = self.__class__
diff --git a/pennylane/ops/op_math/prod.py b/pennylane/ops/op_math/prod.py
index 29770c40c87..7190142f2d3 100644
--- a/pennylane/ops/op_math/prod.py
+++ b/pennylane/ops/op_math/prod.py
@@ -34,7 +34,7 @@
from pennylane.queuing import QueuingManager
from pennylane.typing import TensorLike
-from .composite import CompositeOp
+from .composite import CompositeOp, handle_recursion_error
MAX_NUM_WIRES_KRON_PRODUCT = 9
"""The maximum number of wires up to which using ``math.kron`` is faster than ``math.dot`` for
@@ -273,6 +273,7 @@ def decomposition(self):
return [qml.apply(op) for op in self[::-1]]
return list(self[::-1])
+ @handle_recursion_error
def matrix(self, wire_order=None):
"""Representation of the operator as a matrix in the computational basis."""
if self.pauli_rep:
@@ -309,6 +310,7 @@ def matrix(self, wire_order=None):
)
return math.expand_matrix(full_mat, self.wires, wire_order=wire_order)
+ @handle_recursion_error
def sparse_matrix(self, wire_order=None):
if self.pauli_rep: # Get the sparse matrix from the PauliSentence representation
return self.pauli_rep.to_mat(wire_order=wire_order or self.wires, format="csr")
@@ -326,11 +328,13 @@ def sparse_matrix(self, wire_order=None):
return math.expand_matrix(full_mat, self.wires, wire_order=wire_order)
@property
+ @handle_recursion_error
def has_sparse_matrix(self):
return self.pauli_rep is not None or all(op.has_sparse_matrix for op in self)
# pylint: disable=protected-access
@property
+ @handle_recursion_error
def _queue_category(self):
"""Used for sorting objects into their respective lists in `QuantumTape` objects.
This property is a temporary solution that should not exist long-term and should not be
@@ -353,10 +357,6 @@ def has_adjoint(self):
def adjoint(self):
return Prod(*(qml.adjoint(factor) for factor in self[::-1]))
- @property
- def arithmetic_depth(self) -> int:
- return 1 + max(factor.arithmetic_depth for factor in self)
-
def _build_pauli_rep(self):
"""PauliSentence representation of the Product of operations."""
if all(operand_pauli_reps := [op.pauli_rep for op in self.operands]):
@@ -378,6 +378,7 @@ def _simplify_factors(self, factors: tuple[Operator]) -> tuple[complex, Operator
new_factors.remove_factors(wires=self.wires)
return new_factors.global_phase, new_factors.factors
+ @handle_recursion_error
def simplify(self) -> Union["Prod", Sum]:
r"""
Transforms any nested Prod instance into the form :math:`\sum c_i O_i` where
@@ -432,6 +433,7 @@ def _sort(cls, op_list, wire_map: dict = None) -> list[Operator]:
return op_list
+ @handle_recursion_error
def terms(self):
r"""Representation of the operator as a linear combination of other operators.
diff --git a/pennylane/ops/op_math/sprod.py b/pennylane/ops/op_math/sprod.py
index 64fcdbdfd72..d73b29fce54 100644
--- a/pennylane/ops/op_math/sprod.py
+++ b/pennylane/ops/op_math/sprod.py
@@ -25,6 +25,7 @@
from pennylane.ops.op_math.sum import Sum
from pennylane.queuing import QueuingManager
+from .composite import handle_recursion_error
from .symbolicop import ScalarSymbolicOp
@@ -155,12 +156,14 @@ def __init__(
else:
self._pauli_rep = None
+ @handle_recursion_error
def __repr__(self):
"""Constructor-call-like representation."""
if isinstance(self.base, qml.ops.CompositeOp):
return f"{self.scalar} * ({self.base})"
return f"{self.scalar} * {self.base}"
+ @handle_recursion_error
def label(self, decimals=None, base_label=None, cache=None):
"""The label produced for the SProd op."""
scalar_val = (
@@ -172,6 +175,7 @@ def label(self, decimals=None, base_label=None, cache=None):
return base_label or f"{scalar_val}*{self.base.label(decimals=decimals, cache=cache)}"
@property
+ @handle_recursion_error
def num_params(self):
"""Number of trainable parameters that the operator depends on.
Usually 1 + the number of trainable parameters for the base op.
@@ -181,6 +185,7 @@ def num_params(self):
"""
return 1 + self.base.num_params
+ @handle_recursion_error
def terms(self):
r"""Representation of the operator as a linear combination of other operators.
@@ -200,6 +205,7 @@ def terms(self):
return [self.scalar], [self.base]
@property
+ @handle_recursion_error
def is_hermitian(self):
"""If the base operator is hermitian and the scalar is real,
then the scalar product operator is hermitian."""
@@ -207,10 +213,12 @@ def is_hermitian(self):
# pylint: disable=arguments-renamed,invalid-overridden-method
@property
+ @handle_recursion_error
def has_diagonalizing_gates(self):
"""Bool: Whether the Operator returns defined diagonalizing gates."""
return self.base.has_diagonalizing_gates
+ @handle_recursion_error
def diagonalizing_gates(self):
r"""Sequence of gates that diagonalize the operator in the computational basis.
@@ -230,6 +238,7 @@ def diagonalizing_gates(self):
"""
return self.base.diagonalizing_gates()
+ @handle_recursion_error
def eigvals(self):
r"""Return the eigenvalues of the specified operator.
@@ -244,6 +253,7 @@ def eigvals(self):
base_eigs = qml.math.convert_like(base_eigs, self.scalar)
return self.scalar * base_eigs
+ @handle_recursion_error
def sparse_matrix(self, wire_order=None):
"""Computes, by default, a `scipy.sparse.csr_matrix` representation of this Tensor.
@@ -264,15 +274,18 @@ def sparse_matrix(self, wire_order=None):
return mat
@property
+ @handle_recursion_error
def has_sparse_matrix(self):
return self.pauli_rep is not None or self.base.has_sparse_matrix
@property
+ @handle_recursion_error
def has_matrix(self):
"""Bool: Whether or not the Operator returns a defined matrix."""
return isinstance(self.base, qml.ops.Hamiltonian) or self.base.has_matrix
@staticmethod
+ @handle_recursion_error
def _matrix(scalar, mat):
return scalar * mat
@@ -303,6 +316,7 @@ def adjoint(self):
return SProd(scalar=qml.math.conjugate(self.scalar), base=qml.adjoint(self.base))
# pylint: disable=too-many-return-statements
+ @handle_recursion_error
def simplify(self) -> Operator:
"""Reduce the depth of nested operators to the minimum.
diff --git a/pennylane/ops/op_math/sum.py b/pennylane/ops/op_math/sum.py
index 40ed5f229c3..58590480f40 100644
--- a/pennylane/ops/op_math/sum.py
+++ b/pennylane/ops/op_math/sum.py
@@ -28,7 +28,7 @@
from pennylane.operation import Operator, convert_to_opmath
from pennylane.queuing import QueuingManager
-from .composite import CompositeOp
+from .composite import CompositeOp, handle_recursion_error
def sum(*summands, grouping_type=None, method="rlf", id=None, lazy=True):
@@ -238,6 +238,7 @@ def __init__(
self.compute_grouping(grouping_type=grouping_type, method=method)
@property
+ @handle_recursion_error
def hash(self):
# Since addition is always commutative, we do not need to sort
return hash(("Sum", hash(frozenset(Counter(self.operands).items()))))
@@ -277,11 +278,13 @@ def grouping_indices(self, value):
# make sure all tuples so can be hashable
self._grouping_indices = tuple(tuple(sublist) for sublist in value)
+ @handle_recursion_error
def __str__(self):
"""String representation of the Sum."""
ops = self.operands
return " + ".join(f"{str(op)}" if i == 0 else f"{str(op)}" for i, op in enumerate(ops))
+ @handle_recursion_error
def __repr__(self):
"""Terminal representation for Sum"""
# post-processing the flat str() representation
@@ -293,6 +296,7 @@ def __repr__(self):
return main_string
@property
+ @handle_recursion_error
def is_hermitian(self):
"""If all of the terms in the sum are hermitian, then the Sum is hermitian."""
if self.pauli_rep is not None:
@@ -304,10 +308,12 @@ def is_hermitian(self):
return all(s.is_hermitian for s in self)
+ @handle_recursion_error
def label(self, decimals=None, base_label=None, cache=None):
decimals = None if (len(self.parameters) > 3) else decimals
return Operator.label(self, decimals=decimals, base_label=base_label or "𝓗", cache=cache)
+ @handle_recursion_error
def matrix(self, wire_order=None):
r"""Representation of the operator as a matrix in the computational basis.
@@ -344,9 +350,11 @@ def matrix(self, wire_order=None):
# pylint: disable=arguments-renamed, invalid-overridden-method
@property
+ @handle_recursion_error
def has_sparse_matrix(self) -> bool:
return self.pauli_rep is not None or all(op.has_sparse_matrix for op in self)
+ @handle_recursion_error
def sparse_matrix(self, wire_order=None):
if self.pauli_rep: # Get the sparse matrix from the PauliSentence representation
return self.pauli_rep.to_mat(wire_order=wire_order or self.wires, format="csr")
@@ -417,6 +425,7 @@ def _simplify_summands(cls, summands: list[Operator]):
return new_summands
+ @handle_recursion_error
def simplify(self, cutoff=1.0e-12) -> "Sum": # pylint: disable=arguments-differ
# try using pauli_rep:
if pr := self.pauli_rep:
@@ -428,6 +437,7 @@ def simplify(self, cutoff=1.0e-12) -> "Sum": # pylint: disable=arguments-differ
return Sum(*new_summands) if len(new_summands) > 1 else new_summands[0]
return qml.s_prod(0, qml.Identity(self.wires))
+ @handle_recursion_error
def terms(self):
r"""Representation of the operator as a linear combination of other operators.
diff --git a/pennylane/ops/op_math/symbolicop.py b/pennylane/ops/op_math/symbolicop.py
index 24f81e52e93..304684d06e5 100644
--- a/pennylane/ops/op_math/symbolicop.py
+++ b/pennylane/ops/op_math/symbolicop.py
@@ -23,6 +23,8 @@
from pennylane.operation import _UNSET_BATCH_SIZE, Operator
from pennylane.queuing import QueuingManager
+from .composite import handle_recursion_error
+
class SymbolicOp(Operator):
"""Developer-facing base class for single-operator symbolic operators.
@@ -53,6 +55,7 @@ def _primitive_bind_call(cls, *args, **kwargs):
return cls._primitive.bind(*args, **kwargs)
# pylint: disable=attribute-defined-outside-init
+ @handle_recursion_error
def __copy__(self):
# this method needs to be overwritten because the base must be copied too.
copied_op = object.__new__(type(self))
@@ -98,11 +101,13 @@ def num_params(self):
return self.base.num_params
@property
+ @handle_recursion_error
def wires(self):
return self.base.wires
# pylint:disable = missing-function-docstring
@property
+ @handle_recursion_error
def basis(self):
return self.base.basis
@@ -130,6 +135,7 @@ def queue(self, context=QueuingManager):
return self
@property
+ @handle_recursion_error
def arithmetic_depth(self) -> int:
return 1 + self.base.arithmetic_depth
@@ -142,6 +148,7 @@ def hash(self):
)
)
+ @handle_recursion_error
def map_wires(self, wire_map: dict):
new_op = copy(self)
new_op.hyperparameters["base"] = self.base.map_wires(wire_map=wire_map)
@@ -172,6 +179,7 @@ def __init__(self, base, scalar: float, id=None):
self._batch_size = _UNSET_BATCH_SIZE
@property
+ @handle_recursion_error
def batch_size(self):
if self._batch_size is _UNSET_BATCH_SIZE:
base_batch_size = self.base.batch_size
@@ -190,6 +198,7 @@ def batch_size(self):
return self._batch_size
@property
+ @handle_recursion_error
def data(self):
return (self.scalar, *self.base.data)
@@ -199,10 +208,12 @@ def data(self, new_data):
self.base.data = new_data[1:]
@property
+ @handle_recursion_error
def has_matrix(self):
return self.base.has_matrix or isinstance(self.base, qml.ops.Hamiltonian)
@property
+ @handle_recursion_error
def hash(self):
return hash(
(
@@ -225,6 +236,7 @@ def _matrix(scalar, mat):
mat (ndarray): non-broadcasted matrix
"""
+ @handle_recursion_error
def matrix(self, wire_order=None):
r"""Representation of the operator as a matrix in the computational basis.
diff --git a/tests/ops/op_math/test_composite.py b/tests/ops/op_math/test_composite.py
index 777665ffce1..b417e248978 100644
--- a/tests/ops/op_math/test_composite.py
+++ b/tests/ops/op_math/test_composite.py
@@ -14,6 +14,8 @@
"""
Unit tests for the composite operator class of qubit operations
"""
+import inspect
+
# pylint:disable=protected-access
from copy import copy
@@ -233,6 +235,68 @@ def test_tensor_and_hamiltonian_converted(self):
)
+@pytest.mark.parametrize("math_op", [qml.prod, qml.sum])
+def test_no_recursion_error_raised(math_op):
+ """Tests that no RecursionError is raised from any property of method of a nested op."""
+
+ op = qml.RX(np.random.uniform(0, 2 * np.pi), wires=1)
+ for _ in range(2000):
+ op = math_op(op, qml.RY(np.random.uniform(0, 2 * np.pi), wires=1))
+ _assert_method_and_property_no_recursion_error(op)
+
+
+def test_no_recursion_error_raised_sprod():
+ """Tests that no RecursionError is raised from any property of method of a nested SProd."""
+
+ op = qml.RX(np.random.uniform(0, 2 * np.pi), wires=1)
+ for _ in range(5000):
+ op = qml.s_prod(1, op)
+ _assert_method_and_property_no_recursion_error(op)
+
+
+def _assert_method_and_property_no_recursion_error(instance):
+ """Checks that all methods and properties do not raise a RecursionError when accessed."""
+
+ for name, attr in inspect.getmembers(instance.__class__):
+
+ if inspect.isfunction(attr) and _is_method_with_no_argument(attr):
+ _assert_method_no_recursion_error(instance, name)
+
+ if isinstance(attr, property):
+ _assert_property_no_recursion_error(instance, name)
+
+
+def _assert_method_no_recursion_error(instance, method_name):
+ """Checks that the method does not raise a RecursionError when called."""
+ try:
+ getattr(instance, method_name)()
+ except Exception as e: # pylint: disable=broad-except
+ assert not isinstance(e, RecursionError)
+ if isinstance(e, RuntimeError):
+ assert "This is likely due to nesting too many levels" in str(e)
+
+
+def _assert_property_no_recursion_error(instance, property_name):
+ """Checks that the property does not raise a RecursionError when accessed."""
+ try:
+ getattr(instance, property_name)
+ except Exception as e: # pylint: disable=broad-except
+ assert not isinstance(e, RecursionError)
+ if isinstance(e, RuntimeError):
+ assert "This is likely due to nesting too many levels" in str(e)
+
+
+def _is_method_with_no_argument(method):
+ """Checks if a method has no argument other than self."""
+ parameters = list(inspect.signature(method).parameters.values())
+ if not (parameters and parameters[0].name == "self"):
+ return False
+ for param in parameters[1:]:
+ if param.kind is not param.POSITIONAL_OR_KEYWORD or param.default == param.empty:
+ return False
+ return True
+
+
class TestMscMethods:
"""Test dunder and other visualizing methods."""