Skip to content

Commit

Permalink
apply format
Browse files Browse the repository at this point in the history
  • Loading branch information
LuisAlfredoNu committed Sep 11, 2024
1 parent 6627913 commit 6290953
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 105 deletions.
10 changes: 5 additions & 5 deletions pennylane_lightning/lightning_gpu/_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

try:
from pennylane_lightning.lightning_gpu_ops import MeasurementsC64, MeasurementsC128

try:
from pennylane_lightning.lightning_gpu_ops import MeasurementsMPIC64,MeasurementsMPIC128
from pennylane_lightning.lightning_gpu_ops import MeasurementsMPIC64, MeasurementsMPIC128

MPI_SUPPORT = True
except ImportError:
Expand All @@ -32,7 +32,7 @@

import numpy as np
import pennylane as qml
from pennylane.measurements import CountsMP, SampleMeasurement, Shots, MeasurementProcess
from pennylane.measurements import CountsMP, MeasurementProcess, SampleMeasurement, Shots
from pennylane.typing import TensorLike

from pennylane_lightning.core._measurements_base import LightningBaseMeasurements
Expand Down Expand Up @@ -136,12 +136,12 @@ def probs(self, measurementprocess: MeasurementProcess):
self._qubit_state.apply_operations(diagonalizing_gates)

results = self._measurement_lightning.probs(measurementprocess.wires.tolist())

if diagonalizing_gates:
self._qubit_state.apply_operations(
[qml.adjoint(g, lazy=False) for g in reversed(diagonalizing_gates)]
)

# Device returns as col-major orderings, so perform transpose on data for bit-index shuffle for now.
if len(results) > 0:
num_local_wires = len(results).bit_length() - 1 if len(results) > 0 else 0
Expand Down
72 changes: 37 additions & 35 deletions pennylane_lightning/lightning_gpu/_mpi_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@

try:
# pylint: disable=no-name-in-module
from pennylane_lightning.lightning_gpu_ops import (
DevTag,
MPIManager,
)
from pennylane_lightning.lightning_gpu_ops import DevTag, MPIManager

MPI_SUPPORT = True
except ImportError:
MPI_SUPPORT = False
Expand All @@ -29,91 +27,95 @@

import numpy as np


# MPI options
class LightningGPU_MPIHandler():
"""MPI handler for PennyLane Lightning GPU device
class LightningGPU_MPIHandler:
"""MPI handler for PennyLane Lightning GPU device
MPI handler to use a GPU-backed Lightning device using NVIDIA cuQuantum SDK with parallel capabilities.
Use the MPI library is necessary to initialize different variables and methods to handle the data across nodes and perform checks for memory allocation on each device.
Use the MPI library is necessary to initialize different variables and methods to handle the data across nodes and perform checks for memory allocation on each device.
Args:
mpi (bool): declare if the device will use the MPI support.
mpi_buf_size (int): size of GPU memory (in MiB) set for MPI operation and its default value is 64 MiB.
dev_pool (Callable): Method to handle the GPU devices available.
num_wires (int): the number of wires to initialize the device wit.h
num_wires (int): the number of wires to initialize the device wit.h
c_dtype (np.complex64, np.complex128): Datatypes for statevector representation
"""

def __init__(self,
mpi: bool,
mpi_buf_size: int,
dev_pool: Callable,
num_wires: int,
c_dtype: Union[np.complex64, np.complex128]) -> None:

def __init__(
self,
mpi: bool,
mpi_buf_size: int,
dev_pool: Callable,
num_wires: int,
c_dtype: Union[np.complex64, np.complex128],
) -> None:

self.use_mpi = mpi
self.mpi_but_size = mpi_buf_size
self._dp = dev_pool
if self.use_mpi:

if self.use_mpi:

if not MPI_SUPPORT:
raise ImportError("MPI related APIs are not found.")

if mpi_buf_size < 0:
raise TypeError(f"Unsupported mpi_buf_size value: {mpi_buf_size}, should be >= 0")

if (mpi_buf_size > 0
and (mpi_buf_size & (mpi_buf_size - 1))):
raise ValueError(f"Unsupported mpi_buf_size value: {mpi_buf_size}. mpi_buf_size should be power of 2.")

if mpi_buf_size > 0 and (mpi_buf_size & (mpi_buf_size - 1)):
raise ValueError(
f"Unsupported mpi_buf_size value: {mpi_buf_size}. mpi_buf_size should be power of 2."
)

# After check if all MPI parameter are ok
self.mpi_manager, self.devtag = self._mpi_init_helper(num_wires)

# set the number of global and local wires
commSize = self._mpi_manager.getSize()
self.num_global_wires = commSize.bit_length() - 1
self.num_local_wires = num_wires - self._num_global_wires

# Memory size in bytes
sv_memsize = np.dtype(c_dtype).itemsize * (1 << self.num_local_wires)
if self._mebibytesToBytes(mpi_buf_size) > sv_memsize:
raise ValueError("The MPI buffer size is larger than the local state vector size.")

if not self.use_mpi:
if not self.use_mpi:
self.num_local_wires = num_wires
self.num_global_wires = num_wires

def _mebibytesToBytes(mebibytes):
return mebibytes * 1024 * 1024

def _mpi_init_helper(self, num_wires):
"""Set up MPI checks and initializations."""

# initialize MPIManager and config check in the MPIManager ctor
mpi_manager = MPIManager()

# check if number of GPUs per node is larger than number of processes per node
numDevices = self._dp.getTotalDevices()
numProcsNode = mpi_manager.getSizeNode()

if numDevices < numProcsNode:
raise ValueError(
"Number of devices should be larger than or equal to the number of processes on each node."
)

# check if the process number is larger than number of statevector elements
if mpi_manager.getSize() > (1 << (num_wires - 1)):
raise ValueError(
"Number of processes should be smaller than the number of statevector elements."
)

# set GPU device
rank = self._mpi_manager.getRank()
deviceid = rank % numProcsNode
self._dp.setDeviceID(deviceid)
devtag = DevTag(deviceid)

return (mpi_manager, devtag)
65 changes: 30 additions & 35 deletions pennylane_lightning/lightning_gpu/_state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,16 @@
Class implementation for lightning_gpu state-vector manipulation.
"""
try:
from pennylane_lightning.lightning_gpu_ops import (
StateVectorC64,
StateVectorC128,
)

try: # Try to import the MPI modules
from pennylane_lightning.lightning_gpu_ops import StateVectorC64, StateVectorC128

try: # Try to import the MPI modules
# pylint: disable=no-name-in-module
from pennylane_lightning.lightning_gpu_ops import (
StateVectorMPIC64,
StateVectorMPIC128,
)
from pennylane_lightning.lightning_gpu_ops import StateVectorMPIC64, StateVectorMPIC128

MPI_SUPPORT = True
except ImportError:
MPI_SUPPORT = False

except ImportError:
pass

Expand All @@ -39,11 +33,10 @@
import numpy as np
import pennylane as qml
from pennylane import DeviceError
from pennylane.ops.op_math import Adjoint
from pennylane.wires import Wires
from pennylane.measurements import MidMeasureMP
from pennylane.ops import Conditional
from pennylane import DeviceError
from pennylane.ops.op_math import Adjoint
from pennylane.wires import Wires

from pennylane_lightning.core._serialize import global_phase_diagonal
from pennylane_lightning.core._state_vector_base import LightningBaseStateVector
Expand All @@ -60,6 +53,7 @@
qml.QubitUnitary,
)


class LightningGPUStateVector(LightningBaseStateVector):
"""Lightning GPU state-vector class.
Expand All @@ -72,18 +66,19 @@ class LightningGPUStateVector(LightningBaseStateVector):
device_name(string): state vector device name. Options: ["lightning.gpu"]
"""

def __init__(self,
num_wires,
dtype=np.complex128,
device_name="lightning.gpu",
mpi_handler = None,
sync=True,
):
def __init__(
self,
num_wires,
dtype=np.complex128,
device_name="lightning.gpu",
mpi_handler=None,
sync=True,
):

super().__init__(num_wires, dtype)

self._device_name = device_name

if mpi_handler is None:
mpi_handler = LightningGPU_MPIHandler(False, 0, None, num_wires, dtype)

Expand All @@ -105,7 +100,7 @@ def __init__(self,

if not self._mpi_handler.use_mpi:
self._qubit_state = self._state_dtype()(self.num_wires)

self._create_basis_state(0)

def _state_dtype(self):
Expand Down Expand Up @@ -139,7 +134,6 @@ def syncD2H(self, state_vector, use_async=False):
[0.+0.j 1.+0.j]
"""
self._qubit_state.DeviceToHost(state_vector.ravel(order="C"), use_async)


@property
def state(self):
Expand All @@ -158,7 +152,6 @@ def state(self):
self.syncD2H(state)
return state


def syncH2D(self, state_vector, use_async=False):
"""Copy the state vector data on host provided by the user to the state vector on the device
Args:
Expand All @@ -179,7 +172,7 @@ def syncH2D(self, state_vector, use_async=False):
1.0
"""
self._qubit_state.HostToDevice(state_vector.ravel(order="C"), use_async)

@staticmethod
def _asarray(arr, dtype=None):
arr = np.asarray(arr) # arr is not copied
Expand Down Expand Up @@ -211,10 +204,10 @@ def _apply_state_vector(self, state, device_wires, use_async=False):
use_async(bool): indicates whether to use asynchronous memory copy from host to device or not.
Note: This function only supports synchronized memory copy from host to device.
"""

if isinstance(state, self._qubit_state.__class__):
raise DeviceError("LightningGPU does not support allocate external state_vector.")

# TODO
# state_data = allocate_aligned_array(state.size, np.dtype(self.dtype), True)
# state.getState(state_data)
Expand Down Expand Up @@ -294,7 +287,7 @@ def _apply_lightning_controlled(self, operation):
wires = self.wires.indices(operation.wires)
matrix = global_phase_diagonal(param, self.wires, control_wires, control_values)
state.apply(name, wires, inv, [[param]], matrix)

def _apply_lightning_midmeasure(
self, operation: MidMeasureMP, mid_measurements: dict, postselect_mode: str
):
Expand All @@ -311,7 +304,7 @@ def _apply_lightning_midmeasure(
None
"""
raise DeviceError("LightningGPU does not support Mid-circuit measurements.")

def _apply_lightning(
self, operations, mid_measurements: dict = None, postselect_mode: str = None
):
Expand Down Expand Up @@ -342,7 +335,7 @@ def _apply_lightning(
invert_param = False
method = getattr(state, name, None)
wires = list(operation.wires)

if isinstance(operation, Conditional):
if operation.meas_val.concretize(mid_measurements):
self._apply_lightning([operation.base])
Expand All @@ -364,10 +357,13 @@ def _apply_lightning(
except AttributeError: # pragma: no cover
# To support older versions of PL
mat = operation.matrix



r_dtype = np.float32 if self.dtype == np.complex64 else np.float64
param = [[r_dtype(operation.hash)]] if isinstance(operation, gate_cache_needs_hash) else []
param = (
[[r_dtype(operation.hash)]]
if isinstance(operation, gate_cache_needs_hash)
else []
)
if len(mat) == 0:
raise ValueError("Unsupported operation")

Expand All @@ -378,4 +374,3 @@ def _apply_lightning(
param,
mat.ravel(order="C"), # inv = False: Matrix already in correct form;
) # Parameters can be ignored for explicit matrices; F-order for cuQuantum

Loading

0 comments on commit 6290953

Please sign in to comment.