Skip to content

Commit

Permalink
parallel read with threads
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushalprasadhial committed Dec 23, 2024
1 parent 7d6d62b commit d6de6b4
Showing 1 changed file with 171 additions and 58 deletions.
229 changes: 171 additions & 58 deletions src/scanpy/readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
from pathlib import Path, PurePath
from typing import TYPE_CHECKING

import anndata
import anndata.utils
import h5py
Expand Down Expand Up @@ -32,11 +33,13 @@
read_mtx,
read_text,
)
from anndata import AnnData
import multiprocessing as mp
import time
import threading
from dataclasses import dataclass

import numba
import scipy
from anndata import AnnData
from matplotlib.image import imread

from . import logging as logg
Expand All @@ -55,7 +58,7 @@
semDataLoaded = None # will be initialized later
semDataCopied = None # will be initialized later

thread_workload = 4000000 # experimented value
thread_workload = 4000000 # experimented value

# .gz and .bz2 suffixes are also allowed for text formats
text_exts = {
Expand All @@ -77,111 +80,219 @@
} | text_exts
"""Available file formats for reading data. """

def _load_helper(fname, i, k, datalen, dataArray, indicesArray, startsArray, endsArray):
f = h5py.File(fname,'r')
dataA = np.frombuffer(dataArray,dtype=np.float32)
indicesA = np.frombuffer(indicesArray,dtype=indices_type)
startsA = np.frombuffer(startsArray,dtype=np.int64)
endsA = np.frombuffer(endsArray,dtype=np.int64)
for j in range(datalen//(k*thread_workload)+1):

def get_1d_index(row: int, col: int, num_cols: int) -> int:
"""
Convert 2D coordinates to 1D index.
Parameters:
row (int): Row index in the 2D array.
col (int): Column index in the 2D array.
num_cols (int): Number of columns in the 2D array.
Returns:
int: Corresponding 1D index.
"""
return row * num_cols + col


@dataclass
class LoadHelperData:
i: int
k: int
datalen: int
dataArray: mp.Array
indicesArray: mp.Array
startsArray: mp.Array
endsArray: mp.Array


def _load_helper(fname: str, helper_data: LoadHelperData):
i = helper_data.i
k = helper_data.k
datalen = helper_data.datalen
dataArray = helper_data.dataArray
indicesArray = helper_data.indicesArray
startsArray = helper_data.startsArray
endsArray = helper_data.endsArray

f = h5py.File(fname, "r")
dataA = np.frombuffer(dataArray, dtype=np.float32)
indicesA = np.frombuffer(indicesArray, dtype=indices_type)
startsA = np.frombuffer(startsArray, dtype=np.int64)
endsA = np.frombuffer(endsArray, dtype=np.int64)
for j in range(datalen // (k * thread_workload) + 1):
# compute start, end
s = i*datalen//k + j*thread_workload
e = min(s+thread_workload, (i+1)*datalen//k)
length = e-s
startsA[i]=s
endsA[i]=e
s = i * datalen // k + j * thread_workload
e = min(s + thread_workload, (i + 1) * datalen // k)
length = e - s
startsA[i] = s
endsA[i] = e
# read direct
f['X']['data'].read_direct(dataA, np.s_[s:e], np.s_[i*thread_workload:i*thread_workload+length])
f['X']['indices'].read_direct(indicesA, np.s_[s:e], np.s_[i*thread_workload:i*thread_workload+length])

f["X"]["data"].read_direct(
dataA, np.s_[s:e], np.s_[i * thread_workload : i * thread_workload + length]
)
f["X"]["indices"].read_direct(
indicesA,
np.s_[s:e],
np.s_[i * thread_workload : i * thread_workload + length],
)

# coordinate with copy threads
semDataLoaded[i].release() # done data load
semDataCopied[i].acquire() # wait until data copied


def _waitload(i):
semDataLoaded[i].acquire()


def _signalcopy(i):
semDataCopied[i].release()

@numba.njit(parallel=True)
def _fast_copy(data,dataA,indices,indicesA,starts,ends,k,m):
for i in numba.prange(k):
for _ in range(m):

@dataclass
class CopyData:
data: np.ndarray
dataA: np.ndarray
indices: np.ndarray
indicesA: np.ndarray
startsA: np.ndarray
endsA: np.ndarray


def _fast_copy(copy_data: CopyData, k: int, m: int):
# Access the arrays through copy_data
data = copy_data.data
dataA = copy_data.dataA
indices = copy_data.indices
indicesA = copy_data.indicesA
starts = copy_data.startsA
ends = copy_data.endsA

def thread_fun(i, m):
for j in range(m):
with numba.objmode():
_waitload(i)
length = ends[i]-starts[i]
data[starts[i]:ends[i]] = dataA[i*thread_workload:i*thread_workload+length]
indices[starts[i]:ends[i]] = indicesA[i*thread_workload:i*thread_workload+length]
length = ends[i] - starts[i]
data[starts[i] : ends[i]] = dataA[
i * thread_workload : i * thread_workload + length
]
indices[starts[i] : ends[i]] = indicesA[
i * thread_workload : i * thread_workload + length
]
with numba.objmode():
_signalcopy(i)

def fastload(fname, backed): #, firstn=1):
t0 = time.time()
f = h5py.File(fname,backed)
assert ('X' in f.keys() and 'var' in f.keys() and 'obs' in f.keys())
threads = [threading.Thread(target=thread_fun, args=(i, m)) for i in range(k)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()


def fastload(fname, backed):
f = h5py.File(fname, backed)
assert "X" in f, "'X' is missing from f"
assert "var" in f, "'var' is missing from f"
assert "obs" in f, "'obs' is missing from f"

# get obs dataframe
rows = f['obs'][ list(f['obs'].keys())[0] ].size
rows = f["obs"][list(f["obs"].keys())[0]].size
# load index pointers, prepare shared arrays
indptr = f['X']['indptr'][0:rows+1]
indptr = f["X"]["indptr"][0 : rows + 1]
datalen = int(indptr[-1])

if datalen<thread_workload:
if datalen < thread_workload:
f.close()
return read_h5ad(fname, backed=backed)
if '_index' in f['obs'].keys():
dfobsind = pd.Series(f['obs']['_index'].asstr()[0:rows])
if "_index" in f["obs"]:
dfobsind = pd.Series(f["obs"]["_index"].asstr()[0:rows])
dfobs = pd.DataFrame(index=dfobsind)
else:
dfobs = pd.DataFrame()
for k in f['obs'].keys():
if k=='_index': continue
dfobs[k] = f['obs'][k].asstr()[...]
for k in f["obs"]:
if k == "_index":
continue
dfobs[k] = f["obs"][k].asstr()[...]

# get var dataframe
if '_index' in f['var'].keys():
dfvarind = pd.Series(f['var']['_index'].asstr()[...])
if "_index" in f["var"]:
dfvarind = pd.Series(f["var"]["_index"].asstr()[...])
dfvar = pd.DataFrame(index=dfvarind)
else:
dfvar = pd.DataFrame()
for k in f['var'].keys():
if k=='_index': continue
dfvar[k] = f['var'][k].asstr()[...]
for k in f["var"]:
if k == "_index":
continue
dfvar[k] = f["var"][k].asstr()[...]

f.close()
k = numba.get_num_threads()
dataArray = mp.Array('f',k*thread_workload,lock=False) # should be in shared memory
indicesArray = mp.Array(indices_shm_type,k*thread_workload,lock=False) # should be in shared memory
startsArray = mp.Array('l',k,lock=False) # start index of data read
endsArray = mp.Array('l',k,lock=False) # end index (noninclusive) of data read
dataArray = mp.Array(
"f", k * thread_workload, lock=False
) # should be in shared memory
indicesArray = mp.Array(
indices_shm_type, k * thread_workload, lock=False
) # should be in shared memory
startsArray = mp.Array("l", k, lock=False) # start index of data read
endsArray = mp.Array("l", k, lock=False) # end index (noninclusive) of data read
global semDataLoaded
global semDataCopied
semDataLoaded = [mp.Semaphore(0) for _ in range(k)]
semDataCopied = [mp.Semaphore(0) for _ in range(k)]
dataA = np.frombuffer(dataArray,dtype=np.float32)
indicesA = np.frombuffer(indicesArray,dtype=indices_type)
dataA = np.frombuffer(dataArray, dtype=np.float32)
indicesA = np.frombuffer(indicesArray, dtype=indices_type)
startsA = np.frombuffer(startsArray, dtype=np.int64)
endsA = np.frombuffer(endsArray, dtype=np.int64)
data = np.empty(datalen, dtype=np.float32)
indices = np.empty(datalen, dtype=indices_type)

procs = [mp.Process(target=_load_helper, args=(fname, i, k, datalen, dataArray, indicesArray, startsArray, endsArray)) for i in range(k)]
for p in procs: p.start()
procs = [
mp.Process(
target=_load_helper,
args=(
fname,
LoadHelperData(
i=i,
k=k,
datalen=datalen,
dataArray=dataArray,
indicesArray=indicesArray,
startsArray=startsArray,
endsArray=endsArray,
),
),
)
for i in range(k)
]

for p in procs:
p.start()

copy_data = CopyData(
data=data,
dataA=dataA,
indices=indices,
indicesA=indicesA,
startsA=startsA,
endsA=endsA,
)

_fast_copy(copy_data, k, datalen // (k * thread_workload) + 1)

_fast_copy(data,dataA,indices,indicesA,startsA,endsA,k,datalen//(k*thread_workload)+1)
for p in procs:
p.join()

for p in procs: p.join()

X = scipy.sparse.csr_matrix((0,0))
X = scipy.sparse.csr_matrix((0, 0))
X.data = data
X.indices = indices
X.indptr = indptr
X._shape = ((rows, dfvar.shape[0]))
X._shape = (rows, dfvar.shape[0])

# create AnnData
adata = anndata.AnnData(X, dfobs, dfvar)
return adata
return adata


# --------------------------------------------------------------------------------
Expand All @@ -200,7 +311,7 @@ def fastload(fname, backed): #, firstn=1):
)
def read(
filename: Path | str,
backed: Literal["r", "r+"] | None = 'r+',
backed: Literal["r", "r+"] | None = "r+",
*,
sheet: str | None = None,
ext: str | None = None,
Expand Down Expand Up @@ -455,7 +566,9 @@ def _read_v3_10x_h5(filename, *, start=None):
(
feature_metadata_name,
dsets[feature_metadata_name].astype(
bool if feature_metadata_item.dtype.kind == "thread_workload" else str
bool
if feature_metadata_item.dtype.kind == "thread_workload"
else str
),
)
for feature_metadata_name, feature_metadata_item in f["matrix"][
Expand Down

0 comments on commit d6de6b4

Please sign in to comment.