Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QSP iterative angle solver #6694

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions pennylane/templates/subroutines/qsvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from pennylane.queuing import QueuingManager
from pennylane.wires import Wires

from autograd import jacobian, hessian

Check notice on line 30 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L30

Unused hessian imported from autograd (unused-import)

Check notice on line 30 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L30

Unused jacobian imported from autograd (unused-import)


def qsvt(A, angles, wires, convention=None):
r"""Implements the
Expand Down Expand Up @@ -569,8 +571,154 @@
)

return rotation_angles
###########################################################################################################################
"""

Check notice on line 575 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L575

String statement has no effect (pointless-string-statement)
Implementation of the QSP (Quantum Signal Processing) algorithm proposed in https://arxiv.org/pdf/2002.11649
"""

import scipy

Check notice on line 579 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L579

Import "import scipy" should be placed at the top of the module (wrong-import-position)

def jit_if_available(func):
"""
decorator that jax.jit the qsp_iterative optimization functions if jax is available
the static arguments to the jit function are meant to be contained in kwargs
"""
try:
from functools import partial

Check notice on line 587 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L587

Import outside toplevel (functools.partial) (import-outside-toplevel)

Check notice on line 587 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L587

Unused partial imported from functools (unused-import)
from functools import wraps

Check notice on line 588 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L588

Import outside toplevel (functools.wraps) (import-outside-toplevel)

Check notice on line 588 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L588

Unused wraps imported from functools (unused-import)
import jax

Check notice on line 589 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L589

Import outside toplevel (jax) (import-outside-toplevel)

jax.config.update("jax_enable_x64", True)

Check notice on line 591 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L591

Anomalous backslash in string: '\s'. String constant might be missing an r prefix. (anomalous-backslash-in-string)

Check notice on line 591 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L591

Anomalous backslash in string: '\s'. String constant might be missing an r prefix. (anomalous-backslash-in-string)
return jax.jit(func)
except ImportError:
return func



@jit_if_available
def cheby_pol(x, degree):
"""cos(degree*arcos(x))"""
# if np.abs(x) > 1:
# raise ValueError()
return qml.math.cos(degree * qml.math.arccos(x))

@jit_if_available
def poly_func(
coeffs, degree, parity, x

Check notice on line 607 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L607

Unused argument 'degree' (unused-argument)
):

Check notice on line 608 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L608

Missing function or method docstring (missing-function-docstring)
"""\sum c_kT_{2k} if even else \sum c_kT_{2k+1} if odd where T_k(x)=cos(karccos(x))"""
ind = qml.math.arange(len(coeffs))
return sum(

Check notice on line 611 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L611

Consider using a generator instead 'sum(coeffs[i] * cheby_pol(x, degree=2 * i + parity) for i in ind)' (consider-using-generator)
[coeffs[i] * cheby_pol(x, degree=2 * i + parity) for i in ind]
)

@jit_if_available
def z_rotation(phi):
try:
import jax

Check notice on line 618 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L618

Import outside toplevel (jax) (import-outside-toplevel)

Check notice on line 618 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L618

Unused import jax (unused-import)
interface='jax'
except ModuleNotFoundError:
interface=None
return qml.math.array([[qml.math.exp(1j * phi), 0.0], [0.0, qml.math.exp(-1j * phi)]], like=interface)

@jit_if_available
def W_of_x(x):
"""W(x) defined in Theorem (1) of https://arxiv.org/pdf/2002.11649"""
try:
import jax

Check notice on line 628 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L628

Unused import jax (unused-import)

Check notice on line 628 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L628

Import outside toplevel (jax) (import-outside-toplevel)
interface='jax'
except ModuleNotFoundError:
interface=None
return qml.math.array(
[
[cheby_pol(x=x, degree=1.), 1j * qml.math.sqrt(1 - cheby_pol(x=x, degree=1.) ** 2)],
[1j * qml.math.sqrt(1 - cheby_pol(x=x, degree=1.) ** 2), cheby_pol(x=x, degree=1.)],
], like=interface
)

@jit_if_available

Check notice on line 639 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L639

Anomalous backslash in string: '\p'. String constant might be missing an r prefix. (anomalous-backslash-in-string)
def qsp_iterate(phi, x):
"""defined in Theorem (1) of https://arxiv.org/pdf/2002.11649"""
return qml.math.dot(W_of_x(x=x), z_rotation(phi=phi))


@jit_if_available
def qsp_iterates(phis, x):
"""Eq (13) Resulting unitary of the QSP circuit (on reduced invariant subspace ofc)"""
mtx = qml.math.eye(2)
for phi in phis[::-1][:-1]:
mtx = qml.math.dot(qsp_iterate(x=x, phi=phi), mtx)
mtx = qml.math.dot(z_rotation(phi=phis[0]), mtx)

return mtx

import math

Check notice on line 655 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L655

Import "import math" should be placed at the top of the module (wrong-import-position)

Check notice on line 655 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L655

standard import "math" should be placed before third party imports "numpy", "numpy.polynomial.Polynomial", "pennylane" (...) "pennylane.wires.Wires", "autograd.jacobian", "scipy" (wrong-import-order)

def grid_pts(degree):
"""x_j = cos(\frac{(2j-1)\pi}{4\tilde{d}}) Grid over which the polynomials are evaluated and the optimization is carried defined in page 8"""
d = (degree + 1)// 2 + (degree+1) % 2
return qml.math.array([qml.math.cos((2 * j - 1) * math.pi / (4 * d)) for j in range(1, d + 1)])

def qsp_optimization(degree, coeffs_target_func, optimizer=scipy.optimize.minimize, opt_method="Newton-CG"):
"""Algorithm 1 in https://arxiv.org/pdf/2002.11649 produces the angle parameters by minimizing the distance between the target and qsp polynomail over the grid"""
parity = degree % 2

grid_points = grid_pts(degree)
initial_guess = [qml.numpy.pi / 4] + [0.0] * (degree - 1) + [qml.numpy.pi / 4]
initial_guess = qml.math.array(initial_guess)
targets = [poly_func(coeffs=coeffs_target_func, x=x, degree=degree, parity=parity) for x in grid_points]

def obj_function(phi):
# Equation (23)
obj_func = 0.0

for i,x in enumerate(grid_points):
obj_func += (
qml.math.real(qsp_iterates(phis=phi, x=x)[0, 0])
- targets[i]
) ** 2

return 1 / len(grid_points) * obj_func
opt_kwargs = {}
try:
from jax import jacobian, hessian

Check notice on line 684 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L684

Trailing whitespace (trailing-whitespace)

Check notice on line 684 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L684

Redefining name 'hessian' from outer scope (line 30) (redefined-outer-name)

Check notice on line 684 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L684

Redefining name 'jacobian' from outer scope (line 30) (redefined-outer-name)

Check notice on line 684 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L684

Import outside toplevel (jax.jacobian, jax.hessian) (import-outside-toplevel)
except:

Check notice on line 685 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L685

No exception type(s) specified (bare-except)
from autograd import jacobian, hessian

Check notice on line 686 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L686

Import outside toplevel (autograd.jacobian, autograd.hessian) (import-outside-toplevel)

opt_kwargs["jac"] = jacobian(obj_function)
if opt_method == "Newton-CG":
opt_kwargs["hess"] = hessian(obj_function)

results = optimizer(
fun=obj_function,
x0=initial_guess,
method=opt_method,
**opt_kwargs

Check notice on line 696 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L696

Trailing whitespace (trailing-whitespace)
)

Check notice on line 697 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L697

Trailing whitespace (trailing-whitespace)
phis = results.x
cost_func = results.fun

print(f"cost function {cost_func}")
return phis, cost_func

def _compute_qsp_angles_iteratively(polynomial_coeffs_in_cano_basis, opt_method="Newton-CG"):
polynomial_coeffs_in_cheby_basis = chebyshev.poly2cheb(polynomial_coeffs_in_cano_basis)
degree = len(polynomial_coeffs_in_cheby_basis) - 1

coeffs_odd = polynomial_coeffs_in_cheby_basis[1::2]
coeffs_even = polynomial_coeffs_in_cheby_basis[0::2]

if np.allclose(coeffs_odd, np.zeros_like(coeffs_odd)):
coeffs_target_func = qml.math.array(coeffs_even)
elif np.allclose(coeffs_even, np.zeros_like(coeffs_even)):
coeffs_target_func = qml.math.array(coeffs_odd)
else:
raise ValueError()

angles, *_ = qsp_optimization(degree=degree, coeffs_target_func=coeffs_target_func, opt_method=opt_method)

return angles
###########################################################################################################################
def transform_angles(angles, routine1, routine2):
r"""
Converts angles for quantum signal processing (QSP) and quantum singular value transformation (QSVT) routines.
Expand Down Expand Up @@ -637,7 +785,7 @@
num_angles = len(angles)
update_vals = np.empty(num_angles)

update_vals[0] = 3 * np.pi / 4 - (3 + num_angles % 4) * np.pi / 2

Check notice on line 788 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L788

Too many branches (13/12) (too-many-branches)
update_vals[1:-1] = np.pi / 2
update_vals[-1] = -np.pi / 4
update_vals = qml.math.convert_like(update_vals, angles)
Expand Down Expand Up @@ -735,7 +883,7 @@
for _ in range(len(poly)):
if not np.isclose(poly[-1], 0.0):
break
poly.pop()

Check notice on line 886 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L886

Unnecessary "elif" after "return", remove the leading "el" from "elif" (no-else-return)

if len(poly) == 1:
raise AssertionError("The polynomial must have at least degree 1.")
Expand All @@ -743,7 +891,7 @@
for x in [-1, 0, 1]:
if qml.math.abs(qml.math.sum(coeff * x**i for i, coeff in enumerate(poly))) > 1:
# Check that |P(x)| ≤ 1. Only points -1, 0, 1 will be checked.
raise AssertionError("The polynomial must satisfy that |P(x)| ≤ 1 for all x in [-1, 1]")

Check notice on line 894 in pennylane/templates/subroutines/qsvt.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/templates/subroutines/qsvt.py#L894

Unnecessary "elif" after "return", remove the leading "el" from "elif" (no-else-return)

if routine in ["QSVT", "QSP"]:
if not (
Expand All @@ -760,10 +908,15 @@
if routine == "QSVT":
if angle_solver == "root-finding":
return transform_angles(_compute_qsp_angle(poly), "QSP", "QSVT")
elif angle_solver == "iterative":
return transform_angles(_compute_qsp_angles_iteratively(poly), "QSP", "QSVT")

raise AssertionError("Invalid angle solver method. We currently support 'root-finding'")

if routine == "QSP":
if angle_solver == "root-finding":
return _compute_qsp_angle(poly)
elif angle_solver == "iterative":
return _compute_qsp_angles_iteratively(poly)
raise AssertionError("Invalid angle solver method. Valid value is 'root-finding'")
raise AssertionError("Invalid routine. Valid values are 'QSP' and 'QSVT'")
128 changes: 110 additions & 18 deletions tests/templates/test_subroutines/test_qsvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import pennylane as qml
from pennylane import numpy as np
from pennylane.templates.subroutines.qsvt import _complementary_poly
from pennylane.templates.subroutines.qsvt import _complementary_poly, cheby_pol, poly_func, qsp_iterates, qsp_optimization


def qfunc(A):
Expand Down Expand Up @@ -625,12 +625,35 @@ def test_global_phase_not_alway_applied():


class TestRootFindingSolver:
def generate_polynomial_coeffs(degree, parity=None):
np.random.seed(123)
if parity is None:
polynomial_coeffs_in_canonical_basis = np.random.normal(size=degree+1)
return polynomial_coeffs_in_canonical_basis / np.sum(np.abs(polynomial_coeffs_in_canonical_basis))
if parity == 0:
assert degree % 2 == 0.
polynomial_coeffs_in_canonical_basis = np.zeros((degree + 1))
polynomial_coeffs_in_canonical_basis[0::2] = np.random.normal(size=degree//2+1)
return polynomial_coeffs_in_canonical_basis / np.sum(np.abs(polynomial_coeffs_in_canonical_basis))

if parity == 1:
assert degree % 2 == 1.
polynomial_coeffs_in_canonical_basis = np.zeros((degree + 1))
polynomial_coeffs_in_canonical_basis[0::2] = np.random.uniform(size=degree//2+1)
return polynomial_coeffs_in_canonical_basis / np.sum(np.abs(polynomial_coeffs_in_canonical_basis))

raise ValueError(f"parity must be None, 0 or 1 but got {parity}")

@pytest.mark.parametrize(
"P",
[
([0.1, 0, 0.3, 0, -0.1]),
([0, 0.2, 0, 0.3]),
([-0.4, 0, 0.4, 0, -0.1, 0, 0.1]),
# ([0.1, 0, 0.3, 0, -0.1]),
# ([0, 0.2, 0, 0.3]),
# ([-0.4, 0, 0.4, 0, -0.1, 0, 0.1]),
(generate_polynomial_coeffs(4,0)),
(generate_polynomial_coeffs(3,1)),
(generate_polynomial_coeffs(6,0)),

],
)
def test_complementary_polynomial(self, P):
Expand All @@ -648,15 +671,70 @@ def test_complementary_polynomial(self, P):

Q_val = np.polyval(Q, z)
Q_magnitude_squared = np.abs(Q_val) ** 2

assert np.isclose(P_magnitude_squared + Q_magnitude_squared, 1, atol=1e-7)


@pytest.mark.parametrize(
"polynomial_coeffs_in_cheby_basis",
[
(generate_polynomial_coeffs(10,0)),
(generate_polynomial_coeffs(7,1)),
(generate_polynomial_coeffs(12,0)),
]
)
def test_qsp_on_poly_with_parity(self, polynomial_coeffs_in_cheby_basis):
degree = len(polynomial_coeffs_in_cheby_basis) - 1
parity = degree % 2
if parity:
target_polynomial_coeffs = polynomial_coeffs_in_cheby_basis[1::2]
else:
target_polynomial_coeffs = polynomial_coeffs_in_cheby_basis[0::2]
phis, cost_func = qsp_optimization(degree, target_polynomial_coeffs)
x_point = np.random.uniform(size=1, high=1, low=-1)
x_point = x_point.item()
delta = abs(
np.real(qsp_iterates(phis, x_point)[0, 0])
- poly_func(target_polynomial_coeffs, degree, parity, x_point)
)
print(f"delta {delta}")
# Theorem 4: |\alpha_i-\beta_i|\leq 2\sqrt(cost_func) https://arxiv.org/pdf/2002.11649
# which \implies |target_poly(x)-approx_poly(x)|\leq 2\sqrt(cost_func) \sum_i |T_i(x)|
tolerance = (
np.sum(
np.array(
[
2 * np.sqrt(cost_func) * abs(cheby_pol(degree=2 * i, x=x_point))
for i in range(len(target_polynomial_coeffs))
]
)
)
if not parity
else np.sum(
np.array(
[
2 * np.sqrt(cost_func) * abs(cheby_pol(degree=2 * i + 1, x=x_point))
for i in range(len(target_polynomial_coeffs))
]
)
)
)

assert np.isclose(
np.real(qsp_iterates(phis, x_point)[0, 0]),
poly_func(coeffs=target_polynomial_coeffs, degree=degree, parity=parity, x=x_point),
atol=tolerance,
)

@pytest.mark.parametrize(
"angles",
[
([0.1, 2, 0.3, 3, -0.1]),
([0, 0.2, 1, 0.3, 4, 2.4]),
([-0.4, 2, 0.4, 0, -0.1, 0, 0.1]),
# ([0.1, 2, 0.3, 3, -0.1]),
# ([0, 0.2, 1, 0.3, 4, 2.4]),
# ([-0.4, 2, 0.4, 0, -0.1, 0, 0.1]),
(generate_polynomial_coeffs(4, None)),
(generate_polynomial_coeffs(5, None)),
(generate_polynomial_coeffs(6, None))
],
)
def test_transform_angles(self, angles):
Expand All @@ -674,15 +752,22 @@ def test_transform_angles(self, angles):
@pytest.mark.parametrize(
"poly",
[
([0.1, 0, 0.3, 0, -0.1]),
([0, 0.2, 0, 0.3]),
([-0.4, 0, 0.4, 0, -0.1, 0, 0.1]),
(generate_polynomial_coeffs(4,0)),
(generate_polynomial_coeffs(3,1)),
(generate_polynomial_coeffs(6,0)),
],
)
def test_correctness_QSP_angles_root_finding(self, poly):
@pytest.mark.parametrize(
"angle_solver",
[
"root-finding",
"iterative",
]
)
def test_correctness_QSP_angles_root_finding(self, poly, angle_solver):
"""Tests that angles generate desired poly"""

angles = qml.poly_to_angles(poly, "QSP", angle_solver="root-finding")
angles = qml.poly_to_angles(list(poly), "QSP", angle_solver=angle_solver)
x = 0.5

@qml.qnode(qml.device("default.qubit"))
Expand All @@ -701,15 +786,22 @@ def circuit_qsp():
@pytest.mark.parametrize(
"poly",
[
([0.1, 0, 0.3, 0, -0.1]),
([0, 0.2, 0, 0.3]),
([-0.4, 0, 0.4, 0, -0.1, 0, 0.1]),
(generate_polynomial_coeffs(4,0)),
(generate_polynomial_coeffs(3,1)),
(generate_polynomial_coeffs(6,0)),
],
)
def test_correctness_QSVT_angles(self, poly):
@pytest.mark.parametrize(
"angle_solver",
[
"root-finding",
"iterative",
]
)
def test_correctness_QSVT_angles(self, poly, angle_solver):
"""Tests that angles generate desired poly"""

angles = qml.poly_to_angles(poly, "QSVT")
angles = qml.poly_to_angles(list(poly), "QSVT", angle_solver=angle_solver)
x = 0.5

block_encoding = qml.RX(-2 * np.arccos(x), wires=0)
Expand Down