Skip to content

Commit

Permalink
Make qml.QutritBasisStatePreparation JIT compatible (#6308)
Browse files Browse the repository at this point in the history
This PR makes `qml.QutritBasisStatePreparation` JIT compatible.
Previously, the template used the `qml.TShift` operator in the
decomposition. However, this approach requires non-jittable control flow
on the input. To make the template jittable the decomposition is changed
to return a `qml.ops.QutritUnitary` representing a `TShift` operator
raised to some power. This change requires adding `matrix_power` and
`eigh` to the multi dispatch. This PR addresses sc-70863.

---------

Co-authored-by: Guillermo Alonso-Linaje <[email protected]>
  • Loading branch information
willjmax and KetpuntoG authored Oct 11, 2024
1 parent 0c87b9a commit 814e991
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 25 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@
* The `Hermitian` operator now has a `compute_sparse_matrix` implementation.
[(#6225)](https://github.com/PennyLaneAI/pennylane/pull/6225)

* `qml.QutritBasisStatePreparation` is now JIT compatible.
[(#6308)](https://github.com/PennyLaneAI/pennylane/pull/6308)

* `qml.AmplitudeAmplification` is now compatible with QJIT.
[(#6306)](https://github.com/PennyLaneAI/pennylane/pull/6306)

Expand Down
29 changes: 24 additions & 5 deletions pennylane/templates/state_preparations/basis_qutrit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
Contains the BasisStatePreparation template.
"""

import numpy as np

import pennylane as qml
from pennylane.operation import AnyWires, Operation

Expand Down Expand Up @@ -77,10 +79,11 @@ def __init__(self, basis_state, wires, id=None):
f"Basis states must be of length {len(wires)}; state {i} has length {n_bits}."
)

if any(bit not in [0, 1, 2] for bit in state):
raise ValueError(
f"Basis states must only consist of 0s, 1s, and 2s; state {i} is {state}"
)
if not qml.math.is_abstract(basis_state):
if any(bit not in [0, 1, 2] for bit in state):
raise ValueError(
f"Basis states must only consist of 0s, 1s, and 2s; state {i} is {state}"
)

# TODO: basis_state should be a hyperparameter, not a trainable parameter.
# However, this breaks a test that ensures compatibility with batch_transform.
Expand Down Expand Up @@ -112,7 +115,23 @@ def compute_decomposition(basis_state, wires): # pylint: disable=arguments-diff
"""

op_list = []

if qml.math.is_abstract(basis_state):
for wire, state in zip(wires, basis_state):
op_list.extend(
[
qml.TRY(state * (2 - state) * np.pi, wires=wire, subspace=(0, 1)),
qml.TRY(state * (1 - state) * np.pi / 2, wires=wire, subspace=(0, 2)),
qml.TRZ((-2 * state + 3) * state * np.pi, wires=wire, subspace=(0, 2)),
qml.TRY(state * (2 - state) * np.pi, wires=wire, subspace=(0, 2)),
qml.TRY(state * (1 - state) * np.pi / 2, wires=wire, subspace=(0, 1)),
qml.TRZ(-(7 * state - 10) * state * np.pi, wires=wire, subspace=(0, 2)),
]
)
return op_list

for wire, state in zip(wires, basis_state):
for _ in range(0, state):
for _ in range(state):
op_list.append(qml.TShift(wire))

return op_list
Original file line number Diff line number Diff line change
Expand Up @@ -102,33 +102,81 @@ def circuit(obs):
assert np.allclose(output_state, target_state, atol=tol, rtol=0)

@pytest.mark.jax
@pytest.mark.parametrize(
"basis_state,wires,target_state",
[
([0, 1], [0, 1], [0, 1, 0]),
([1, 1, 0], [0, 1, 2], [1, 1, 0]),
([1, 0, 1], [2, 0, 1], [0, 1, 1]),
],
)
@pytest.mark.xfail(reason="JIT comptability not yet implemented")
def test_state_preparation_jax_jit(
self, tol, qutrit_device_3_wires, basis_state, wires, target_state
):
"""Tests that the template produces the correct expectation values."""
def test_state_preparation_jax_jit(self):
"""Tests that the template can be JIT compiled."""
import jax

@qml.qnode(qutrit_device_3_wires, interface="jax")
def circuit(state, obs):
qml.QutritBasisStatePreparation(state, wires)
dev = qml.device("default.qutrit", wires=1)

return [qml.expval(qml.THermitian(A=obs, wires=i)) for i in range(3)]
@qml.qnode(dev)
def circuit(state):
qml.QutritBasisStatePreparation(state, [0])
return qml.state()

circuit = jax.jit(circuit)

obs = np.array([[1, 0, 0], [0, 2, 0], [0, 0, 3]])
output_state = [x - 1 for x in circuit(basis_state, obs)]
basis_state = qml.math.array([2], like="jax")
output_state = circuit(basis_state)

assert np.allclose(output_state, target_state, atol=tol, rtol=0)
assert qml.math.allclose(output_state, [0, 0, 1])

@pytest.mark.jax
def test_state_preparation_with_simpling_jax_jit(self):
"""Tests that the template can be compiled with JIT when returning
a sampled measurement."""
import jax

n = 2

@jax.jit
@qml.qnode(qml.device("default.qutrit", wires=n, shots=1))
def circuit(state):
qml.QutritBasisStatePreparation(state, wires=range(n))
return qml.sample(wires=range(n))

state = jax.numpy.array([1, 1])
circuit(state)

@pytest.mark.jax
@pytest.mark.parametrize("state", [0, 1, 2])
def test_decomposition_matrix_jax_jit(self, state):
"""Tests that the decomposition matrix is correct when JIT compiled."""
import jax
import jax.numpy as jnp

tshift = jnp.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
jit_decomp = jax.jit(qml.QutritBasisStatePreparation.compute_decomposition)

decomp = jit_decomp(jnp.array([state]), wires=[0])
matrix = qml.matrix(qml.prod(*decomp[::-1]))
assert qml.math.allclose(matrix, jnp.linalg.matrix_power(tshift, state))

@pytest.mark.jax
@pytest.mark.parametrize("state", [0, 1, 2])
def test_decomposition_pl_gates_jax_jit(self, state):
"""Tests that the decomposition gates are correct when JIT compiled."""
import jax
import jax.numpy as jnp

jit_decomp = jax.jit(
qml.QutritBasisStatePreparation.compute_decomposition, static_argnames="wires"
)

wire = (0,)
state = jnp.array([state])
decomp = jit_decomp(jnp.array([state]), wires=wire)

op_list = [
qml.TRY(state * (2 - state) * np.pi, wires=wire, subspace=(0, 1)),
qml.TRY(state * (1 - state) * np.pi / 2, wires=wire, subspace=(0, 2)),
qml.TRZ((-2 * state + 3) * state * np.pi, wires=wire, subspace=(0, 2)),
qml.TRY(state * (2 - state) * np.pi, wires=wire, subspace=(0, 2)),
qml.TRY(state * (1 - state) * np.pi / 2, wires=wire, subspace=(0, 1)),
qml.TRZ(-(7 * state - 10) * state * np.pi, wires=wire, subspace=(0, 2)),
]

for op1, op2 in zip(decomp, op_list):
qml.assert_equal(op1, op2)

@pytest.mark.tf
@pytest.mark.parametrize(
Expand Down

0 comments on commit 814e991

Please sign in to comment.