diff --git a/doc/_static/templates/subroutines/trotter_product.png b/doc/_static/templates/subroutines/trotter_product.png new file mode 100644 index 00000000000..b5fb79ff6ad Binary files /dev/null and b/doc/_static/templates/subroutines/trotter_product.png differ diff --git a/doc/introduction/templates.rst b/doc/introduction/templates.rst index dbde3cfcddf..26ee024653c 100644 --- a/doc/introduction/templates.rst +++ b/doc/introduction/templates.rst @@ -227,6 +227,10 @@ Other useful templates which do not belong to the previous categories can be fou :description: :doc:`ApproxTimeEvolution <../code/api/pennylane.ApproxTimeEvolution>` :figure: _static/templates/subroutines/approx_time_evolution.png +.. gallery-item:: + :description: :doc:`TrotterProduct <../code/api/pennylane.TrotterProduct>` + :figure: _static/templates/subroutines/trotter_product.png + .. gallery-item:: :description: :doc:`Permute <../code/api/pennylane.Permute>` :figure: _static/templates/subroutines/permute.png diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 3f4e6a3ae58..7fb363fdcd7 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -4,6 +4,41 @@

New features since last release

+

Exponentiate Hamiltonians with flexible Trotter products 🐖

+ +* Higher-order Trotter-Suzuki methods are now easily accessible through a new operation + called `TrotterProduct`. + [(#4661)](https://github.com/PennyLaneAI/pennylane/pull/4661) + + Trotterization techniques are an affective route towards accurate and efficient + Hamiltonian simulation. The Suzuki-Trotter product formula allows for the ability + to express higher-order approximations to the matrix exponential of a Hamiltonian, + and it is now available to use in PennyLane via the `TrotterProduct` operation. + Simply specify the `order` of the approximation and the evolution `time`. + + ```python + coeffs = [0.25, 0.75] + ops = [qml.PauliX(0), qml.PauliZ(0)] + H = qml.dot(coeffs, ops) + + dev = qml.device("default.qubit", wires=2) + + @qml.qnode(dev) + def circuit(): + qml.Hadamard(0) + qml.TrotterProduct(H, time=2.4, order=2) + return qml.state() + ``` + + ```pycon + >>> circuit() + [-0.13259524+0.59790098j 0. +0.j -0.13259524-0.77932754j 0. +0.j ] + ``` + + The already-available `ApproxTimeEvolution` operation represents the special case of `order=1`. + It is recommended to switch over to use of `TrotterProduct` because `ApproxTimeEvolution` will be + deprecated and removed in upcoming releases. + * Support drawing QJIT QNode from Catalyst. [(#4609)](https://github.com/PennyLaneAI/pennylane/pull/4609) diff --git a/pennylane/ops/functions/equal.py b/pennylane/ops/functions/equal.py index b7af0f8f422..f6a513ac235 100644 --- a/pennylane/ops/functions/equal.py +++ b/pennylane/ops/functions/equal.py @@ -264,7 +264,9 @@ def _equal_adjoint(op1: Adjoint, op2: Adjoint, **kwargs): # pylint: disable=unused-argument def _equal_exp(op1: Exp, op2: Exp, **kwargs): """Determine whether two Exp objects are equal""" - if op1.coeff != op2.coeff: + rtol, atol = (kwargs["rtol"], kwargs["atol"]) + + if not qml.math.allclose(op1.coeff, op2.coeff, rtol=rtol, atol=atol): return False return qml.equal(op1.base, op2.base) @@ -273,7 +275,9 @@ def _equal_exp(op1: Exp, op2: Exp, **kwargs): # pylint: disable=unused-argument def _equal_sprod(op1: SProd, op2: SProd, **kwargs): """Determine whether two SProd objects are equal""" - if op1.scalar != op2.scalar: + rtol, atol = (kwargs["rtol"], kwargs["atol"]) + + if not qml.math.allclose(op1.scalar, op2.scalar, rtol=rtol, atol=atol): return False return qml.equal(op1.base, op2.base) diff --git a/pennylane/templates/subroutines/__init__.py b/pennylane/templates/subroutines/__init__.py index bff83756839..70c64d46fc6 100644 --- a/pennylane/templates/subroutines/__init__.py +++ b/pennylane/templates/subroutines/__init__.py @@ -35,3 +35,4 @@ from .basis_rotation import BasisRotation from .qsvt import QSVT, qsvt from .select import Select +from .trotter import TrotterProduct diff --git a/pennylane/templates/subroutines/trotter.py b/pennylane/templates/subroutines/trotter.py new file mode 100644 index 00000000000..f843994f5d6 --- /dev/null +++ b/pennylane/templates/subroutines/trotter.py @@ -0,0 +1,292 @@ +# Copyright 2018-2023 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Contains templates for Suzuki-Trotter approximation based subroutines. +""" +import pennylane as qml +from pennylane.operation import Operation +from pennylane.ops import Sum + + +def _scalar(order): + """Compute the scalar used in the recursive expression. + + Args: + order (int): order of Trotter product (assume order is an even integer > 2). + + Returns: + float: scalar to be used in the recursive expression. + """ + root = 1 / (order - 1) + return (4 - 4**root) ** -1 + + +@qml.QueuingManager.stop_recording() +def _recursive_expression(x, order, ops): + """Generate a list of operations using the + recursive expression which defines the Trotter product. + + Args: + x (complex): the evolution 'time' + order (int): the order of the Trotter expansion + ops (Iterable(~.Operators)): a list of terms in the Hamiltonian + + Returns: + list: the approximation as product of exponentials of the Hamiltonian terms + """ + if order == 1: + return [qml.exp(op, x * 1j) for op in ops] + + if order == 2: + return [qml.exp(op, x * 0.5j) for op in ops + ops[::-1]] + + scalar_1 = _scalar(order) + scalar_2 = 1 - 4 * scalar_1 + + ops_lst_1 = _recursive_expression(scalar_1 * x, order - 2, ops) + ops_lst_2 = _recursive_expression(scalar_2 * x, order - 2, ops) + + return (2 * ops_lst_1) + ops_lst_2 + (2 * ops_lst_1) + + +class TrotterProduct(Operation): + r"""An operation representing the Suzuki-Trotter product approximation for the complex matrix + exponential of a given Hamiltonian. + + The Suzuki-Trotter product formula provides a method to approximate the matrix exponential of + Hamiltonian expressed as a linear combination of terms which in general do not commute. Consider + the Hamiltonian :math:`H = \Sigma^{N}_{j=0} O_{j}`, the product formula is constructed using + symmetrized products of the terms in the Hamiltonian. The symmetrized products of order + :math:`m \in [1, 2, 4, ..., 2k]` with :math:`k \in \mathbb{N}` are given by: + + .. math:: + + \begin{align} + S_{1}(t) &= \Pi_{j=0}^{N} \ e^{i t O_{j}} \\ + S_{2}(t) &= \Pi_{j=0}^{N} \ e^{i \frac{t}{2} O_{j}} \cdot \Pi_{j=N}^{0} \ e^{i \frac{t}{2} O_{j}} \\ + &\vdots \\ + S_{m}(t) &= S_{m-2}(p_{m}t)^{2} \cdot S_{m-2}((1-4p_{m})t) \cdot S_{m-2}(p_{m}t)^{2}, + \end{align} + + where the coefficient is :math:`p_{m} = 1 / (4 - \sqrt[m - 1]{4})`. The :math:`m`th order, + :math:`n`-step Suzuki-Trotter approximation is then defined as: + + .. math:: e^{iHt} \approx \left [S_{m}(t / n) \right ]^{n}. + + For more details see `J. Math. Phys. 32, 400 (1991) `_. + + Args: + hamiltonian (Union[.Hamiltonian, .Sum]): The Hamiltonian written as a linear combination + of operators with known matrix exponentials. + time (float): The time of evolution, namely the parameter :math:`t` in :math:`e^{iHt}` + n (int): An integer representing the number of Trotter steps to perform + order (int): An integer (:math:`m`) representing the order of the approximation (must be 1 or even) + check_hermitian (bool): A flag to enable the validation check to ensure this is a valid unitary operator + + Raises: + TypeError: The ``hamiltonian`` is not of type :class:`~.Hamiltonian`, or :class:`~.Sum`. + ValueError: The ``hamiltonian`` must have atleast two terms. + ValueError: One or more of the terms in ``hamiltonian`` are not Hermitian. + ValueError: The ``order`` is not one or a positive even integer. + + **Example** + + .. code-block:: python3 + + coeffs = [0.25, 0.75] + ops = [qml.PauliX(0), qml.PauliZ(0)] + H = qml.dot(coeffs, ops) + + dev = qml.device("default.qubit", wires=2) + @qml.qnode(dev) + def my_circ(): + # Prepare some state + qml.Hadamard(0) + + # Evolve according to H + qml.TrotterProduct(H, time=2.4, order=2) + + # Measure some quantity + return qml.state() + + >>> my_circ() + [-0.13259524+0.59790098j 0. +0.j -0.13259524-0.77932754j 0. +0.j ] + + .. details:: + :title: Usage Details + + One can recover the behaviour of :class:`~.ApproxTimeEvolution` by setting :code:`order=1`. + We can also compute the gradient with respect to the coefficients of the Hamiltonian and the + evolution time: + + .. code-block:: python3 + + @qml.qnode(dev) + def my_circ(c1, c2, time): + # Prepare H: + H = qml.dot([c1, c2], [qml.PauliX(0), qml.PauliZ(0)]) + + # Prepare some state + qml.Hadamard(0) + + # Evolve according to H + qml.TrotterProduct(H, time, order=2) + + # Measure some quantity + return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1)) + + >>> args = np.array([1.23, 4.5, 0.1]) + >>> qml.grad(my_circ)(*tuple(args)) + (tensor(0.00961064, requires_grad=True), tensor(-0.12338274, requires_grad=True), tensor(-5.43401259, requires_grad=True)) + """ + + def __init__( # pylint: disable=too-many-arguments + self, hamiltonian, time, n=1, order=1, check_hermitian=True, id=None + ): + r"""Initialize the TrotterProduct class""" + + if order <= 0 or order != 1 and order % 2 != 0: + raise ValueError( + f"The order of a TrotterProduct must be 1 or a positive even integer, got {order}." + ) + + if isinstance(hamiltonian, qml.Hamiltonian): + coeffs, ops = hamiltonian.terms() + if len(coeffs) < 2: + raise ValueError( + "There should be atleast 2 terms in the Hamiltonian. Otherwise use `qml.exp`" + ) + + hamiltonian = qml.dot(coeffs, ops) + + if not isinstance(hamiltonian, Sum): + raise TypeError( + f"The given operator must be a PennyLane ~.Hamiltonian or ~.Sum got {hamiltonian}" + ) + + if check_hermitian: + for op in hamiltonian.operands: + if not op.is_hermitian: + raise ValueError( + "One or more of the terms in the Hamiltonian may not be Hermitian" + ) + + self._hyperparameters = { + "n": n, + "order": order, + "base": hamiltonian, + "check_hermitian": check_hermitian, + } + super().__init__(time, wires=hamiltonian.wires, id=id) + + def _flatten(self): + """Serialize the operation into trainable and non-trainable components. + + Returns: + data, metadata: The trainable and non-trainable components. + + See ``Operator._unflatten``. + + The data component can be recursive and include other operations. For example, the trainable component of ``Adjoint(RX(1, wires=0))`` + will be the operator ``RX(1, wires=0)``. + + The metadata **must** be hashable. If the hyperparameters contain a non-hashable component, then this + method and ``Operator._unflatten`` should be overridden to provide a hashable version of the hyperparameters. + + **Example:** + + >>> op = qml.Rot(1.2, 2.3, 3.4, wires=0) + >>> qml.Rot._unflatten(*op._flatten()) + Rot(1.2, 2.3, 3.4, wires=[0]) + >>> op = qml.PauliRot(1.2, "XY", wires=(0,1)) + >>> qml.PauliRot._unflatten(*op._flatten()) + PauliRot(1.2, XY, wires=[0, 1]) + + Operators that have trainable components that differ from their ``Operator.data`` must implement their own + ``_flatten`` methods. + + >>> op = qml.ctrl(qml.U2(3.4, 4.5, wires="a"), ("b", "c") ) + >>> op._flatten() + ((U2(3.4, 4.5, wires=['a']),), + (, (True, True), )) + """ + hamiltonian = self.hyperparameters["base"] + time = self.parameters[0] + + hashable_hyperparameters = tuple( + (key, value) for key, value in self.hyperparameters.items() if key != "base" + ) + return (hamiltonian, time), hashable_hyperparameters + + @classmethod + def _unflatten(cls, data, metadata): + """Recreate an operation from its serialized format. + + Args: + data: the trainable component of the operation + metadata: the non-trainable component of the operation. + + The output of ``Operator._flatten`` and the class type must be sufficient to reconstruct the original + operation with ``Operator._unflatten``. + + **Example:** + + >>> op = qml.Rot(1.2, 2.3, 3.4, wires=0) + >>> op._flatten() + ((1.2, 2.3, 3.4), (, ())) + >>> qml.Rot._unflatten(*op._flatten()) + >>> op = qml.PauliRot(1.2, "XY", wires=(0,1)) + >>> op._flatten() + ((1.2,), (, (('pauli_word', 'XY'),))) + >>> op = qml.ctrl(qml.U2(3.4, 4.5, wires="a"), ("b", "c") ) + >>> type(op)._unflatten(*op._flatten()) + Controlled(U2(3.4, 4.5, wires=['a']), control_wires=['b', 'c']) + + """ + hyperparameters_dict = dict(metadata) + return cls(*data, **hyperparameters_dict) + + @staticmethod + def compute_decomposition(*args, **kwargs): + r"""Representation of the operator as a product of other operators (static method). + + .. math:: O = O_1 O_2 \dots O_n. + + .. note:: + + Operations making up the decomposition should be queued within the + ``compute_decomposition`` method. + + .. seealso:: :meth:`~.Operator.decomposition`. + + Args: + *params (list): trainable parameters of the operator, as stored in the ``parameters`` attribute + wires (Iterable[Any], Wires): wires that the operator acts on + **hyperparams (dict): non-trainable hyperparameters of the operator, as stored in the ``hyperparameters`` attribute + + Returns: + list[Operator]: decomposition of the operator + """ + time = args[0] + n = kwargs["n"] + order = kwargs["order"] + ops = kwargs["base"].operands + + decomp = _recursive_expression(time / n, order, ops)[::-1] * n + + if qml.QueuingManager.recording(): + for op in decomp: # apply operators in reverse order of expression + qml.apply(op) + + return decomp diff --git a/tests/ops/functions/test_equal.py b/tests/ops/functions/test_equal.py index c8c5e308002..679bddfa98f 100644 --- a/tests/ops/functions/test_equal.py +++ b/tests/ops/functions/test_equal.py @@ -1438,6 +1438,14 @@ def test_exp_comparison(self, bases_bases_match, params_params_match): op2 = qml.exp(base2, param2) assert qml.equal(op1, op2) == (bases_match and params_match) + def test_exp_comparison_with_tolerance(self): + """Test that equal compares the parameters within a provided tolerance.""" + op1 = qml.exp(qml.PauliX(0), 0.12345) + op2 = qml.exp(qml.PauliX(0), 0.12356) + + assert qml.equal(op1, op2, atol=1e-3, rtol=1e-2) + assert not qml.equal(op1, op2, atol=1e-5, rtol=1e-4) + @pytest.mark.parametrize("bases_bases_match", BASES) @pytest.mark.parametrize("params_params_match", PARAMS) def test_s_prod_comparison(self, bases_bases_match, params_params_match): @@ -1448,6 +1456,14 @@ def test_s_prod_comparison(self, bases_bases_match, params_params_match): op2 = qml.s_prod(param2, base2) assert qml.equal(op1, op2) == (bases_match and params_match) + def test_s_prod_comparison_with_tolerance(self): + """Test that equal compares the parameters within a provided tolerance.""" + op1 = qml.s_prod(0.12345, qml.PauliX(0)) + op2 = qml.s_prod(0.12356, qml.PauliX(0)) + + assert qml.equal(op1, op2, atol=1e-3, rtol=1e-2) + assert not qml.equal(op1, op2, atol=1e-5, rtol=1e-4) + class TestProdComparisons: """Tests comparisons between Prod operators""" diff --git a/tests/templates/test_subroutines/test_trotter.py b/tests/templates/test_subroutines/test_trotter.py new file mode 100644 index 00000000000..0cbc7f3e93a --- /dev/null +++ b/tests/templates/test_subroutines/test_trotter.py @@ -0,0 +1,842 @@ +# Copyright 2018-2023 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for the TrotterProduct template and helper functions. +""" +# pylint: disable=private-access, protected-access +import copy +from functools import reduce + +import pytest + +import pennylane as qml +from pennylane import numpy as qnp +from pennylane.math import allclose, get_interface +from pennylane.templates.subroutines.trotter import _recursive_expression, _scalar + +test_hamiltonians = ( + qml.dot([1, 1, 1], [qml.PauliX(0), qml.PauliY(0), qml.PauliZ(1)]), + qml.dot( + [1.23, -0.45], [qml.s_prod(0.1, qml.PauliX(0)), qml.prod(qml.PauliX(0), qml.PauliZ(1))] + ), # op arith + qml.dot( + [1, -0.5, 0.5], [qml.Identity(wires=[0, 1]), qml.PauliZ(0), qml.PauliZ(0)] + ), # H = Identity +) + +p_4 = (4 - 4 ** (1 / 3)) ** -1 + +test_decompositions = ( + { # (hamiltonian_index, order): decomposition assuming t = 4.2, computed by hand + (0, 1): [ + qml.exp(qml.PauliX(0), 4.2j), + qml.exp(qml.PauliY(0), 4.2j), + qml.exp(qml.PauliZ(1), 4.2j), + ], + (0, 2): [ + qml.exp(qml.PauliX(0), 4.2j / 2), + qml.exp(qml.PauliY(0), 4.2j / 2), + qml.exp(qml.PauliZ(1), 4.2j / 2), + qml.exp(qml.PauliZ(1), 4.2j / 2), + qml.exp(qml.PauliY(0), 4.2j / 2), + qml.exp(qml.PauliX(0), 4.2j / 2), + ], + (0, 4): [ + qml.exp(qml.PauliX(0), p_4 * 4.2j / 2), + qml.exp(qml.PauliY(0), p_4 * 4.2j / 2), + qml.exp(qml.PauliZ(1), p_4 * 4.2j / 2), + qml.exp(qml.PauliZ(1), p_4 * 4.2j / 2), + qml.exp(qml.PauliY(0), p_4 * 4.2j / 2), + qml.exp(qml.PauliX(0), p_4 * 4.2j / 2), + qml.exp(qml.PauliX(0), p_4 * 4.2j / 2), + qml.exp(qml.PauliY(0), p_4 * 4.2j / 2), + qml.exp(qml.PauliZ(1), p_4 * 4.2j / 2), + qml.exp(qml.PauliZ(1), p_4 * 4.2j / 2), + qml.exp(qml.PauliY(0), p_4 * 4.2j / 2), + qml.exp(qml.PauliX(0), p_4 * 4.2j / 2), # S_2(p * t) ^ 2 + qml.exp(qml.PauliX(0), (1 - 4 * p_4) * 4.2j / 2), + qml.exp(qml.PauliY(0), (1 - 4 * p_4) * 4.2j / 2), + qml.exp(qml.PauliZ(1), (1 - 4 * p_4) * 4.2j / 2), + qml.exp(qml.PauliZ(1), (1 - 4 * p_4) * 4.2j / 2), + qml.exp(qml.PauliY(0), (1 - 4 * p_4) * 4.2j / 2), + qml.exp(qml.PauliX(0), (1 - 4 * p_4) * 4.2j / 2), # S_2((1 - 4p) * t) + qml.exp(qml.PauliX(0), p_4 * 4.2j / 2), + qml.exp(qml.PauliY(0), p_4 * 4.2j / 2), + qml.exp(qml.PauliZ(1), p_4 * 4.2j / 2), + qml.exp(qml.PauliZ(1), p_4 * 4.2j / 2), + qml.exp(qml.PauliY(0), p_4 * 4.2j / 2), + qml.exp(qml.PauliX(0), p_4 * 4.2j / 2), + qml.exp(qml.PauliX(0), p_4 * 4.2j / 2), + qml.exp(qml.PauliY(0), p_4 * 4.2j / 2), + qml.exp(qml.PauliZ(1), p_4 * 4.2j / 2), + qml.exp(qml.PauliZ(1), p_4 * 4.2j / 2), + qml.exp(qml.PauliY(0), p_4 * 4.2j / 2), + qml.exp(qml.PauliX(0), p_4 * 4.2j / 2), # S_2(p * t) ^ 2 + ], + (1, 1): [ + qml.exp(qml.s_prod(0.1, qml.PauliX(0)), 1.23 * 4.2j), + qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), -0.45 * 4.2j), + ], + (1, 2): [ + qml.exp(qml.s_prod(0.1, qml.PauliX(0)), 1.23 * 4.2j / 2), + qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), -0.45 * 4.2j / 2), + qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), -0.45 * 4.2j / 2), + qml.exp(qml.s_prod(0.1, qml.PauliX(0)), 1.23 * 4.2j / 2), + ], + (1, 4): [ + qml.exp(qml.s_prod(0.1, qml.PauliX(0)), p_4 * 1.23 * 4.2j / 2), + qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), p_4 * -0.45 * 4.2j / 2), + qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), p_4 * -0.45 * 4.2j / 2), + qml.exp(qml.s_prod(0.1, qml.PauliX(0)), p_4 * 1.23 * 4.2j / 2), + qml.exp(qml.s_prod(0.1, qml.PauliX(0)), p_4 * 1.23 * 4.2j / 2), + qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), p_4 * -0.45 * 4.2j / 2), + qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), p_4 * -0.45 * 4.2j / 2), + qml.exp(qml.s_prod(0.1, qml.PauliX(0)), p_4 * 1.23 * 4.2j / 2), + qml.exp(qml.s_prod(0.1, qml.PauliX(0)), (1 - 4 * p_4) * 1.23 * 4.2j / 2), + qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), (1 - 4 * p_4) * -0.45 * 4.2j / 2), + qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), (1 - 4 * p_4) * -0.45 * 4.2j / 2), + qml.exp(qml.s_prod(0.1, qml.PauliX(0)), (1 - 4 * p_4) * 1.23 * 4.2j / 2), + qml.exp(qml.s_prod(0.1, qml.PauliX(0)), p_4 * 1.23 * 4.2j / 2), + qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), p_4 * -0.45 * 4.2j / 2), + qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), p_4 * -0.45 * 4.2j / 2), + qml.exp(qml.s_prod(0.1, qml.PauliX(0)), p_4 * 1.23 * 4.2j / 2), + qml.exp(qml.s_prod(0.1, qml.PauliX(0)), p_4 * 1.23 * 4.2j / 2), + qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), p_4 * -0.45 * 4.2j / 2), + qml.exp(qml.prod(qml.PauliX(0), qml.PauliZ(1)), p_4 * -0.45 * 4.2j / 2), + qml.exp(qml.s_prod(0.1, qml.PauliX(0)), p_4 * 1.23 * 4.2j / 2), + ], + (2, 1): [ + qml.exp(qml.Identity(wires=[0, 1]), 4.2j), + qml.exp(qml.PauliZ(0), -0.5 * 4.2j), + qml.exp(qml.PauliZ(0), 0.5 * 4.2j), + ], + (2, 2): [ + qml.exp(qml.Identity(wires=[0, 1]), 4.2j / 2), + qml.exp(qml.PauliZ(0), -0.5 * 4.2j / 2), + qml.exp(qml.PauliZ(0), 0.5 * 4.2j / 2), + qml.exp(qml.PauliZ(0), 0.5 * 4.2j / 2), + qml.exp(qml.PauliZ(0), -0.5 * 4.2j / 2), + qml.exp(qml.Identity(wires=[0, 1]), 4.2j / 2), + ], + (2, 4): [ + qml.exp(qml.Identity(wires=[0, 1]), p_4 * 4.2j / 2), + qml.exp(qml.PauliZ(0), p_4 * -0.5 * 4.2j / 2), + qml.exp(qml.PauliZ(0), p_4 * 0.5 * 4.2j / 2), + qml.exp(qml.PauliZ(0), p_4 * 0.5 * 4.2j / 2), + qml.exp(qml.PauliZ(0), p_4 * -0.5 * 4.2j / 2), + qml.exp(qml.Identity(wires=[0, 1]), p_4 * 4.2j / 2), + qml.exp(qml.Identity(wires=[0, 1]), p_4 * 4.2j / 2), + qml.exp(qml.PauliZ(0), p_4 * -0.5 * 4.2j / 2), + qml.exp(qml.PauliZ(0), p_4 * 0.5 * 4.2j / 2), + qml.exp(qml.PauliZ(0), p_4 * 0.5 * 4.2j / 2), + qml.exp(qml.PauliZ(0), p_4 * -0.5 * 4.2j / 2), + qml.exp(qml.Identity(wires=[0, 1]), p_4 * 4.2j / 2), + qml.exp(qml.Identity(wires=[0, 1]), (1 - 4 * p_4) * 4.2j / 2), + qml.exp(qml.PauliZ(0), (1 - 4 * p_4) * -0.5 * 4.2j / 2), + qml.exp(qml.PauliZ(0), (1 - 4 * p_4) * 0.5 * 4.2j / 2), + qml.exp(qml.PauliZ(0), (1 - 4 * p_4) * 0.5 * 4.2j / 2), + qml.exp(qml.PauliZ(0), (1 - 4 * p_4) * -0.5 * 4.2j / 2), + qml.exp(qml.Identity(wires=[0, 1]), (1 - 4 * p_4) * 4.2j / 2), + qml.exp(qml.Identity(wires=[0, 1]), p_4 * 4.2j / 2), + qml.exp(qml.PauliZ(0), p_4 * -0.5 * 4.2j / 2), + qml.exp(qml.PauliZ(0), p_4 * 0.5 * 4.2j / 2), + qml.exp(qml.PauliZ(0), p_4 * 0.5 * 4.2j / 2), + qml.exp(qml.PauliZ(0), p_4 * -0.5 * 4.2j / 2), + qml.exp(qml.Identity(wires=[0, 1]), p_4 * 4.2j / 2), + qml.exp(qml.Identity(wires=[0, 1]), p_4 * 4.2j / 2), + qml.exp(qml.PauliZ(0), p_4 * -0.5 * 4.2j / 2), + qml.exp(qml.PauliZ(0), p_4 * 0.5 * 4.2j / 2), + qml.exp(qml.PauliZ(0), p_4 * 0.5 * 4.2j / 2), + qml.exp(qml.PauliZ(0), p_4 * -0.5 * 4.2j / 2), + qml.exp(qml.Identity(wires=[0, 1]), p_4 * 4.2j / 2), + ], + } +) + + +def _generate_simple_decomp(coeffs, ops, time, order, n): + """Given coeffs, ops and a time argument in a given framework, generate the + Trotter product for order and number of trotter steps.""" + decomp = [] + if order == 1: + decomp.extend(qml.exp(op, coeff * (time / n) * 1j) for coeff, op in zip(coeffs, ops)) + + coeffs_ops = zip(coeffs, ops) + + if get_interface(coeffs) == "torch": + import torch + + coeffs_ops_reversed = zip(torch.flip(coeffs, dims=(0,)), ops[::-1]) + else: + coeffs_ops_reversed = zip(coeffs[::-1], ops[::-1]) + + if order == 2: + decomp.extend(qml.exp(op, coeff * (time / n) * 1j / 2) for coeff, op in coeffs_ops) + decomp.extend(qml.exp(op, coeff * (time / n) * 1j / 2) for coeff, op in coeffs_ops_reversed) + + if order == 4: + s_2 = [] + s_2_p = [] + + for coeff, op in coeffs_ops: + s_2.append(qml.exp(op, (p_4 * coeff) * (time / n) * 1j / 2)) + s_2_p.append(qml.exp(op, ((1 - (4 * p_4)) * coeff) * (time / n) * 1j / 2)) + + for coeff, op in coeffs_ops_reversed: + s_2.append(qml.exp(op, (p_4 * coeff) * (time / n) * 1j / 2)) + s_2_p.append(qml.exp(op, ((1 - (4 * p_4)) * coeff) * (time / n) * 1j / 2)) + + decomp = (s_2 * 2) + s_2_p + (s_2 * 2) + + return decomp * n + + +class TestInitialization: + """Test the TrotterProduct class initializes correctly.""" + + @pytest.mark.parametrize( + "hamiltonian, raise_error", + ( + (qml.PauliX(0), True), + (qml.prod(qml.PauliX(0), qml.PauliZ(1)), True), + (qml.Hamiltonian([1.23, 3.45], [qml.PauliX(0), qml.PauliZ(1)]), False), + (qml.dot([1.23, 3.45], [qml.PauliX(0), qml.PauliZ(1)]), False), + ), + ) + def test_error_type(self, hamiltonian, raise_error): + """Test an error is raised of an incorrect type is passed""" + if raise_error: + with pytest.raises( + TypeError, match="The given operator must be a PennyLane ~.Hamiltonian or ~.Sum" + ): + qml.TrotterProduct(hamiltonian, time=1.23) + + else: + try: + qml.TrotterProduct(hamiltonian, time=1.23) + except TypeError: + assert False # test should fail if an error was raised when we expect it not to + + @pytest.mark.parametrize( + "hamiltonian", + ( + qml.Hamiltonian([1.23, 4 + 5j], [qml.PauliX(0), qml.PauliZ(1)]), + qml.dot([1.23, 4 + 5j], [qml.PauliX(0), qml.PauliZ(1)]), + qml.dot([1.23, 0.5], [qml.RY(1.23, 0), qml.RZ(3.45, 1)]), + ), + ) + def test_error_hermiticity(self, hamiltonian): + """Test that an error is raised if any terms in + the Hamiltonian are not Hermitian and check_hermitian is True.""" + + with pytest.raises( + ValueError, match="One or more of the terms in the Hamiltonian may not be Hermitian" + ): + qml.TrotterProduct(hamiltonian, time=0.5) + + try: + qml.TrotterProduct(hamiltonian, time=0.5, check_hermitian=False) + except ValueError: + assert False # No error should be raised if the check_hermitian flag is disabled + + def test_error_hamiltonian(self): + """Test that an error is raised if the input hamultonian has only 1 term.""" + with pytest.raises(ValueError, match="There should be atleast 2 terms in the Hamiltonian."): + qml.TrotterProduct(qml.Hamiltonian([1.0], [qml.PauliX(0)]), 1.23, n=2, order=4) + + @pytest.mark.parametrize("order", (-1, 0, 0.5, 3, 7.0)) + def test_error_order(self, order): + """Test that an error is raised if 'order' is not one or positive even number.""" + time = 0.5 + hamiltonian = qml.dot([1.23, 3.45], [qml.PauliX(0), qml.PauliZ(1)]) + + with pytest.raises( + ValueError, match="The order of a TrotterProduct must be 1 or a positive even integer," + ): + qml.TrotterProduct(hamiltonian, time, order=order) + + @pytest.mark.parametrize("hamiltonian", test_hamiltonians) + def test_init_correctly(self, hamiltonian): + """Test that all of the attributes are initalized correctly.""" + time, n, order = (4.2, 10, 4) + op = qml.TrotterProduct(hamiltonian, time, n=n, order=order, check_hermitian=False) + + assert op.wires == hamiltonian.wires + assert op.parameters == [time] + assert op.data == (time,) + assert op.hyperparameters == { + "base": hamiltonian, + "n": n, + "order": order, + "check_hermitian": False, + } + + @pytest.mark.parametrize("n", (1, 2, 5, 10)) + @pytest.mark.parametrize("time", (0.5, 1.2)) + @pytest.mark.parametrize("order", (1, 2, 4)) + @pytest.mark.parametrize("hamiltonian", test_hamiltonians) + def test_copy(self, hamiltonian, time, n, order): + """Test that we can make deep and shallow copies of TrotterProduct correctly.""" + op = qml.TrotterProduct(hamiltonian, time, n=n, order=order) + new_op = copy.copy(op) + + assert op.wires == new_op.wires + assert op.parameters == new_op.parameters + assert op.data == new_op.data + assert op.hyperparameters == new_op.hyperparameters + assert op is not new_op + + @pytest.mark.parametrize("hamiltonian", test_hamiltonians) + def test_flatten_and_unflatten(self, hamiltonian): + """Test that the flatten and unflatten methods work correctly.""" + time, n, order = (4.2, 10, 4) + op = qml.TrotterProduct(hamiltonian, time, n=n, order=order) + + data, metadata = op._flatten() + assert qml.equal(data[0], hamiltonian) + assert data[1] == time + assert dict(metadata) == {"n": n, "order": order, "check_hermitian": True} + + new_op = type(op)._unflatten(data, metadata) + assert qml.equal(op, new_op) + assert new_op is not op + + +class TestPrivateFunctions: + """Test the private helper functions.""" + + @pytest.mark.parametrize( + "order, result", + ( + (4, 0.4144907717943757), + (6, 0.3730658277332728), + (8, 0.35958464934999224), + ), + ) # Computed by hand + def test_private_scalar(self, order, result): + """Test the _scalar function correctly computes the parameter scalar.""" + s = _scalar(order) + assert qnp.isclose(s, result) + + expected_expansions = ( # for H = X0 + Y0 + Z1, t = 1.23, computed by hand + [ # S_1(t) + qml.exp(qml.PauliX(0), 1.23j), + qml.exp(qml.PauliY(0), 1.23j), + qml.exp(qml.PauliZ(1), 1.23j), + ], + [ # S_2(t) + qml.exp(qml.PauliX(0), 1.23j / 2), + qml.exp(qml.PauliY(0), 1.23j / 2), + qml.exp(qml.PauliZ(1), 1.23j / 2), + qml.exp(qml.PauliZ(1), 1.23j / 2), + qml.exp(qml.PauliY(0), 1.23j / 2), + qml.exp(qml.PauliX(0), 1.23j / 2), + ], + [ # S_4(t) + qml.exp(qml.PauliX(0), p_4 * 1.23j / 2), + qml.exp(qml.PauliY(0), p_4 * 1.23j / 2), + qml.exp(qml.PauliZ(1), p_4 * 1.23j / 2), + qml.exp(qml.PauliZ(1), p_4 * 1.23j / 2), + qml.exp(qml.PauliY(0), p_4 * 1.23j / 2), + qml.exp(qml.PauliX(0), p_4 * 1.23j / 2), + qml.exp(qml.PauliX(0), p_4 * 1.23j / 2), + qml.exp(qml.PauliY(0), p_4 * 1.23j / 2), + qml.exp(qml.PauliZ(1), p_4 * 1.23j / 2), + qml.exp(qml.PauliZ(1), p_4 * 1.23j / 2), + qml.exp(qml.PauliY(0), p_4 * 1.23j / 2), + qml.exp(qml.PauliX(0), p_4 * 1.23j / 2), # S_2(p * t) ^ 2 + qml.exp(qml.PauliX(0), (1 - 4 * p_4) * 1.23j / 2), + qml.exp(qml.PauliY(0), (1 - 4 * p_4) * 1.23j / 2), + qml.exp(qml.PauliZ(1), (1 - 4 * p_4) * 1.23j / 2), + qml.exp(qml.PauliZ(1), (1 - 4 * p_4) * 1.23j / 2), + qml.exp(qml.PauliY(0), (1 - 4 * p_4) * 1.23j / 2), + qml.exp(qml.PauliX(0), (1 - 4 * p_4) * 1.23j / 2), # S_2((1 - 4p) * t) + qml.exp(qml.PauliX(0), p_4 * 1.23j / 2), + qml.exp(qml.PauliY(0), p_4 * 1.23j / 2), + qml.exp(qml.PauliZ(1), p_4 * 1.23j / 2), + qml.exp(qml.PauliZ(1), p_4 * 1.23j / 2), + qml.exp(qml.PauliY(0), p_4 * 1.23j / 2), + qml.exp(qml.PauliX(0), p_4 * 1.23j / 2), + qml.exp(qml.PauliX(0), p_4 * 1.23j / 2), + qml.exp(qml.PauliY(0), p_4 * 1.23j / 2), + qml.exp(qml.PauliZ(1), p_4 * 1.23j / 2), + qml.exp(qml.PauliZ(1), p_4 * 1.23j / 2), + qml.exp(qml.PauliY(0), p_4 * 1.23j / 2), + qml.exp(qml.PauliX(0), p_4 * 1.23j / 2), # S_2(p * t) ^ 2 + ], + ) + + @pytest.mark.parametrize("order, expected_expansion", zip((1, 2, 4), expected_expansions)) + def test_recursive_expression_no_queue(self, order, expected_expansion): + """Test the _recursive_expression function correctly generates the decomposition""" + ops = [qml.PauliX(0), qml.PauliY(0), qml.PauliZ(1)] + + with qml.tape.QuantumTape() as tape: + decomp = _recursive_expression(1.23, order, ops) + + assert tape.operations == [] # No queuing! + assert all( + qml.equal(op1, op2) for op1, op2 in zip(decomp, expected_expansion) + ) # Expected expression + + +class TestDecomposition: + """Test the decomposition of the TrotterProduct class.""" + + @pytest.mark.parametrize("order", (1, 2, 4)) + @pytest.mark.parametrize("hamiltonian_index, hamiltonian", enumerate(test_hamiltonians)) + def test_compute_decomposition(self, hamiltonian, hamiltonian_index, order): + """Test the decomposition is correct and queues""" + op = qml.TrotterProduct(hamiltonian, 4.2, order=order) + with qml.tape.QuantumTape() as tape: + decomp = op.compute_decomposition(*op.parameters, **op.hyperparameters) + + assert decomp == tape.operations # queue matches decomp with circuit ordering + + decomp = [qml.simplify(op) for op in decomp] + true_decomp = [ + qml.simplify(op) for op in test_decompositions[(hamiltonian_index, order)][::-1] + ] + assert all( + qml.equal(op1, op2) for op1, op2 in zip(decomp, true_decomp) + ) # decomp is correct + + @pytest.mark.parametrize("order", (1, 2)) + @pytest.mark.parametrize("num_steps", (1, 2, 3)) + def test_compute_decomposition_n_steps(self, num_steps, order): + """Test the decomposition is correct when we set the number of trotter steps""" + time = 0.5 + hamiltonian = qml.sum(qml.PauliX(0), qml.PauliZ(0)) + + if order == 1: + base_decomp = [ + qml.exp(qml.PauliZ(0), 0.5j / num_steps), + qml.exp(qml.PauliX(0), 0.5j / num_steps), + ] + if order == 2: + base_decomp = [ + qml.exp(qml.PauliX(0), 0.25j / num_steps), + qml.exp(qml.PauliZ(0), 0.25j / num_steps), + qml.exp(qml.PauliZ(0), 0.25j / num_steps), + qml.exp(qml.PauliX(0), 0.25j / num_steps), + ] + + true_decomp = base_decomp * num_steps + + op = qml.TrotterProduct(hamiltonian, time, n=num_steps, order=order) + decomp = op.compute_decomposition(*op.parameters, **op.hyperparameters) + assert all(qml.equal(op1, op2) for op1, op2 in zip(decomp, true_decomp)) + + +class TestIntegration: + """Test that the TrotterProduct can be executed and differentiated + through all interfaces.""" + + # Circuit execution tests: + @pytest.mark.parametrize("order", (1, 2, 4)) + @pytest.mark.parametrize("hamiltonian_index, hamiltonian", enumerate(test_hamiltonians)) + def test_execute_circuit(self, hamiltonian, hamiltonian_index, order): + """Test that the gate executes correctly in a circuit.""" + wires = hamiltonian.wires + dev = qml.device("default.qubit", wires=wires) + + @qml.qnode(dev) + def circ(): + qml.TrotterProduct(hamiltonian, time=4.2, order=order) + return qml.state() + + initial_state = qnp.zeros(2 ** (len(wires))) + initial_state[0] = 1 + + expected_state = ( + reduce( + lambda x, y: x @ y, + [ + qml.matrix(op, wire_order=wires) + for op in test_decompositions[(hamiltonian_index, order)] + ], + ) + @ initial_state + ) + state = circ() + + assert qnp.allclose(expected_state, state) + + @pytest.mark.parametrize("order", (1, 2)) + @pytest.mark.parametrize("num_steps", (1, 2, 3)) + def test_execute_circuit_n_steps(self, num_steps, order): + """Test that the circuit executes as expected when we set the number of trotter steps""" + time = 0.5 + hamiltonian = qml.sum(qml.PauliX(0), qml.PauliZ(0)) + + if order == 1: + base_decomp = [ + qml.exp(qml.PauliZ(0), 0.5j / num_steps), + qml.exp(qml.PauliX(0), 0.5j / num_steps), + ] + if order == 2: + base_decomp = [ + qml.exp(qml.PauliX(0), 0.25j / num_steps), + qml.exp(qml.PauliZ(0), 0.25j / num_steps), + qml.exp(qml.PauliZ(0), 0.25j / num_steps), + qml.exp(qml.PauliX(0), 0.25j / num_steps), + ] + + true_decomp = base_decomp * num_steps + + wires = hamiltonian.wires + dev = qml.device("default.qubit", wires=wires) + + @qml.qnode(dev) + def circ(): + qml.TrotterProduct(hamiltonian, time, n=num_steps, order=order) + return qml.state() + + initial_state = qnp.zeros(2 ** (len(wires))) + initial_state[0] = 1 + + expected_state = ( + reduce( + lambda x, y: x @ y, [qml.matrix(op, wire_order=wires) for op in true_decomp[::-1]] + ) + @ initial_state + ) + state = circ() + assert qnp.allclose(expected_state, state) + + @pytest.mark.jax + @pytest.mark.parametrize("time", (0.5, 1, 2)) + def test_jax_execute(self, time): + """Test that the gate executes correctly in the jax interface.""" + from jax import numpy as jnp + + time = jnp.array(time) + coeffs = jnp.array([1.23, -0.45]) + terms = [qml.PauliX(0), qml.PauliZ(0)] + + dev = qml.device("default.qubit", wires=2) + + @qml.qnode(dev) + def circ(time, coeffs): + h = qml.dot(coeffs, terms) + qml.TrotterProduct(h, time, n=2, order=2) + return qml.state() + + initial_state = jnp.array([1.0, 0.0, 0.0, 0.0]) + + expected_product_sequence = _generate_simple_decomp(coeffs, terms, time, order=2, n=2) + + expected_state = ( + reduce( + lambda x, y: x @ y, + [qml.matrix(op, wire_order=range(2)) for op in expected_product_sequence], + ) + @ initial_state + ) + + state = circ(time, coeffs) + assert allclose(expected_state, state) + + @pytest.mark.jax + @pytest.mark.parametrize("time", (0.5, 1, 2)) + def test_jaxjit_execute(self, time): + """Test that the gate executes correctly in the jax interface with jit.""" + import jax + from jax import numpy as jnp + + time = jnp.array(time) + c1 = jnp.array(1.23) + c2 = jnp.array(-0.45) + terms = [qml.PauliX(0), qml.PauliZ(0)] + + dev = qml.device("default.qubit", wires=2) + + @jax.jit + @qml.qnode(dev, interface="jax") + def circ(time, c1, c2): + h = qml.sum( + qml.s_prod(c1, terms[0]), + qml.s_prod(c2, terms[1]), + ) + qml.TrotterProduct(h, time, n=2, order=2, check_hermitian=False) + return qml.state() + + initial_state = jnp.array([1.0, 0.0, 0.0, 0.0]) + + expected_product_sequence = _generate_simple_decomp([c1, c2], terms, time, order=2, n=2) + + expected_state = ( + reduce( + lambda x, y: x @ y, + [qml.matrix(op, wire_order=range(2)) for op in expected_product_sequence], + ) + @ initial_state + ) + + state = circ(time, c1, c2) + assert allclose(expected_state, state) + + @pytest.mark.tf + @pytest.mark.parametrize("time", (0.5, 1, 2)) + def test_tf_execute(self, time): + """Test that the gate executes correctly in the tensorflow interface.""" + import tensorflow as tf + + time = tf.Variable(time, dtype=tf.complex128) + coeffs = tf.Variable([1.23, -0.45], dtype=tf.complex128) + terms = [qml.PauliX(0), qml.PauliZ(0)] + + dev = qml.device("default.qubit", wires=2) + + @qml.qnode(dev) + def circ(time, coeffs): + h = qml.sum( + qml.s_prod(coeffs[0], terms[0]), + qml.s_prod(coeffs[1], terms[1]), + ) + qml.TrotterProduct(h, time, n=2, order=2) + + return qml.state() + + initial_state = tf.Variable([1.0, 0.0, 0.0, 0.0], dtype=tf.complex128) + + expected_product_sequence = _generate_simple_decomp(coeffs, terms, time, order=2, n=2) + + expected_state = tf.linalg.matvec( + reduce( + lambda x, y: x @ y, + [qml.matrix(op, wire_order=range(2)) for op in expected_product_sequence], + ), + initial_state, + ) + + state = circ(time, coeffs) + assert allclose(expected_state, state) + + @pytest.mark.torch + @pytest.mark.parametrize("time", (0.5, 1, 2)) + def test_torch_execute(self, time): + """Test that the gate executes correctly in the torch interface.""" + import torch + + time = torch.tensor(time, dtype=torch.complex64, requires_grad=True) + coeffs = torch.tensor([1.23, -0.45], dtype=torch.complex64, requires_grad=True) + terms = [qml.PauliX(0), qml.PauliZ(0)] + + dev = qml.device("default.qubit", wires=2) + + @qml.qnode(dev) + def circ(time, coeffs): + h = qml.dot(coeffs, terms) + qml.TrotterProduct(h, time, n=2, order=2) + return qml.state() + + initial_state = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.complex64) + + expected_product_sequence = _generate_simple_decomp(coeffs, terms, time, order=2, n=2) + + expected_state = ( + reduce( + lambda x, y: x @ y, + [qml.matrix(op, wire_order=range(2)) for op in expected_product_sequence], + ) + @ initial_state + ) + + state = circ(time, coeffs) + assert allclose(expected_state, state) + + @pytest.mark.autograd + @pytest.mark.parametrize("time", (0.5, 1, 2)) + def test_autograd_execute(self, time): + """Test that the gate executes correctly in the autograd interface.""" + time = qnp.array(time) + coeffs = qnp.array([1.23, -0.45]) + terms = [qml.PauliX(0), qml.PauliZ(0)] + + dev = qml.device("default.qubit", wires=2) + + @qml.qnode(dev) + def circ(time, coeffs): + h = qml.dot(coeffs, terms) + qml.TrotterProduct(h, time, n=2, order=2) + return qml.state() + + initial_state = qnp.array([1.0, 0.0, 0.0, 0.0]) + + expected_product_sequence = _generate_simple_decomp(coeffs, terms, time, order=2, n=2) + + expected_state = ( + reduce( + lambda x, y: x @ y, + [qml.matrix(op, wire_order=range(2)) for op in expected_product_sequence], + ) + @ initial_state + ) + + state = circ(time, coeffs) + assert qnp.allclose(expected_state, state) + + @pytest.mark.autograd + @pytest.mark.parametrize("order, n", ((1, 1), (1, 2), (2, 1), (4, 1))) + def test_autograd_gradient(self, order, n): + """Test that the gradient is computed correctly""" + time = qnp.array(1.5) + coeffs = qnp.array([1.23, -0.45]) + terms = [qml.PauliX(0), qml.PauliZ(0)] + + dev = qml.device("default.qubit", wires=1) + + @qml.qnode(dev) + def circ(time, coeffs): + h = qml.dot(coeffs, terms) + qml.TrotterProduct(h, time, n=n, order=order) + return qml.expval(qml.Hadamard(0)) + + @qml.qnode(dev) + def reference_circ(time, coeffs): + with qml.QueuingManager.stop_recording(): + decomp = _generate_simple_decomp(coeffs, terms, time, order, n) + + for op in decomp[::-1]: + qml.apply(op) + + return qml.expval(qml.Hadamard(0)) + + measured_time_grad, measured_coeff_grad = qml.grad(circ)(time, coeffs) + reference_time_grad, reference_coeff_grad = qml.grad(reference_circ)(time, coeffs) + assert allclose(measured_time_grad, reference_time_grad) + assert allclose(measured_coeff_grad, reference_coeff_grad) + + @pytest.mark.torch + @pytest.mark.parametrize("order, n", ((1, 1), (1, 2), (2, 1), (4, 1))) + def test_torch_gradient(self, order, n): + """Test that the gradient is computed correctly using torch""" + import torch + + time = torch.tensor(1.5, dtype=torch.complex64, requires_grad=True) + coeffs = torch.tensor([1.23, -0.45], dtype=torch.complex64, requires_grad=True) + time_reference = torch.tensor(1.5, dtype=torch.complex64, requires_grad=True) + coeffs_reference = torch.tensor([1.23, -0.45], dtype=torch.complex64, requires_grad=True) + terms = [qml.PauliX(0), qml.PauliZ(0)] + + dev = qml.device("default.qubit", wires=1) + + @qml.qnode(dev) + def circ(time, coeffs): + h = qml.dot(coeffs, terms) + qml.TrotterProduct(h, time, n=n, order=order) + return qml.expval(qml.Hadamard(0)) + + @qml.qnode(dev) + def reference_circ(time, coeffs): + with qml.QueuingManager.stop_recording(): + decomp = _generate_simple_decomp(coeffs, terms, time, order, n) + + for op in decomp[::-1]: + qml.apply(op) + + return qml.expval(qml.Hadamard(0)) + + res_circ = circ(time, coeffs) + res_circ.backward() + measured_time_grad = time.grad + measured_coeff_grad = coeffs.grad + + ref_circ = reference_circ(time_reference, coeffs_reference) + ref_circ.backward() + reference_time_grad = time_reference.grad + reference_coeff_grad = coeffs_reference.grad + + assert allclose(measured_time_grad, reference_time_grad) + assert allclose(measured_coeff_grad, reference_coeff_grad) + + @pytest.mark.tf + @pytest.mark.parametrize("order, n", ((1, 1), (1, 2), (2, 1), (4, 1))) + def test_tf_gradient(self, order, n): + """Test that the gradient is computed correctly using tensorflow""" + import tensorflow as tf + + time = tf.Variable(1.5, dtype=tf.complex128) + coeffs = tf.Variable([1.23, -0.45], dtype=tf.complex128) + terms = [qml.PauliX(0), qml.PauliZ(0)] + + dev = qml.device("default.qubit", wires=1) + + @qml.qnode(dev) + def circ(time, coeffs): + h = qml.sum( + qml.s_prod(coeffs[0], terms[0]), + qml.s_prod(coeffs[1], terms[1]), + ) + qml.TrotterProduct(h, time, n=n, order=order) + return qml.expval(qml.Hadamard(0)) + + @qml.qnode(dev) + def reference_circ(time, coeffs): + with qml.QueuingManager.stop_recording(): + decomp = _generate_simple_decomp(coeffs, terms, time, order, n) + + for op in decomp[::-1]: + qml.apply(op) + + return qml.expval(qml.Hadamard(0)) + + with tf.GradientTape() as tape: + result = circ(time, coeffs) + + measured_time_grad, measured_coeff_grad = tape.gradient(result, (time, coeffs)) + + with tf.GradientTape() as tape: + result = reference_circ(time, coeffs) + + reference_time_grad, reference_coeff_grad = tape.gradient(result, (time, coeffs)) + assert allclose(measured_time_grad, reference_time_grad) + assert allclose(measured_coeff_grad, reference_coeff_grad) + + @pytest.mark.jax + @pytest.mark.parametrize("order, n", ((1, 1), (1, 2), (2, 1), (4, 1))) + def test_jax_gradient(self, order, n): + """Test that the gradient is computed correctly""" + import jax + from jax import numpy as jnp + + time = jnp.array(1.5) + coeffs = jnp.array([1.23, -0.45]) + terms = [qml.PauliX(0), qml.PauliZ(0)] + + dev = qml.device("default.qubit", wires=1) + + @qml.qnode(dev) + def circ(time, coeffs): + h = qml.dot(coeffs, terms) + qml.TrotterProduct(h, time, n=n, order=order) + return qml.expval(qml.Hadamard(0)) + + @qml.qnode(dev) + def reference_circ(time, coeffs): + with qml.QueuingManager.stop_recording(): + decomp = _generate_simple_decomp(coeffs, terms, time, order, n) + + for op in decomp[::-1]: + qml.apply(op) + + return qml.expval(qml.Hadamard(0)) + + measured_time_grad, measured_coeff_grad = jax.grad(circ, argnums=[0, 1])(time, coeffs) + reference_time_grad, reference_coeff_grad = jax.grad(reference_circ, argnums=[0, 1])( + time, coeffs + ) + assert allclose(measured_time_grad, reference_time_grad) + assert allclose(measured_coeff_grad, reference_coeff_grad)