From 814e991b104e69d900c92388f0c132d013db452f Mon Sep 17 00:00:00 2001 From: Will Date: Fri, 11 Oct 2024 16:06:40 -0400 Subject: [PATCH] Make `qml.QutritBasisStatePreparation` JIT compatible (#6308) This PR makes `qml.QutritBasisStatePreparation` JIT compatible. Previously, the template used the `qml.TShift` operator in the decomposition. However, this approach requires non-jittable control flow on the input. To make the template jittable the decomposition is changed to return a `qml.ops.QutritUnitary` representing a `TShift` operator raised to some power. This change requires adding `matrix_power` and `eigh` to the multi dispatch. This PR addresses sc-70863. --------- Co-authored-by: Guillermo Alonso-Linaje <65235481+KetpuntoG@users.noreply.github.com> --- doc/releases/changelog-dev.md | 3 + .../state_preparations/basis_qutrit.py | 29 ++++-- .../test_qutrit_basis_state_prep.py | 88 ++++++++++++++----- 3 files changed, 95 insertions(+), 25 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 51b5f7f8220..db92dcf79b3 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -69,6 +69,9 @@ * The `Hermitian` operator now has a `compute_sparse_matrix` implementation. [(#6225)](https://github.com/PennyLaneAI/pennylane/pull/6225) +* `qml.QutritBasisStatePreparation` is now JIT compatible. + [(#6308)](https://github.com/PennyLaneAI/pennylane/pull/6308) + * `qml.AmplitudeAmplification` is now compatible with QJIT. [(#6306)](https://github.com/PennyLaneAI/pennylane/pull/6306) diff --git a/pennylane/templates/state_preparations/basis_qutrit.py b/pennylane/templates/state_preparations/basis_qutrit.py index fb79eddb0f4..a0c17d8ba51 100644 --- a/pennylane/templates/state_preparations/basis_qutrit.py +++ b/pennylane/templates/state_preparations/basis_qutrit.py @@ -15,6 +15,8 @@ Contains the BasisStatePreparation template. """ +import numpy as np + import pennylane as qml from pennylane.operation import AnyWires, Operation @@ -77,10 +79,11 @@ def __init__(self, basis_state, wires, id=None): f"Basis states must be of length {len(wires)}; state {i} has length {n_bits}." ) - if any(bit not in [0, 1, 2] for bit in state): - raise ValueError( - f"Basis states must only consist of 0s, 1s, and 2s; state {i} is {state}" - ) + if not qml.math.is_abstract(basis_state): + if any(bit not in [0, 1, 2] for bit in state): + raise ValueError( + f"Basis states must only consist of 0s, 1s, and 2s; state {i} is {state}" + ) # TODO: basis_state should be a hyperparameter, not a trainable parameter. # However, this breaks a test that ensures compatibility with batch_transform. @@ -112,7 +115,23 @@ def compute_decomposition(basis_state, wires): # pylint: disable=arguments-diff """ op_list = [] + + if qml.math.is_abstract(basis_state): + for wire, state in zip(wires, basis_state): + op_list.extend( + [ + qml.TRY(state * (2 - state) * np.pi, wires=wire, subspace=(0, 1)), + qml.TRY(state * (1 - state) * np.pi / 2, wires=wire, subspace=(0, 2)), + qml.TRZ((-2 * state + 3) * state * np.pi, wires=wire, subspace=(0, 2)), + qml.TRY(state * (2 - state) * np.pi, wires=wire, subspace=(0, 2)), + qml.TRY(state * (1 - state) * np.pi / 2, wires=wire, subspace=(0, 1)), + qml.TRZ(-(7 * state - 10) * state * np.pi, wires=wire, subspace=(0, 2)), + ] + ) + return op_list + for wire, state in zip(wires, basis_state): - for _ in range(0, state): + for _ in range(state): op_list.append(qml.TShift(wire)) + return op_list diff --git a/tests/templates/test_state_preparations/test_qutrit_basis_state_prep.py b/tests/templates/test_state_preparations/test_qutrit_basis_state_prep.py index 47702382890..24325afeed8 100644 --- a/tests/templates/test_state_preparations/test_qutrit_basis_state_prep.py +++ b/tests/templates/test_state_preparations/test_qutrit_basis_state_prep.py @@ -102,33 +102,81 @@ def circuit(obs): assert np.allclose(output_state, target_state, atol=tol, rtol=0) @pytest.mark.jax - @pytest.mark.parametrize( - "basis_state,wires,target_state", - [ - ([0, 1], [0, 1], [0, 1, 0]), - ([1, 1, 0], [0, 1, 2], [1, 1, 0]), - ([1, 0, 1], [2, 0, 1], [0, 1, 1]), - ], - ) - @pytest.mark.xfail(reason="JIT comptability not yet implemented") - def test_state_preparation_jax_jit( - self, tol, qutrit_device_3_wires, basis_state, wires, target_state - ): - """Tests that the template produces the correct expectation values.""" + def test_state_preparation_jax_jit(self): + """Tests that the template can be JIT compiled.""" import jax - @qml.qnode(qutrit_device_3_wires, interface="jax") - def circuit(state, obs): - qml.QutritBasisStatePreparation(state, wires) + dev = qml.device("default.qutrit", wires=1) - return [qml.expval(qml.THermitian(A=obs, wires=i)) for i in range(3)] + @qml.qnode(dev) + def circuit(state): + qml.QutritBasisStatePreparation(state, [0]) + return qml.state() circuit = jax.jit(circuit) - obs = np.array([[1, 0, 0], [0, 2, 0], [0, 0, 3]]) - output_state = [x - 1 for x in circuit(basis_state, obs)] + basis_state = qml.math.array([2], like="jax") + output_state = circuit(basis_state) - assert np.allclose(output_state, target_state, atol=tol, rtol=0) + assert qml.math.allclose(output_state, [0, 0, 1]) + + @pytest.mark.jax + def test_state_preparation_with_simpling_jax_jit(self): + """Tests that the template can be compiled with JIT when returning + a sampled measurement.""" + import jax + + n = 2 + + @jax.jit + @qml.qnode(qml.device("default.qutrit", wires=n, shots=1)) + def circuit(state): + qml.QutritBasisStatePreparation(state, wires=range(n)) + return qml.sample(wires=range(n)) + + state = jax.numpy.array([1, 1]) + circuit(state) + + @pytest.mark.jax + @pytest.mark.parametrize("state", [0, 1, 2]) + def test_decomposition_matrix_jax_jit(self, state): + """Tests that the decomposition matrix is correct when JIT compiled.""" + import jax + import jax.numpy as jnp + + tshift = jnp.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) + jit_decomp = jax.jit(qml.QutritBasisStatePreparation.compute_decomposition) + + decomp = jit_decomp(jnp.array([state]), wires=[0]) + matrix = qml.matrix(qml.prod(*decomp[::-1])) + assert qml.math.allclose(matrix, jnp.linalg.matrix_power(tshift, state)) + + @pytest.mark.jax + @pytest.mark.parametrize("state", [0, 1, 2]) + def test_decomposition_pl_gates_jax_jit(self, state): + """Tests that the decomposition gates are correct when JIT compiled.""" + import jax + import jax.numpy as jnp + + jit_decomp = jax.jit( + qml.QutritBasisStatePreparation.compute_decomposition, static_argnames="wires" + ) + + wire = (0,) + state = jnp.array([state]) + decomp = jit_decomp(jnp.array([state]), wires=wire) + + op_list = [ + qml.TRY(state * (2 - state) * np.pi, wires=wire, subspace=(0, 1)), + qml.TRY(state * (1 - state) * np.pi / 2, wires=wire, subspace=(0, 2)), + qml.TRZ((-2 * state + 3) * state * np.pi, wires=wire, subspace=(0, 2)), + qml.TRY(state * (2 - state) * np.pi, wires=wire, subspace=(0, 2)), + qml.TRY(state * (1 - state) * np.pi / 2, wires=wire, subspace=(0, 1)), + qml.TRZ(-(7 * state - 10) * state * np.pi, wires=wire, subspace=(0, 2)), + ] + + for op1, op2 in zip(decomp, op_list): + qml.assert_equal(op1, op2) @pytest.mark.tf @pytest.mark.parametrize(