Skip to content

Commit

Permalink
move the MPS shape chacker to Cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
LuisAlfredoNu committed Dec 16, 2024
1 parent c1206e7 commit 93b2480
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 156 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,31 @@ class TNCuda : public TNCudaBase<PrecisionT, Derived> {
host_data_size);
}

/**
* @brief Check if the provided MPS has the correct dimension for C++
* backend.
*
* @param MPS_shape_source Dimension list of incoming MPS.
*/

void MPSShapeCheck(
const std::vector<std::vector<std::size_t>> &MPS_shape_source) {
bool sameShape = sitesExtents_ == MPS_shape_source;
if (!sameShape) {
auto MPS_shape_source_str =
Pennylane::Util::vector2DToString<std::size_t>(
MPS_shape_source);
auto MPS_shape_dest_str =
Pennylane::Util::vector2DToString<std::size_t>(sitesExtents_);

PL_ABORT_IF_NOT(
sitesExtents_ == MPS_shape_source,
"The incoming MPS does not have the correct layout for "
"lightning.tensor.\n Incoming MPS: " +
MPS_shape_source_str +
"\n Destination MPS: " + MPS_shape_dest_str)
}
}
/**
* @brief Append multiple gates to the compute graph.
* NOTE: This function does not update the quantum state but only appends
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,19 @@ void registerBackendClassSpecificBindingsMPS(PyClass &pyclass) {
.def(
"updateMPSSitesData",
[](TensorNet &tensor_network, std::vector<np_arr_c> &tensors) {
// Extract the incoming MPS shape
std::vector<std::vector<std::size_t>> MPS_shape_source;
for (std::size_t idx = 0; idx < tensors.size(); idx++) {
py::buffer_info numpyArrayInfo = tensors[idx].request();
auto MPS_site_source_shape = numpyArrayInfo.shape;
std::vector<std::size_t> MPS_site_source(
MPS_site_source_shape.begin(),
MPS_site_source_shape.end());
MPS_shape_source.push_back(std::move(MPS_site_source));
}

tensor_network.MPSShapeCheck(MPS_shape_source);

for (std::size_t idx = 0; idx < tensors.size(); idx++) {
py::buffer_info numpyArrayInfo = tensors[idx].request();
auto *data_ptr = static_cast<std::complex<PrecisionT> *>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,54 @@ TEMPLATE_TEST_CASE("MPSTNCuda::setIthMPSSite", "[MPSTNCuda]", float, double) {
}
}

TEMPLATE_TEST_CASE("MPSTNCuda::MPSShapeCheck()", "[MPSTNCuda]", float, double) {
const std::size_t num_qubits = 4;
const std::size_t maxBondDim = 8;
SECTION("Correct incoming MPS shape") {

Check notice on line 127 in pennylane_lightning/core/src/simulators/lightning_tensor/tncuda/tests/Tests_MPSTNCuda.cpp

View check run for this annotation

codefactor.io / CodeFactor

pennylane_lightning/core/src/simulators/lightning_tensor/tncuda/tests/Tests_MPSTNCuda.cpp#L127

Redundant blank line at the start of a code block should be deleted. (whitespace/blank_line)
MPSTNCuda<TestType> mps_state{num_qubits, maxBondDim};

std::vector<std::vector<std::size_t>> correct_shape{
{2, 2}, {2, 2, 4}, {4, 2, 2}, {2, 2}};

REQUIRE_NOTHROW(mps_state.MPSShapeCheck(correct_shape));
}

SECTION("Incorrect incoming MPS shape, bond dimension") {
MPSTNCuda<TestType> mps_state{num_qubits, maxBondDim};

std::vector<std::vector<std::size_t>> incorrect_shape{
{2, 2}, {2, 2, 2}, {2, 2, 2}, {2, 2}};

REQUIRE_THROWS_WITH(
mps_state.MPSShapeCheck(incorrect_shape),
Catch::Matchers::Contains("The incoming MPS does not have the "
"correct layout for lightning.tensor"));
}
SECTION("Incorrect incoming MPS shape, physical dimension") {
MPSTNCuda<TestType> mps_state{num_qubits, maxBondDim};

std::vector<std::vector<std::size_t>> incorrect_shape{
{4, 2}, {2, 4, 4}, {4, 4, 2}, {2, 4}};

REQUIRE_THROWS_WITH(
mps_state.MPSShapeCheck(incorrect_shape),
Catch::Matchers::Contains("The incoming MPS does not have the "
"correct layout for lightning.tensor"));
}
SECTION("Incorrect incoming MPS shape, number sites") {
MPSTNCuda<TestType> mps_state{num_qubits, maxBondDim};

std::vector<std::vector<std::size_t>> incorrect_shape{
{2, 2}, {2, 2, 2}, {2, 2}};

REQUIRE_THROWS_WITH(
mps_state.MPSShapeCheck(incorrect_shape),
Catch::Matchers::Contains("The incoming MPS does not have the "
"correct layout for lightning.tensor"));
}
}

TEMPLATE_TEST_CASE("MPSTNCuda::SetBasisStates() & reset()", "[MPSTNCuda]",
float, double) {
std::vector<std::vector<std::size_t>> basisStates = {
Expand Down
30 changes: 30 additions & 0 deletions pennylane_lightning/core/src/utils/Util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,4 +592,34 @@ bool areVecsDisjoint(const std::vector<T> &v1, const std::vector<T> &v2) {
}
return true;
}

/**
* @brief Convert a 2D vector to string.
* @tparam T Data type.
* @param vec Vector to convert.
*
* @return std::string String with the vector values.
*/
template <typename T>
std::string vector2DToString(const std::vector<std::vector<T>> &vec) {
std::ostringstream oss;
oss << "[";

for (std::size_t i = 0; i < vec.size(); ++i) {
oss << "[";
for (std::size_t j = 0; j < vec[i].size(); ++j) {
oss << vec[i][j];
if (j != vec[i].size() - 1) {
oss << ", "; // Add a comma between elements in the inner vector
}
}
oss << "]";
if (i != vec.size() - 1) {
oss << ", "; // Add a comma between inner vectors
}
}
oss << "]";
return oss.str(); // Return the resulting string
}

} // namespace Pennylane::Util
12 changes: 12 additions & 0 deletions pennylane_lightning/core/src/utils/tests/Test_Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,15 @@ TEST_CASE("Util::areVecsDisjoint", "[Util][LinearAlgebra]") {
REQUIRE(areVecsDisjoint(vec0, vec1) == false);
}
}

TEST_CASE("Utils::vector2DToString", "[Utils]") {
SECTION("Test for convert 2Dvector to string") {
std::vector<std::vector<std::size_t>> vec{
{2, 2, 4}, {4, 2, 8}, {8, 2, 8}};
std::string ref_str{"[[2, 2, 4], [4, 2, 8], [8, 2, 8]]"};

std::string vec2str = vector2DToString<std::size_t>(vec);

REQUIRE(ref_str == vec2str);
}
}
71 changes: 2 additions & 69 deletions pennylane_lightning/lightning_tensor/_tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
except ImportError:
pass

from typing import List

import numpy as np
import pennylane as qml
from pennylane import BasisState, DeviceError, MPSPrep, StatePrep
Expand Down Expand Up @@ -125,56 +123,6 @@ def gate_matrix_decompose(gate_ops_matrix, wires, max_mpo_bond_dim, c_dtype):
return mpos, sorted_wires


def set_bond_dims(num_qubits: int, max_bond_dim: int) -> List:
"""Compute the MPS bond dimensions on base to the number of wires."""

log_max_bond_dim = np.log2(max_bond_dim)
limit_dimension = 2 ** int(log_max_bond_dim)
localBondDims = [limit_dimension] * (num_qubits - 1)

for i in range(len(localBondDims)):
bondDim = min(i + 1, num_qubits - i - 1)
if bondDim <= log_max_bond_dim:
localBondDims[i] = 2**bondDim

return localBondDims


def set_sites_extents(num_qubits: int, max_bond_dim: int) -> List:
"""Compute the MPS sites dimensions on base to the number of wires."""

bondDims = set_bond_dims(num_qubits, max_bond_dim)
qubitDims = [2 for _ in range(num_qubits)]

localSiteExtents = []
for i in range(num_qubits):
if i == 0:
localSite = [qubitDims[i], bondDims[i]]
elif i == num_qubits - 1:
localSite = [bondDims[i - 1], qubitDims[i]]
else:
localSite = [bondDims[i - 1], qubitDims[i], bondDims[i]]

localSiteExtents.append(localSite)

return localSiteExtents


def MPSPrep_check(MPS: List, num_wires: int, max_bond_dim: int) -> None:
"""Check if the provided MPS has the correct dimension for C++ backend."""

MPS_shape_dest = set_sites_extents(num_wires, max_bond_dim)

MPS_shape_source = [list(site.shape) for site in MPS]

same_shape = [s == d for s, d in zip(MPS_shape_source, MPS_shape_dest)]

if not all(same_shape):
raise ValueError(
f"The custom MPS does not have the correct layout for lightning.tensor.\n MPS source shape {MPS_shape_source}\n MPS destination shape {MPS_shape_dest}"
)


# pylint: disable=too-many-instance-attributes
class LightningTensorNet:
"""Lightning tensornet class.
Expand Down Expand Up @@ -376,22 +324,6 @@ def _apply_basis_state(self, state, wires):

self._tensornet.setBasisState(state)

def _load_mps_state(self, state: List):
"""Prepares an initial state using MPS.
Args:
state (List): A list of different numpy array with the MPS sites values. The structure should be as follows:
[ (2, 2), (2, 2, 4), (4, 2, 8), ...,
(8, 2, 4), (4, 2, 2), (2, 2) ]
wires (List): wires that the provided computational state should be
initialized on.
Note: The correct MPS sites format and layout are user responsible.
"""
mps = state.mps
MPSPrep_check(mps, self._num_wires, self._max_bond_dim)
self._tensornet.updateMPSSitesData(mps)

def _apply_MPO(self, gate_matrix, wires):
"""Apply a matrix product operator to the quantum state (MPS method only).
Expand Down Expand Up @@ -513,7 +445,8 @@ def apply_operations(self, operations):
self._apply_basis_state(operations[0].parameters[0], operations[0].wires)
operations = operations[1:]
elif isinstance(operations[0], MPSPrep):
self._load_mps_state(operations[0])
mps = operations[0].mps
self._tensornet.updateMPSSitesData(mps)
operations = operations[1:]

self._apply_lightning(operations)
Expand Down
88 changes: 1 addition & 87 deletions tests/lightning_tensor/test_tensornet_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,8 @@
else:
from pennylane_lightning.lightning_tensor._tensornet import (
LightningTensorNet,
MPSPrep_check,
decompose_dense,
gate_matrix_decompose,
set_bond_dims,
set_sites_extents,
)

if not LightningDevice._CPP_BINARY_AVAILABLE: # pylint: disable=protected-access
Expand Down Expand Up @@ -74,7 +71,7 @@ def test_wrong_device_name():


def test_wrong_method_name():
"""Test an invalid device name"""
"""Test an invalid method name"""
with pytest.raises(qml.DeviceError, match="The method "):
LightningTensorNet(3, max_bond_dim=5, device_name="lightning.tensor", method="spider_web")

Expand Down Expand Up @@ -138,86 +135,3 @@ def test_gate_matrix_decompose():

assert np.allclose(sorted_wired, sorted(wires), atol=1e-6)
assert np.allclose(unitary_f, original_gate, atol=1e-6)


@pytest.mark.parametrize(
"n_qubits,max_bond,expected",
[
(2, 128, [2]),
(8, 128, [2, 4, 8, 16, 8, 4, 2]),
(8, 8, [2, 4, 8, 8, 8, 4, 2]),
(15, 2, [2 for _ in range(14)]),
],
)
def test_set_bond_dims(n_qubits, max_bond, expected):

result = set_bond_dims(n_qubits, max_bond)

assert len(result) == len(expected)
assert all([a == b for a, b in zip(result, expected)])


@pytest.mark.parametrize(
"n_qubits,max_bond,expected",
[
(2, 128, [[2, 2], [2, 2]]),
(
8,
128,
[[2, 2], [2, 2, 4], [4, 2, 8], [8, 2, 16], [16, 2, 8], [8, 2, 4], [4, 2, 2], [2, 2]],
),
(8, 8, [[2, 2], [2, 2, 4], [4, 2, 8], [8, 2, 8], [8, 2, 8], [8, 2, 4], [4, 2, 2], [2, 2]]),
(15, 2, [[2, 2]] + [[2, 2, 2] for _ in range(13)] + [[2, 2]]),
],
)
def test_set_sites_extents(n_qubits, max_bond, expected):

result = set_sites_extents(n_qubits, max_bond)

assert len(result) == len(expected)
assert all([a == b for a, b in zip(result, expected)])


@pytest.mark.parametrize(
"wires,max_bond,MPS_shape",
[
(2, 128, [[2, 2], [2, 2]]),
(
8,
128,
[[2, 2], [2, 2, 4], [4, 2, 8], [8, 2, 16], [16, 2, 8], [8, 2, 4], [4, 2, 2], [2, 2]],
),
(8, 8, [[2, 2], [2, 2, 4], [4, 2, 8], [8, 2, 8], [8, 2, 8], [8, 2, 4], [4, 2, 2], [2, 2]]),
(15, 2, [[2, 2]] + [[2, 2, 2] for _ in range(13)] + [[2, 2]]),
],
)
def test_MPSPrep_check_pass(wires, max_bond, MPS_shape):
MPS = [np.zeros(i) for i in MPS_shape]

MPSPrep_check(MPS, wires, max_bond)


@pytest.mark.parametrize(
"wires,max_bond,MPS_shape",
[
(2, 128, [[2, 3], [3, 2]]), # Incorrect bond dim.
(
8,
128,
[[2, 2], [2, 4, 4], [4, 4, 8], [8, 4, 16], [16, 4, 8], [8, 4, 4], [4, 4, 2], [2, 2]],
), # Incorrect physical dim.
(
8,
8,
[[2, 2], [3, 2, 4], [4, 2, 8], [8, 2, 8], [8, 2, 8], [8, 2, 4], [4, 2, 2], [2, 2]],
), # Incorrect only one bond dim.
(15, 2, [[2, 2]] + [[2, 2, 2] for _ in range(14)] + [[2, 2]]), # Incorrect amount of sites
],
)
def test_MPSPrep_check_fail(wires, max_bond, MPS_shape):
MPS = [np.zeros(i) for i in MPS_shape]

with pytest.raises(
ValueError, match="The custom MPS does not have the correct layout for lightning.tensor"
):
MPSPrep_check(MPS, wires, max_bond)

0 comments on commit 93b2480

Please sign in to comment.