Skip to content

Commit

Permalink
BugFix: adjoint metric tensor with jax (#5271)
Browse files Browse the repository at this point in the history
**Context:**
The `adjoint_metric_tensor` transform does not work with jax variables
because jax variables are not considered trainable parameters until they
become tracers.

**Description of the Change:**
1. Add a custom expand transform to `adjoint_metric_tensor` that expands
trainable parameters based on argnums.
2. Add an optional `use_argnum_in_expand` argument to transform programs
to determine whether or not to perform `jax_argnums_to_tape_trainable`
on the parameters.

**Benefits:**
BugFix

**Possible Drawbacks:**
Adding yet another keyword argument to `transform` may make code look
messy.

**Related GitHub Issues:**
#5197

**Related Shortcut Stories:**
[sc-56734]
  • Loading branch information
astralcai authored Feb 28, 2024
1 parent c1f3997 commit d4dacf9
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 10 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-0.35.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,9 @@
* The overwriting of the class names of `I`, `X`, `Y`, and `Z` no longer happens in the init after causing problems with datasets. Now happens globally.
[(#5252)](https://github.com/PennyLaneAI/pennylane/pull/5252)

* The `adjoint_metric_tensor` transform now works with `jax`.
[(#5271)](https://github.com/PennyLaneAI/pennylane/pull/5271)

<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):
Expand Down
25 changes: 20 additions & 5 deletions pennylane/gradients/adjoint_metric_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def _group_operations(tape):
# Extract tape operations list
ops = tape.operations
# Find the indices of trainable operations in the tape operations list
trainables = np.where([qml.operation.is_trainable(op) for op in ops])[0]
# pylint: disable=protected-access
trainable_par_info = [tape._par_info[i] for i in tape.trainable_params]
trainables = [info["op_idx"] for info in trainable_par_info]
# Add the indices incremented by one to the trainable indices
split_ids = list(chain.from_iterable([idx, idx + 1] for idx in trainables))

Expand All @@ -54,14 +56,27 @@ def _group_operations(tape):
return trainable_operations, group_after_trainable_op


def _expand_trainable_multipar(
tape: qml.tape.QuantumTape,
) -> (Sequence[qml.tape.QuantumTape], Callable):
"""Expand trainable multi-parameter operations in a quantum tape."""

interface = qml.math.get_interface(*tape.get_parameters())
use_tape_argnum = interface == "jax"
expand_fn = qml.transforms.create_expand_trainable_multipar(
tape, use_tape_argnum=use_tape_argnum
)
return [expand_fn(tape)], lambda x: x[0]


@partial(
transform,
expand_transform=_expand_trainable_multipar,
classical_cotransform=_contract_metric_tensor_with_cjac,
is_informative=True,
use_argnum_in_expand=True,
)
def adjoint_metric_tensor(
tape: qml.tape.QuantumTape,
) -> (Sequence[qml.tape.QuantumTape], Callable):
def adjoint_metric_tensor(tape: qml.tape.QuantumTape) -> (Sequence[qml.tape.QuantumTape], Callable):
r"""Implements the adjoint method outlined in
`Jones <https://arxiv.org/abs/2011.02991>`__ to compute the metric tensor.
Expand Down Expand Up @@ -147,10 +162,10 @@ def processing_fn(tapes):
wire_map = {w: i for i, w in enumerate(tape.wires)}
tapes, fn = qml.map_wires(tape, wire_map)
tape = fn(tapes)
tape = qml.transforms.expand_trainable_multipar(tape)

# Divide all operations of a tape into trainable operations and blocks
# of untrainable operations after each trainable one.

trainable_operations, group_after_trainable_op = _group_operations(tape)

dim = 2**tape.num_wires
Expand Down
1 change: 1 addition & 0 deletions pennylane/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def circuit(x, y):
expand_trainable_multipar,
create_expand_fn,
create_decomp_expand_fn,
create_expand_trainable_multipar,
set_decomposition,
)
from .transpile import transpile
Expand Down
6 changes: 5 additions & 1 deletion pennylane/transforms/core/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def transform(
classical_cotransform=None,
is_informative=False,
final_transform=False,
):
use_argnum_in_expand=False,
): # pylint: disable=too-many-arguments
"""Generalizes a function that transforms tapes to work with additional circuit-like objects such as a
:class:`~.QNode`.
Expand Down Expand Up @@ -55,6 +56,8 @@ def transform(
of the transform program and the tapes or qnode aren't executed.
final_transform=False (bool): Whether or not the transform is terminal. If true the transform is queued at the end
of the transform program. ``is_informative`` supersedes ``final_transform``.
use_argnum_in_expand=False (bool): Whether or not to use ``argnum`` of the tape to determine trainable parameters
during the expansion transform process.
Returns:
Expand Down Expand Up @@ -187,4 +190,5 @@ def qnode_circuit(a):
classical_cotransform=classical_cotransform,
is_informative=is_informative,
final_transform=final_transform,
use_argnum_in_expand=use_argnum_in_expand,
)
10 changes: 9 additions & 1 deletion pennylane/transforms/core/transform_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
classical_cotransform=None,
is_informative=False,
final_transform=False,
use_argnum_in_expand=False,
): # pylint:disable=redefined-outer-name
self._transform = transform
self._expand_transform = expand_transform
Expand All @@ -79,6 +80,7 @@ def __init__(
# is_informative supersedes final_transform
self._final_transform = is_informative or final_transform
self._qnode_transform = self.default_qnode_transform
self._use_argnum_in_expand = use_argnum_in_expand
functools.update_wrapper(self, transform)

def __call__(self, *targs, **tkwargs): # pylint: disable=too-many-return-statements
Expand Down Expand Up @@ -208,7 +210,11 @@ def default_qnode_transform(self, qnode, targs, tkwargs):
qnode = copy.copy(qnode)

if self.expand_transform:
qnode.add_transform(TransformContainer(self._expand_transform, targs, tkwargs))
qnode.add_transform(
TransformContainer(
self._expand_transform, targs, tkwargs, use_argnum=self._use_argnum_in_expand
)
)
qnode.add_transform(
TransformContainer(
self._transform,
Expand Down Expand Up @@ -382,13 +388,15 @@ def __init__(
classical_cotransform=None,
is_informative=False,
final_transform=False,
use_argnum=False,
): # pylint:disable=redefined-outer-name,too-many-arguments
self._transform = transform
self._args = args or []
self._kwargs = kwargs or {}
self._classical_cotransform = classical_cotransform
self._is_informative = is_informative
self._final_transform = is_informative or final_transform
self._use_argnum = use_argnum

def __repr__(self):
return f"<{self._transform.__name__}({self._args}, {self._kwargs})>"
Expand Down
3 changes: 2 additions & 1 deletion pennylane/transforms/core/transform_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,8 @@ def _set_all_argnums(self, qnode, args, kwargs, argnums):
argnums_list = []
for index, transform in enumerate(self):
argnums = [0] if qnode.interface in ["jax", "jax-jit"] and argnums is None else argnums
if transform.classical_cotransform and argnums:
# pylint: disable=protected-access
if (transform._use_argnum or transform.classical_cotransform) and argnums:
params = qml.math.jax_argnums_to_tape_trainable(
qnode, argnums, TransformProgram(self[0:index]), args, kwargs
)
Expand Down
26 changes: 26 additions & 0 deletions pennylane/transforms/tape_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,32 @@ def expand_fn(tape, depth=depth, **kwargs):
docstring=_expand_trainable_multipar_doc,
)


def create_expand_trainable_multipar(tape, use_tape_argnum=False):
"""Creates the expand_trainable_multipar expansion transform with an option to include argnums."""

if not use_tape_argnum:
return expand_trainable_multipar

# pylint: disable=protected-access
trainable_par_info = [tape._par_info[i] for i in tape.trainable_params]
trainable_ops = [info["op"] for info in trainable_par_info]

@qml.BooleanFn
def _is_trainable(obj):
return obj in trainable_ops

return create_expand_fn(
depth=10,
stop_at=not_tape
| is_measurement
| has_nopar
| (~_is_trainable)
| (has_gen & ~gen_is_multi_term_hamiltonian),
docstring=_expand_trainable_multipar_doc,
)


_expand_nonunitary_gen_doc = """Expand out a tape so that all its parametrized
operations have a unitary generator.
Expand Down
3 changes: 1 addition & 2 deletions tests/gradients/core/test_adjoint_metric_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,6 @@ def circuit(*params):
assert qml.math.allclose(mt, expected)

@pytest.mark.jax
@pytest.mark.skip("JAX does not support forward pass execution of the metric tensor.")
@pytest.mark.parametrize("ansatz, params", list(zip(fubini_ansatze, fubini_params)))
def test_correct_output_qnode_jax(self, ansatz, params):
"""Test that the output is correct when using JAX and
Expand All @@ -418,7 +417,7 @@ def circuit(*params):
ansatz(*params, dev.wires)
return qml.expval(qml.PauliZ(0))

mt = qml.adjoint_metric_tensor(circuit)(*j_params)
mt = qml.adjoint_metric_tensor(circuit, argnums=list(range(len(j_params))))(*j_params)

if isinstance(mt, tuple):
assert all(qml.math.allclose(_mt, _exp) for _mt, _exp in zip(mt, expected))
Expand Down

0 comments on commit d4dacf9

Please sign in to comment.