diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 1804494a8ae..e077b96306b 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -85,6 +85,9 @@ unprocessed diagonalizing gates. [(#6290)](https://github.com/PennyLaneAI/pennylane/pull/6290) +* A more sensible error message is raised from a `RecursionError` encountered when accessing properties and methods of a nested `CompositeOp` or `SProd`. + [(#6375)](https://github.com/PennyLaneAI/pennylane/pull/6375) +

Capturing and representing hybrid programs

* `qml.wires.Wires` now accepts JAX arrays as input. Furthermore, a `FutureWarning` is no longer raised in `JAX 0.4.30+` diff --git a/pennylane/ops/op_math/composite.py b/pennylane/ops/op_math/composite.py index 4b0b75f905f..51dd94fb605 100644 --- a/pennylane/ops/op_math/composite.py +++ b/pennylane/ops/op_math/composite.py @@ -18,6 +18,7 @@ import abc import copy from collections.abc import Callable +from functools import wraps import pennylane as qml from pennylane import math @@ -27,6 +28,24 @@ # pylint: disable=too-many-instance-attributes +def handle_recursion_error(func): + """Handles any recursion errors raised from too many levels of nesting.""" + + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except RecursionError as e: + raise RuntimeError( + "Maximum recursion depth reached! This is likely due to nesting too many levels " + "of composite operators. Try setting lazy=False when calling qml.sum, qml.prod, " + "and qml.s_prod, or use the +, @, and * operators instead. Alternatively, you " + "can periodically call qml.simplify on your operators." + ) from e + + return wrapper + + class CompositeOp(Operator): """A base class for operators that are composed of other operators. @@ -70,6 +89,7 @@ def __init__( self.queue() self._batch_size = _UNSET_BATCH_SIZE + @handle_recursion_error def _check_batching(self): batch_sizes = {op.batch_size for op in self if op.batch_size is not None} if len(batch_sizes) > 1: @@ -84,6 +104,7 @@ def __repr__(self): [f"({op})" if op.arithmetic_depth > 0 else f"{op}" for op in self] ) + @handle_recursion_error def __copy__(self): cls = self.__class__ copied_op = cls.__new__(cls) @@ -113,6 +134,7 @@ def _op_symbol(self) -> str: """The symbol used when visualizing the composite operator""" @property + @handle_recursion_error def data(self): """Create data property""" return tuple(d for op in self for d in op.data) @@ -132,6 +154,7 @@ def num_wires(self): return len(self.wires) @property + @handle_recursion_error def num_params(self): return sum(op.num_params for op in self) @@ -152,9 +175,11 @@ def is_hermitian(self): # pylint: disable=arguments-renamed, invalid-overridden-method @property + @handle_recursion_error def has_matrix(self): return all(op.has_matrix or isinstance(op, qml.ops.Hamiltonian) for op in self) + @handle_recursion_error def eigvals(self): """Return the eigenvalues of the specified operator. @@ -290,6 +315,7 @@ def diagonalizing_gates(self): ) return diag_gates + @handle_recursion_error def label(self, decimals=None, base_label=None, cache=None): r"""How the composite operator is represented in diagrams and drawings. @@ -344,6 +370,7 @@ def _sort(cls, op_list, wire_map: dict = None) -> list[Operator]: """Sort composite operands by their wire indices.""" @property + @handle_recursion_error def hash(self): if self._hash is None: self._hash = hash( @@ -357,6 +384,7 @@ def basis(self): return None @property + @handle_recursion_error def arithmetic_depth(self) -> int: return 1 + max(op.arithmetic_depth for op in self) @@ -365,6 +393,7 @@ def arithmetic_depth(self) -> int: def _math_op(self) -> Callable: """The function used when combining the operands of the composite operator""" + @handle_recursion_error def map_wires(self, wire_map: dict): # pylint:disable=protected-access cls = self.__class__ diff --git a/pennylane/ops/op_math/prod.py b/pennylane/ops/op_math/prod.py index 29770c40c87..7190142f2d3 100644 --- a/pennylane/ops/op_math/prod.py +++ b/pennylane/ops/op_math/prod.py @@ -34,7 +34,7 @@ from pennylane.queuing import QueuingManager from pennylane.typing import TensorLike -from .composite import CompositeOp +from .composite import CompositeOp, handle_recursion_error MAX_NUM_WIRES_KRON_PRODUCT = 9 """The maximum number of wires up to which using ``math.kron`` is faster than ``math.dot`` for @@ -273,6 +273,7 @@ def decomposition(self): return [qml.apply(op) for op in self[::-1]] return list(self[::-1]) + @handle_recursion_error def matrix(self, wire_order=None): """Representation of the operator as a matrix in the computational basis.""" if self.pauli_rep: @@ -309,6 +310,7 @@ def matrix(self, wire_order=None): ) return math.expand_matrix(full_mat, self.wires, wire_order=wire_order) + @handle_recursion_error def sparse_matrix(self, wire_order=None): if self.pauli_rep: # Get the sparse matrix from the PauliSentence representation return self.pauli_rep.to_mat(wire_order=wire_order or self.wires, format="csr") @@ -326,11 +328,13 @@ def sparse_matrix(self, wire_order=None): return math.expand_matrix(full_mat, self.wires, wire_order=wire_order) @property + @handle_recursion_error def has_sparse_matrix(self): return self.pauli_rep is not None or all(op.has_sparse_matrix for op in self) # pylint: disable=protected-access @property + @handle_recursion_error def _queue_category(self): """Used for sorting objects into their respective lists in `QuantumTape` objects. This property is a temporary solution that should not exist long-term and should not be @@ -353,10 +357,6 @@ def has_adjoint(self): def adjoint(self): return Prod(*(qml.adjoint(factor) for factor in self[::-1])) - @property - def arithmetic_depth(self) -> int: - return 1 + max(factor.arithmetic_depth for factor in self) - def _build_pauli_rep(self): """PauliSentence representation of the Product of operations.""" if all(operand_pauli_reps := [op.pauli_rep for op in self.operands]): @@ -378,6 +378,7 @@ def _simplify_factors(self, factors: tuple[Operator]) -> tuple[complex, Operator new_factors.remove_factors(wires=self.wires) return new_factors.global_phase, new_factors.factors + @handle_recursion_error def simplify(self) -> Union["Prod", Sum]: r""" Transforms any nested Prod instance into the form :math:`\sum c_i O_i` where @@ -432,6 +433,7 @@ def _sort(cls, op_list, wire_map: dict = None) -> list[Operator]: return op_list + @handle_recursion_error def terms(self): r"""Representation of the operator as a linear combination of other operators. diff --git a/pennylane/ops/op_math/sprod.py b/pennylane/ops/op_math/sprod.py index 64fcdbdfd72..d73b29fce54 100644 --- a/pennylane/ops/op_math/sprod.py +++ b/pennylane/ops/op_math/sprod.py @@ -25,6 +25,7 @@ from pennylane.ops.op_math.sum import Sum from pennylane.queuing import QueuingManager +from .composite import handle_recursion_error from .symbolicop import ScalarSymbolicOp @@ -155,12 +156,14 @@ def __init__( else: self._pauli_rep = None + @handle_recursion_error def __repr__(self): """Constructor-call-like representation.""" if isinstance(self.base, qml.ops.CompositeOp): return f"{self.scalar} * ({self.base})" return f"{self.scalar} * {self.base}" + @handle_recursion_error def label(self, decimals=None, base_label=None, cache=None): """The label produced for the SProd op.""" scalar_val = ( @@ -172,6 +175,7 @@ def label(self, decimals=None, base_label=None, cache=None): return base_label or f"{scalar_val}*{self.base.label(decimals=decimals, cache=cache)}" @property + @handle_recursion_error def num_params(self): """Number of trainable parameters that the operator depends on. Usually 1 + the number of trainable parameters for the base op. @@ -181,6 +185,7 @@ def num_params(self): """ return 1 + self.base.num_params + @handle_recursion_error def terms(self): r"""Representation of the operator as a linear combination of other operators. @@ -200,6 +205,7 @@ def terms(self): return [self.scalar], [self.base] @property + @handle_recursion_error def is_hermitian(self): """If the base operator is hermitian and the scalar is real, then the scalar product operator is hermitian.""" @@ -207,10 +213,12 @@ def is_hermitian(self): # pylint: disable=arguments-renamed,invalid-overridden-method @property + @handle_recursion_error def has_diagonalizing_gates(self): """Bool: Whether the Operator returns defined diagonalizing gates.""" return self.base.has_diagonalizing_gates + @handle_recursion_error def diagonalizing_gates(self): r"""Sequence of gates that diagonalize the operator in the computational basis. @@ -230,6 +238,7 @@ def diagonalizing_gates(self): """ return self.base.diagonalizing_gates() + @handle_recursion_error def eigvals(self): r"""Return the eigenvalues of the specified operator. @@ -244,6 +253,7 @@ def eigvals(self): base_eigs = qml.math.convert_like(base_eigs, self.scalar) return self.scalar * base_eigs + @handle_recursion_error def sparse_matrix(self, wire_order=None): """Computes, by default, a `scipy.sparse.csr_matrix` representation of this Tensor. @@ -264,15 +274,18 @@ def sparse_matrix(self, wire_order=None): return mat @property + @handle_recursion_error def has_sparse_matrix(self): return self.pauli_rep is not None or self.base.has_sparse_matrix @property + @handle_recursion_error def has_matrix(self): """Bool: Whether or not the Operator returns a defined matrix.""" return isinstance(self.base, qml.ops.Hamiltonian) or self.base.has_matrix @staticmethod + @handle_recursion_error def _matrix(scalar, mat): return scalar * mat @@ -303,6 +316,7 @@ def adjoint(self): return SProd(scalar=qml.math.conjugate(self.scalar), base=qml.adjoint(self.base)) # pylint: disable=too-many-return-statements + @handle_recursion_error def simplify(self) -> Operator: """Reduce the depth of nested operators to the minimum. diff --git a/pennylane/ops/op_math/sum.py b/pennylane/ops/op_math/sum.py index 40ed5f229c3..58590480f40 100644 --- a/pennylane/ops/op_math/sum.py +++ b/pennylane/ops/op_math/sum.py @@ -28,7 +28,7 @@ from pennylane.operation import Operator, convert_to_opmath from pennylane.queuing import QueuingManager -from .composite import CompositeOp +from .composite import CompositeOp, handle_recursion_error def sum(*summands, grouping_type=None, method="rlf", id=None, lazy=True): @@ -238,6 +238,7 @@ def __init__( self.compute_grouping(grouping_type=grouping_type, method=method) @property + @handle_recursion_error def hash(self): # Since addition is always commutative, we do not need to sort return hash(("Sum", hash(frozenset(Counter(self.operands).items())))) @@ -277,11 +278,13 @@ def grouping_indices(self, value): # make sure all tuples so can be hashable self._grouping_indices = tuple(tuple(sublist) for sublist in value) + @handle_recursion_error def __str__(self): """String representation of the Sum.""" ops = self.operands return " + ".join(f"{str(op)}" if i == 0 else f"{str(op)}" for i, op in enumerate(ops)) + @handle_recursion_error def __repr__(self): """Terminal representation for Sum""" # post-processing the flat str() representation @@ -293,6 +296,7 @@ def __repr__(self): return main_string @property + @handle_recursion_error def is_hermitian(self): """If all of the terms in the sum are hermitian, then the Sum is hermitian.""" if self.pauli_rep is not None: @@ -304,10 +308,12 @@ def is_hermitian(self): return all(s.is_hermitian for s in self) + @handle_recursion_error def label(self, decimals=None, base_label=None, cache=None): decimals = None if (len(self.parameters) > 3) else decimals return Operator.label(self, decimals=decimals, base_label=base_label or "𝓗", cache=cache) + @handle_recursion_error def matrix(self, wire_order=None): r"""Representation of the operator as a matrix in the computational basis. @@ -344,9 +350,11 @@ def matrix(self, wire_order=None): # pylint: disable=arguments-renamed, invalid-overridden-method @property + @handle_recursion_error def has_sparse_matrix(self) -> bool: return self.pauli_rep is not None or all(op.has_sparse_matrix for op in self) + @handle_recursion_error def sparse_matrix(self, wire_order=None): if self.pauli_rep: # Get the sparse matrix from the PauliSentence representation return self.pauli_rep.to_mat(wire_order=wire_order or self.wires, format="csr") @@ -417,6 +425,7 @@ def _simplify_summands(cls, summands: list[Operator]): return new_summands + @handle_recursion_error def simplify(self, cutoff=1.0e-12) -> "Sum": # pylint: disable=arguments-differ # try using pauli_rep: if pr := self.pauli_rep: @@ -428,6 +437,7 @@ def simplify(self, cutoff=1.0e-12) -> "Sum": # pylint: disable=arguments-differ return Sum(*new_summands) if len(new_summands) > 1 else new_summands[0] return qml.s_prod(0, qml.Identity(self.wires)) + @handle_recursion_error def terms(self): r"""Representation of the operator as a linear combination of other operators. diff --git a/pennylane/ops/op_math/symbolicop.py b/pennylane/ops/op_math/symbolicop.py index 24f81e52e93..304684d06e5 100644 --- a/pennylane/ops/op_math/symbolicop.py +++ b/pennylane/ops/op_math/symbolicop.py @@ -23,6 +23,8 @@ from pennylane.operation import _UNSET_BATCH_SIZE, Operator from pennylane.queuing import QueuingManager +from .composite import handle_recursion_error + class SymbolicOp(Operator): """Developer-facing base class for single-operator symbolic operators. @@ -53,6 +55,7 @@ def _primitive_bind_call(cls, *args, **kwargs): return cls._primitive.bind(*args, **kwargs) # pylint: disable=attribute-defined-outside-init + @handle_recursion_error def __copy__(self): # this method needs to be overwritten because the base must be copied too. copied_op = object.__new__(type(self)) @@ -98,11 +101,13 @@ def num_params(self): return self.base.num_params @property + @handle_recursion_error def wires(self): return self.base.wires # pylint:disable = missing-function-docstring @property + @handle_recursion_error def basis(self): return self.base.basis @@ -130,6 +135,7 @@ def queue(self, context=QueuingManager): return self @property + @handle_recursion_error def arithmetic_depth(self) -> int: return 1 + self.base.arithmetic_depth @@ -142,6 +148,7 @@ def hash(self): ) ) + @handle_recursion_error def map_wires(self, wire_map: dict): new_op = copy(self) new_op.hyperparameters["base"] = self.base.map_wires(wire_map=wire_map) @@ -172,6 +179,7 @@ def __init__(self, base, scalar: float, id=None): self._batch_size = _UNSET_BATCH_SIZE @property + @handle_recursion_error def batch_size(self): if self._batch_size is _UNSET_BATCH_SIZE: base_batch_size = self.base.batch_size @@ -190,6 +198,7 @@ def batch_size(self): return self._batch_size @property + @handle_recursion_error def data(self): return (self.scalar, *self.base.data) @@ -199,10 +208,12 @@ def data(self, new_data): self.base.data = new_data[1:] @property + @handle_recursion_error def has_matrix(self): return self.base.has_matrix or isinstance(self.base, qml.ops.Hamiltonian) @property + @handle_recursion_error def hash(self): return hash( ( @@ -225,6 +236,7 @@ def _matrix(scalar, mat): mat (ndarray): non-broadcasted matrix """ + @handle_recursion_error def matrix(self, wire_order=None): r"""Representation of the operator as a matrix in the computational basis. diff --git a/tests/ops/op_math/test_composite.py b/tests/ops/op_math/test_composite.py index 777665ffce1..b417e248978 100644 --- a/tests/ops/op_math/test_composite.py +++ b/tests/ops/op_math/test_composite.py @@ -14,6 +14,8 @@ """ Unit tests for the composite operator class of qubit operations """ +import inspect + # pylint:disable=protected-access from copy import copy @@ -233,6 +235,68 @@ def test_tensor_and_hamiltonian_converted(self): ) +@pytest.mark.parametrize("math_op", [qml.prod, qml.sum]) +def test_no_recursion_error_raised(math_op): + """Tests that no RecursionError is raised from any property of method of a nested op.""" + + op = qml.RX(np.random.uniform(0, 2 * np.pi), wires=1) + for _ in range(2000): + op = math_op(op, qml.RY(np.random.uniform(0, 2 * np.pi), wires=1)) + _assert_method_and_property_no_recursion_error(op) + + +def test_no_recursion_error_raised_sprod(): + """Tests that no RecursionError is raised from any property of method of a nested SProd.""" + + op = qml.RX(np.random.uniform(0, 2 * np.pi), wires=1) + for _ in range(5000): + op = qml.s_prod(1, op) + _assert_method_and_property_no_recursion_error(op) + + +def _assert_method_and_property_no_recursion_error(instance): + """Checks that all methods and properties do not raise a RecursionError when accessed.""" + + for name, attr in inspect.getmembers(instance.__class__): + + if inspect.isfunction(attr) and _is_method_with_no_argument(attr): + _assert_method_no_recursion_error(instance, name) + + if isinstance(attr, property): + _assert_property_no_recursion_error(instance, name) + + +def _assert_method_no_recursion_error(instance, method_name): + """Checks that the method does not raise a RecursionError when called.""" + try: + getattr(instance, method_name)() + except Exception as e: # pylint: disable=broad-except + assert not isinstance(e, RecursionError) + if isinstance(e, RuntimeError): + assert "This is likely due to nesting too many levels" in str(e) + + +def _assert_property_no_recursion_error(instance, property_name): + """Checks that the property does not raise a RecursionError when accessed.""" + try: + getattr(instance, property_name) + except Exception as e: # pylint: disable=broad-except + assert not isinstance(e, RecursionError) + if isinstance(e, RuntimeError): + assert "This is likely due to nesting too many levels" in str(e) + + +def _is_method_with_no_argument(method): + """Checks if a method has no argument other than self.""" + parameters = list(inspect.signature(method).parameters.values()) + if not (parameters and parameters[0].name == "self"): + return False + for param in parameters[1:]: + if param.kind is not param.POSITIONAL_OR_KEYWORD or param.default == param.empty: + return False + return True + + class TestMscMethods: """Test dunder and other visualizing methods."""