Skip to content

Commit

Permalink
Torch boundary uses JacobianProductCalculator (#4654)
Browse files Browse the repository at this point in the history
This PR is a follow on to #4557 .

It updates the torch interface to rely on the
`JacobianProductCalculator` class to compute vjp's.

Note: I cannot figure out how to replicate the failing test locally.
It's finite shots, which always makes me a bit suspect. But passes fine
locally 😕

---------

Co-authored-by: David Wierichs <[email protected]>
Co-authored-by: Matthew Silverman <[email protected]>
  • Loading branch information
3 people authored Nov 13, 2023
1 parent 768bea0 commit 1cea8c6
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 241 deletions.
3 changes: 2 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
* `qml.draw` now supports drawing mid-circuit measurements.
[(#4775)](https://github.com/PennyLaneAI/pennylane/pull/4775)

* Autograd can now use vjps provided by the device from the new device API. If a device provides
* Autograd and torch can now use vjps provided by the device from the new device API. If a device provides
a vector Jacobian product, this can be selected by providing `device_vjp=True` to
`qml.execute`.
[(#4557)](https://github.com/PennyLaneAI/pennylane/pull/4557)
[(#4654)](https://github.com/PennyLaneAI/pennylane/pull/4654)

* Updates to some relevant Pytests to enable its use as a suite of benchmarks.
[(#4703)](https://github.com/PennyLaneAI/pennylane/pull/4703)
Expand Down
2 changes: 1 addition & 1 deletion pennylane/interfaces/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

device_type = Union[qml.Device, "qml.devices.Device"]

jpc_interfaces = {"autograd", "numpy"}
jpc_interfaces = {"autograd", "numpy", "torch", "pytorch"}

INTERFACE_MAP = {
None: "Numpy",
Expand Down
1 change: 0 additions & 1 deletion pennylane/interfaces/jacobian_products.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def _compute_vjps(jacs, dys, tapes):
vjps.append(qml.math.sum(qml.math.stack(shot_vjps), axis=0))
else:
vjps.append(f[multi](dy, jac))

return tuple(vjps)


Expand Down
281 changes: 77 additions & 204 deletions pennylane/interfaces/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,51 @@
"""
This module contains functions for adding the PyTorch interface
to a PennyLane Device class.
**How to bind a custom derivative with Torch.**
See `the Torch documentation <https://pytorch.org/docs/stable/notes/extending.html>`_ for more complete
information.
Suppose I have a function ``f`` that I want to define a custom vjp for.
We need to inherit from ``torch.autograd.Function`` and define ``forward`` and ``backward`` static
methods.
.. code-block:: python
class CustomFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, exponent=2):
ctx.saved_info = {'x': x, 'exponent': exponent}
return x ** exponent
@staticmethod
def backward(ctx, dy):
x = ctx.saved_info['x']
exponent = ctx.saved_info['exponent']
print(f"Calculating the gradient with x={x}, dy={dy}, exponent={exponent}")
return dy * exponent * x ** (exponent-1), None
To use the ``CustomFunction`` class, we call it with the static ``apply`` method.
>>> val = torch.tensor(2.0, requires_grad=True)
>>> res = CustomFunction.apply(val)
>>> res
tensor(4., grad_fn=<CustomFunctionBackward>)
>>> res.backward()
>>> val.grad
Calculating the gradient with x=2.0, dy=1.0, exponent=2
tensor(4.)
Note that for custom functions, the output of ``forward`` and the output of ``backward`` are flattened iterables of
Torch arrays. While autograd and jax can handle nested result objects like ``((np.array(1), np.array(2)), np.array(3))``,
torch requires that it be flattened like ``(np.array(1), np.array(2), np.array(3))``. The ``pytreeify`` class decorator
modifies the output of ``forward`` and the input to ``backward`` to unpack and repack the nested structure of the PennyLane
result object.
"""
# pylint: disable=too-many-arguments,protected-access,abstract-method
import inspect
Expand Down Expand Up @@ -64,27 +109,6 @@ def new_backward(ctx, *flat_grad_outputs):
return cls


def _compute_vjps(dys, jacs, multi_measurements):
"""Compute the vjps of multiple tapes, directly for a Jacobian and tangents."""
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Entry with args=(dys=%s, jacs=%s, multi_measurements=%s) called by=%s",
dys,
jacs,
multi_measurements,
"::L".join(str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]),
)

vjps = []

for i, multi in enumerate(multi_measurements):
compute_func = (
qml.gradients.compute_vjp_multi if multi else qml.gradients.compute_vjp_single
)
vjps.extend(compute_func(dys[i], jacs[i]))
return vjps


@pytreeify
class ExecuteTapes(torch.autograd.Function):
"""The signature of this ``torch.autograd.Function`` is designed to
Expand All @@ -96,26 +120,17 @@ class ExecuteTapes(torch.autograd.Function):
as the first argument ``kwargs``. This dictionary **must** contain:
* ``"tapes"``: the quantum tapes to batch evaluate
* ``"device"``: the quantum device to use to evaluate the tapes
* ``"execute_fn"``: the execution function to use on forward passes
* ``"gradient_fn"``: the gradient transform function to use
for backward passes
* ``"gradient_kwargs"``: gradient keyword arguments to pass to the
gradient function
* ``"max_diff``: the maximum order of derivatives to support
* ``"execute_fn"``: a function that calculates the results of the tapes
* ``"jpc"``: a :class:`~.JacobianProductCalculator` that can compute the vjp.
Further, note that the ``parameters`` argument is dependent on the
``tapes``; this function should always be called
with the parameters extracted directly from the tapes as follows:
>>> parameters = []
>>> [parameters.extend(t.get_parameters()) for t in tapes]
>>> kwargs = {"tapes": tapes, "device": device, "gradient_fn": gradient_fn, ...}
>>> parameters = [p for t in tapes for p in t.get_parameters()]
>>> kwargs = {"tapes": tapes, "execute_fn": execute_fn, "jpc": jpc}
>>> ExecuteTapes.apply(kwargs, *parameters)
The private argument ``_n`` is used to track nesting of derivatives, for example
if the nth-order derivative is requested. Do not set this argument unless you
understand the consequences!
"""

@staticmethod
Expand All @@ -133,16 +148,9 @@ def forward(ctx, kwargs, *parameters): # pylint: disable=arguments-differ
)

ctx.tapes = kwargs["tapes"]
ctx.device = kwargs["device"]
ctx.jpc = kwargs["jpc"]

ctx.execute_fn = kwargs["execute_fn"]
ctx.gradient_fn = kwargs["gradient_fn"]

ctx.gradient_kwargs = kwargs["gradient_kwargs"]
ctx.max_diff = kwargs["max_diff"]
ctx._n = kwargs.get("_n", 1)

res, ctx.jacs = ctx.execute_fn(ctx.tapes, **ctx.gradient_kwargs)
res = tuple(kwargs["execute_fn"](ctx.tapes))

# if any input tensor uses the GPU, the output should as well
ctx.torch_device = None
Expand All @@ -151,12 +159,7 @@ def forward(ctx, kwargs, *parameters): # pylint: disable=arguments-differ
if isinstance(p, torch.Tensor) and p.is_cuda: # pragma: no cover
ctx.torch_device = p.get_device()
break

res = tuple(_res_to_torch(r, ctx) for r in res)
for i, _ in enumerate(res):
# In place change of the numpy array Jacobians to Torch objects
_jac_to_torch(i, ctx)

return res

@staticmethod
Expand All @@ -173,124 +176,39 @@ def backward(ctx, *dy):
),
)

multi_measurements = [len(tape.measurements) > 1 for tape in ctx.tapes]

if ctx.jacs:
# Jacobians were computed on the forward pass (mode="forward")
# No additional quantum evaluations needed; simply compute the VJPs directly.
vjps = _compute_vjps(dy, ctx.jacs, multi_measurements)

else:
# Need to compute the Jacobians on the backward pass (accumulation="backward")

if isinstance(ctx.gradient_fn, qml.transforms.core.TransformDispatcher):
# Gradient function is a gradient transform.

# Generate and execute the required gradient tapes
if ctx._n < ctx.max_diff:
# The derivative order is less than the max derivative order.
# Compute the VJP recursively by using the gradient transform
# and calling ``execute`` to compute the results.
# This will allow higher-order derivatives to be computed
# if requested.

vjp_tapes, processing_fn = qml.gradients.batch_vjp(
ctx.tapes,
dy,
ctx.gradient_fn,
reduction="extend",
gradient_kwargs=ctx.gradient_kwargs,
)
# This is where the magic happens. Note that we call ``execute``.
# This recursion, coupled with the fact that the gradient transforms
# are differentiable, allows for arbitrary order differentiation.
res = execute(
vjp_tapes,
ctx.device,
ctx.execute_fn,
ctx.gradient_fn,
ctx.gradient_kwargs,
_n=ctx._n + 1,
max_diff=ctx.max_diff,
)
vjps = processing_fn(res)

else:
# The derivative order is at the maximum. Compute the VJP
# in a non-differentiable manner to reduce overhead.
vjp_tapes, processing_fn = qml.gradients.batch_vjp(
ctx.tapes,
dy,
ctx.gradient_fn,
reduction="extend",
gradient_kwargs=ctx.gradient_kwargs,
)

vjps = processing_fn(ctx.execute_fn(vjp_tapes)[0])

else:
# Gradient function is not a gradient transform
# (e.g., it might be a device method).
# Note that unlike the previous branch:
#
# - there is no recursion here
# - gradient_fn is not differentiable
#
# so we cannot support higher-order derivatives.

jacs = ctx.gradient_fn(ctx.tapes, **ctx.gradient_kwargs)

vjps = _compute_vjps(dy, jacs, multi_measurements)

# Remove empty vjps (from tape with non trainable params)
vjps = [vjp for vjp in vjps if list(vjp.shape) != [0]]
vjps = ctx.jpc.compute_vjp(ctx.tapes, dy)

# split tensor into separate entries
unpacked_vjps = []
for vjp_slice in vjps:
if vjp_slice is not None and np.squeeze(vjp_slice).shape != (0,):
unpacked_vjps.extend(_res_to_torch(vjp_slice, ctx))
vjps = tuple(unpacked_vjps)
# The output of backward must match the input of forward.
# Therefore, we return `None` for the gradient of `kwargs`.
return (None,) + tuple(vjps)
return (None,) + vjps


def execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=1):
def execute(tapes, execute_fn, jpc):
"""Execute a batch of tapes with Torch parameters on a device.
This function may be called recursively, if ``gradient_fn`` is a differentiable
transform, and ``_n < max_diff``.
Args:
tapes (Sequence[.QuantumTape]): batch of tapes to execute
device (pennylane.Device): Device to use to execute the batch of tapes.
If the device does not provide a ``batch_execute`` method,
by default the tapes will be executed in serial.
execute_fn (callable): The execution function used to execute the tapes
during the forward pass. This function must return a tuple ``(results, jacobians)``.
If ``jacobians`` is an empty list, then ``gradient_fn`` is used to
compute the gradients during the backwards pass.
gradient_kwargs (dict): dictionary of keyword arguments to pass when
determining the gradients of tapes
gradient_fn (callable): the gradient function to use to compute quantum gradients
_n (int): a positive integer used to track nesting of derivatives, for example
if the nth-order derivative is requested.
max_diff (int): If ``gradient_fn`` is a gradient transform, this option specifies
the maximum order of derivatives to support. Increasing this value allows
for higher order derivatives to be extracted, at the cost of additional
(classical) computational overhead during the backwards pass.
execute_fn (Callable[[Sequence[.QuantumTape]], ResultBatch]): a function that turns a batch of circuits into results
jpc (JacobianProductCalculator): a class that can compute the vector jacobian product for the input tapes.
Returns:
list[list[torch.Tensor]]: A nested list of tape results. Each element in
the returned list corresponds in order to the provided tapes.
TensorLike: A nested tuple of tape results. Each element in
the returned tuple corresponds in order to the provided tapes.
"""
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Entry with args=(tapes=%s, device=%s, execute_fn=%s, gradient_fn=%s, gradient_kwargs=%s, _n=%s, max_diff=%s) called by=%s",
"Entry with args=(tapes=%s, execute_fn=%s, jpc=%s",
tapes,
repr(device),
execute_fn
if not (logger.isEnabledFor(qml.logging.TRACE) and inspect.isfunction(execute_fn))
else "\n" + inspect.getsource(execute_fn) + "\n",
gradient_fn
if not (logger.isEnabledFor(qml.logging.TRACE) and inspect.isfunction(gradient_fn))
else "\n" + inspect.getsource(gradient_fn) + "\n",
gradient_kwargs,
_n,
max_diff,
"::L".join(str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]),
f"\n{inspect.getsource(execute_fn)}\n"
if logger.isEnabledFor(qml.logging.TRACE)
else execute_fn,
jpc,
)

# pylint: disable=unused-argument
Expand All @@ -302,63 +220,18 @@ def execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_d
parameters.extend(tape.get_parameters())

kwargs = {
"tapes": tapes,
"device": device,
"tapes": tuple(tapes),
"execute_fn": execute_fn,
"gradient_fn": gradient_fn,
"gradient_kwargs": gradient_kwargs,
"_n": _n,
"max_diff": max_diff,
"jpc": jpc,
}

return ExecuteTapes.apply(kwargs, *parameters)


def _res_to_torch(r, ctx):
"""Convert results from unwrapped execution to torch."""
if isinstance(r, dict):
return r
if isinstance(r, (list, tuple)):
res = []
for t in r:
if isinstance(t, dict) or isinstance(t, list) and all(isinstance(i, dict) for i in t):
# count result, single or broadcasted
res.append(t)
else:
if isinstance(t, tuple):
res.append(tuple(torch.as_tensor(el, device=ctx.torch_device) for el in t))
else:
res.append(torch.as_tensor(t, device=ctx.torch_device))
if isinstance(r, tuple):
res = tuple(res)
elif isinstance(r, dict):
res = r
else:
res = torch.as_tensor(r, device=ctx.torch_device)

return res


def _jac_to_torch(i, ctx):
"""Convert Jacobian from unwrapped execution to torch in the given ctx."""
if ctx.jacs:
ctx_jacs = list(ctx.jacs)
multi_m = len(ctx.tapes[i].measurements) > 1
multi_p = len(ctx.tapes[i].trainable_params) > 1

# Multiple measurements and parameters: Jacobian is a tuple of tuple
if multi_p and multi_m:
jacobians = []
for jacobian in ctx_jacs[i]:
inside_nested_jacobian = [
torch.as_tensor(j, device=ctx.torch_device) for j in jacobian
]
inside_nested_jacobian_tuple = tuple(inside_nested_jacobian)
jacobians.append(inside_nested_jacobian_tuple)
ctx_jacs[i] = tuple(jacobians)
# Single measurement and single parameter: Jacobian is a tensor
elif not multi_p and not multi_m:
ctx_jacs[i] = torch.as_tensor(np.array(ctx_jacs[i]), device=ctx.torch_device)
# Multiple measurements or multiple parameters: Jacobian is a tuple
else:
jacobian = [torch.as_tensor(jac, device=ctx.torch_device) for jac in ctx_jacs[i]]
ctx_jacs[i] = tuple(jacobian)
ctx.jacs = tuple(ctx_jacs)
return type(r)(_res_to_torch(t, ctx) for t in r)
return torch.as_tensor(r, device=ctx.torch_device)
Loading

0 comments on commit 1cea8c6

Please sign in to comment.