Skip to content

Commit

Permalink
Gradients transforms update (#4595)
Browse files Browse the repository at this point in the history
**Description of the Change:**

- Update the gradients transforms to the new system

**Benefits:**

- [x] Parameter shift
- [x] Parameter shift CV
- [x] Finite diff
- [x] Hadamard
- [x] SPSA
- [x] Pulse gradient
- [x] Pulse ode 
- [x] Metric tensor
- [x] Adjoint metric tensor
- [x] Quantum Fisher

~Classical fisher~ no tape equivalent
~Hessian param shift~ Separate PR

Major changes:

- The full program is constructed in the QNode because we need to access
the QNode when building classical jacobians.
- Update drawer
- Two private methods to construct argnums and classical jacobians in
the transform program
- Metric tensor raises errors instead of warnings
- Remove the spy for gradient transforms
-------

**Possible Drawbacks:**

**Related GitHub Issues:**

---------

Co-authored-by: Mudit Pandey <[email protected]>
  • Loading branch information
rmoyard and mudit2812 authored Oct 13, 2023
1 parent 82bab31 commit e68a56f
Show file tree
Hide file tree
Showing 75 changed files with 1,160 additions and 1,741 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@
* Transforms can be applied on devices following the new device API.
[(#4667)](https://github.com/PennyLaneAI/pennylane/pull/4667)

* All gradient transforms are updated to the new transform program system.
[(#4595)](https://github.com/PennyLaneAI/pennylane/pull/4595)

* All quantum functions transforms are update to the new transform program system.
[(#4439)](https://github.com/PennyLaneAI/pennylane/pull/4439)

Expand Down
1 change: 0 additions & 1 deletion pennylane/_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,6 @@ def _jacobian_function(*args, **kwargs):
"If this is unintended, please add trainable parameters via the "
"'requires_grad' attribute or 'argnum' keyword."
)

jac = tuple(_jacobian(func, arg)(*args, **kwargs) for arg in _argnum)

return jac[0] if unpack else jac
Expand Down
5 changes: 4 additions & 1 deletion pennylane/drawer/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,13 @@ def wrapper(*args, **kwargs):
_wire_order = wire_order or qnode.tape.wires
else:
original_expansion_strategy = getattr(qnode, "expansion_strategy", None)

try:
qnode.expansion_strategy = expansion_strategy or original_expansion_strategy
tapes = qnode.construct(args, kwargs)
if isinstance(qnode.device, qml.devices.Device):
program = qnode.transform_program
tapes = program([qnode.tape])

finally:
qnode.expansion_strategy = original_expansion_strategy

Expand Down
43 changes: 38 additions & 5 deletions pennylane/gradients/finite_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,28 @@
This module contains functions for computing the finite-difference gradient
of a quantum tape.
"""
# pylint: disable=protected-access,too-many-arguments,too-many-branches,too-many-statements
# pylint: disable=protected-access,too-many-arguments,too-many-branches,too-many-statements,unused-argument
from typing import Sequence, Callable
import functools
from functools import partial
from warnings import warn

import numpy as np
from scipy.special import factorial

import pennylane as qml
from pennylane.measurements import ProbabilityMP
from pennylane.transforms.core import transform
from pennylane.transforms.tape_expand import expand_invalid_trainable
from pennylane.gradients.gradient_transform import _contract_qjac_with_cjac


from .general_shift_rules import generate_shifted_tapes
from .gradient_transform import (
_all_zero_grad,
assert_no_tape_batching,
choose_grad_methods,
gradient_analysis_and_validation,
gradient_transform,
_no_trainable_grad,
)

Expand Down Expand Up @@ -167,17 +172,44 @@ def _processing_fn(results, shots, single_shot_batch_fn):
return tuple(grads_tuple)


@gradient_transform
def _expand_transform_finite_diff(
tape: qml.tape.QuantumTape,
argnum=None,
h=1e-7,
approx_order=1,
n=1,
strategy="forward",
f0=None,
validate_params=True,
) -> (Sequence[qml.tape.QuantumTape], Callable):
"""Expand function to be applied before finite difference."""
expanded_tape = expand_invalid_trainable(tape)

def null_postprocessing(results):
"""A postprocesing function returned by a transform that only converts the batch of results
into a result for a single ``QuantumTape``.
"""
return results[0]

return [expanded_tape], null_postprocessing


@partial(
transform,
expand_transform=_expand_transform_finite_diff,
classical_cotransform=_contract_qjac_with_cjac,
final_transform=True,
)
def finite_diff(
tape,
tape: qml.tape.QuantumTape,
argnum=None,
h=1e-7,
approx_order=1,
n=1,
strategy="forward",
f0=None,
validate_params=True,
):
) -> (Sequence[qml.tape.QuantumTape], Callable):
r"""Transform a QNode to compute the finite-difference gradient of all gate parameters with respect to its inputs.
Args:
Expand Down Expand Up @@ -318,6 +350,7 @@ def finite_diff(
The outermost tuple contains results corresponding to each element of the shot vector.
"""

transform_name = "finite difference"
assert_no_tape_batching(tape, transform_name)

Expand Down
14 changes: 9 additions & 5 deletions pennylane/gradients/gradient_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,13 +391,19 @@ def reorder_grads(grads, tape_specs):


# pylint: disable=too-many-return-statements,too-many-branches
def _contract_qjac_with_cjac(qjac, cjac, num_measurements, has_partitioned_shots):
def _contract_qjac_with_cjac(qjac, cjac, tape):
"""Contract a quantum Jacobian with a classical preprocessing Jacobian.
Essentially, this function computes the generalized version of
``tensordot(qjac, cjac)`` over the tape parameter axis, adapted to the new
return type system. This function takes the measurement shapes and different
QNode arguments into account.
"""
num_measurements = len(tape.measurements)
has_partitioned_shots = tape.shots.has_partitioned_shots

if isinstance(qjac, tuple) and len(qjac) == 1:
qjac = qjac[0]

if isinstance(cjac, tuple) and len(cjac) == 1:
cjac = cjac[0]

Expand Down Expand Up @@ -453,7 +459,7 @@ def _reshape(x):
return tuple(tuple(tdot(qml.math.stack(q), c) for c in cjac if c is not None) for q in qjac)


class gradient_transform(qml.batch_transform):
class gradient_transform(qml.batch_transform): # pragma: no cover
"""Decorator for defining quantum gradient transforms.
Quantum gradient transforms are a specific case of :class:`~.batch_transform`.
Expand Down Expand Up @@ -601,8 +607,6 @@ def jacobian_wrapper(
qnode, argnum=argnum_cjac, expand_fn=self.expand_fn
)(*args, **kwargs)

num_measurements = len(qnode.tape.measurements)
has_partitioned_shots = qnode.tape.shots.has_partitioned_shots
return _contract_qjac_with_cjac(qjac, cjac, num_measurements, has_partitioned_shots)
return _contract_qjac_with_cjac(qjac, cjac, qnode.tape) # pragma: no cover

return jacobian_wrapper
42 changes: 33 additions & 9 deletions pennylane/gradients/hadamard_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@
This module contains functions for computing the Hadamard-test gradient
of a qubit-based quantum tape.
"""
# pylint: disable=unused-argument
from typing import Sequence, Callable
from functools import partial
import pennylane as qml
import pennylane.numpy as np
from pennylane.transforms.metric_tensor import _get_aux_wire
from pennylane.transforms.core import transform
from pennylane.gradients.gradient_transform import _contract_qjac_with_cjac
from pennylane.transforms.tape_expand import expand_invalid_trainable_hadamard_gradient

from .gradient_transform import (
Expand All @@ -27,17 +32,40 @@
assert_no_variance,
choose_grad_methods,
gradient_analysis_and_validation,
gradient_transform,
_no_trainable_grad,
)


def _hadamard_grad(
tape,
def _expand_transform_hadamard(
tape: qml.tape.QuantumTape,
argnum=None,
aux_wire=None,
device_wires=None,
):
) -> (Sequence[qml.tape.QuantumTape], Callable):
"""Expand function to be applied before hadamard gradient."""
expanded_tape = expand_invalid_trainable_hadamard_gradient(tape)

def null_postprocessing(results):
"""A postprocesing function returned by a transform that only converts the batch of results
into a result for a single ``QuantumTape``.
"""
return results[0]

return [expanded_tape], null_postprocessing


@partial(
transform,
expand_transform=_expand_transform_hadamard,
classical_cotransform=_contract_qjac_with_cjac,
final_transform=True,
)
def hadamard_grad(
tape: qml.tape.QuantumTape,
argnum=None,
aux_wire=None,
device_wires=None,
) -> (Sequence[qml.tape.QuantumTape], Callable):
r"""Transform a QNode to compute the Hadamard test gradient of all gates with respect to their inputs.
Args:
Expand Down Expand Up @@ -174,6 +202,7 @@ def _hadamard_grad(
The number of trainable parameters may increase due to the decomposition.
"""

transform_name = "Hadamard test"
assert_no_state_returns(tape.measurements, transform_name)
assert_no_variance(tape.measurements, transform_name)
Expand Down Expand Up @@ -421,8 +450,3 @@ def _get_generators(trainable_op):
coeffs = trainable_op.generator().coeffs

return coeffs, generators


hadamard_grad = gradient_transform(
_hadamard_grad, expand_fn=expand_invalid_trainable_hadamard_gradient
)
43 changes: 37 additions & 6 deletions pennylane/gradients/parameter_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@
This module contains functions for computing the parameter-shift gradient
of a qubit-based quantum tape.
"""
# pylint: disable=protected-access,too-many-arguments,too-many-statements
# pylint: disable=protected-access,too-many-arguments,too-many-statements,unused-argument
from typing import Sequence, Callable
from functools import partial

import numpy as np

import pennylane as qml
from pennylane.measurements import VarianceMP
from pennylane.transforms.core import transform
from pennylane.transforms.tape_expand import expand_invalid_trainable
from pennylane.gradients.gradient_transform import _contract_qjac_with_cjac

from .finite_difference import finite_diff
from .general_shift_rules import (
Expand All @@ -35,7 +41,6 @@
assert_multimeasure_not_broadcasted,
choose_grad_methods,
gradient_analysis_and_validation,
gradient_transform,
_no_trainable_grad,
reorder_grads,
)
Expand Down Expand Up @@ -155,7 +160,6 @@ def _single_meas_grad(result, coeffs, unshifted_coeff, r0):
) # pragma: no cover
# return the unshifted term, which is the only contribution
return qml.math.array(unshifted_coeff * r0)

result = qml.math.stack(result)
coeffs = qml.math.convert_like(coeffs, result)
g = qml.math.tensordot(result, coeffs, [[0], [0]])
Expand Down Expand Up @@ -719,16 +723,42 @@ def var_param_shift(tape, argnum, shifts=None, gradient_recipes=None, f0=None, b
return gradient_tapes, processing_fn


@gradient_transform
def _expand_transform_param_shift(
tape: qml.tape.QuantumTape,
argnum=None,
shifts=None,
gradient_recipes=None,
fallback_fn=finite_diff,
f0=None,
broadcast=False,
) -> (Sequence[qml.tape.QuantumTape], Callable):
"""Expand function to be applied before parameter shift."""
expanded_tape = expand_invalid_trainable(tape)

def null_postprocessing(results):
"""A postprocesing function returned by a transform that only converts the batch of results
into a result for a single ``QuantumTape``.
"""
return results[0]

return [expanded_tape], null_postprocessing


@partial(
transform,
expand_transform=_expand_transform_param_shift,
classical_cotransform=_contract_qjac_with_cjac,
final_transform=True,
)
def param_shift(
tape,
tape: qml.tape.QuantumTape,
argnum=None,
shifts=None,
gradient_recipes=None,
fallback_fn=finite_diff,
f0=None,
broadcast=False,
):
) -> (Sequence[qml.tape.QuantumTape], Callable):
r"""Transform a QNode to compute the parameter-shift gradient of all gate
parameters with respect to its inputs.
Expand Down Expand Up @@ -1004,6 +1034,7 @@ def param_shift(
Note that ``broadcast=True`` requires additional memory by a factor of the largest
batch_size of the created tapes.
"""

transform_name = "parameter-shift rule"
assert_no_state_returns(tape.measurements, transform_name)
assert_multimeasure_not_broadcasted(tape.measurements, broadcast)
Expand Down
Loading

0 comments on commit e68a56f

Please sign in to comment.