Skip to content

Commit

Permalink
Allow pulses to be applied to broadcasted states (#4863)
Browse files Browse the repository at this point in the history
**Context:**
Currently, we are raising an error in `apply_operation` of the new qubit
device whenever a broadcasted state is supplied together with a
non-broadcasted `ParametrizedEvolution` operation.
We actually can support this scenario via the evolved matrix code, and
via the evolved state code if `return_intermediate` is `False`.

**Description of the Change:**
Evolve batched states instead of raising an error.
If the pulse has `return_intermediate=True`, introducing a batch
dimension itself, we always use the matrix evolution method, modifying
the branching slightly.

**Benefits:**
More code is supported, in particular the stochastic pulse
parameter-shift rule in conjunction with multiple pulses within a QNode.

**Possible Drawbacks:**
N/A

**Related GitHub Issues:**
Fixes #4859

---------

Co-authored-by: Christina Lee <[email protected]>
  • Loading branch information
dwierichs and albi3ro authored Nov 22, 2023
1 parent 34aba12 commit c025b98
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 38 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@

<h3>Improvements 🛠</h3>

* `default.qubit` now can evolve already batched states with `ParametrizedEvolution`
[(#4863)](https://github.com/PennyLaneAI/pennylane/pull/4863)

* `default.qubit` no longer uses a dense matrix for `MultiControlledX` for more than 8 operation wires.
[(#4673)](https://github.com/PennyLaneAI/pennylane/pull/4673)

Expand Down
39 changes: 25 additions & 14 deletions pennylane/devices/qubit/apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,23 +334,26 @@ def apply_parametrized_evolution(
):
"""Apply ParametrizedEvolution by evolving the state rather than the operator matrix
if we are operating on more than half of the subsystem"""
if is_state_batched and op.batch_size is None:
raise RuntimeError(
"ParameterizedEvolution does not support standard broadcasting, but received a batched state"
)

# shape(state) is static (not a tracer), we can use an if statement
num_wires = len(qml.math.shape(state))
num_wires = len(qml.math.shape(state)) - is_state_batched
state = qml.math.cast(state, complex)
if 2 * len(op.wires) > num_wires and not op.hyperparameters["complementary"]:
# the subsystem operated is more than half of the system based on the state vector --> evolve state
return _evolve_state_vector_under_parametrized_evolution(op, state, num_wires)
# otherwise --> evolve matrix
return _apply_operation_default(op, state, is_state_batched, debugger)
if (
2 * len(op.wires) <= num_wires
or op.hyperparameters["complementary"]
or (is_state_batched and op.hyperparameters["return_intermediate"])
):
# the subsystem operated on is half as big as the total system, or less
# or we want complementary time evolution
# or both the state and the operation have a batch dimension
# --> evolve matrix
return _apply_operation_default(op, state, is_state_batched, debugger)
# otherwise --> evolve state
return _evolve_state_vector_under_parametrized_evolution(op, state, num_wires, is_state_batched)


def _evolve_state_vector_under_parametrized_evolution(
operation: qml.pulse.ParametrizedEvolution, state, num_wires
operation: qml.pulse.ParametrizedEvolution, state, num_wires, is_state_batched
):
"""Uses an odeint solver to compute the evolution of the input ``state`` under the given
``ParametrizedEvolution`` operation.
Expand Down Expand Up @@ -385,7 +388,13 @@ def _evolve_state_vector_under_parametrized_evolution(
"You can update these values by calling the ParametrizedEvolution class: EV(params, t)."
)

state = state.flatten()
if is_state_batched:
batch_dim = state.shape[0]
state = qml.math.moveaxis(state.reshape((batch_dim, 2**num_wires)), 1, 0)
out_shape = [2] * num_wires + [batch_dim] # this shape is before moving the batch_dim back
else:
state = state.flatten()
out_shape = [2] * num_wires

with jax.ensure_compile_time_eval():
H_jax = ParametrizedHamiltonianPytree.from_hamiltonian( # pragma: no cover
Expand All @@ -399,7 +408,9 @@ def fun(y, t):
return (-1j * H_jax(operation.data, t=t)) @ y

result = odeint(fun, state, operation.t, **operation.odeint_kwargs)
out_shape = [2] * num_wires
if operation.hyperparameters["return_intermediate"]:
return qml.math.reshape(result, [-1] + out_shape)
return qml.math.reshape(result[-1], out_shape)
result = qml.math.reshape(result[-1], out_shape)
if is_state_batched:
return qml.math.moveaxis(result, -1, 0)
return result
55 changes: 33 additions & 22 deletions tests/devices/qubit/test_apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,18 +242,11 @@ def test_globalphase(self, method, wire, ml_framework):
assert qml.math.allclose(shift * initial_state, new_state_no_wire)


# pylint:disable = unused-argument
def time_independent_hamiltonian():
"""Create a time-independent Hamiltonian on two qubits."""
ops = [qml.PauliX(0), qml.PauliZ(1), qml.PauliY(0), qml.PauliX(1)]

def f1(params, t):
return params # constant

def f2(params, t):
return params # constant

coeffs = [f1, f2, 4, 9]
coeffs = [qml.pulse.constant, qml.pulse.constant, 0.4, 0.9]

return qml.pulse.ParametrizedHamiltonian(coeffs, ops)

Expand All @@ -275,10 +268,10 @@ def f2(params, t):


@pytest.mark.jax
class TestApplyParameterizedEvolution:
class TestApplyParametrizedEvolution:
@pytest.mark.parametrize("method", methods)
def test_parameterized_evolution_time_independent(self, method):
"""Test that applying a ParameterizedEvolution gives the expected state
"""Test that applying a ParametrizedEvolution gives the expected state
for a time-independent hamiltonian"""

import jax.numpy as jnp
Expand All @@ -292,7 +285,7 @@ def test_parameterized_evolution_time_independent(self, method):

H = time_independent_hamiltonian()
params = jnp.array([1.0, 2.0])
t = 4
t = 0.4

op = qml.pulse.ParametrizedEvolution(H=H, params=params, t=t)

Expand All @@ -306,7 +299,7 @@ def test_parameterized_evolution_time_independent(self, method):

@pytest.mark.parametrize("method", methods)
def test_parameterized_evolution_time_dependent(self, method):
"""Test that applying a ParameterizedEvolution gives the expected state
"""Test that applying a ParametrizedEvolution gives the expected state
for a time dependent Hamiltonian"""

import jax
Expand All @@ -321,7 +314,7 @@ def test_parameterized_evolution_time_dependent(self, method):

H = time_dependent_hamiltonian()
params = jnp.array([1.0, 2.0])
t = 4
t = 0.4

op = qml.pulse.ParametrizedEvolution(H=H, params=params, t=t)

Expand All @@ -340,7 +333,7 @@ def generator(params):
assert np.allclose(new_state, new_state_expected, atol=0.002)

def test_large_state_small_matrix_evolves_matrix(self, mocker):
"""Test that applying a ParameterizedEvolution operating on less
"""Test that applying a ParametrizedEvolution operating on less
than half of the wires in the state uses the default function to evolve
the matrix"""

Expand All @@ -357,7 +350,7 @@ def test_large_state_small_matrix_evolves_matrix(self, mocker):

H = time_independent_hamiltonian()
params = jnp.array([1.0, 2.0])
t = 4
t = 0.4

op = qml.pulse.ParametrizedEvolution(H=H, params=params, t=t)

Expand All @@ -374,7 +367,7 @@ def test_large_state_small_matrix_evolves_matrix(self, mocker):
assert spy.call_count == 1

def test_small_evolves_state(self, mocker):
"""Test that applying a ParameterizedEvolution operating on less
"""Test that applying a ParametrizedEvolution operating on less
than half of the wires in the state uses the default function to evolve
the matrix"""

Expand Down Expand Up @@ -433,7 +426,7 @@ def test_small_evolves_state(self, mocker):

H = time_independent_hamiltonian()
params = jnp.array([1.0, 2.0])
t = 4
t = 0.4

op = qml.pulse.ParametrizedEvolution(H=H, params=params, t=t)

Expand Down Expand Up @@ -480,23 +473,41 @@ def test_parametrized_evolution_state_vector_return_intermediate(self, mocker):
assert spy.call_count == 2
assert qml.math.allclose(state_ev, state_rx, atol=1e-6)

def test_batched_state_raises_an_error(self):
"""Test that if is_state_batche=True, an error is raised"""
@pytest.mark.parametrize("num_state_wires", [2, 4])
def test_with_batched_state(self, num_state_wires, mocker):
"""Test that a ParametrizedEvolution is applied correctly to a batched state.
Note that the branching logic is different for batched input states, because
evolving the state vector does not support batching of the state. Instead,
the evolved matrix is used always."""
spy_einsum = mocker.spy(qml.math, "einsum")
H = time_independent_hamiltonian()
params = np.array([1.0, 2.0])
t = 4
t = 0.1

op = qml.pulse.ParametrizedEvolution(H=H, params=params, t=t)

initial_state = np.array(
[
[[0.81677345 + 0.0j, 0.0 + 0.0j], [0.0 - 0.57695852j, 0.0 + 0.0j]],
[[0.33894597 + 0.0j, 0.0 + 0.0j], [0.0 - 0.94080584j, 0.0 + 0.0j]],
[[0.33894597 + 0.0j, 0.0 + 0.0j], [0.0 - 0.94080584j, 0.0 + 0.0j]],
]
)
if num_state_wires == 4:
zero_state_two_wires = np.eye(4)[0].reshape((2, 2))
initial_state = np.tensordot(initial_state, zero_state_two_wires, axes=0)

with pytest.raises(RuntimeError, match="does not support standard broadcasting"):
_ = apply_operation(op, initial_state, is_state_batched=True)
true_mat = qml.math.expm(-1j * qml.matrix(H(params, t=t)) * t)
U = qml.QubitUnitary(U=true_mat, wires=[0, 1])

new_state = apply_operation(op, initial_state, is_state_batched=True)
new_state_expected = apply_operation(U, initial_state, is_state_batched=True)
assert np.allclose(new_state, new_state_expected, atol=0.002)

if num_state_wires == 4:
assert spy_einsum.call_count == 2
else:
assert spy_einsum.call_count == 1


@pytest.mark.parametrize("ml_framework", ml_frameworks_list)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def cost(x, cache):

def atol_for_shots(shots):
"""Return higher tolerance if finite shots."""
return 2e-2 if shots else 1e-6
return 3e-2 if shots else 1e-6


@pytest.mark.parametrize("execute_kwargs, shots, device", test_matrix)
Expand Down
2 changes: 1 addition & 1 deletion tests/pulse/test_parametrized_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def U(params):

@pytest.mark.jax
def test_map_wires():
"""Test that map wires returns a new ParameterizedEvolution, with wires updated on
"""Test that map wires returns a new ParametrizedEvolution, with wires updated on
both the operator and the corresponding Hamiltonian"""

def f1(p, t):
Expand Down

0 comments on commit c025b98

Please sign in to comment.