Skip to content

Commit

Permalink
add vmap support for devices that do not support parameter broadcasti…
Browse files Browse the repository at this point in the history
…ng (#5286)

Fixes #5240 [sc-57137] [sc-57848] Fixes #5289

Basically when we set `vectorized=True` inside the `pure_callback` call,
we assumed that the device natively supports broadcasting. And then we
only tested with devices that did indeed natively support parameter
broadcasting.

This problem was made worse by the fact that our `vmap` tests included a
`Hamiltonian` expectation value, which caused up to skip many of the
test cases that we really should have been testing. So I got rid of the
`Hamiltonian` from the test so we could actually test more situations.

I also added more `lightning.qubit` tests to the test configuration.
That forced one or two other changes.

The major problem with `jax.vmap` is that it adds in a parameter
broadcasting dimension *after* we have already handled all of our
preprocessing and breaking up parameter broadcasting.

---------

Co-authored-by: Josh Izaac <[email protected]>
Co-authored-by: Nathan Killoran <[email protected]>
Co-authored-by: Matthew Silverman <[email protected]>
Co-authored-by: Astral Cai <[email protected]>
Co-authored-by: Mikhail Andrenkov <[email protected]>
Co-authored-by: Korbinian Kottmann <[email protected]>
Co-authored-by: Thomas R. Bromley <[email protected]>
Co-authored-by: Isaac De Vlugt <[email protected]>
Co-authored-by: Isaac De Vlugt <[email protected]>
Co-authored-by: Pietropaolo Frisoni <[email protected]>
Co-authored-by: soranjh <[email protected]>
Co-authored-by: Mudit Pandey <[email protected]>
  • Loading branch information
13 people authored Apr 8, 2024
1 parent b1d460c commit 992fb00
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 92 deletions.
15 changes: 2 additions & 13 deletions pennylane/tape/qscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,19 +759,8 @@ def shape(self, device):
>>> qs.shape(dev)
((4,), (), (4,))
"""

if isinstance(device, qml.devices.Device):
# MP.shape (called below) takes 2 arguments: `device` and `shots`.
# With the new device interface, shots are stored on tapes rather than the device
# TODO: refactor MP.shape to accept `wires` instead of device (not currently done
# because probs.shape uses device.cutoff)
shots = self.shots
else:
shots = (
Shots(device._raw_shot_sequence)
if device.shot_vector is not None
else Shots(device.shots)
)
shots = self.shots
# even with the legacy device interface, the shots on the tape will agree with the shots used by the device for the execution

if len(shots.shot_vector) > 1 and self.batch_size is not None:
raise NotImplementedError(
Expand Down
27 changes: 13 additions & 14 deletions pennylane/workflow/interfaces/jax_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,20 +146,19 @@ def _execute_wrapper_inner(params, tapes, execute_fn, _, device, is_vjp=False) -

def pure_callback_wrapper(p):
new_tapes = _set_fn(tapes.vals, p)
res = tuple(_to_jax(execute_fn(new_tapes)))
# When executed under `jax.vmap` the `result_shapes_dtypes` will contain
# the shape without the vmap dimensions, while the function here will be
# executed with objects containing the vmap dimensions. So res[i].ndim
# will have an extra dimension for every `jax.vmap` wrapping this execution.
#
# The execute_fn will return an object with shape `(n_observables, batches)`
# but the default behaviour for `jax.pure_callback` is to add the extra
# dimension at the beginning, so `(batches, n_observables)`. So in here
# we detect with the heuristic above if we are executing under vmap and we
# swap the order in that case.
return jax.tree_map(lambda r, s: r.T if r.ndim > s.ndim else r, res, shape_dtype_structs)

out = jax.pure_callback(pure_callback_wrapper, shape_dtype_structs, params, vectorized=True)
return _to_jax(execute_fn(new_tapes))

if isinstance(device, qml.Device):
device_supports_vectorization = device.capabilities().get("supports_broadcasting")
else:
# first order way of determining native parameter broadcasting support
# will be inaccurate when inclusion of broadcast_expand depends on ExecutionConfig values (like adjoint)
device_supports_vectorization = (
qml.transforms.broadcast_expand not in device.preprocess()[0]
)
out = jax.pure_callback(
pure_callback_wrapper, shape_dtype_structs, params, vectorized=device_supports_vectorization
)
return out


Expand Down
Loading

0 comments on commit 992fb00

Please sign in to comment.