Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add qml.math.grad and qml.math.jacobian for differentiating any interface #6741

Merged
merged 17 commits into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ such as `shots`, `rng` and `prng_key`.

<h4>Other Improvements</h4>

* `qml.math.grad` cannot differentiate a function with inputs of any interface in a jax-like manner.
albi3ro marked this conversation as resolved.
Show resolved Hide resolved

* `qml.GroverOperator` now has a `work_wires` property.
[(#6738)](https://github.com/PennyLaneAI/pennylane/pull/6738)

Expand Down
2 changes: 2 additions & 0 deletions pennylane/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
get_interface,
Interface,
)
from .grad import grad

sum = ar.numpy.sum
toarray = ar.numpy.to_numpy
Expand Down Expand Up @@ -168,6 +169,7 @@ def __getattr__(name):
"get_canonical_interface_name",
"get_deep_interface",
"get_trainable_indices",
"grad",
"in_backprop",
"is_abstract",
"is_independent",
Expand Down
98 changes: 98 additions & 0 deletions pennylane/math/grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.
JerryChen97 marked this conversation as resolved.
Show resolved Hide resolved

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This submodule defines grad and jacobian for differentiating circuits in an interface
independent way.
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
"""

from typing import Callable, Sequence

from pennylane._grad import grad as autograd_grad

from .interface_utils import get_interface


# pylint: disable=import-outside-toplevel
def grad(f: Callable, argnums: Sequence[int] | int = 0) -> Callable:
"""Compute the gradient in a jax-like manner for any interface.

Args:
f (Callable): a function with a single 0-D scalar output
argnums (Sequence[int] | int ) = 0 : which arguments to differentiate

Returns:
Callable: a function with the same signature as ``f`` that returns the gradient.

Note that this function follows the same design as jax. By default, the function will return the gradient
of the first argument, whether or not other arguments are trainable.

>>> def f(x, y):
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
... return x * y
>>> qml.math.grad(f)(qml.numpy.array(2.0). qml.numpy.array(3.0))
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
tensor(3., requires_grad=True)
>>> qml.math.grad(f)(jax.numpy.array(2.0), jax.numpy.array(3.0))
Array(3., dtype=float32, weak_type=True)
>>> qml.math.grad(f)(torch.tensor(2.0, requires_grad=True), torch.tensor(3.0, requires_grad=True))
tensor(3.)
>>> qml.math.grad(f)(tf.Variable(2.0), tf.Variable(3.0))
<tf.Tensor: shape=(), dtype=float32, numpy=3.0>

``argnums`` can be provided to differentiate multiple arguments.

>>> qml.math.grad(f, argnums=(0,1))(torch.tensor(2.0, requires_grad=True), torch.tensor(3.0, requires_grad=True))
(tensor(3.), tensor(2.))

Note that the selected arguments *must* be of an appropriately trainable datatype, or an error may occur.

>>> qml.math.grad(f)(torch.tensor(1.0), torch.tensor(2.))
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

"""

argnums_integer = False
if isinstance(argnums, int):
argnums = (argnums,)
argnums_integer = True

def compute_grad(*args, **kwargs):
interface = get_interface(*args)

if interface == "autograd":
g = autograd_grad(f, argnum=argnums)(*args, **kwargs)
return g[0] if argnums_integer else g

if interface == "jax":
import jax

g = jax.grad(f, argnums=argnums)(*args, **kwargs)
return g[0] if argnums_integer else g

if interface == "torch":
y = f(*args, **kwargs)
y.backward()
g = tuple(args[i].grad for i in argnums)
return g[0] if argnums_integer else g

if interface == "tensorflow":
import tensorflow as tf

with tf.GradientTape() as tape:
y = f(*args, **kwargs)

g = tape.gradient(y, tuple(args[i] for i in argnums))
return g[0] if argnums_integer else g

raise ValueError(f"Interface {interface} is not differentiatble.")

return compute_grad
24 changes: 17 additions & 7 deletions pennylane/math/single_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,12 @@ def _tensorflow_allclose(a, b, **kwargs):
)


def _tf_convert_to_tensor(x, **kwargs):
def _tf_convert_to_tensor(x, requires_grad=False, **kwargs):
if isinstance(x, _i("tf").Tensor) and "dtype" in kwargs:
return _i("tf").cast(x, **kwargs)
return _i("tf").convert_to_tensor(x, **kwargs)
out = _i("tf").cast(x, **kwargs)
else:
out = _i("tf").convert_to_tensor(x, **kwargs)
return _i("tf").Variable(out) if requires_grad else out


ar.register_function("tensorflow", "asarray", _tf_convert_to_tensor)
Expand Down Expand Up @@ -541,7 +543,7 @@ def _to_numpy_torch(x):
ar.register_function("torch", "to_numpy", _to_numpy_torch)


def _asarray_torch(x, dtype=None, **kwargs):
def _asarray_torch(x, dtype=None, requires_grad=False, **kwargs):
import torch

dtype_map = {
Expand All @@ -556,9 +558,9 @@ def _asarray_torch(x, dtype=None, **kwargs):
np.complex128: torch.complex128,
"float64": torch.float64,
}

if dtype in dtype_map:
return torch.as_tensor(x, dtype=dtype_map[dtype], **kwargs)
dtype = dtype_map.get(dtype, dtype)
if requires_grad:
return torch.tensor(x, dtype=dtype, **kwargs, requires_grad=True)

return torch.as_tensor(x, dtype=dtype, **kwargs)

Expand Down Expand Up @@ -814,6 +816,14 @@ def _to_numpy_jax(x):
ar.register_function("jax", "gather", lambda x, indices: x[np.array(indices)])


# pylint: disable=unused-argument
def _asarray_jax(x, dtype=None, requires_grad=False, **kwargs):
return _i("jax").numpy.array(x, dtype=dtype, **kwargs)


ar.register_function("jax", "asarray", _asarray_jax)


def _ndim_jax(x):
import jax.numpy as jnp

Expand Down
70 changes: 70 additions & 0 deletions tests/math/test_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test the qml.math.grad and qml.math.jacobian functions.
"""

import pytest

from pennylane import math

pytestmark = pytest.mark.all_interfaces


@pytest.mark.parametrize("interface", ("autograd", "jax", "tensorflow", "torch"))
def test_differentiate_first_arg(interface):
"""Test taht we just differentiate the first argument by default."""

def f(x, y):
return x * y

x = math.asarray(2.0, like=interface, requires_grad=True)
y = math.asarray(3.0, like=interface, requires_grad=True)

g = math.grad(f)(x, y)
if interface != "autograd":
assert math.get_interface(g) == interface
assert math.allclose(g, 3.0)


@pytest.mark.parametrize("interface", ("autograd", "jax", "tensorflow", "torch"))
def test_multiple_argnums(interface):
"""Test that we can differentiate multiple arguments."""

def g(x, y):
return 2 * x + 3 * y

x = math.asarray(0.5, like=interface, requires_grad=True)
y = math.asarray(2.5, like=interface, requires_grad=True)

g1, g2 = math.grad(g, argnums=(0, 1))(x, y)
if interface != "autograd":
assert math.get_interface(g1) == interface
assert math.get_interface(g2) == interface

assert math.allclose(g1, 2)
assert math.allclose(g2, 3)


@pytest.mark.parametrize("interface", ("autograd", "jax", "tensorflow", "torch"))
def test_keyword_arguments(interface):
"""Test that keyword arguments are considered."""

def f(x, *, constant):
return constant * x

x = math.asarray(2.0, like=interface, requires_grad=True)

g = math.grad(f)(x, constant=2.0)
assert math.allclose(g, 2.0)
Loading