From d4dacf951e1f7a1c51074dd50f96bd62206584c9 Mon Sep 17 00:00:00 2001 From: Astral Cai Date: Wed, 28 Feb 2024 17:21:32 -0500 Subject: [PATCH] BugFix: adjoint metric tensor with jax (#5271) **Context:** The `adjoint_metric_tensor` transform does not work with jax variables because jax variables are not considered trainable parameters until they become tracers. **Description of the Change:** 1. Add a custom expand transform to `adjoint_metric_tensor` that expands trainable parameters based on argnums. 2. Add an optional `use_argnum_in_expand` argument to transform programs to determine whether or not to perform `jax_argnums_to_tape_trainable` on the parameters. **Benefits:** BugFix **Possible Drawbacks:** Adding yet another keyword argument to `transform` may make code look messy. **Related GitHub Issues:** https://github.com/PennyLaneAI/pennylane/issues/5197 **Related Shortcut Stories:** [sc-56734] --- doc/releases/changelog-0.35.0.md | 3 +++ pennylane/gradients/adjoint_metric_tensor.py | 25 ++++++++++++++---- pennylane/transforms/__init__.py | 1 + pennylane/transforms/core/transform.py | 6 ++++- .../transforms/core/transform_dispatcher.py | 10 ++++++- .../transforms/core/transform_program.py | 3 ++- pennylane/transforms/tape_expand.py | 26 +++++++++++++++++++ .../core/test_adjoint_metric_tensor.py | 3 +-- 8 files changed, 67 insertions(+), 10 deletions(-) diff --git a/doc/releases/changelog-0.35.0.md b/doc/releases/changelog-0.35.0.md index 09b80093664..6df308c90ba 100644 --- a/doc/releases/changelog-0.35.0.md +++ b/doc/releases/changelog-0.35.0.md @@ -655,6 +655,9 @@ * The overwriting of the class names of `I`, `X`, `Y`, and `Z` no longer happens in the init after causing problems with datasets. Now happens globally. [(#5252)](https://github.com/PennyLaneAI/pennylane/pull/5252) +* The `adjoint_metric_tensor` transform now works with `jax`. + [(#5271)](https://github.com/PennyLaneAI/pennylane/pull/5271) +

Contributors ✍️

This release contains contributions from (in alphabetical order): diff --git a/pennylane/gradients/adjoint_metric_tensor.py b/pennylane/gradients/adjoint_metric_tensor.py index 47ebb766bde..db21ec27b41 100644 --- a/pennylane/gradients/adjoint_metric_tensor.py +++ b/pennylane/gradients/adjoint_metric_tensor.py @@ -38,7 +38,9 @@ def _group_operations(tape): # Extract tape operations list ops = tape.operations # Find the indices of trainable operations in the tape operations list - trainables = np.where([qml.operation.is_trainable(op) for op in ops])[0] + # pylint: disable=protected-access + trainable_par_info = [tape._par_info[i] for i in tape.trainable_params] + trainables = [info["op_idx"] for info in trainable_par_info] # Add the indices incremented by one to the trainable indices split_ids = list(chain.from_iterable([idx, idx + 1] for idx in trainables)) @@ -54,14 +56,27 @@ def _group_operations(tape): return trainable_operations, group_after_trainable_op +def _expand_trainable_multipar( + tape: qml.tape.QuantumTape, +) -> (Sequence[qml.tape.QuantumTape], Callable): + """Expand trainable multi-parameter operations in a quantum tape.""" + + interface = qml.math.get_interface(*tape.get_parameters()) + use_tape_argnum = interface == "jax" + expand_fn = qml.transforms.create_expand_trainable_multipar( + tape, use_tape_argnum=use_tape_argnum + ) + return [expand_fn(tape)], lambda x: x[0] + + @partial( transform, + expand_transform=_expand_trainable_multipar, classical_cotransform=_contract_metric_tensor_with_cjac, is_informative=True, + use_argnum_in_expand=True, ) -def adjoint_metric_tensor( - tape: qml.tape.QuantumTape, -) -> (Sequence[qml.tape.QuantumTape], Callable): +def adjoint_metric_tensor(tape: qml.tape.QuantumTape) -> (Sequence[qml.tape.QuantumTape], Callable): r"""Implements the adjoint method outlined in `Jones `__ to compute the metric tensor. @@ -147,10 +162,10 @@ def processing_fn(tapes): wire_map = {w: i for i, w in enumerate(tape.wires)} tapes, fn = qml.map_wires(tape, wire_map) tape = fn(tapes) - tape = qml.transforms.expand_trainable_multipar(tape) # Divide all operations of a tape into trainable operations and blocks # of untrainable operations after each trainable one. + trainable_operations, group_after_trainable_op = _group_operations(tape) dim = 2**tape.num_wires diff --git a/pennylane/transforms/__init__.py b/pennylane/transforms/__init__.py index bb1f378610a..653244aa307 100644 --- a/pennylane/transforms/__init__.py +++ b/pennylane/transforms/__init__.py @@ -330,6 +330,7 @@ def circuit(x, y): expand_trainable_multipar, create_expand_fn, create_decomp_expand_fn, + create_expand_trainable_multipar, set_decomposition, ) from .transpile import transpile diff --git a/pennylane/transforms/core/transform.py b/pennylane/transforms/core/transform.py index e887f9178ef..35c19f96f70 100644 --- a/pennylane/transforms/core/transform.py +++ b/pennylane/transforms/core/transform.py @@ -25,7 +25,8 @@ def transform( classical_cotransform=None, is_informative=False, final_transform=False, -): + use_argnum_in_expand=False, +): # pylint: disable=too-many-arguments """Generalizes a function that transforms tapes to work with additional circuit-like objects such as a :class:`~.QNode`. @@ -55,6 +56,8 @@ def transform( of the transform program and the tapes or qnode aren't executed. final_transform=False (bool): Whether or not the transform is terminal. If true the transform is queued at the end of the transform program. ``is_informative`` supersedes ``final_transform``. + use_argnum_in_expand=False (bool): Whether or not to use ``argnum`` of the tape to determine trainable parameters + during the expansion transform process. Returns: @@ -187,4 +190,5 @@ def qnode_circuit(a): classical_cotransform=classical_cotransform, is_informative=is_informative, final_transform=final_transform, + use_argnum_in_expand=use_argnum_in_expand, ) diff --git a/pennylane/transforms/core/transform_dispatcher.py b/pennylane/transforms/core/transform_dispatcher.py index f9b6faeca96..873449d6f8d 100644 --- a/pennylane/transforms/core/transform_dispatcher.py +++ b/pennylane/transforms/core/transform_dispatcher.py @@ -71,6 +71,7 @@ def __init__( classical_cotransform=None, is_informative=False, final_transform=False, + use_argnum_in_expand=False, ): # pylint:disable=redefined-outer-name self._transform = transform self._expand_transform = expand_transform @@ -79,6 +80,7 @@ def __init__( # is_informative supersedes final_transform self._final_transform = is_informative or final_transform self._qnode_transform = self.default_qnode_transform + self._use_argnum_in_expand = use_argnum_in_expand functools.update_wrapper(self, transform) def __call__(self, *targs, **tkwargs): # pylint: disable=too-many-return-statements @@ -208,7 +210,11 @@ def default_qnode_transform(self, qnode, targs, tkwargs): qnode = copy.copy(qnode) if self.expand_transform: - qnode.add_transform(TransformContainer(self._expand_transform, targs, tkwargs)) + qnode.add_transform( + TransformContainer( + self._expand_transform, targs, tkwargs, use_argnum=self._use_argnum_in_expand + ) + ) qnode.add_transform( TransformContainer( self._transform, @@ -382,6 +388,7 @@ def __init__( classical_cotransform=None, is_informative=False, final_transform=False, + use_argnum=False, ): # pylint:disable=redefined-outer-name,too-many-arguments self._transform = transform self._args = args or [] @@ -389,6 +396,7 @@ def __init__( self._classical_cotransform = classical_cotransform self._is_informative = is_informative self._final_transform = is_informative or final_transform + self._use_argnum = use_argnum def __repr__(self): return f"<{self._transform.__name__}({self._args}, {self._kwargs})>" diff --git a/pennylane/transforms/core/transform_program.py b/pennylane/transforms/core/transform_program.py index 857cdc3b8ae..58f5045cf5d 100644 --- a/pennylane/transforms/core/transform_program.py +++ b/pennylane/transforms/core/transform_program.py @@ -440,7 +440,8 @@ def _set_all_argnums(self, qnode, args, kwargs, argnums): argnums_list = [] for index, transform in enumerate(self): argnums = [0] if qnode.interface in ["jax", "jax-jit"] and argnums is None else argnums - if transform.classical_cotransform and argnums: + # pylint: disable=protected-access + if (transform._use_argnum or transform.classical_cotransform) and argnums: params = qml.math.jax_argnums_to_tape_trainable( qnode, argnums, TransformProgram(self[0:index]), args, kwargs ) diff --git a/pennylane/transforms/tape_expand.py b/pennylane/transforms/tape_expand.py index adc0ad065a0..0ee01262071 100644 --- a/pennylane/transforms/tape_expand.py +++ b/pennylane/transforms/tape_expand.py @@ -163,6 +163,32 @@ def expand_fn(tape, depth=depth, **kwargs): docstring=_expand_trainable_multipar_doc, ) + +def create_expand_trainable_multipar(tape, use_tape_argnum=False): + """Creates the expand_trainable_multipar expansion transform with an option to include argnums.""" + + if not use_tape_argnum: + return expand_trainable_multipar + + # pylint: disable=protected-access + trainable_par_info = [tape._par_info[i] for i in tape.trainable_params] + trainable_ops = [info["op"] for info in trainable_par_info] + + @qml.BooleanFn + def _is_trainable(obj): + return obj in trainable_ops + + return create_expand_fn( + depth=10, + stop_at=not_tape + | is_measurement + | has_nopar + | (~_is_trainable) + | (has_gen & ~gen_is_multi_term_hamiltonian), + docstring=_expand_trainable_multipar_doc, + ) + + _expand_nonunitary_gen_doc = """Expand out a tape so that all its parametrized operations have a unitary generator. diff --git a/tests/gradients/core/test_adjoint_metric_tensor.py b/tests/gradients/core/test_adjoint_metric_tensor.py index 50f6513ea6e..c7ce5b689f4 100644 --- a/tests/gradients/core/test_adjoint_metric_tensor.py +++ b/tests/gradients/core/test_adjoint_metric_tensor.py @@ -400,7 +400,6 @@ def circuit(*params): assert qml.math.allclose(mt, expected) @pytest.mark.jax - @pytest.mark.skip("JAX does not support forward pass execution of the metric tensor.") @pytest.mark.parametrize("ansatz, params", list(zip(fubini_ansatze, fubini_params))) def test_correct_output_qnode_jax(self, ansatz, params): """Test that the output is correct when using JAX and @@ -418,7 +417,7 @@ def circuit(*params): ansatz(*params, dev.wires) return qml.expval(qml.PauliZ(0)) - mt = qml.adjoint_metric_tensor(circuit)(*j_params) + mt = qml.adjoint_metric_tensor(circuit, argnums=list(range(len(j_params))))(*j_params) if isinstance(mt, tuple): assert all(qml.math.allclose(_mt, _exp) for _mt, _exp in zip(mt, expected))