Skip to content

Commit

Permalink
added test in Python Layer
Browse files Browse the repository at this point in the history
  • Loading branch information
LuisAlfredoNu committed Dec 17, 2024
1 parent b4d10d8 commit de866a0
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 3 deletions.
10 changes: 7 additions & 3 deletions pennylane_lightning/lightning_tensor/_tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,9 +445,13 @@ def apply_operations(self, operations):
self._apply_basis_state(operations[0].parameters[0], operations[0].wires)
operations = operations[1:]
elif isinstance(operations[0], MPSPrep):
mps = operations[0].mps
self._tensornet.updateMPSSitesData(mps)
operations = operations[1:]
if self.method == "tn":
raise DeviceError("Exact Tensor Network does not support MPSPrep")

if self.method == "mps":
mps = operations[0].mps
self._tensornet.updateMPSSitesData(mps)
operations = operations[1:]

self._apply_lightning(operations)

Expand Down
85 changes: 85 additions & 0 deletions tests/lightning_tensor/test_lightning_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
pytest.skip("Skipping tests for the LightningTensor class.", allow_module_level=True)
else:
from pennylane_lightning.lightning_tensor import LightningTensor
from pennylane_lightning.lightning_tensor_ops import LightningException


if not LightningDevice._CPP_BINARY_AVAILABLE: # pylint: disable=protected-access
pytest.skip("Device doesn't have C++ support yet.", allow_module_level=True)
Expand Down Expand Up @@ -157,3 +159,86 @@ def test_execute_and_compute_vjp(self, method):
match="The computation of vector-Jacobian product has yet to be implemented for the lightning.tensor device.",
):
dev.execute_and_compute_vjp(circuits=None, cotangents=None)

@pytest.mark.parametrize(
"wires,max_bond,MPS_shape",
[
(2, 128, [[2, 2], [2, 2]]),
(
8,
128,
[[2, 2], [2, 2, 4], [4, 2, 8], [8, 2, 16], [16, 2, 8], [8, 2, 4], [4, 2, 2], [2, 2]],
),
(8, 8, [[2, 2], [2, 2, 4], [4, 2, 8], [8, 2, 8], [8, 2, 8], [8, 2, 4], [4, 2, 2], [2, 2]]),
(15, 2, [[2, 2]] + [[2, 2, 2] for _ in range(13)] + [[2, 2]]),
],
)
def test_MPSPrep_check_pass(wires, max_bond, MPS_shape):
"""Test the correct behavior regarding MPS shape of MPSPrep."""
MPS = [np.zeros(i) for i in MPS_shape]
dev = LightningTensor(wires=wires, method='mps', max_bond_dim=max_bond)
dev_wires = dev.wires.tolist()

def circuit(MPS):
qml.MPSPrep(mps=MPS, wires=dev_wires)
return qml.state()

qnode_ltensor = qml.QNode(circuit, dev)

try:
_ = qnode_ltensor(MPS)
except Exception as excinfo:
pytest.fail(f"Unexpected exception raised: {excinfo}")

@pytest.mark.parametrize(
"wires,max_bond,MPS_shape",
[
(
8,
8,
[[2, 2], [2, 2, 4], [4, 2, 8], [8, 2, 16], [16, 2, 8], [8, 2, 4], [4, 2, 2], [2, 2]],
), # Incorrect max bond dim.
(15, 2, [[2, 2]] + [[2, 2, 2] for _ in range(14)] + [[2, 2]]), # Incorrect amount of sites
],
)
def test_MPSPrep_check_fail(wires, max_bond, MPS_shape):
"""Test the exceptions regarding MPS shape of MPSPrep."""

MPS = [np.zeros(i) for i in MPS_shape]
dev = LightningTensor(wires=wires, method='mps', max_bond_dim=max_bond)
dev_wires = dev.wires.tolist()

def circuit(MPS):
qml.MPSPrep(mps=MPS, wires=dev_wires)
return qml.state()

qnode_ltensor = qml.QNode(circuit, dev)

with pytest.raises(
LightningException, match="The incoming MPS does not have the correct layout for lightning.tensor"
):
_ = qnode_ltensor(MPS)

@pytest.mark.parametrize(
"wires, MPS_shape",
[
(2, [[2, 2], [2, 2]]),
],
)
def test_MPSPrep_with_tn(wires, MPS_shape):
"""Test the exception of MPSPrep with the method exact tensor network (tn)."""

MPS = [np.zeros(i) for i in MPS_shape]
dev = LightningTensor(wires=wires, method='tn')
dev_wires = dev.wires.tolist()

def circuit(MPS):
qml.MPSPrep(mps=MPS, wires=dev_wires)
return qml.state()

qnode_ltensor = qml.QNode(circuit, dev)

with pytest.raises(
qml.DeviceError, match="Exact Tensor Network does not support MPSPrep"
):
_ = qnode_ltensor(MPS)

0 comments on commit de866a0

Please sign in to comment.