Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BugFix: adjoint metric tensor with jax #5271

Merged
merged 5 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-0.35.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,9 @@
* The overwriting of the class names of `I`, `X`, `Y`, and `Z` no longer happens in the init after causing problems with datasets. Now happens globally.
[(#5252)](https://github.com/PennyLaneAI/pennylane/pull/5252)

* The `adjoint_metric_tensor` transform now works with `jax`.
[(#5271)](https://github.com/PennyLaneAI/pennylane/pull/5271)

<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):
Expand Down
31 changes: 26 additions & 5 deletions pennylane/gradients/adjoint_metric_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,20 @@ def _reshape_real_imag(state, dim):
return qml.math.real(state), qml.math.imag(state)


def _group_operations(tape):
def _group_operations(tape, argnums):
"""Divide all operations of a tape into trainable operations and blocks
of untrainable operations after each trainable one."""

# Extract tape operations list
ops = tape.operations
# Find the indices of trainable operations in the tape operations list
trainables = np.where([qml.operation.is_trainable(op) for op in ops])[0]
if argnums is None:
argnums = tape.trainable_params
elif isinstance(argnums, int):
argnums = [argnums]
# pylint: disable=protected-access
trainable_par_info = [tape._par_info[i] for i in argnums]
trainables = [info["op_idx"] for info in trainable_par_info]
# Add the indices incremented by one to the trainable indices
split_ids = list(chain.from_iterable([idx, idx + 1] for idx in trainables))

Expand All @@ -54,13 +60,28 @@ def _group_operations(tape):
return trainable_operations, group_after_trainable_op


def _expand_trainable_multipar(
tape: qml.tape.QuantumTape,
) -> (Sequence[qml.tape.QuantumTape], Callable):
"""Expand trainable multi-parameter operations in a quantum tape."""

interface = qml.math.get_interface(*tape.get_parameters())
use_tape_argnum = interface == "jax"
expand_fn = qml.transforms.create_expand_trainable_multipar(
tape, use_tape_argnum=use_tape_argnum
)
return [expand_fn(tape)], lambda x: x[0]


@partial(
transform,
expand_transform=_expand_trainable_multipar,
classical_cotransform=_contract_metric_tensor_with_cjac,
is_informative=True,
use_argnum_in_expand=True,
)
def adjoint_metric_tensor(
tape: qml.tape.QuantumTape,
tape: qml.tape.QuantumTape, argnums=None
astralcai marked this conversation as resolved.
Show resolved Hide resolved
) -> (Sequence[qml.tape.QuantumTape], Callable):
r"""Implements the adjoint method outlined in
`Jones <https://arxiv.org/abs/2011.02991>`__ to compute the metric tensor.
Expand Down Expand Up @@ -147,11 +168,11 @@ def processing_fn(tapes):
wire_map = {w: i for i, w in enumerate(tape.wires)}
tapes, fn = qml.map_wires(tape, wire_map)
tape = fn(tapes)
tape = qml.transforms.expand_trainable_multipar(tape)

# Divide all operations of a tape into trainable operations and blocks
# of untrainable operations after each trainable one.
trainable_operations, group_after_trainable_op = _group_operations(tape)

trainable_operations, group_after_trainable_op = _group_operations(tape, argnums)

dim = 2**tape.num_wires
# generate and extract initial state
Expand Down
1 change: 1 addition & 0 deletions pennylane/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def circuit(x, y):
expand_trainable_multipar,
create_expand_fn,
create_decomp_expand_fn,
create_expand_trainable_multipar,
set_decomposition,
)
from .transpile import transpile
Expand Down
6 changes: 5 additions & 1 deletion pennylane/transforms/core/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def transform(
classical_cotransform=None,
is_informative=False,
final_transform=False,
):
use_argnum_in_expand=False,
): # pylint: disable=too-many-arguments
"""Generalizes a function that transforms tapes to work with additional circuit-like objects such as a
:class:`~.QNode`.

Expand Down Expand Up @@ -55,6 +56,8 @@ def transform(
of the transform program and the tapes or qnode aren't executed.
final_transform=False (bool): Whether or not the transform is terminal. If true the transform is queued at the end
of the transform program. ``is_informative`` supersedes ``final_transform``.
use_argnum_in_expand=False (bool): Whether or not to use ``argnum`` of the tape to determine trainable parameters
during the expansion transform process.

Returns:

Expand Down Expand Up @@ -187,4 +190,5 @@ def qnode_circuit(a):
classical_cotransform=classical_cotransform,
is_informative=is_informative,
final_transform=final_transform,
use_argnum_in_expand=use_argnum_in_expand,
)
10 changes: 9 additions & 1 deletion pennylane/transforms/core/transform_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
classical_cotransform=None,
is_informative=False,
final_transform=False,
use_argnum_in_expand=False,
): # pylint:disable=redefined-outer-name
self._transform = transform
self._expand_transform = expand_transform
Expand All @@ -79,6 +80,7 @@ def __init__(
# is_informative supersedes final_transform
self._final_transform = is_informative or final_transform
self._qnode_transform = self.default_qnode_transform
self._use_argnum_in_expand = use_argnum_in_expand
functools.update_wrapper(self, transform)

def __call__(self, *targs, **tkwargs): # pylint: disable=too-many-return-statements
Expand Down Expand Up @@ -208,7 +210,11 @@ def default_qnode_transform(self, qnode, targs, tkwargs):
qnode = copy.copy(qnode)

if self.expand_transform:
qnode.add_transform(TransformContainer(self._expand_transform, targs, tkwargs))
qnode.add_transform(
TransformContainer(
self._expand_transform, targs, tkwargs, use_argnum=self._use_argnum_in_expand
)
)
qnode.add_transform(
TransformContainer(
self._transform,
Expand Down Expand Up @@ -382,13 +388,15 @@ def __init__(
classical_cotransform=None,
is_informative=False,
final_transform=False,
use_argnum=False,
): # pylint:disable=redefined-outer-name,too-many-arguments
self._transform = transform
self._args = args or []
self._kwargs = kwargs or {}
self._classical_cotransform = classical_cotransform
self._is_informative = is_informative
self._final_transform = is_informative or final_transform
self._use_argnum = use_argnum

def __repr__(self):
return f"<{self._transform.__name__}({self._args}, {self._kwargs})>"
Expand Down
3 changes: 2 additions & 1 deletion pennylane/transforms/core/transform_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,8 @@ def _set_all_argnums(self, qnode, args, kwargs, argnums):
argnums_list = []
for index, transform in enumerate(self):
argnums = [0] if qnode.interface in ["jax", "jax-jit"] and argnums is None else argnums
if transform.classical_cotransform and argnums:
# pylint: disable=protected-access
if (transform._use_argnum or transform.classical_cotransform) and argnums:
params = qml.math.jax_argnums_to_tape_trainable(
qnode, argnums, TransformProgram(self[0:index]), args, kwargs
)
Expand Down
26 changes: 26 additions & 0 deletions pennylane/transforms/tape_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,32 @@ def expand_fn(tape, depth=depth, **kwargs):
docstring=_expand_trainable_multipar_doc,
)


def create_expand_trainable_multipar(tape, use_tape_argnum=False):
"""Creates the expand_trainable_multipar expansion transform with an option to include argnums."""

if not use_tape_argnum:
return expand_trainable_multipar

# pylint: disable=protected-access
trainable_par_info = [tape._par_info[i] for i in tape.trainable_params]
trainable_ops = [info["op"] for info in trainable_par_info]

@qml.BooleanFn
def _is_trainable(obj):
return obj in trainable_ops

return create_expand_fn(
depth=10,
stop_at=not_tape
| is_measurement
| has_nopar
| (~_is_trainable)
| (has_gen & ~gen_is_multi_term_hamiltonian),
docstring=_expand_trainable_multipar_doc,
)


_expand_nonunitary_gen_doc = """Expand out a tape so that all its parametrized
operations have a unitary generator.

Expand Down
3 changes: 1 addition & 2 deletions tests/gradients/core/test_adjoint_metric_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,6 @@ def circuit(*params):
assert qml.math.allclose(mt, expected)

@pytest.mark.jax
@pytest.mark.skip("JAX does not support forward pass execution of the metric tensor.")
@pytest.mark.parametrize("ansatz, params", list(zip(fubini_ansatze, fubini_params)))
def test_correct_output_qnode_jax(self, ansatz, params):
"""Test that the output is correct when using JAX and
Expand All @@ -418,7 +417,7 @@ def circuit(*params):
ansatz(*params, dev.wires)
return qml.expval(qml.PauliZ(0))

mt = qml.adjoint_metric_tensor(circuit)(*j_params)
mt = qml.adjoint_metric_tensor(circuit, argnums=list(range(len(j_params))))(*j_params)

if isinstance(mt, tuple):
assert all(qml.math.allclose(_mt, _exp) for _mt, _exp in zip(mt, expected))
Expand Down
Loading