From c025b9822ab47f995647f143e8400cd1f4e352b3 Mon Sep 17 00:00:00 2001 From: David Wierichs Date: Wed, 22 Nov 2023 08:38:35 +0100 Subject: [PATCH] Allow pulses to be applied to broadcasted states (#4863) **Context:** Currently, we are raising an error in `apply_operation` of the new qubit device whenever a broadcasted state is supplied together with a non-broadcasted `ParametrizedEvolution` operation. We actually can support this scenario via the evolved matrix code, and via the evolved state code if `return_intermediate` is `False`. **Description of the Change:** Evolve batched states instead of raising an error. If the pulse has `return_intermediate=True`, introducing a batch dimension itself, we always use the matrix evolution method, modifying the branching slightly. **Benefits:** More code is supported, in particular the stochastic pulse parameter-shift rule in conjunction with multiple pulses within a QNode. **Possible Drawbacks:** N/A **Related GitHub Issues:** Fixes #4859 --------- Co-authored-by: Christina Lee --- doc/releases/changelog-dev.md | 3 + pennylane/devices/qubit/apply_operation.py | 39 ++++++++----- tests/devices/qubit/test_apply_operation.py | 55 +++++++++++-------- .../test_jax_default_qubit_2.py | 2 +- tests/pulse/test_parametrized_evolution.py | 2 +- 5 files changed, 63 insertions(+), 38 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 1d1150fce87..5f7b9c1cd0c 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -39,6 +39,9 @@

Improvements 🛠

+* `default.qubit` now can evolve already batched states with `ParametrizedEvolution` + [(#4863)](https://github.com/PennyLaneAI/pennylane/pull/4863) + * `default.qubit` no longer uses a dense matrix for `MultiControlledX` for more than 8 operation wires. [(#4673)](https://github.com/PennyLaneAI/pennylane/pull/4673) diff --git a/pennylane/devices/qubit/apply_operation.py b/pennylane/devices/qubit/apply_operation.py index 51cdaa0e79a..40c086c3a73 100644 --- a/pennylane/devices/qubit/apply_operation.py +++ b/pennylane/devices/qubit/apply_operation.py @@ -334,23 +334,26 @@ def apply_parametrized_evolution( ): """Apply ParametrizedEvolution by evolving the state rather than the operator matrix if we are operating on more than half of the subsystem""" - if is_state_batched and op.batch_size is None: - raise RuntimeError( - "ParameterizedEvolution does not support standard broadcasting, but received a batched state" - ) # shape(state) is static (not a tracer), we can use an if statement - num_wires = len(qml.math.shape(state)) + num_wires = len(qml.math.shape(state)) - is_state_batched state = qml.math.cast(state, complex) - if 2 * len(op.wires) > num_wires and not op.hyperparameters["complementary"]: - # the subsystem operated is more than half of the system based on the state vector --> evolve state - return _evolve_state_vector_under_parametrized_evolution(op, state, num_wires) - # otherwise --> evolve matrix - return _apply_operation_default(op, state, is_state_batched, debugger) + if ( + 2 * len(op.wires) <= num_wires + or op.hyperparameters["complementary"] + or (is_state_batched and op.hyperparameters["return_intermediate"]) + ): + # the subsystem operated on is half as big as the total system, or less + # or we want complementary time evolution + # or both the state and the operation have a batch dimension + # --> evolve matrix + return _apply_operation_default(op, state, is_state_batched, debugger) + # otherwise --> evolve state + return _evolve_state_vector_under_parametrized_evolution(op, state, num_wires, is_state_batched) def _evolve_state_vector_under_parametrized_evolution( - operation: qml.pulse.ParametrizedEvolution, state, num_wires + operation: qml.pulse.ParametrizedEvolution, state, num_wires, is_state_batched ): """Uses an odeint solver to compute the evolution of the input ``state`` under the given ``ParametrizedEvolution`` operation. @@ -385,7 +388,13 @@ def _evolve_state_vector_under_parametrized_evolution( "You can update these values by calling the ParametrizedEvolution class: EV(params, t)." ) - state = state.flatten() + if is_state_batched: + batch_dim = state.shape[0] + state = qml.math.moveaxis(state.reshape((batch_dim, 2**num_wires)), 1, 0) + out_shape = [2] * num_wires + [batch_dim] # this shape is before moving the batch_dim back + else: + state = state.flatten() + out_shape = [2] * num_wires with jax.ensure_compile_time_eval(): H_jax = ParametrizedHamiltonianPytree.from_hamiltonian( # pragma: no cover @@ -399,7 +408,9 @@ def fun(y, t): return (-1j * H_jax(operation.data, t=t)) @ y result = odeint(fun, state, operation.t, **operation.odeint_kwargs) - out_shape = [2] * num_wires if operation.hyperparameters["return_intermediate"]: return qml.math.reshape(result, [-1] + out_shape) - return qml.math.reshape(result[-1], out_shape) + result = qml.math.reshape(result[-1], out_shape) + if is_state_batched: + return qml.math.moveaxis(result, -1, 0) + return result diff --git a/tests/devices/qubit/test_apply_operation.py b/tests/devices/qubit/test_apply_operation.py index 7b44f57c2fc..56fa5106984 100644 --- a/tests/devices/qubit/test_apply_operation.py +++ b/tests/devices/qubit/test_apply_operation.py @@ -242,18 +242,11 @@ def test_globalphase(self, method, wire, ml_framework): assert qml.math.allclose(shift * initial_state, new_state_no_wire) -# pylint:disable = unused-argument def time_independent_hamiltonian(): """Create a time-independent Hamiltonian on two qubits.""" ops = [qml.PauliX(0), qml.PauliZ(1), qml.PauliY(0), qml.PauliX(1)] - def f1(params, t): - return params # constant - - def f2(params, t): - return params # constant - - coeffs = [f1, f2, 4, 9] + coeffs = [qml.pulse.constant, qml.pulse.constant, 0.4, 0.9] return qml.pulse.ParametrizedHamiltonian(coeffs, ops) @@ -275,10 +268,10 @@ def f2(params, t): @pytest.mark.jax -class TestApplyParameterizedEvolution: +class TestApplyParametrizedEvolution: @pytest.mark.parametrize("method", methods) def test_parameterized_evolution_time_independent(self, method): - """Test that applying a ParameterizedEvolution gives the expected state + """Test that applying a ParametrizedEvolution gives the expected state for a time-independent hamiltonian""" import jax.numpy as jnp @@ -292,7 +285,7 @@ def test_parameterized_evolution_time_independent(self, method): H = time_independent_hamiltonian() params = jnp.array([1.0, 2.0]) - t = 4 + t = 0.4 op = qml.pulse.ParametrizedEvolution(H=H, params=params, t=t) @@ -306,7 +299,7 @@ def test_parameterized_evolution_time_independent(self, method): @pytest.mark.parametrize("method", methods) def test_parameterized_evolution_time_dependent(self, method): - """Test that applying a ParameterizedEvolution gives the expected state + """Test that applying a ParametrizedEvolution gives the expected state for a time dependent Hamiltonian""" import jax @@ -321,7 +314,7 @@ def test_parameterized_evolution_time_dependent(self, method): H = time_dependent_hamiltonian() params = jnp.array([1.0, 2.0]) - t = 4 + t = 0.4 op = qml.pulse.ParametrizedEvolution(H=H, params=params, t=t) @@ -340,7 +333,7 @@ def generator(params): assert np.allclose(new_state, new_state_expected, atol=0.002) def test_large_state_small_matrix_evolves_matrix(self, mocker): - """Test that applying a ParameterizedEvolution operating on less + """Test that applying a ParametrizedEvolution operating on less than half of the wires in the state uses the default function to evolve the matrix""" @@ -357,7 +350,7 @@ def test_large_state_small_matrix_evolves_matrix(self, mocker): H = time_independent_hamiltonian() params = jnp.array([1.0, 2.0]) - t = 4 + t = 0.4 op = qml.pulse.ParametrizedEvolution(H=H, params=params, t=t) @@ -374,7 +367,7 @@ def test_large_state_small_matrix_evolves_matrix(self, mocker): assert spy.call_count == 1 def test_small_evolves_state(self, mocker): - """Test that applying a ParameterizedEvolution operating on less + """Test that applying a ParametrizedEvolution operating on less than half of the wires in the state uses the default function to evolve the matrix""" @@ -433,7 +426,7 @@ def test_small_evolves_state(self, mocker): H = time_independent_hamiltonian() params = jnp.array([1.0, 2.0]) - t = 4 + t = 0.4 op = qml.pulse.ParametrizedEvolution(H=H, params=params, t=t) @@ -480,11 +473,16 @@ def test_parametrized_evolution_state_vector_return_intermediate(self, mocker): assert spy.call_count == 2 assert qml.math.allclose(state_ev, state_rx, atol=1e-6) - def test_batched_state_raises_an_error(self): - """Test that if is_state_batche=True, an error is raised""" + @pytest.mark.parametrize("num_state_wires", [2, 4]) + def test_with_batched_state(self, num_state_wires, mocker): + """Test that a ParametrizedEvolution is applied correctly to a batched state. + Note that the branching logic is different for batched input states, because + evolving the state vector does not support batching of the state. Instead, + the evolved matrix is used always.""" + spy_einsum = mocker.spy(qml.math, "einsum") H = time_independent_hamiltonian() params = np.array([1.0, 2.0]) - t = 4 + t = 0.1 op = qml.pulse.ParametrizedEvolution(H=H, params=params, t=t) @@ -492,11 +490,24 @@ def test_batched_state_raises_an_error(self): [ [[0.81677345 + 0.0j, 0.0 + 0.0j], [0.0 - 0.57695852j, 0.0 + 0.0j]], [[0.33894597 + 0.0j, 0.0 + 0.0j], [0.0 - 0.94080584j, 0.0 + 0.0j]], + [[0.33894597 + 0.0j, 0.0 + 0.0j], [0.0 - 0.94080584j, 0.0 + 0.0j]], ] ) + if num_state_wires == 4: + zero_state_two_wires = np.eye(4)[0].reshape((2, 2)) + initial_state = np.tensordot(initial_state, zero_state_two_wires, axes=0) - with pytest.raises(RuntimeError, match="does not support standard broadcasting"): - _ = apply_operation(op, initial_state, is_state_batched=True) + true_mat = qml.math.expm(-1j * qml.matrix(H(params, t=t)) * t) + U = qml.QubitUnitary(U=true_mat, wires=[0, 1]) + + new_state = apply_operation(op, initial_state, is_state_batched=True) + new_state_expected = apply_operation(U, initial_state, is_state_batched=True) + assert np.allclose(new_state, new_state_expected, atol=0.002) + + if num_state_wires == 4: + assert spy_einsum.call_count == 2 + else: + assert spy_einsum.call_count == 1 @pytest.mark.parametrize("ml_framework", ml_frameworks_list) diff --git a/tests/interfaces/default_qubit_2_integration/test_jax_default_qubit_2.py b/tests/interfaces/default_qubit_2_integration/test_jax_default_qubit_2.py index 15f0913b843..4ee9464e754 100644 --- a/tests/interfaces/default_qubit_2_integration/test_jax_default_qubit_2.py +++ b/tests/interfaces/default_qubit_2_integration/test_jax_default_qubit_2.py @@ -134,7 +134,7 @@ def cost(x, cache): def atol_for_shots(shots): """Return higher tolerance if finite shots.""" - return 2e-2 if shots else 1e-6 + return 3e-2 if shots else 1e-6 @pytest.mark.parametrize("execute_kwargs, shots, device", test_matrix) diff --git a/tests/pulse/test_parametrized_evolution.py b/tests/pulse/test_parametrized_evolution.py index 397f57a567b..fd5803fcf87 100644 --- a/tests/pulse/test_parametrized_evolution.py +++ b/tests/pulse/test_parametrized_evolution.py @@ -776,7 +776,7 @@ def U(params): @pytest.mark.jax def test_map_wires(): - """Test that map wires returns a new ParameterizedEvolution, with wires updated on + """Test that map wires returns a new ParametrizedEvolution, with wires updated on both the operator and the corresponding Hamiltonian""" def f1(p, t):