Skip to content

Commit

Permalink
Handle recursion errors in CompositeOp methods (#6375)
Browse files Browse the repository at this point in the history
By default, `qml.sum` and `qml.prod` set `lazy=True`, which keeps its
operands nested. Given the recursive nature of such structures, if there
are too many levels of nesting, a `RecursionError` would occur when
accessing many of the properties and methods. Catching these errors and
re-raising a more sensible error message suggesting that the user could
either set `lazy=False` or use the `+` and `@` operators instead, which
sets `lazy=False`.

**Update**
Extending the behaviour to `SProd`

Fixes #5948
[sc-67745]
  • Loading branch information
astralcai authored and mudit2812 committed Nov 11, 2024
1 parent d6f3556 commit c293405
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 6 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@
unprocessed diagonalizing gates.
[(#6290)](https://github.com/PennyLaneAI/pennylane/pull/6290)

* A more sensible error message is raised from a `RecursionError` encountered when accessing properties and methods of a nested `CompositeOp` or `SProd`.
[(#6375)](https://github.com/PennyLaneAI/pennylane/pull/6375)

<h4>Capturing and representing hybrid programs</h4>

* `qml.wires.Wires` now accepts JAX arrays as input. Furthermore, a `FutureWarning` is no longer raised in `JAX 0.4.30+`
Expand Down
29 changes: 29 additions & 0 deletions pennylane/ops/op_math/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import abc
import copy
from collections.abc import Callable
from functools import wraps

import pennylane as qml
from pennylane import math
Expand All @@ -27,6 +28,24 @@
# pylint: disable=too-many-instance-attributes


def handle_recursion_error(func):
"""Handles any recursion errors raised from too many levels of nesting."""

@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except RecursionError as e:
raise RuntimeError(
"Maximum recursion depth reached! This is likely due to nesting too many levels "
"of composite operators. Try setting lazy=False when calling qml.sum, qml.prod, "
"and qml.s_prod, or use the +, @, and * operators instead. Alternatively, you "
"can periodically call qml.simplify on your operators."
) from e

return wrapper


class CompositeOp(Operator):
"""A base class for operators that are composed of other operators.
Expand Down Expand Up @@ -70,6 +89,7 @@ def __init__(
self.queue()
self._batch_size = _UNSET_BATCH_SIZE

@handle_recursion_error
def _check_batching(self):
batch_sizes = {op.batch_size for op in self if op.batch_size is not None}
if len(batch_sizes) > 1:
Expand All @@ -84,6 +104,7 @@ def __repr__(self):
[f"({op})" if op.arithmetic_depth > 0 else f"{op}" for op in self]
)

@handle_recursion_error
def __copy__(self):
cls = self.__class__
copied_op = cls.__new__(cls)
Expand Down Expand Up @@ -113,6 +134,7 @@ def _op_symbol(self) -> str:
"""The symbol used when visualizing the composite operator"""

@property
@handle_recursion_error
def data(self):
"""Create data property"""
return tuple(d for op in self for d in op.data)
Expand All @@ -132,6 +154,7 @@ def num_wires(self):
return len(self.wires)

@property
@handle_recursion_error
def num_params(self):
return sum(op.num_params for op in self)

Expand All @@ -152,9 +175,11 @@ def is_hermitian(self):

# pylint: disable=arguments-renamed, invalid-overridden-method
@property
@handle_recursion_error
def has_matrix(self):
return all(op.has_matrix or isinstance(op, qml.ops.Hamiltonian) for op in self)

@handle_recursion_error
def eigvals(self):
"""Return the eigenvalues of the specified operator.
Expand Down Expand Up @@ -290,6 +315,7 @@ def diagonalizing_gates(self):
)
return diag_gates

@handle_recursion_error
def label(self, decimals=None, base_label=None, cache=None):
r"""How the composite operator is represented in diagrams and drawings.
Expand Down Expand Up @@ -344,6 +370,7 @@ def _sort(cls, op_list, wire_map: dict = None) -> list[Operator]:
"""Sort composite operands by their wire indices."""

@property
@handle_recursion_error
def hash(self):
if self._hash is None:
self._hash = hash(
Expand All @@ -357,6 +384,7 @@ def basis(self):
return None

@property
@handle_recursion_error
def arithmetic_depth(self) -> int:
return 1 + max(op.arithmetic_depth for op in self)

Expand All @@ -365,6 +393,7 @@ def arithmetic_depth(self) -> int:
def _math_op(self) -> Callable:
"""The function used when combining the operands of the composite operator"""

@handle_recursion_error
def map_wires(self, wire_map: dict):
# pylint:disable=protected-access
cls = self.__class__
Expand Down
12 changes: 7 additions & 5 deletions pennylane/ops/op_math/prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from pennylane.queuing import QueuingManager
from pennylane.typing import TensorLike

from .composite import CompositeOp
from .composite import CompositeOp, handle_recursion_error

MAX_NUM_WIRES_KRON_PRODUCT = 9
"""The maximum number of wires up to which using ``math.kron`` is faster than ``math.dot`` for
Expand Down Expand Up @@ -273,6 +273,7 @@ def decomposition(self):
return [qml.apply(op) for op in self[::-1]]
return list(self[::-1])

@handle_recursion_error
def matrix(self, wire_order=None):
"""Representation of the operator as a matrix in the computational basis."""
if self.pauli_rep:
Expand Down Expand Up @@ -309,6 +310,7 @@ def matrix(self, wire_order=None):
)
return math.expand_matrix(full_mat, self.wires, wire_order=wire_order)

@handle_recursion_error
def sparse_matrix(self, wire_order=None):
if self.pauli_rep: # Get the sparse matrix from the PauliSentence representation
return self.pauli_rep.to_mat(wire_order=wire_order or self.wires, format="csr")
Expand All @@ -326,11 +328,13 @@ def sparse_matrix(self, wire_order=None):
return math.expand_matrix(full_mat, self.wires, wire_order=wire_order)

@property
@handle_recursion_error
def has_sparse_matrix(self):
return self.pauli_rep is not None or all(op.has_sparse_matrix for op in self)

# pylint: disable=protected-access
@property
@handle_recursion_error
def _queue_category(self):
"""Used for sorting objects into their respective lists in `QuantumTape` objects.
This property is a temporary solution that should not exist long-term and should not be
Expand All @@ -353,10 +357,6 @@ def has_adjoint(self):
def adjoint(self):
return Prod(*(qml.adjoint(factor) for factor in self[::-1]))

@property
def arithmetic_depth(self) -> int:
return 1 + max(factor.arithmetic_depth for factor in self)

def _build_pauli_rep(self):
"""PauliSentence representation of the Product of operations."""
if all(operand_pauli_reps := [op.pauli_rep for op in self.operands]):
Expand All @@ -378,6 +378,7 @@ def _simplify_factors(self, factors: tuple[Operator]) -> tuple[complex, Operator
new_factors.remove_factors(wires=self.wires)
return new_factors.global_phase, new_factors.factors

@handle_recursion_error
def simplify(self) -> Union["Prod", Sum]:
r"""
Transforms any nested Prod instance into the form :math:`\sum c_i O_i` where
Expand Down Expand Up @@ -432,6 +433,7 @@ def _sort(cls, op_list, wire_map: dict = None) -> list[Operator]:

return op_list

@handle_recursion_error
def terms(self):
r"""Representation of the operator as a linear combination of other operators.
Expand Down
14 changes: 14 additions & 0 deletions pennylane/ops/op_math/sprod.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pennylane.ops.op_math.sum import Sum
from pennylane.queuing import QueuingManager

from .composite import handle_recursion_error
from .symbolicop import ScalarSymbolicOp


Expand Down Expand Up @@ -155,12 +156,14 @@ def __init__(
else:
self._pauli_rep = None

@handle_recursion_error
def __repr__(self):
"""Constructor-call-like representation."""
if isinstance(self.base, qml.ops.CompositeOp):
return f"{self.scalar} * ({self.base})"
return f"{self.scalar} * {self.base}"

@handle_recursion_error
def label(self, decimals=None, base_label=None, cache=None):
"""The label produced for the SProd op."""
scalar_val = (
Expand All @@ -172,6 +175,7 @@ def label(self, decimals=None, base_label=None, cache=None):
return base_label or f"{scalar_val}*{self.base.label(decimals=decimals, cache=cache)}"

@property
@handle_recursion_error
def num_params(self):
"""Number of trainable parameters that the operator depends on.
Usually 1 + the number of trainable parameters for the base op.
Expand All @@ -181,6 +185,7 @@ def num_params(self):
"""
return 1 + self.base.num_params

@handle_recursion_error
def terms(self):
r"""Representation of the operator as a linear combination of other operators.
Expand All @@ -200,17 +205,20 @@ def terms(self):
return [self.scalar], [self.base]

@property
@handle_recursion_error
def is_hermitian(self):
"""If the base operator is hermitian and the scalar is real,
then the scalar product operator is hermitian."""
return self.base.is_hermitian and not qml.math.iscomplex(self.scalar)

# pylint: disable=arguments-renamed,invalid-overridden-method
@property
@handle_recursion_error
def has_diagonalizing_gates(self):
"""Bool: Whether the Operator returns defined diagonalizing gates."""
return self.base.has_diagonalizing_gates

@handle_recursion_error
def diagonalizing_gates(self):
r"""Sequence of gates that diagonalize the operator in the computational basis.
Expand All @@ -230,6 +238,7 @@ def diagonalizing_gates(self):
"""
return self.base.diagonalizing_gates()

@handle_recursion_error
def eigvals(self):
r"""Return the eigenvalues of the specified operator.
Expand All @@ -244,6 +253,7 @@ def eigvals(self):
base_eigs = qml.math.convert_like(base_eigs, self.scalar)
return self.scalar * base_eigs

@handle_recursion_error
def sparse_matrix(self, wire_order=None):
"""Computes, by default, a `scipy.sparse.csr_matrix` representation of this Tensor.
Expand All @@ -264,15 +274,18 @@ def sparse_matrix(self, wire_order=None):
return mat

@property
@handle_recursion_error
def has_sparse_matrix(self):
return self.pauli_rep is not None or self.base.has_sparse_matrix

@property
@handle_recursion_error
def has_matrix(self):
"""Bool: Whether or not the Operator returns a defined matrix."""
return isinstance(self.base, qml.ops.Hamiltonian) or self.base.has_matrix

@staticmethod
@handle_recursion_error
def _matrix(scalar, mat):
return scalar * mat

Expand Down Expand Up @@ -303,6 +316,7 @@ def adjoint(self):
return SProd(scalar=qml.math.conjugate(self.scalar), base=qml.adjoint(self.base))

# pylint: disable=too-many-return-statements
@handle_recursion_error
def simplify(self) -> Operator:
"""Reduce the depth of nested operators to the minimum.
Expand Down
12 changes: 11 additions & 1 deletion pennylane/ops/op_math/sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pennylane.operation import Operator, convert_to_opmath
from pennylane.queuing import QueuingManager

from .composite import CompositeOp
from .composite import CompositeOp, handle_recursion_error


def sum(*summands, grouping_type=None, method="rlf", id=None, lazy=True):
Expand Down Expand Up @@ -238,6 +238,7 @@ def __init__(
self.compute_grouping(grouping_type=grouping_type, method=method)

@property
@handle_recursion_error
def hash(self):
# Since addition is always commutative, we do not need to sort
return hash(("Sum", hash(frozenset(Counter(self.operands).items()))))
Expand Down Expand Up @@ -277,11 +278,13 @@ def grouping_indices(self, value):
# make sure all tuples so can be hashable
self._grouping_indices = tuple(tuple(sublist) for sublist in value)

@handle_recursion_error
def __str__(self):
"""String representation of the Sum."""
ops = self.operands
return " + ".join(f"{str(op)}" if i == 0 else f"{str(op)}" for i, op in enumerate(ops))

@handle_recursion_error
def __repr__(self):
"""Terminal representation for Sum"""
# post-processing the flat str() representation
Expand All @@ -293,6 +296,7 @@ def __repr__(self):
return main_string

@property
@handle_recursion_error
def is_hermitian(self):
"""If all of the terms in the sum are hermitian, then the Sum is hermitian."""
if self.pauli_rep is not None:
Expand All @@ -304,10 +308,12 @@ def is_hermitian(self):

return all(s.is_hermitian for s in self)

@handle_recursion_error
def label(self, decimals=None, base_label=None, cache=None):
decimals = None if (len(self.parameters) > 3) else decimals
return Operator.label(self, decimals=decimals, base_label=base_label or "𝓗", cache=cache)

@handle_recursion_error
def matrix(self, wire_order=None):
r"""Representation of the operator as a matrix in the computational basis.
Expand Down Expand Up @@ -344,9 +350,11 @@ def matrix(self, wire_order=None):

# pylint: disable=arguments-renamed, invalid-overridden-method
@property
@handle_recursion_error
def has_sparse_matrix(self) -> bool:
return self.pauli_rep is not None or all(op.has_sparse_matrix for op in self)

@handle_recursion_error
def sparse_matrix(self, wire_order=None):
if self.pauli_rep: # Get the sparse matrix from the PauliSentence representation
return self.pauli_rep.to_mat(wire_order=wire_order or self.wires, format="csr")
Expand Down Expand Up @@ -417,6 +425,7 @@ def _simplify_summands(cls, summands: list[Operator]):

return new_summands

@handle_recursion_error
def simplify(self, cutoff=1.0e-12) -> "Sum": # pylint: disable=arguments-differ
# try using pauli_rep:
if pr := self.pauli_rep:
Expand All @@ -428,6 +437,7 @@ def simplify(self, cutoff=1.0e-12) -> "Sum": # pylint: disable=arguments-differ
return Sum(*new_summands) if len(new_summands) > 1 else new_summands[0]
return qml.s_prod(0, qml.Identity(self.wires))

@handle_recursion_error
def terms(self):
r"""Representation of the operator as a linear combination of other operators.
Expand Down
Loading

0 comments on commit c293405

Please sign in to comment.