diff --git a/doc/_static/templates/subroutines/trotter_product.png b/doc/_static/templates/subroutines/trotter_product.png
new file mode 100644
index 00000000000..b5fb79ff6ad
Binary files /dev/null and b/doc/_static/templates/subroutines/trotter_product.png differ
diff --git a/doc/introduction/templates.rst b/doc/introduction/templates.rst
index dbde3cfcddf..26ee024653c 100644
--- a/doc/introduction/templates.rst
+++ b/doc/introduction/templates.rst
@@ -227,6 +227,10 @@ Other useful templates which do not belong to the previous categories can be fou
:description: :doc:`ApproxTimeEvolution <../code/api/pennylane.ApproxTimeEvolution>`
:figure: _static/templates/subroutines/approx_time_evolution.png
+.. gallery-item::
+ :description: :doc:`TrotterProduct <../code/api/pennylane.TrotterProduct>`
+ :figure: _static/templates/subroutines/trotter_product.png
+
.. gallery-item::
:description: :doc:`Permute <../code/api/pennylane.Permute>`
:figure: _static/templates/subroutines/permute.png
diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 3f4e6a3ae58..7fb363fdcd7 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -4,6 +4,41 @@
New features since last release
+Exponentiate Hamiltonians with flexible Trotter products 🐖
+
+* Higher-order Trotter-Suzuki methods are now easily accessible through a new operation
+ called `TrotterProduct`.
+ [(#4661)](https://github.com/PennyLaneAI/pennylane/pull/4661)
+
+ Trotterization techniques are an affective route towards accurate and efficient
+ Hamiltonian simulation. The Suzuki-Trotter product formula allows for the ability
+ to express higher-order approximations to the matrix exponential of a Hamiltonian,
+ and it is now available to use in PennyLane via the `TrotterProduct` operation.
+ Simply specify the `order` of the approximation and the evolution `time`.
+
+ ```python
+ coeffs = [0.25, 0.75]
+ ops = [qml.PauliX(0), qml.PauliZ(0)]
+ H = qml.dot(coeffs, ops)
+
+ dev = qml.device("default.qubit", wires=2)
+
+ @qml.qnode(dev)
+ def circuit():
+ qml.Hadamard(0)
+ qml.TrotterProduct(H, time=2.4, order=2)
+ return qml.state()
+ ```
+
+ ```pycon
+ >>> circuit()
+ [-0.13259524+0.59790098j 0. +0.j -0.13259524-0.77932754j 0. +0.j ]
+ ```
+
+ The already-available `ApproxTimeEvolution` operation represents the special case of `order=1`.
+ It is recommended to switch over to use of `TrotterProduct` because `ApproxTimeEvolution` will be
+ deprecated and removed in upcoming releases.
+
* Support drawing QJIT QNode from Catalyst.
[(#4609)](https://github.com/PennyLaneAI/pennylane/pull/4609)
diff --git a/pennylane/ops/functions/equal.py b/pennylane/ops/functions/equal.py
index b7af0f8f422..f6a513ac235 100644
--- a/pennylane/ops/functions/equal.py
+++ b/pennylane/ops/functions/equal.py
@@ -264,7 +264,9 @@ def _equal_adjoint(op1: Adjoint, op2: Adjoint, **kwargs):
# pylint: disable=unused-argument
def _equal_exp(op1: Exp, op2: Exp, **kwargs):
"""Determine whether two Exp objects are equal"""
- if op1.coeff != op2.coeff:
+ rtol, atol = (kwargs["rtol"], kwargs["atol"])
+
+ if not qml.math.allclose(op1.coeff, op2.coeff, rtol=rtol, atol=atol):
return False
return qml.equal(op1.base, op2.base)
@@ -273,7 +275,9 @@ def _equal_exp(op1: Exp, op2: Exp, **kwargs):
# pylint: disable=unused-argument
def _equal_sprod(op1: SProd, op2: SProd, **kwargs):
"""Determine whether two SProd objects are equal"""
- if op1.scalar != op2.scalar:
+ rtol, atol = (kwargs["rtol"], kwargs["atol"])
+
+ if not qml.math.allclose(op1.scalar, op2.scalar, rtol=rtol, atol=atol):
return False
return qml.equal(op1.base, op2.base)
diff --git a/pennylane/templates/subroutines/__init__.py b/pennylane/templates/subroutines/__init__.py
index bff83756839..70c64d46fc6 100644
--- a/pennylane/templates/subroutines/__init__.py
+++ b/pennylane/templates/subroutines/__init__.py
@@ -35,3 +35,4 @@
from .basis_rotation import BasisRotation
from .qsvt import QSVT, qsvt
from .select import Select
+from .trotter import TrotterProduct
diff --git a/pennylane/templates/subroutines/trotter.py b/pennylane/templates/subroutines/trotter.py
new file mode 100644
index 00000000000..f843994f5d6
--- /dev/null
+++ b/pennylane/templates/subroutines/trotter.py
@@ -0,0 +1,292 @@
+# Copyright 2018-2023 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Contains templates for Suzuki-Trotter approximation based subroutines.
+"""
+import pennylane as qml
+from pennylane.operation import Operation
+from pennylane.ops import Sum
+
+
+def _scalar(order):
+ """Compute the scalar used in the recursive expression.
+
+ Args:
+ order (int): order of Trotter product (assume order is an even integer > 2).
+
+ Returns:
+ float: scalar to be used in the recursive expression.
+ """
+ root = 1 / (order - 1)
+ return (4 - 4**root) ** -1
+
+
+@qml.QueuingManager.stop_recording()
+def _recursive_expression(x, order, ops):
+ """Generate a list of operations using the
+ recursive expression which defines the Trotter product.
+
+ Args:
+ x (complex): the evolution 'time'
+ order (int): the order of the Trotter expansion
+ ops (Iterable(~.Operators)): a list of terms in the Hamiltonian
+
+ Returns:
+ list: the approximation as product of exponentials of the Hamiltonian terms
+ """
+ if order == 1:
+ return [qml.exp(op, x * 1j) for op in ops]
+
+ if order == 2:
+ return [qml.exp(op, x * 0.5j) for op in ops + ops[::-1]]
+
+ scalar_1 = _scalar(order)
+ scalar_2 = 1 - 4 * scalar_1
+
+ ops_lst_1 = _recursive_expression(scalar_1 * x, order - 2, ops)
+ ops_lst_2 = _recursive_expression(scalar_2 * x, order - 2, ops)
+
+ return (2 * ops_lst_1) + ops_lst_2 + (2 * ops_lst_1)
+
+
+class TrotterProduct(Operation):
+ r"""An operation representing the Suzuki-Trotter product approximation for the complex matrix
+ exponential of a given Hamiltonian.
+
+ The Suzuki-Trotter product formula provides a method to approximate the matrix exponential of
+ Hamiltonian expressed as a linear combination of terms which in general do not commute. Consider
+ the Hamiltonian :math:`H = \Sigma^{N}_{j=0} O_{j}`, the product formula is constructed using
+ symmetrized products of the terms in the Hamiltonian. The symmetrized products of order
+ :math:`m \in [1, 2, 4, ..., 2k]` with :math:`k \in \mathbb{N}` are given by:
+
+ .. math::
+
+ \begin{align}
+ S_{1}(t) &= \Pi_{j=0}^{N} \ e^{i t O_{j}} \\
+ S_{2}(t) &= \Pi_{j=0}^{N} \ e^{i \frac{t}{2} O_{j}} \cdot \Pi_{j=N}^{0} \ e^{i \frac{t}{2} O_{j}} \\
+ &\vdots \\
+ S_{m}(t) &= S_{m-2}(p_{m}t)^{2} \cdot S_{m-2}((1-4p_{m})t) \cdot S_{m-2}(p_{m}t)^{2},
+ \end{align}
+
+ where the coefficient is :math:`p_{m} = 1 / (4 - \sqrt[m - 1]{4})`. The :math:`m`th order,
+ :math:`n`-step Suzuki-Trotter approximation is then defined as:
+
+ .. math:: e^{iHt} \approx \left [S_{m}(t / n) \right ]^{n}.
+
+ For more details see `J. Math. Phys. 32, 400 (1991) `_.
+
+ Args:
+ hamiltonian (Union[.Hamiltonian, .Sum]): The Hamiltonian written as a linear combination
+ of operators with known matrix exponentials.
+ time (float): The time of evolution, namely the parameter :math:`t` in :math:`e^{iHt}`
+ n (int): An integer representing the number of Trotter steps to perform
+ order (int): An integer (:math:`m`) representing the order of the approximation (must be 1 or even)
+ check_hermitian (bool): A flag to enable the validation check to ensure this is a valid unitary operator
+
+ Raises:
+ TypeError: The ``hamiltonian`` is not of type :class:`~.Hamiltonian`, or :class:`~.Sum`.
+ ValueError: The ``hamiltonian`` must have atleast two terms.
+ ValueError: One or more of the terms in ``hamiltonian`` are not Hermitian.
+ ValueError: The ``order`` is not one or a positive even integer.
+
+ **Example**
+
+ .. code-block:: python3
+
+ coeffs = [0.25, 0.75]
+ ops = [qml.PauliX(0), qml.PauliZ(0)]
+ H = qml.dot(coeffs, ops)
+
+ dev = qml.device("default.qubit", wires=2)
+ @qml.qnode(dev)
+ def my_circ():
+ # Prepare some state
+ qml.Hadamard(0)
+
+ # Evolve according to H
+ qml.TrotterProduct(H, time=2.4, order=2)
+
+ # Measure some quantity
+ return qml.state()
+
+ >>> my_circ()
+ [-0.13259524+0.59790098j 0. +0.j -0.13259524-0.77932754j 0. +0.j ]
+
+ .. details::
+ :title: Usage Details
+
+ One can recover the behaviour of :class:`~.ApproxTimeEvolution` by setting :code:`order=1`.
+ We can also compute the gradient with respect to the coefficients of the Hamiltonian and the
+ evolution time:
+
+ .. code-block:: python3
+
+ @qml.qnode(dev)
+ def my_circ(c1, c2, time):
+ # Prepare H:
+ H = qml.dot([c1, c2], [qml.PauliX(0), qml.PauliZ(0)])
+
+ # Prepare some state
+ qml.Hadamard(0)
+
+ # Evolve according to H
+ qml.TrotterProduct(H, time, order=2)
+
+ # Measure some quantity
+ return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))
+
+ >>> args = np.array([1.23, 4.5, 0.1])
+ >>> qml.grad(my_circ)(*tuple(args))
+ (tensor(0.00961064, requires_grad=True), tensor(-0.12338274, requires_grad=True), tensor(-5.43401259, requires_grad=True))
+ """
+
+ def __init__( # pylint: disable=too-many-arguments
+ self, hamiltonian, time, n=1, order=1, check_hermitian=True, id=None
+ ):
+ r"""Initialize the TrotterProduct class"""
+
+ if order <= 0 or order != 1 and order % 2 != 0:
+ raise ValueError(
+ f"The order of a TrotterProduct must be 1 or a positive even integer, got {order}."
+ )
+
+ if isinstance(hamiltonian, qml.Hamiltonian):
+ coeffs, ops = hamiltonian.terms()
+ if len(coeffs) < 2:
+ raise ValueError(
+ "There should be atleast 2 terms in the Hamiltonian. Otherwise use `qml.exp`"
+ )
+
+ hamiltonian = qml.dot(coeffs, ops)
+
+ if not isinstance(hamiltonian, Sum):
+ raise TypeError(
+ f"The given operator must be a PennyLane ~.Hamiltonian or ~.Sum got {hamiltonian}"
+ )
+
+ if check_hermitian:
+ for op in hamiltonian.operands:
+ if not op.is_hermitian:
+ raise ValueError(
+ "One or more of the terms in the Hamiltonian may not be Hermitian"
+ )
+
+ self._hyperparameters = {
+ "n": n,
+ "order": order,
+ "base": hamiltonian,
+ "check_hermitian": check_hermitian,
+ }
+ super().__init__(time, wires=hamiltonian.wires, id=id)
+
+ def _flatten(self):
+ """Serialize the operation into trainable and non-trainable components.
+
+ Returns:
+ data, metadata: The trainable and non-trainable components.
+
+ See ``Operator._unflatten``.
+
+ The data component can be recursive and include other operations. For example, the trainable component of ``Adjoint(RX(1, wires=0))``
+ will be the operator ``RX(1, wires=0)``.
+
+ The metadata **must** be hashable. If the hyperparameters contain a non-hashable component, then this
+ method and ``Operator._unflatten`` should be overridden to provide a hashable version of the hyperparameters.
+
+ **Example:**
+
+ >>> op = qml.Rot(1.2, 2.3, 3.4, wires=0)
+ >>> qml.Rot._unflatten(*op._flatten())
+ Rot(1.2, 2.3, 3.4, wires=[0])
+ >>> op = qml.PauliRot(1.2, "XY", wires=(0,1))
+ >>> qml.PauliRot._unflatten(*op._flatten())
+ PauliRot(1.2, XY, wires=[0, 1])
+
+ Operators that have trainable components that differ from their ``Operator.data`` must implement their own
+ ``_flatten`` methods.
+
+ >>> op = qml.ctrl(qml.U2(3.4, 4.5, wires="a"), ("b", "c") )
+ >>> op._flatten()
+ ((U2(3.4, 4.5, wires=['a']),),
+ (, (True, True), ))
+ """
+ hamiltonian = self.hyperparameters["base"]
+ time = self.parameters[0]
+
+ hashable_hyperparameters = tuple(
+ (key, value) for key, value in self.hyperparameters.items() if key != "base"
+ )
+ return (hamiltonian, time), hashable_hyperparameters
+
+ @classmethod
+ def _unflatten(cls, data, metadata):
+ """Recreate an operation from its serialized format.
+
+ Args:
+ data: the trainable component of the operation
+ metadata: the non-trainable component of the operation.
+
+ The output of ``Operator._flatten`` and the class type must be sufficient to reconstruct the original
+ operation with ``Operator._unflatten``.
+
+ **Example:**
+
+ >>> op = qml.Rot(1.2, 2.3, 3.4, wires=0)
+ >>> op._flatten()
+ ((1.2, 2.3, 3.4), (, ()))
+ >>> qml.Rot._unflatten(*op._flatten())
+ >>> op = qml.PauliRot(1.2, "XY", wires=(0,1))
+ >>> op._flatten()
+ ((1.2,), (, (('pauli_word', 'XY'),)))
+ >>> op = qml.ctrl(qml.U2(3.4, 4.5, wires="a"), ("b", "c") )
+ >>> type(op)._unflatten(*op._flatten())
+ Controlled(U2(3.4, 4.5, wires=['a']), control_wires=['b', 'c'])
+
+ """
+ hyperparameters_dict = dict(metadata)
+ return cls(*data, **hyperparameters_dict)
+
+ @staticmethod
+ def compute_decomposition(*args, **kwargs):
+ r"""Representation of the operator as a product of other operators (static method).
+
+ .. math:: O = O_1 O_2 \dots O_n.
+
+ .. note::
+
+ Operations making up the decomposition should be queued within the
+ ``compute_decomposition`` method.
+
+ .. seealso:: :meth:`~.Operator.decomposition`.
+
+ Args:
+ *params (list): trainable parameters of the operator, as stored in the ``parameters`` attribute
+ wires (Iterable[Any], Wires): wires that the operator acts on
+ **hyperparams (dict): non-trainable hyperparameters of the operator, as stored in the ``hyperparameters`` attribute
+
+ Returns:
+ list[Operator]: decomposition of the operator
+ """
+ time = args[0]
+ n = kwargs["n"]
+ order = kwargs["order"]
+ ops = kwargs["base"].operands
+
+ decomp = _recursive_expression(time / n, order, ops)[::-1] * n
+
+ if qml.QueuingManager.recording():
+ for op in decomp: # apply operators in reverse order of expression
+ qml.apply(op)
+
+ return decomp
diff --git a/tests/ops/functions/test_equal.py b/tests/ops/functions/test_equal.py
index c8c5e308002..679bddfa98f 100644
--- a/tests/ops/functions/test_equal.py
+++ b/tests/ops/functions/test_equal.py
@@ -1438,6 +1438,14 @@ def test_exp_comparison(self, bases_bases_match, params_params_match):
op2 = qml.exp(base2, param2)
assert qml.equal(op1, op2) == (bases_match and params_match)
+ def test_exp_comparison_with_tolerance(self):
+ """Test that equal compares the parameters within a provided tolerance."""
+ op1 = qml.exp(qml.PauliX(0), 0.12345)
+ op2 = qml.exp(qml.PauliX(0), 0.12356)
+
+ assert qml.equal(op1, op2, atol=1e-3, rtol=1e-2)
+ assert not qml.equal(op1, op2, atol=1e-5, rtol=1e-4)
+
@pytest.mark.parametrize("bases_bases_match", BASES)
@pytest.mark.parametrize("params_params_match", PARAMS)
def test_s_prod_comparison(self, bases_bases_match, params_params_match):
@@ -1448,6 +1456,14 @@ def test_s_prod_comparison(self, bases_bases_match, params_params_match):
op2 = qml.s_prod(param2, base2)
assert qml.equal(op1, op2) == (bases_match and params_match)
+ def test_s_prod_comparison_with_tolerance(self):
+ """Test that equal compares the parameters within a provided tolerance."""
+ op1 = qml.s_prod(0.12345, qml.PauliX(0))
+ op2 = qml.s_prod(0.12356, qml.PauliX(0))
+
+ assert qml.equal(op1, op2, atol=1e-3, rtol=1e-2)
+ assert not qml.equal(op1, op2, atol=1e-5, rtol=1e-4)
+
class TestProdComparisons:
"""Tests comparisons between Prod operators"""
diff --git a/tests/templates/test_subroutines/test_trotter.py b/tests/templates/test_subroutines/test_trotter.py
new file mode 100644
index 00000000000..0cbc7f3e93a
--- /dev/null
+++ b/tests/templates/test_subroutines/test_trotter.py
@@ -0,0 +1,842 @@
+# Copyright 2018-2023 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Tests for the TrotterProduct template and helper functions.
+"""
+# pylint: disable=private-access, protected-access
+import copy
+from functools import reduce
+
+import pytest
+
+import pennylane as qml
+from pennylane import numpy as qnp
+from pennylane.math import allclose, get_interface
+from pennylane.templates.subroutines.trotter import _recursive_expression, _scalar
+
+test_hamiltonians = (
+ qml.dot([1, 1, 1], [qml.PauliX(0), qml.PauliY(0), qml.PauliZ(1)]),
+ qml.dot(
+ [1.23, -0.45], [qml.s_prod(0.1, qml.PauliX(0)), qml.prod(qml.PauliX(0), qml.PauliZ(1))]
+ ), # op arith
+ qml.dot(
+ [1, -0.5, 0.5], [qml.Identity(wires=[0, 1]), qml.PauliZ(0), qml.PauliZ(0)]
+ ), # H = Identity
+)
+
+p_4 = (4 - 4 ** (1 / 3)) ** -1
+
+test_decompositions = (
+ { # (hamiltonian_index, order): decomposition assuming t = 4.2, computed by hand
+ (0, 1): [
+ qml.exp(qml.PauliX(0), 4.2j),
+ qml.exp(qml.PauliY(0), 4.2j),
+ qml.exp(qml.PauliZ(1), 4.2j),
+ ],
+ (0, 2): [
+ qml.exp(qml.PauliX(0), 4.2j / 2),
+ qml.exp(qml.PauliY(0), 4.2j / 2),
+ qml.exp(qml.PauliZ(1), 4.2j / 2),
+ qml.exp(qml.PauliZ(1), 4.2j / 2),
+ qml.exp(qml.PauliY(0), 4.2j / 2),
+ qml.exp(qml.PauliX(0), 4.2j / 2),
+ ],
+ (0, 4): [
+ qml.exp(qml.PauliX(0), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliY(0), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliZ(1), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliZ(1), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliY(0), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliX(0), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliX(0), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliY(0), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliZ(1), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliZ(1), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliY(0), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliX(0), p_4 * 4.2j / 2), # S_2(p * t) ^ 2
+ qml.exp(qml.PauliX(0), (1 - 4 * p_4) * 4.2j / 2),
+ qml.exp(qml.PauliY(0), (1 - 4 * p_4) * 4.2j / 2),
+ qml.exp(qml.PauliZ(1), (1 - 4 * p_4) * 4.2j / 2),
+ qml.exp(qml.PauliZ(1), (1 - 4 * p_4) * 4.2j / 2),
+ qml.exp(qml.PauliY(0), (1 - 4 * p_4) * 4.2j / 2),
+ qml.exp(qml.PauliX(0), (1 - 4 * p_4) * 4.2j / 2), # S_2((1 - 4p) * t)
+ qml.exp(qml.PauliX(0), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliY(0), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliZ(1), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliZ(1), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliY(0), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliX(0), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliX(0), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliY(0), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliZ(1), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliZ(1), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliY(0), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliX(0), p_4 * 4.2j / 2), # S_2(p * t) ^ 2
+ ],
+ (1, 1): [
+ qml.exp(qml.s_prod(0.1, qml.PauliX(0)), 1.23 * 4.2j),
+ qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), -0.45 * 4.2j),
+ ],
+ (1, 2): [
+ qml.exp(qml.s_prod(0.1, qml.PauliX(0)), 1.23 * 4.2j / 2),
+ qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), -0.45 * 4.2j / 2),
+ qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), -0.45 * 4.2j / 2),
+ qml.exp(qml.s_prod(0.1, qml.PauliX(0)), 1.23 * 4.2j / 2),
+ ],
+ (1, 4): [
+ qml.exp(qml.s_prod(0.1, qml.PauliX(0)), p_4 * 1.23 * 4.2j / 2),
+ qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), p_4 * -0.45 * 4.2j / 2),
+ qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), p_4 * -0.45 * 4.2j / 2),
+ qml.exp(qml.s_prod(0.1, qml.PauliX(0)), p_4 * 1.23 * 4.2j / 2),
+ qml.exp(qml.s_prod(0.1, qml.PauliX(0)), p_4 * 1.23 * 4.2j / 2),
+ qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), p_4 * -0.45 * 4.2j / 2),
+ qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), p_4 * -0.45 * 4.2j / 2),
+ qml.exp(qml.s_prod(0.1, qml.PauliX(0)), p_4 * 1.23 * 4.2j / 2),
+ qml.exp(qml.s_prod(0.1, qml.PauliX(0)), (1 - 4 * p_4) * 1.23 * 4.2j / 2),
+ qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), (1 - 4 * p_4) * -0.45 * 4.2j / 2),
+ qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), (1 - 4 * p_4) * -0.45 * 4.2j / 2),
+ qml.exp(qml.s_prod(0.1, qml.PauliX(0)), (1 - 4 * p_4) * 1.23 * 4.2j / 2),
+ qml.exp(qml.s_prod(0.1, qml.PauliX(0)), p_4 * 1.23 * 4.2j / 2),
+ qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), p_4 * -0.45 * 4.2j / 2),
+ qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), p_4 * -0.45 * 4.2j / 2),
+ qml.exp(qml.s_prod(0.1, qml.PauliX(0)), p_4 * 1.23 * 4.2j / 2),
+ qml.exp(qml.s_prod(0.1, qml.PauliX(0)), p_4 * 1.23 * 4.2j / 2),
+ qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), p_4 * -0.45 * 4.2j / 2),
+ qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), p_4 * -0.45 * 4.2j / 2),
+ qml.exp(qml.s_prod(0.1, qml.PauliX(0)), p_4 * 1.23 * 4.2j / 2),
+ ],
+ (2, 1): [
+ qml.exp(qml.Identity(wires=[0, 1]), 4.2j),
+ qml.exp(qml.PauliZ(0), -0.5 * 4.2j),
+ qml.exp(qml.PauliZ(0), 0.5 * 4.2j),
+ ],
+ (2, 2): [
+ qml.exp(qml.Identity(wires=[0, 1]), 4.2j / 2),
+ qml.exp(qml.PauliZ(0), -0.5 * 4.2j / 2),
+ qml.exp(qml.PauliZ(0), 0.5 * 4.2j / 2),
+ qml.exp(qml.PauliZ(0), 0.5 * 4.2j / 2),
+ qml.exp(qml.PauliZ(0), -0.5 * 4.2j / 2),
+ qml.exp(qml.Identity(wires=[0, 1]), 4.2j / 2),
+ ],
+ (2, 4): [
+ qml.exp(qml.Identity(wires=[0, 1]), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliZ(0), p_4 * -0.5 * 4.2j / 2),
+ qml.exp(qml.PauliZ(0), p_4 * 0.5 * 4.2j / 2),
+ qml.exp(qml.PauliZ(0), p_4 * 0.5 * 4.2j / 2),
+ qml.exp(qml.PauliZ(0), p_4 * -0.5 * 4.2j / 2),
+ qml.exp(qml.Identity(wires=[0, 1]), p_4 * 4.2j / 2),
+ qml.exp(qml.Identity(wires=[0, 1]), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliZ(0), p_4 * -0.5 * 4.2j / 2),
+ qml.exp(qml.PauliZ(0), p_4 * 0.5 * 4.2j / 2),
+ qml.exp(qml.PauliZ(0), p_4 * 0.5 * 4.2j / 2),
+ qml.exp(qml.PauliZ(0), p_4 * -0.5 * 4.2j / 2),
+ qml.exp(qml.Identity(wires=[0, 1]), p_4 * 4.2j / 2),
+ qml.exp(qml.Identity(wires=[0, 1]), (1 - 4 * p_4) * 4.2j / 2),
+ qml.exp(qml.PauliZ(0), (1 - 4 * p_4) * -0.5 * 4.2j / 2),
+ qml.exp(qml.PauliZ(0), (1 - 4 * p_4) * 0.5 * 4.2j / 2),
+ qml.exp(qml.PauliZ(0), (1 - 4 * p_4) * 0.5 * 4.2j / 2),
+ qml.exp(qml.PauliZ(0), (1 - 4 * p_4) * -0.5 * 4.2j / 2),
+ qml.exp(qml.Identity(wires=[0, 1]), (1 - 4 * p_4) * 4.2j / 2),
+ qml.exp(qml.Identity(wires=[0, 1]), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliZ(0), p_4 * -0.5 * 4.2j / 2),
+ qml.exp(qml.PauliZ(0), p_4 * 0.5 * 4.2j / 2),
+ qml.exp(qml.PauliZ(0), p_4 * 0.5 * 4.2j / 2),
+ qml.exp(qml.PauliZ(0), p_4 * -0.5 * 4.2j / 2),
+ qml.exp(qml.Identity(wires=[0, 1]), p_4 * 4.2j / 2),
+ qml.exp(qml.Identity(wires=[0, 1]), p_4 * 4.2j / 2),
+ qml.exp(qml.PauliZ(0), p_4 * -0.5 * 4.2j / 2),
+ qml.exp(qml.PauliZ(0), p_4 * 0.5 * 4.2j / 2),
+ qml.exp(qml.PauliZ(0), p_4 * 0.5 * 4.2j / 2),
+ qml.exp(qml.PauliZ(0), p_4 * -0.5 * 4.2j / 2),
+ qml.exp(qml.Identity(wires=[0, 1]), p_4 * 4.2j / 2),
+ ],
+ }
+)
+
+
+def _generate_simple_decomp(coeffs, ops, time, order, n):
+ """Given coeffs, ops and a time argument in a given framework, generate the
+ Trotter product for order and number of trotter steps."""
+ decomp = []
+ if order == 1:
+ decomp.extend(qml.exp(op, coeff * (time / n) * 1j) for coeff, op in zip(coeffs, ops))
+
+ coeffs_ops = zip(coeffs, ops)
+
+ if get_interface(coeffs) == "torch":
+ import torch
+
+ coeffs_ops_reversed = zip(torch.flip(coeffs, dims=(0,)), ops[::-1])
+ else:
+ coeffs_ops_reversed = zip(coeffs[::-1], ops[::-1])
+
+ if order == 2:
+ decomp.extend(qml.exp(op, coeff * (time / n) * 1j / 2) for coeff, op in coeffs_ops)
+ decomp.extend(qml.exp(op, coeff * (time / n) * 1j / 2) for coeff, op in coeffs_ops_reversed)
+
+ if order == 4:
+ s_2 = []
+ s_2_p = []
+
+ for coeff, op in coeffs_ops:
+ s_2.append(qml.exp(op, (p_4 * coeff) * (time / n) * 1j / 2))
+ s_2_p.append(qml.exp(op, ((1 - (4 * p_4)) * coeff) * (time / n) * 1j / 2))
+
+ for coeff, op in coeffs_ops_reversed:
+ s_2.append(qml.exp(op, (p_4 * coeff) * (time / n) * 1j / 2))
+ s_2_p.append(qml.exp(op, ((1 - (4 * p_4)) * coeff) * (time / n) * 1j / 2))
+
+ decomp = (s_2 * 2) + s_2_p + (s_2 * 2)
+
+ return decomp * n
+
+
+class TestInitialization:
+ """Test the TrotterProduct class initializes correctly."""
+
+ @pytest.mark.parametrize(
+ "hamiltonian, raise_error",
+ (
+ (qml.PauliX(0), True),
+ (qml.prod(qml.PauliX(0), qml.PauliZ(1)), True),
+ (qml.Hamiltonian([1.23, 3.45], [qml.PauliX(0), qml.PauliZ(1)]), False),
+ (qml.dot([1.23, 3.45], [qml.PauliX(0), qml.PauliZ(1)]), False),
+ ),
+ )
+ def test_error_type(self, hamiltonian, raise_error):
+ """Test an error is raised of an incorrect type is passed"""
+ if raise_error:
+ with pytest.raises(
+ TypeError, match="The given operator must be a PennyLane ~.Hamiltonian or ~.Sum"
+ ):
+ qml.TrotterProduct(hamiltonian, time=1.23)
+
+ else:
+ try:
+ qml.TrotterProduct(hamiltonian, time=1.23)
+ except TypeError:
+ assert False # test should fail if an error was raised when we expect it not to
+
+ @pytest.mark.parametrize(
+ "hamiltonian",
+ (
+ qml.Hamiltonian([1.23, 4 + 5j], [qml.PauliX(0), qml.PauliZ(1)]),
+ qml.dot([1.23, 4 + 5j], [qml.PauliX(0), qml.PauliZ(1)]),
+ qml.dot([1.23, 0.5], [qml.RY(1.23, 0), qml.RZ(3.45, 1)]),
+ ),
+ )
+ def test_error_hermiticity(self, hamiltonian):
+ """Test that an error is raised if any terms in
+ the Hamiltonian are not Hermitian and check_hermitian is True."""
+
+ with pytest.raises(
+ ValueError, match="One or more of the terms in the Hamiltonian may not be Hermitian"
+ ):
+ qml.TrotterProduct(hamiltonian, time=0.5)
+
+ try:
+ qml.TrotterProduct(hamiltonian, time=0.5, check_hermitian=False)
+ except ValueError:
+ assert False # No error should be raised if the check_hermitian flag is disabled
+
+ def test_error_hamiltonian(self):
+ """Test that an error is raised if the input hamultonian has only 1 term."""
+ with pytest.raises(ValueError, match="There should be atleast 2 terms in the Hamiltonian."):
+ qml.TrotterProduct(qml.Hamiltonian([1.0], [qml.PauliX(0)]), 1.23, n=2, order=4)
+
+ @pytest.mark.parametrize("order", (-1, 0, 0.5, 3, 7.0))
+ def test_error_order(self, order):
+ """Test that an error is raised if 'order' is not one or positive even number."""
+ time = 0.5
+ hamiltonian = qml.dot([1.23, 3.45], [qml.PauliX(0), qml.PauliZ(1)])
+
+ with pytest.raises(
+ ValueError, match="The order of a TrotterProduct must be 1 or a positive even integer,"
+ ):
+ qml.TrotterProduct(hamiltonian, time, order=order)
+
+ @pytest.mark.parametrize("hamiltonian", test_hamiltonians)
+ def test_init_correctly(self, hamiltonian):
+ """Test that all of the attributes are initalized correctly."""
+ time, n, order = (4.2, 10, 4)
+ op = qml.TrotterProduct(hamiltonian, time, n=n, order=order, check_hermitian=False)
+
+ assert op.wires == hamiltonian.wires
+ assert op.parameters == [time]
+ assert op.data == (time,)
+ assert op.hyperparameters == {
+ "base": hamiltonian,
+ "n": n,
+ "order": order,
+ "check_hermitian": False,
+ }
+
+ @pytest.mark.parametrize("n", (1, 2, 5, 10))
+ @pytest.mark.parametrize("time", (0.5, 1.2))
+ @pytest.mark.parametrize("order", (1, 2, 4))
+ @pytest.mark.parametrize("hamiltonian", test_hamiltonians)
+ def test_copy(self, hamiltonian, time, n, order):
+ """Test that we can make deep and shallow copies of TrotterProduct correctly."""
+ op = qml.TrotterProduct(hamiltonian, time, n=n, order=order)
+ new_op = copy.copy(op)
+
+ assert op.wires == new_op.wires
+ assert op.parameters == new_op.parameters
+ assert op.data == new_op.data
+ assert op.hyperparameters == new_op.hyperparameters
+ assert op is not new_op
+
+ @pytest.mark.parametrize("hamiltonian", test_hamiltonians)
+ def test_flatten_and_unflatten(self, hamiltonian):
+ """Test that the flatten and unflatten methods work correctly."""
+ time, n, order = (4.2, 10, 4)
+ op = qml.TrotterProduct(hamiltonian, time, n=n, order=order)
+
+ data, metadata = op._flatten()
+ assert qml.equal(data[0], hamiltonian)
+ assert data[1] == time
+ assert dict(metadata) == {"n": n, "order": order, "check_hermitian": True}
+
+ new_op = type(op)._unflatten(data, metadata)
+ assert qml.equal(op, new_op)
+ assert new_op is not op
+
+
+class TestPrivateFunctions:
+ """Test the private helper functions."""
+
+ @pytest.mark.parametrize(
+ "order, result",
+ (
+ (4, 0.4144907717943757),
+ (6, 0.3730658277332728),
+ (8, 0.35958464934999224),
+ ),
+ ) # Computed by hand
+ def test_private_scalar(self, order, result):
+ """Test the _scalar function correctly computes the parameter scalar."""
+ s = _scalar(order)
+ assert qnp.isclose(s, result)
+
+ expected_expansions = ( # for H = X0 + Y0 + Z1, t = 1.23, computed by hand
+ [ # S_1(t)
+ qml.exp(qml.PauliX(0), 1.23j),
+ qml.exp(qml.PauliY(0), 1.23j),
+ qml.exp(qml.PauliZ(1), 1.23j),
+ ],
+ [ # S_2(t)
+ qml.exp(qml.PauliX(0), 1.23j / 2),
+ qml.exp(qml.PauliY(0), 1.23j / 2),
+ qml.exp(qml.PauliZ(1), 1.23j / 2),
+ qml.exp(qml.PauliZ(1), 1.23j / 2),
+ qml.exp(qml.PauliY(0), 1.23j / 2),
+ qml.exp(qml.PauliX(0), 1.23j / 2),
+ ],
+ [ # S_4(t)
+ qml.exp(qml.PauliX(0), p_4 * 1.23j / 2),
+ qml.exp(qml.PauliY(0), p_4 * 1.23j / 2),
+ qml.exp(qml.PauliZ(1), p_4 * 1.23j / 2),
+ qml.exp(qml.PauliZ(1), p_4 * 1.23j / 2),
+ qml.exp(qml.PauliY(0), p_4 * 1.23j / 2),
+ qml.exp(qml.PauliX(0), p_4 * 1.23j / 2),
+ qml.exp(qml.PauliX(0), p_4 * 1.23j / 2),
+ qml.exp(qml.PauliY(0), p_4 * 1.23j / 2),
+ qml.exp(qml.PauliZ(1), p_4 * 1.23j / 2),
+ qml.exp(qml.PauliZ(1), p_4 * 1.23j / 2),
+ qml.exp(qml.PauliY(0), p_4 * 1.23j / 2),
+ qml.exp(qml.PauliX(0), p_4 * 1.23j / 2), # S_2(p * t) ^ 2
+ qml.exp(qml.PauliX(0), (1 - 4 * p_4) * 1.23j / 2),
+ qml.exp(qml.PauliY(0), (1 - 4 * p_4) * 1.23j / 2),
+ qml.exp(qml.PauliZ(1), (1 - 4 * p_4) * 1.23j / 2),
+ qml.exp(qml.PauliZ(1), (1 - 4 * p_4) * 1.23j / 2),
+ qml.exp(qml.PauliY(0), (1 - 4 * p_4) * 1.23j / 2),
+ qml.exp(qml.PauliX(0), (1 - 4 * p_4) * 1.23j / 2), # S_2((1 - 4p) * t)
+ qml.exp(qml.PauliX(0), p_4 * 1.23j / 2),
+ qml.exp(qml.PauliY(0), p_4 * 1.23j / 2),
+ qml.exp(qml.PauliZ(1), p_4 * 1.23j / 2),
+ qml.exp(qml.PauliZ(1), p_4 * 1.23j / 2),
+ qml.exp(qml.PauliY(0), p_4 * 1.23j / 2),
+ qml.exp(qml.PauliX(0), p_4 * 1.23j / 2),
+ qml.exp(qml.PauliX(0), p_4 * 1.23j / 2),
+ qml.exp(qml.PauliY(0), p_4 * 1.23j / 2),
+ qml.exp(qml.PauliZ(1), p_4 * 1.23j / 2),
+ qml.exp(qml.PauliZ(1), p_4 * 1.23j / 2),
+ qml.exp(qml.PauliY(0), p_4 * 1.23j / 2),
+ qml.exp(qml.PauliX(0), p_4 * 1.23j / 2), # S_2(p * t) ^ 2
+ ],
+ )
+
+ @pytest.mark.parametrize("order, expected_expansion", zip((1, 2, 4), expected_expansions))
+ def test_recursive_expression_no_queue(self, order, expected_expansion):
+ """Test the _recursive_expression function correctly generates the decomposition"""
+ ops = [qml.PauliX(0), qml.PauliY(0), qml.PauliZ(1)]
+
+ with qml.tape.QuantumTape() as tape:
+ decomp = _recursive_expression(1.23, order, ops)
+
+ assert tape.operations == [] # No queuing!
+ assert all(
+ qml.equal(op1, op2) for op1, op2 in zip(decomp, expected_expansion)
+ ) # Expected expression
+
+
+class TestDecomposition:
+ """Test the decomposition of the TrotterProduct class."""
+
+ @pytest.mark.parametrize("order", (1, 2, 4))
+ @pytest.mark.parametrize("hamiltonian_index, hamiltonian", enumerate(test_hamiltonians))
+ def test_compute_decomposition(self, hamiltonian, hamiltonian_index, order):
+ """Test the decomposition is correct and queues"""
+ op = qml.TrotterProduct(hamiltonian, 4.2, order=order)
+ with qml.tape.QuantumTape() as tape:
+ decomp = op.compute_decomposition(*op.parameters, **op.hyperparameters)
+
+ assert decomp == tape.operations # queue matches decomp with circuit ordering
+
+ decomp = [qml.simplify(op) for op in decomp]
+ true_decomp = [
+ qml.simplify(op) for op in test_decompositions[(hamiltonian_index, order)][::-1]
+ ]
+ assert all(
+ qml.equal(op1, op2) for op1, op2 in zip(decomp, true_decomp)
+ ) # decomp is correct
+
+ @pytest.mark.parametrize("order", (1, 2))
+ @pytest.mark.parametrize("num_steps", (1, 2, 3))
+ def test_compute_decomposition_n_steps(self, num_steps, order):
+ """Test the decomposition is correct when we set the number of trotter steps"""
+ time = 0.5
+ hamiltonian = qml.sum(qml.PauliX(0), qml.PauliZ(0))
+
+ if order == 1:
+ base_decomp = [
+ qml.exp(qml.PauliZ(0), 0.5j / num_steps),
+ qml.exp(qml.PauliX(0), 0.5j / num_steps),
+ ]
+ if order == 2:
+ base_decomp = [
+ qml.exp(qml.PauliX(0), 0.25j / num_steps),
+ qml.exp(qml.PauliZ(0), 0.25j / num_steps),
+ qml.exp(qml.PauliZ(0), 0.25j / num_steps),
+ qml.exp(qml.PauliX(0), 0.25j / num_steps),
+ ]
+
+ true_decomp = base_decomp * num_steps
+
+ op = qml.TrotterProduct(hamiltonian, time, n=num_steps, order=order)
+ decomp = op.compute_decomposition(*op.parameters, **op.hyperparameters)
+ assert all(qml.equal(op1, op2) for op1, op2 in zip(decomp, true_decomp))
+
+
+class TestIntegration:
+ """Test that the TrotterProduct can be executed and differentiated
+ through all interfaces."""
+
+ # Circuit execution tests:
+ @pytest.mark.parametrize("order", (1, 2, 4))
+ @pytest.mark.parametrize("hamiltonian_index, hamiltonian", enumerate(test_hamiltonians))
+ def test_execute_circuit(self, hamiltonian, hamiltonian_index, order):
+ """Test that the gate executes correctly in a circuit."""
+ wires = hamiltonian.wires
+ dev = qml.device("default.qubit", wires=wires)
+
+ @qml.qnode(dev)
+ def circ():
+ qml.TrotterProduct(hamiltonian, time=4.2, order=order)
+ return qml.state()
+
+ initial_state = qnp.zeros(2 ** (len(wires)))
+ initial_state[0] = 1
+
+ expected_state = (
+ reduce(
+ lambda x, y: x @ y,
+ [
+ qml.matrix(op, wire_order=wires)
+ for op in test_decompositions[(hamiltonian_index, order)]
+ ],
+ )
+ @ initial_state
+ )
+ state = circ()
+
+ assert qnp.allclose(expected_state, state)
+
+ @pytest.mark.parametrize("order", (1, 2))
+ @pytest.mark.parametrize("num_steps", (1, 2, 3))
+ def test_execute_circuit_n_steps(self, num_steps, order):
+ """Test that the circuit executes as expected when we set the number of trotter steps"""
+ time = 0.5
+ hamiltonian = qml.sum(qml.PauliX(0), qml.PauliZ(0))
+
+ if order == 1:
+ base_decomp = [
+ qml.exp(qml.PauliZ(0), 0.5j / num_steps),
+ qml.exp(qml.PauliX(0), 0.5j / num_steps),
+ ]
+ if order == 2:
+ base_decomp = [
+ qml.exp(qml.PauliX(0), 0.25j / num_steps),
+ qml.exp(qml.PauliZ(0), 0.25j / num_steps),
+ qml.exp(qml.PauliZ(0), 0.25j / num_steps),
+ qml.exp(qml.PauliX(0), 0.25j / num_steps),
+ ]
+
+ true_decomp = base_decomp * num_steps
+
+ wires = hamiltonian.wires
+ dev = qml.device("default.qubit", wires=wires)
+
+ @qml.qnode(dev)
+ def circ():
+ qml.TrotterProduct(hamiltonian, time, n=num_steps, order=order)
+ return qml.state()
+
+ initial_state = qnp.zeros(2 ** (len(wires)))
+ initial_state[0] = 1
+
+ expected_state = (
+ reduce(
+ lambda x, y: x @ y, [qml.matrix(op, wire_order=wires) for op in true_decomp[::-1]]
+ )
+ @ initial_state
+ )
+ state = circ()
+ assert qnp.allclose(expected_state, state)
+
+ @pytest.mark.jax
+ @pytest.mark.parametrize("time", (0.5, 1, 2))
+ def test_jax_execute(self, time):
+ """Test that the gate executes correctly in the jax interface."""
+ from jax import numpy as jnp
+
+ time = jnp.array(time)
+ coeffs = jnp.array([1.23, -0.45])
+ terms = [qml.PauliX(0), qml.PauliZ(0)]
+
+ dev = qml.device("default.qubit", wires=2)
+
+ @qml.qnode(dev)
+ def circ(time, coeffs):
+ h = qml.dot(coeffs, terms)
+ qml.TrotterProduct(h, time, n=2, order=2)
+ return qml.state()
+
+ initial_state = jnp.array([1.0, 0.0, 0.0, 0.0])
+
+ expected_product_sequence = _generate_simple_decomp(coeffs, terms, time, order=2, n=2)
+
+ expected_state = (
+ reduce(
+ lambda x, y: x @ y,
+ [qml.matrix(op, wire_order=range(2)) for op in expected_product_sequence],
+ )
+ @ initial_state
+ )
+
+ state = circ(time, coeffs)
+ assert allclose(expected_state, state)
+
+ @pytest.mark.jax
+ @pytest.mark.parametrize("time", (0.5, 1, 2))
+ def test_jaxjit_execute(self, time):
+ """Test that the gate executes correctly in the jax interface with jit."""
+ import jax
+ from jax import numpy as jnp
+
+ time = jnp.array(time)
+ c1 = jnp.array(1.23)
+ c2 = jnp.array(-0.45)
+ terms = [qml.PauliX(0), qml.PauliZ(0)]
+
+ dev = qml.device("default.qubit", wires=2)
+
+ @jax.jit
+ @qml.qnode(dev, interface="jax")
+ def circ(time, c1, c2):
+ h = qml.sum(
+ qml.s_prod(c1, terms[0]),
+ qml.s_prod(c2, terms[1]),
+ )
+ qml.TrotterProduct(h, time, n=2, order=2, check_hermitian=False)
+ return qml.state()
+
+ initial_state = jnp.array([1.0, 0.0, 0.0, 0.0])
+
+ expected_product_sequence = _generate_simple_decomp([c1, c2], terms, time, order=2, n=2)
+
+ expected_state = (
+ reduce(
+ lambda x, y: x @ y,
+ [qml.matrix(op, wire_order=range(2)) for op in expected_product_sequence],
+ )
+ @ initial_state
+ )
+
+ state = circ(time, c1, c2)
+ assert allclose(expected_state, state)
+
+ @pytest.mark.tf
+ @pytest.mark.parametrize("time", (0.5, 1, 2))
+ def test_tf_execute(self, time):
+ """Test that the gate executes correctly in the tensorflow interface."""
+ import tensorflow as tf
+
+ time = tf.Variable(time, dtype=tf.complex128)
+ coeffs = tf.Variable([1.23, -0.45], dtype=tf.complex128)
+ terms = [qml.PauliX(0), qml.PauliZ(0)]
+
+ dev = qml.device("default.qubit", wires=2)
+
+ @qml.qnode(dev)
+ def circ(time, coeffs):
+ h = qml.sum(
+ qml.s_prod(coeffs[0], terms[0]),
+ qml.s_prod(coeffs[1], terms[1]),
+ )
+ qml.TrotterProduct(h, time, n=2, order=2)
+
+ return qml.state()
+
+ initial_state = tf.Variable([1.0, 0.0, 0.0, 0.0], dtype=tf.complex128)
+
+ expected_product_sequence = _generate_simple_decomp(coeffs, terms, time, order=2, n=2)
+
+ expected_state = tf.linalg.matvec(
+ reduce(
+ lambda x, y: x @ y,
+ [qml.matrix(op, wire_order=range(2)) for op in expected_product_sequence],
+ ),
+ initial_state,
+ )
+
+ state = circ(time, coeffs)
+ assert allclose(expected_state, state)
+
+ @pytest.mark.torch
+ @pytest.mark.parametrize("time", (0.5, 1, 2))
+ def test_torch_execute(self, time):
+ """Test that the gate executes correctly in the torch interface."""
+ import torch
+
+ time = torch.tensor(time, dtype=torch.complex64, requires_grad=True)
+ coeffs = torch.tensor([1.23, -0.45], dtype=torch.complex64, requires_grad=True)
+ terms = [qml.PauliX(0), qml.PauliZ(0)]
+
+ dev = qml.device("default.qubit", wires=2)
+
+ @qml.qnode(dev)
+ def circ(time, coeffs):
+ h = qml.dot(coeffs, terms)
+ qml.TrotterProduct(h, time, n=2, order=2)
+ return qml.state()
+
+ initial_state = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.complex64)
+
+ expected_product_sequence = _generate_simple_decomp(coeffs, terms, time, order=2, n=2)
+
+ expected_state = (
+ reduce(
+ lambda x, y: x @ y,
+ [qml.matrix(op, wire_order=range(2)) for op in expected_product_sequence],
+ )
+ @ initial_state
+ )
+
+ state = circ(time, coeffs)
+ assert allclose(expected_state, state)
+
+ @pytest.mark.autograd
+ @pytest.mark.parametrize("time", (0.5, 1, 2))
+ def test_autograd_execute(self, time):
+ """Test that the gate executes correctly in the autograd interface."""
+ time = qnp.array(time)
+ coeffs = qnp.array([1.23, -0.45])
+ terms = [qml.PauliX(0), qml.PauliZ(0)]
+
+ dev = qml.device("default.qubit", wires=2)
+
+ @qml.qnode(dev)
+ def circ(time, coeffs):
+ h = qml.dot(coeffs, terms)
+ qml.TrotterProduct(h, time, n=2, order=2)
+ return qml.state()
+
+ initial_state = qnp.array([1.0, 0.0, 0.0, 0.0])
+
+ expected_product_sequence = _generate_simple_decomp(coeffs, terms, time, order=2, n=2)
+
+ expected_state = (
+ reduce(
+ lambda x, y: x @ y,
+ [qml.matrix(op, wire_order=range(2)) for op in expected_product_sequence],
+ )
+ @ initial_state
+ )
+
+ state = circ(time, coeffs)
+ assert qnp.allclose(expected_state, state)
+
+ @pytest.mark.autograd
+ @pytest.mark.parametrize("order, n", ((1, 1), (1, 2), (2, 1), (4, 1)))
+ def test_autograd_gradient(self, order, n):
+ """Test that the gradient is computed correctly"""
+ time = qnp.array(1.5)
+ coeffs = qnp.array([1.23, -0.45])
+ terms = [qml.PauliX(0), qml.PauliZ(0)]
+
+ dev = qml.device("default.qubit", wires=1)
+
+ @qml.qnode(dev)
+ def circ(time, coeffs):
+ h = qml.dot(coeffs, terms)
+ qml.TrotterProduct(h, time, n=n, order=order)
+ return qml.expval(qml.Hadamard(0))
+
+ @qml.qnode(dev)
+ def reference_circ(time, coeffs):
+ with qml.QueuingManager.stop_recording():
+ decomp = _generate_simple_decomp(coeffs, terms, time, order, n)
+
+ for op in decomp[::-1]:
+ qml.apply(op)
+
+ return qml.expval(qml.Hadamard(0))
+
+ measured_time_grad, measured_coeff_grad = qml.grad(circ)(time, coeffs)
+ reference_time_grad, reference_coeff_grad = qml.grad(reference_circ)(time, coeffs)
+ assert allclose(measured_time_grad, reference_time_grad)
+ assert allclose(measured_coeff_grad, reference_coeff_grad)
+
+ @pytest.mark.torch
+ @pytest.mark.parametrize("order, n", ((1, 1), (1, 2), (2, 1), (4, 1)))
+ def test_torch_gradient(self, order, n):
+ """Test that the gradient is computed correctly using torch"""
+ import torch
+
+ time = torch.tensor(1.5, dtype=torch.complex64, requires_grad=True)
+ coeffs = torch.tensor([1.23, -0.45], dtype=torch.complex64, requires_grad=True)
+ time_reference = torch.tensor(1.5, dtype=torch.complex64, requires_grad=True)
+ coeffs_reference = torch.tensor([1.23, -0.45], dtype=torch.complex64, requires_grad=True)
+ terms = [qml.PauliX(0), qml.PauliZ(0)]
+
+ dev = qml.device("default.qubit", wires=1)
+
+ @qml.qnode(dev)
+ def circ(time, coeffs):
+ h = qml.dot(coeffs, terms)
+ qml.TrotterProduct(h, time, n=n, order=order)
+ return qml.expval(qml.Hadamard(0))
+
+ @qml.qnode(dev)
+ def reference_circ(time, coeffs):
+ with qml.QueuingManager.stop_recording():
+ decomp = _generate_simple_decomp(coeffs, terms, time, order, n)
+
+ for op in decomp[::-1]:
+ qml.apply(op)
+
+ return qml.expval(qml.Hadamard(0))
+
+ res_circ = circ(time, coeffs)
+ res_circ.backward()
+ measured_time_grad = time.grad
+ measured_coeff_grad = coeffs.grad
+
+ ref_circ = reference_circ(time_reference, coeffs_reference)
+ ref_circ.backward()
+ reference_time_grad = time_reference.grad
+ reference_coeff_grad = coeffs_reference.grad
+
+ assert allclose(measured_time_grad, reference_time_grad)
+ assert allclose(measured_coeff_grad, reference_coeff_grad)
+
+ @pytest.mark.tf
+ @pytest.mark.parametrize("order, n", ((1, 1), (1, 2), (2, 1), (4, 1)))
+ def test_tf_gradient(self, order, n):
+ """Test that the gradient is computed correctly using tensorflow"""
+ import tensorflow as tf
+
+ time = tf.Variable(1.5, dtype=tf.complex128)
+ coeffs = tf.Variable([1.23, -0.45], dtype=tf.complex128)
+ terms = [qml.PauliX(0), qml.PauliZ(0)]
+
+ dev = qml.device("default.qubit", wires=1)
+
+ @qml.qnode(dev)
+ def circ(time, coeffs):
+ h = qml.sum(
+ qml.s_prod(coeffs[0], terms[0]),
+ qml.s_prod(coeffs[1], terms[1]),
+ )
+ qml.TrotterProduct(h, time, n=n, order=order)
+ return qml.expval(qml.Hadamard(0))
+
+ @qml.qnode(dev)
+ def reference_circ(time, coeffs):
+ with qml.QueuingManager.stop_recording():
+ decomp = _generate_simple_decomp(coeffs, terms, time, order, n)
+
+ for op in decomp[::-1]:
+ qml.apply(op)
+
+ return qml.expval(qml.Hadamard(0))
+
+ with tf.GradientTape() as tape:
+ result = circ(time, coeffs)
+
+ measured_time_grad, measured_coeff_grad = tape.gradient(result, (time, coeffs))
+
+ with tf.GradientTape() as tape:
+ result = reference_circ(time, coeffs)
+
+ reference_time_grad, reference_coeff_grad = tape.gradient(result, (time, coeffs))
+ assert allclose(measured_time_grad, reference_time_grad)
+ assert allclose(measured_coeff_grad, reference_coeff_grad)
+
+ @pytest.mark.jax
+ @pytest.mark.parametrize("order, n", ((1, 1), (1, 2), (2, 1), (4, 1)))
+ def test_jax_gradient(self, order, n):
+ """Test that the gradient is computed correctly"""
+ import jax
+ from jax import numpy as jnp
+
+ time = jnp.array(1.5)
+ coeffs = jnp.array([1.23, -0.45])
+ terms = [qml.PauliX(0), qml.PauliZ(0)]
+
+ dev = qml.device("default.qubit", wires=1)
+
+ @qml.qnode(dev)
+ def circ(time, coeffs):
+ h = qml.dot(coeffs, terms)
+ qml.TrotterProduct(h, time, n=n, order=order)
+ return qml.expval(qml.Hadamard(0))
+
+ @qml.qnode(dev)
+ def reference_circ(time, coeffs):
+ with qml.QueuingManager.stop_recording():
+ decomp = _generate_simple_decomp(coeffs, terms, time, order, n)
+
+ for op in decomp[::-1]:
+ qml.apply(op)
+
+ return qml.expval(qml.Hadamard(0))
+
+ measured_time_grad, measured_coeff_grad = jax.grad(circ, argnums=[0, 1])(time, coeffs)
+ reference_time_grad, reference_coeff_grad = jax.grad(reference_circ, argnums=[0, 1])(
+ time, coeffs
+ )
+ assert allclose(measured_time_grad, reference_time_grad)
+ assert allclose(measured_coeff_grad, reference_coeff_grad)