Skip to content

Commit

Permalink
Add the state_vector, measurement class and simulate method for…
Browse files Browse the repository at this point in the history
… the LightningGPU with the new device API (#892)

### Before submitting

Please complete the following checklist when submitting a PR:

- [X] All new features must include a unit test.
If you've fixed a bug or added code that should be tested, add a test to
the
      [`tests`](../tests) directory!

- [X] All new functions and code must be clearly commented and
documented.
If you do make documentation changes, make sure that the docs build and
      render correctly by running `make docs`.

- [X] Ensure that the test suite passes, by running `make test`.

- [ ] Add a new entry to the `.github/CHANGELOG.md` file, summarizing
the
      change, and including a link back to the PR.

- [X] Ensure that code is properly formatted by running `make format`. 

When all the above are checked, delete everything above the dashed
line and fill in the pull request template.


------------------------------------------------------------------------------------------------------------

**Context:**
Migrate LightningGPU to the new device API

**Description of the Change:**
Create the `state_vector`, and `measurement` class for the new device
API to achieve the `simulate` method

**Benefits:**
Integration of LGPU with the new device API

**Possible Drawbacks:**

**Related GitHub Issues:**
## **Freezzed PR** ⚠️ ❄️ 
To make a smooth integration of LightningGPU with the new device API, we
set the branch `gpuNewAPI_backend` as the base branch target for future
developments related to this big task.

The branch `gpuNewAPI_backend` has the mock of all classes and methods
necessary for the new API. Also, several tests were disabled with
``` python
if device_name == "lightning.gpu":
    pytest.skip("LGPU new API in WIP.  Skipping.",allow_module_level=True)
```
However, these tests will unblocked as the implementation progresses.

After all the developments for integrating LightningGPU with the new API
have been completed then the PR will be open to merge to `master`

[sc-70932]

---------

Co-authored-by: Vincent Michaud-Rioux <[email protected]>
Co-authored-by: Ali Asadi <[email protected]>
  • Loading branch information
3 people authored Sep 16, 2024
1 parent b375dd3 commit da75534
Show file tree
Hide file tree
Showing 16 changed files with 717 additions and 75 deletions.
19 changes: 16 additions & 3 deletions pennylane_lightning/core/_measurements_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,24 +130,37 @@ def expval(self, measurementprocess: MeasurementProcess):
measurementprocess.obs.name, measurementprocess.obs.wires
)

def _probs_retval_conversion(self, probs_results: Any) -> np.ndarray:
"""Convert the data structure from the C++ backend to a common structure through lightning devices.
Args:
probs_result (Any): Result provided by C++ backend.
Returns:
np.ndarray with probabilities of the supplied observable or wires.
"""
return probs_results

def probs(self, measurementprocess: MeasurementProcess):
"""Probabilities of the supplied observable or wires contained in the MeasurementProcess.
Args:
measurementprocess (StateMeasurement): measurement to apply to the state
measurementprocess (StateMeasurement): measurement to apply to the state.
Returns:
Probabilities of the supplied observable or wires
Probabilities of the supplied observable or wires.
"""
diagonalizing_gates = measurementprocess.diagonalizing_gates()

if diagonalizing_gates:
self._qubit_state.apply_operations(diagonalizing_gates)

results = self._measurement_lightning.probs(measurementprocess.wires.tolist())

if diagonalizing_gates:
self._qubit_state.apply_operations(
[qml.adjoint(g, lazy=False) for g in reversed(diagonalizing_gates)]
)
return results

return self._probs_retval_conversion(results)

def var(self, measurementprocess: MeasurementProcess):
"""Variance of the supplied observable contained in the MeasurementProcess.
Expand Down
30 changes: 22 additions & 8 deletions pennylane_lightning/core/_state_vector_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""

from abc import ABC, abstractmethod
from typing import Union
from typing import Optional, Union

import numpy as np
from pennylane import BasisState, StatePrep
Expand All @@ -35,16 +35,20 @@ class LightningBaseStateVector(ABC):
num_wires(int): the number of wires to initialize the device with
dtype: Datatypes for state-vector representation. Must be one of
``np.complex64`` or ``np.complex128``. Default is ``np.complex128``
sync Optional(bool): immediately sync with host-sv after applying operation.
"""

def __init__(self, num_wires: int, dtype: Union[np.complex128, np.complex64]):
def __init__(
self, num_wires: int, dtype: Union[np.complex128, np.complex64], sync: Optional[bool] = None
):

if dtype not in [np.complex64, np.complex128]:
raise TypeError(f"Unsupported complex type: {dtype}")

self._num_wires = num_wires
self._wires = Wires(range(num_wires))
self._dtype = dtype
self._base_sync = sync

# Dummy for the device name
self._device_name = None
Expand Down Expand Up @@ -96,28 +100,32 @@ def _state_dtype(self):
Returns: the state vector class
"""

def reset_state(self):
def reset_state(self, sync: Optional[bool] = None):
"""Reset the device's state"""
# init the state vector to |00..0>
self._qubit_state.resetStateVector()
if sync == None:
self._qubit_state.resetStateVector()
else:
self._qubit_state.resetStateVector(sync)

@abstractmethod
def _apply_state_vector(self, state, device_wires: Wires):
def _apply_state_vector(self, state, device_wires: Wires, sync: Optional[bool] = None):
"""Initialize the internal state vector in a specified state.
Args:
state (array[complex]): normalized input state of length ``2**len(wires)``
or broadcasted state of shape ``(batch_size, 2**len(wires))``
device_wires (Wires): wires that get initialized in the state
"""

def _apply_basis_state(self, state, wires):
def _apply_basis_state(self, state, wires, use_async: Optional[bool] = None):
"""Initialize the state vector in a specified computational basis state.
Args:
state (array[int]): computational basis state of shape ``(wires,)``
consisting of 0s and 1s.
wires (Wires): wires that the provided computational state should be
initialized on
use_async(Optional[bool]): immediately sync with host-sv after applying operation.
Note: This function does not support broadcasted inputs yet.
"""
Expand All @@ -128,7 +136,11 @@ def _apply_basis_state(self, state, wires):
raise ValueError("BasisState parameter and wires must be of equal length.")

# Return a computational basis state over all wires.
self._qubit_state.setBasisState(list(state), list(wires))
print("FSX:", use_async)
if use_async == None:
self._qubit_state.setBasisState(list(state), list(wires))
else:
self._qubit_state.setBasisState(list(state), list(wires), use_async)

@abstractmethod
def _apply_lightning_controlled(self, operation):
Expand Down Expand Up @@ -185,7 +197,9 @@ def apply_operations(
self._apply_state_vector(operations[0].parameters[0].copy(), operations[0].wires)
operations = operations[1:]
elif isinstance(operations[0], BasisState):
self._apply_basis_state(operations[0].parameters[0], operations[0].wires)
self._apply_basis_state(
operations[0].parameters[0], operations[0].wires, self._base_sync
)
operations = operations[1:]
self._apply_lightning(
operations, mid_measurements=mid_measurements, postselect_mode=postselect_mode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,42 @@ class StateVectorCudaManaged
stream_id);
}

/**
* @brief Prepare a single computational basis state.
*
* @param state Binary number representing the index
* @param wires Wires.
* @param use_async(Optional[bool]): immediately sync with host-sv after
applying operation.
*/
void setBasisState(const std::vector<std::size_t> &state,
const std::vector<std::size_t> &wires,
const bool use_async) {
PL_ABORT_IF_NOT(state.size() == wires.size(),
"state and wires must have equal dimensions.");
const auto num_qubits = BaseType::getNumQubits();
PL_ABORT_IF_NOT(
std::find_if(wires.begin(), wires.end(),
[&num_qubits](const auto i) {
return i >= num_qubits;
}) == wires.end(),
"wires must take values lower than the number of qubits.");
const auto n_wires = wires.size();
std::size_t index{0U};
for (std::size_t k = 0; k < n_wires; k++) {
const auto bit = static_cast<std::size_t>(state[k]);
index |= bit << (num_qubits - 1 - wires[k]);
}

BaseType::getDataBuffer().zeroInit();
const std::complex<PrecisionT> value(1, 0);
CFP_t value_cu = cuUtil::complexToCu<std::complex<Precision>>(value);
auto stream_id = BaseType::getDataBuffer().getDevTag().getStreamID();
setBasisState_CUDA(BaseType::getData(), value_cu, index, use_async,
stream_id);
}

/**
* @brief Set values for a batch of elements of the state-vector. This
* method is implemented by the customized CUDA kernel defined in the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,21 @@ void registerBackendClassSpecificBindings(PyClass &pyclass) {
static_cast<std::size_t>(arr.size()));
}))
.def(
"setBasisState",
[](StateVectorT &sv, const std::size_t index,
const bool use_async) {
"setBasisStateZero",
[](StateVectorT &sv, const bool use_async) {
const std::complex<PrecisionT> value(1, 0);
sv.setBasisState(value, index, use_async);
std::size_t zero{0U};
sv.setBasisState(value, zero, use_async);
},
"Create Basis State on GPU.")
"Create Basis State to zero on GPU.")
.def(
"setBasisState",
[](StateVectorT &sv, const std::vector<std::size_t> &state,
const std::vector<std::size_t> &wires, const bool use_async) {
sv.setBasisState(state, wires, use_async);
},
"Set the state vector to a basis state on GPU.")

.def(
"setStateVector",
[](StateVectorT &sv, const np_arr_sparse_ind &indices,
Expand Down Expand Up @@ -152,7 +160,7 @@ void registerBackendClassSpecificBindings(PyClass &pyclass) {
"Get the GPU index for the statevector data.")
.def("numQubits", &StateVectorT::getNumQubits)
.def("dataLength", &StateVectorT::getLength)
.def("resetGPU", &StateVectorT::initSV)
.def("resetStateVector", &StateVectorT::initSV)
.def(
"apply",
[](StateVectorT &sv, const std::string &str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ void registerBackendClassSpecificBindingsMPI(PyClass &pyclass) {
"Get the GPU index for the statevector data.")
.def("numQubits", &StateVectorT::getNumQubits)
.def("dataLength", &StateVectorT::getLength)
.def("resetGPU", &StateVectorT::initSV)
.def("resetStateVector", &StateVectorT::initSV)
.def(
"apply",
[](StateVectorT &sv, const std::string &str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ class Measurements final
PL_CUSTATEVEC_IS_SUCCESS(custatevecSamplerSample(
this->_statevector.getCusvHandle(), sampler, bitStrings.data(),
bitOrdering.data(), bitStringLen, rand_nums.data(), num_samples,
CUSTATEVEC_SAMPLER_OUTPUT_ASCENDING_ORDER));
CUSTATEVEC_SAMPLER_OUTPUT_RANDNUM_ORDER));
PL_CUDA_IS_SUCCESS(cudaStreamSynchronize(
this->_statevector.getDataBuffer().getDevTag().getStreamID()));

Expand Down
106 changes: 106 additions & 0 deletions pennylane_lightning/lightning_gpu/_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,30 @@
Class implementation for state vector measurements.
"""

from warnings import warn

try:
from pennylane_lightning.lightning_gpu_ops import MeasurementsC64, MeasurementsC128

try:
from pennylane_lightning.lightning_gpu_ops import MeasurementsMPIC64, MeasurementsMPIC128

Check notice on line 25 in pennylane_lightning/lightning_gpu/_measurements.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane_lightning/lightning_gpu/_measurements.py#L25

Too few public methods (0/2) (too-few-public-methods)
MPI_SUPPORT = True
except ImportError as ex:
warn(str(ex), UserWarning)

MPI_SUPPORT = False

except ImportError as ex:
warn(str(ex), UserWarning)

pass

Check notice on line 35 in pennylane_lightning/lightning_gpu/_measurements.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane_lightning/lightning_gpu/_measurements.py#L35

Unnecessary pass statement (unnecessary-pass)

from typing import Any, List

import numpy as np
import pennylane as qml
from pennylane.measurements import CountsMP, MeasurementProcess, SampleMeasurement, Shots
from pennylane.typing import TensorLike

from pennylane_lightning.core._measurements_base import LightningBaseMeasurements

Check notice on line 44 in pennylane_lightning/lightning_gpu/_measurements.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane_lightning/lightning_gpu/_measurements.py#L44

Imports from package pennylane_lightning are not grouped (ungrouped-imports)
Expand All @@ -37,3 +59,87 @@ def __init__(
) -> TensorLike:

super().__init__(lgpu_state)

self._measurement_lightning = self._measurement_dtype()(lgpu_state.state_vector)

def _measurement_dtype(self):
"""Binding to Lightning GPU Measurements C++ class.
Returns: the Measurements class
"""
return MeasurementsC64 if self.dtype == np.complex64 else MeasurementsC128

def _measure_with_samples_diagonalizing_gates(
self,
mps: List[SampleMeasurement],
shots: Shots,
) -> TensorLike:
"""
Returns the samples of the measurement process performed on the given state,
by rotating the state into the measurement basis using the diagonalizing gates
given by the measurement process.
Args:
mps (~.measurements.SampleMeasurement): The sample measurements to perform
shots (~.measurements.Shots): The number of samples to take
Returns:
TensorLike[Any]: Sample measurement results
"""
# apply diagonalizing gates
self._apply_diagonalizing_gates(mps)

# Specific for LGPU:
total_indices = self._qubit_state.num_wires
wires = qml.wires.Wires(range(total_indices))

def _process_single_shot(samples):
processed = []
for mp in mps:
res = mp.process_samples(samples, wires)
if not isinstance(mp, CountsMP):
res = qml.math.squeeze(res)

processed.append(res)

return tuple(processed)

try:
samples = self._measurement_lightning.generate_samples(
len(wires), shots.total_shots
).astype(int, copy=False)

except ValueError as ex:

Check notice on line 112 in pennylane_lightning/lightning_gpu/_measurements.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane_lightning/lightning_gpu/_measurements.py#L112

Redefining name 'ex' from outer scope (line 27) (redefined-outer-name)
if str(ex) != "probabilities contain NaN":
raise ex
samples = qml.math.full((shots.total_shots, len(wires)), 0)

self._apply_diagonalizing_gates(mps, adjoint=True)

# if there is a shot vector, use the shots.bins generator to
# split samples w.r.t. the shots
processed_samples = []
for lower, upper in shots.bins():
result = _process_single_shot(samples[..., lower:upper, :])
processed_samples.append(result)

return (
tuple(zip(*processed_samples)) if shots.has_partitioned_shots else processed_samples[0]
)

def _probs_retval_conversion(self, probs_results: Any) -> np.ndarray:
"""Convert the data structure from the C++ backend to a common structure through lightning devices.
Args:
probs_result (Any): Result provided by C++ backend.
Returns:
np.ndarray with probabilities of the supplied observable or wires.
"""

# Device returns as col-major orderings, so perform transpose on data for bit-index shuffle for now.
if len(probs_results) > 0:
num_local_wires = len(probs_results).bit_length() - 1 if len(probs_results) > 0 else 0
return probs_results.reshape([2] * num_local_wires).transpose().reshape(-1)

return probs_results
Loading

0 comments on commit da75534

Please sign in to comment.