Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TrotterProduct template #4661

Merged
merged 38 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
96e10ea
trotter
Jaybsoni Sep 13, 2023
d28f1fd
testing
Jaybsoni Sep 21, 2023
a604d90
Adding test file and doc-strings
Jaybsoni Oct 4, 2023
176b274
Adding tests for TrotterProduct
Jaybsoni Oct 10, 2023
e17d3b7
Merge branch 'master' into trotter_product
Jaybsoni Oct 10, 2023
7a151d1
fix doc-string
Jaybsoni Oct 10, 2023
ec3c585
Fix decomp test and doc-string
Jaybsoni Oct 10, 2023
a46ae64
add execution test
Jaybsoni Oct 10, 2023
80de04f
Apply suggestions from code review
Jaybsoni Oct 11, 2023
805b50a
Final tests and polishing
Jaybsoni Oct 11, 2023
d75bfc6
codefactor
Jaybsoni Oct 11, 2023
02e735a
Merge branch 'master' into trotter_product
Jaybsoni Oct 11, 2023
4c60eb4
Apply suggestions from code review
Jaybsoni Oct 11, 2023
725e34b
lint
Jaybsoni Oct 12, 2023
9536b48
lint
Jaybsoni Oct 12, 2023
7a5211a
add template image
soranjh Oct 14, 2023
8bb964e
correct typos
soranjh Oct 14, 2023
03fbb87
add and update args
soranjh Oct 14, 2023
ff7738b
fix pylint
soranjh Oct 14, 2023
78579ef
add reference
soranjh Oct 14, 2023
341e4a1
correct typo
soranjh Oct 14, 2023
fc9bacd
correct docstring math
soranjh Oct 14, 2023
b20d1fb
modify docstring
soranjh Oct 14, 2023
42ac504
fix pylint
soranjh Oct 14, 2023
946d6ad
modify doc and image
soranjh Oct 16, 2023
5c3fdda
Adding Interface and Gradient tests to `TrotterProduct` (#4677)
Jaybsoni Oct 18, 2023
6f8ae41
update changelog
soranjh Oct 18, 2023
f691986
update changelog
soranjh Oct 18, 2023
2f13527
Apply suggestions from code review
Jaybsoni Oct 18, 2023
d69fe6f
Apply suggestions from code review
Jaybsoni Oct 18, 2023
8b974b8
address review comments
Jaybsoni Oct 18, 2023
7a7bacd
Merge branch 'trotter_product' of https://github.com/PennyLaneAI/penn…
soranjh Oct 18, 2023
c0e744b
update changelog and image
soranjh Oct 18, 2023
2ddf0ea
apply code review comments
soranjh Oct 18, 2023
285f420
lint
Jaybsoni Oct 19, 2023
198ecf6
lint
Jaybsoni Oct 19, 2023
c08127e
Update pennylane/templates/subroutines/trotter.py
Jaybsoni Oct 19, 2023
971d560
Merge branch 'master' into trotter_product
Jaybsoni Oct 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
soranjh marked this conversation as resolved.
Show resolved Hide resolved
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions doc/introduction/templates.rst
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ Other useful templates which do not belong to the previous categories can be fou
:description: :doc:`ApproxTimeEvolution <../code/api/pennylane.ApproxTimeEvolution>`
:figure: _static/templates/subroutines/approx_time_evolution.png

.. gallery-item::
:description: :doc:`TrotterProduct <../code/api/pennylane.TrotterProduct>`
:figure: _static/templates/subroutines/trotter_product.png

.. gallery-item::
:description: :doc:`Permute <../code/api/pennylane.Permute>`
:figure: _static/templates/subroutines/permute.png
Expand Down
8 changes: 6 additions & 2 deletions pennylane/ops/functions/equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,9 @@ def _equal_adjoint(op1: Adjoint, op2: Adjoint, **kwargs):
# pylint: disable=unused-argument
def _equal_exp(op1: Exp, op2: Exp, **kwargs):
"""Determine whether two Exp objects are equal"""
if op1.coeff != op2.coeff:
rtol, atol = (kwargs["rtol"], kwargs["atol"])

if not qml.math.allclose(op1.coeff, op2.coeff, rtol=rtol, atol=atol):
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved
return False
return qml.equal(op1.base, op2.base)

Expand All @@ -273,7 +275,9 @@ def _equal_exp(op1: Exp, op2: Exp, **kwargs):
# pylint: disable=unused-argument
def _equal_sprod(op1: SProd, op2: SProd, **kwargs):
"""Determine whether two SProd objects are equal"""
if op1.scalar != op2.scalar:
rtol, atol = (kwargs["rtol"], kwargs["atol"])

if not qml.math.allclose(op1.scalar, op2.scalar, rtol=rtol, atol=atol):
return False
return qml.equal(op1.base, op2.base)

Expand Down
1 change: 1 addition & 0 deletions pennylane/templates/subroutines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@
from .basis_rotation import BasisRotation
from .qsvt import QSVT, qsvt
from .select import Select
from .trotter import TrotterProduct
285 changes: 285 additions & 0 deletions pennylane/templates/subroutines/trotter.py
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
# Copyright 2018-2021 Xanadu Quantum Technologies Inc.
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains templates for Suzuki-Trotter approximation based subroutines.
"""
import pennylane as qml
from pennylane.operation import Operation
from pennylane.ops import Sum


def _scalar(order):
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved
"""Compute the scalar used in the recursive expression.
soranjh marked this conversation as resolved.
Show resolved Hide resolved

Args:
order (int): order of Trotter product (assume order is an even integer > 2).

Returns:
float: scalar to be used in the recursive expression.
"""
root = 1 / (order - 1)
return (4 - 4**root) ** -1
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved


@qml.QueuingManager.stop_recording()
def _recursive_expression(x, order, ops):
"""Generate a list of operations using the
recursive expression which defines the Trotter product.

Args:
x (complex): the evolution 'time'
order (int): the order of the Trotter Expansion
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved
ops (Iterable(~.Operators)): a list of terms in the Hamiltonian

Returns:
List: the approximation as product of exponentials of the Hamiltonian terms
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved
"""
if order == 1:
return [qml.exp(op, x * 1j) for op in ops]

if order == 2:
return [qml.exp(op, x * 0.5j) for op in ops + ops[::-1]]

scalar_1 = _scalar(order)
scalar_2 = 1 - 4 * scalar_1

ops_lst_1 = _recursive_expression(scalar_1 * x, order - 2, ops)
ops_lst_2 = _recursive_expression(scalar_2 * x, order - 2, ops)
trbromley marked this conversation as resolved.
Show resolved Hide resolved

return (2 * ops_lst_1) + ops_lst_2 + (2 * ops_lst_1)
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved


class TrotterProduct(Operation):
r"""An operation representing the Suzuki-Trotter product approximation for the complex matrix
exponential of a given Hamiltonian.

The Suzuki-Trotter product formula provides a method to approximate the matrix exponential of
Hamiltonian expressed as a linear combination of terms which in general do not commute. Consider
the Hamiltonian :math:`H = \Sigma^{N}_{j=0} O_{j}`, the product formula is constructed using
symmetrized products of the terms in the Hamiltonian. The symmetrized products of order
:math:`m \in [1, 2, 4, ..., 2k]` with :math:`k \in \mathbb{N}` are given by:

.. math::

\begin{align}
S_{1}(t) &= \Pi_{j=0}^{N} \ e^{i t O_{j}} \\
S_{2}(t) &= \Pi_{j=0}^{N} \ e^{i \frac{t}{2} O_{j}} \cdot \Pi_{j=N}^{0} \ e^{i \frac{t}{2} O_{j}} \\
&\vdots \\
S_{2k}(t) &= S_{2k-2}(p_{2k}t)^{2} \cdot S_{2k-2}((1-4p_{2k})t) \cdot S_{2k-2}(p_{2k}t)^{2},
\end{align}

where the coefficient is :math:`p_{2k} = 1 / (4 - \sqrt[2k - 1]{4})`. The :math:`2k`th order,
:math:`n`-step Suzuki-Trotter approximation is then defined as:

.. math:: e^{iHt} \approx \left [S_{2k}(t / n) \right ]^{n}.
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved

For more details see `J. Math. Phys. 32, 400 (1991) <https://pubs.aip.org/aip/jmp/article-abstract/32/2/400/229229>`_.

Args:
hamiltonian (Union[~.Hamiltonian, ~.Sum]): The Hamiltonian written in terms of products of
Pauli gates
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved
time (int or float): The time of evolution, namely the parameter :math:`t` in :math:`e^{-iHt}`
n (int): An integer representing the number of Trotter steps to perform
order (int): An integer representing the order of the approximation (must be 1 or even)
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved
check_hermitian (bool): A flag to enable the validation check to ensure this is a valid unitary operator

Raises:
TypeError: The ``hamiltonian`` is not of type :class:`~.Hamiltonian`, or :class:`~.Sum`
ValueError: One or more of the terms in ``hamiltonian`` are not Hermitian
ValueError: The ``order`` is not one or a positive even integer

**Example**

.. code-block:: python3

coeffs = [0.25, 0.75]
ops = [qml.PauliX(0), qml.PauliZ(0)]
H = qml.dot(coeffs, ops)

dev = qml.device("default.qubit", wires=2)
@qml.qnode(dev)
def my_circ():
# Prepare some state
qml.Hadamard(0)

# Evolve according to H
qml.TrotterProduct(H, time=2.4, order=2)

# Measure some quantity
return qml.state()

>>> my_circ()
[-0.13259524+0.59790098j 0. +0.j -0.13259524-0.77932754j 0. +0.j ]

.. details::
:title: Usage Details

We can also compute the gradient with respect to the coefficients of the Hamiltonian and the
evolution time:

.. code-block:: python3

@qml.qnode(dev)
def my_circ(c1, c2, time):
# Prepare H:
H = qml.dot([c1, c2], [qml.PauliX(0), qml.PauliZ(1)])
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved

# Prepare some state
qml.Hadamard(0)

# Evolve according to H
qml.TrotterProduct(H, time, order=2)

# Measure some quantity
return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

>>> args = qnp.array([1.23, 4.5, 0.1])
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved
>>> qml.grad(my_circ)(*tuple(args))
(tensor(0.00961064, requires_grad=True), tensor(-0.12338274, requires_grad=True), tensor(-5.43401259, requires_grad=True))
"""
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved

Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved
def __init__( # pylint: disable=too-many-arguments
self, hamiltonian, time, n=1, order=1, check_hermitian=True, id=None
):
r"""Initialize the TrotterProduct class"""

if order <= 0 or order != 1 and order % 2 != 0:
raise ValueError(
f"The order of a TrotterProduct must be 1 or a positive even integer, got {order}."
)

if isinstance(hamiltonian, qml.Hamiltonian):
coeffs, ops = hamiltonian.terms()
hamiltonian = qml.dot(coeffs, ops)

if not isinstance(hamiltonian, Sum):
raise TypeError(
f"The given operator must be a PennyLane ~.Hamiltonian or ~.Sum got {hamiltonian}"
)

if check_hermitian:
for op in hamiltonian.operands:
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved
if not op.is_hermitian:
raise ValueError(
"One or more of the terms in the Hamiltonian may not be Hermitian"
)

self._hyperparameters = {
"n": n,
"order": order,
"base": hamiltonian,
"check_hermitian": check_hermitian,
}
super().__init__(time, wires=hamiltonian.wires, id=id)

def _flatten(self):
"""Serialize the operation into trainable and non-trainable components.

Returns:
data, metadata: The trainable and non-trainable components.

See ``Operator._unflatten``.

The data component can be recursive and include other operations. For example, the trainable component of ``Adjoint(RX(1, wires=0))``
will be the operator ``RX(1, wires=0)``.

The metadata **must** be hashable. If the hyperparameters contain a non-hashable component, then this
method and ``Operator._unflatten`` should be overridden to provide a hashable version of the hyperparameters.

**Example:**

>>> op = qml.Rot(1.2, 2.3, 3.4, wires=0)
>>> qml.Rot._unflatten(*op._flatten())
Rot(1.2, 2.3, 3.4, wires=[0])
>>> op = qml.PauliRot(1.2, "XY", wires=(0,1))
>>> qml.PauliRot._unflatten(*op._flatten())
PauliRot(1.2, XY, wires=[0, 1])

Operators that have trainable components that differ from their ``Operator.data`` must implement their own
``_flatten`` methods.

>>> op = qml.ctrl(qml.U2(3.4, 4.5, wires="a"), ("b", "c") )
>>> op._flatten()
((U2(3.4, 4.5, wires=['a']),),
(<Wires = ['b', 'c']>, (True, True), <Wires = []>))
"""
hamiltonian = self.hyperparameters["base"]
time = self.parameters[0]

hashable_hyperparameters = tuple(
(key, value) for key, value in self.hyperparameters.items() if key != "base"
)
return (hamiltonian, time), hashable_hyperparameters

@classmethod
def _unflatten(cls, data, metadata):
"""Recreate an operation from its serialized format.

Args:
data: the trainable component of the operation
metadata: the non-trainable component of the operation.

The output of ``Operator._flatten`` and the class type must be sufficient to reconstruct the original
operation with ``Operator._unflatten``.

**Example:**

>>> op = qml.Rot(1.2, 2.3, 3.4, wires=0)
>>> op._flatten()
((1.2, 2.3, 3.4), (<Wires = [0]>, ()))
>>> qml.Rot._unflatten(*op._flatten())
>>> op = qml.PauliRot(1.2, "XY", wires=(0,1))
>>> op._flatten()
((1.2,), (<Wires = [0, 1]>, (('pauli_word', 'XY'),)))
>>> op = qml.ctrl(qml.U2(3.4, 4.5, wires="a"), ("b", "c") )
>>> type(op)._unflatten(*op._flatten())
Controlled(U2(3.4, 4.5, wires=['a']), control_wires=['b', 'c'])

"""
hyperparameters_dict = dict(metadata)
return cls(*data, **hyperparameters_dict)

@staticmethod
def compute_decomposition(*args, **kwargs):
r"""Representation of the operator as a product of other operators (static method).

.. math:: O = O_1 O_2 \dots O_n.

.. note::

Operations making up the decomposition should be queued within the
``compute_decomposition`` method.

.. seealso:: :meth:`~.Operator.decomposition`.

Args:
*params (list): trainable parameters of the operator, as stored in the ``parameters`` attribute
wires (Iterable[Any], Wires): wires that the operator acts on
**hyperparams (dict): non-trainable hyperparameters of the operator, as stored in the ``hyperparameters`` attribute

Returns:
list[Operator]: decomposition of the operator
"""
time = args[0]
n = kwargs["n"]
order = kwargs["order"]
ops = kwargs["base"].operands

decomp = _recursive_expression(time / n, order, ops)[-1::-1] * n
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved

if qml.QueuingManager.recording():
for op in decomp: # apply operators in reverse order of expression
qml.apply(op)
Jaybsoni marked this conversation as resolved.
Show resolved Hide resolved

return decomp
16 changes: 16 additions & 0 deletions tests/ops/functions/test_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,6 +1438,14 @@ def test_exp_comparison(self, bases_bases_match, params_params_match):
op2 = qml.exp(base2, param2)
assert qml.equal(op1, op2) == (bases_match and params_match)

def test_exp_comparison_with_tolerance(self):
"""Test that equal compares the parameters within a provided tolerance."""
op1 = qml.exp(qml.PauliX(0), 0.12345)
op2 = qml.exp(qml.PauliX(0), 0.12356)

assert qml.equal(op1, op2, atol=1e-3, rtol=1e-2)
assert not qml.equal(op1, op2, atol=1e-5, rtol=1e-4)

@pytest.mark.parametrize("bases_bases_match", BASES)
@pytest.mark.parametrize("params_params_match", PARAMS)
def test_s_prod_comparison(self, bases_bases_match, params_params_match):
Expand All @@ -1448,6 +1456,14 @@ def test_s_prod_comparison(self, bases_bases_match, params_params_match):
op2 = qml.s_prod(param2, base2)
assert qml.equal(op1, op2) == (bases_match and params_match)

def test_s_prod_comparison_with_tolerance(self):
"""Test that equal compares the parameters within a provided tolerance."""
op1 = qml.s_prod(0.12345, qml.PauliX(0))
op2 = qml.s_prod(0.12356, qml.PauliX(0))

assert qml.equal(op1, op2, atol=1e-3, rtol=1e-2)
assert not qml.equal(op1, op2, atol=1e-5, rtol=1e-4)


class TestProdComparisons:
"""Tests comparisons between Prod operators"""
Expand Down
Loading
Loading