From d6de6b4ea88025a11e166df34d2d273cf26c32ad Mon Sep 17 00:00:00 2001 From: kaushalprasadhial Date: Mon, 23 Dec 2024 04:11:42 +0000 Subject: [PATCH] parallel read with threads --- src/scanpy/readwrite.py | 229 ++++++++++++++++++++++++++++++---------- 1 file changed, 171 insertions(+), 58 deletions(-) diff --git a/src/scanpy/readwrite.py b/src/scanpy/readwrite.py index 5e9242106a..7c4904cf10 100644 --- a/src/scanpy/readwrite.py +++ b/src/scanpy/readwrite.py @@ -5,6 +5,7 @@ import json from pathlib import Path, PurePath from typing import TYPE_CHECKING + import anndata import anndata.utils import h5py @@ -32,11 +33,13 @@ read_mtx, read_text, ) -from anndata import AnnData import multiprocessing as mp -import time +import threading +from dataclasses import dataclass + import numba import scipy +from anndata import AnnData from matplotlib.image import imread from . import logging as logg @@ -55,7 +58,7 @@ semDataLoaded = None # will be initialized later semDataCopied = None # will be initialized later -thread_workload = 4000000 # experimented value +thread_workload = 4000000 # experimented value # .gz and .bz2 suffixes are also allowed for text formats text_exts = { @@ -77,111 +80,219 @@ } | text_exts """Available file formats for reading data. """ -def _load_helper(fname, i, k, datalen, dataArray, indicesArray, startsArray, endsArray): - f = h5py.File(fname,'r') - dataA = np.frombuffer(dataArray,dtype=np.float32) - indicesA = np.frombuffer(indicesArray,dtype=indices_type) - startsA = np.frombuffer(startsArray,dtype=np.int64) - endsA = np.frombuffer(endsArray,dtype=np.int64) - for j in range(datalen//(k*thread_workload)+1): + +def get_1d_index(row: int, col: int, num_cols: int) -> int: + """ + Convert 2D coordinates to 1D index. + + Parameters: + row (int): Row index in the 2D array. + col (int): Column index in the 2D array. + num_cols (int): Number of columns in the 2D array. + + Returns: + int: Corresponding 1D index. + """ + return row * num_cols + col + + +@dataclass +class LoadHelperData: + i: int + k: int + datalen: int + dataArray: mp.Array + indicesArray: mp.Array + startsArray: mp.Array + endsArray: mp.Array + + +def _load_helper(fname: str, helper_data: LoadHelperData): + i = helper_data.i + k = helper_data.k + datalen = helper_data.datalen + dataArray = helper_data.dataArray + indicesArray = helper_data.indicesArray + startsArray = helper_data.startsArray + endsArray = helper_data.endsArray + + f = h5py.File(fname, "r") + dataA = np.frombuffer(dataArray, dtype=np.float32) + indicesA = np.frombuffer(indicesArray, dtype=indices_type) + startsA = np.frombuffer(startsArray, dtype=np.int64) + endsA = np.frombuffer(endsArray, dtype=np.int64) + for j in range(datalen // (k * thread_workload) + 1): # compute start, end - s = i*datalen//k + j*thread_workload - e = min(s+thread_workload, (i+1)*datalen//k) - length = e-s - startsA[i]=s - endsA[i]=e + s = i * datalen // k + j * thread_workload + e = min(s + thread_workload, (i + 1) * datalen // k) + length = e - s + startsA[i] = s + endsA[i] = e # read direct - f['X']['data'].read_direct(dataA, np.s_[s:e], np.s_[i*thread_workload:i*thread_workload+length]) - f['X']['indices'].read_direct(indicesA, np.s_[s:e], np.s_[i*thread_workload:i*thread_workload+length]) - + f["X"]["data"].read_direct( + dataA, np.s_[s:e], np.s_[i * thread_workload : i * thread_workload + length] + ) + f["X"]["indices"].read_direct( + indicesA, + np.s_[s:e], + np.s_[i * thread_workload : i * thread_workload + length], + ) + # coordinate with copy threads semDataLoaded[i].release() # done data load semDataCopied[i].acquire() # wait until data copied + def _waitload(i): semDataLoaded[i].acquire() + def _signalcopy(i): semDataCopied[i].release() -@numba.njit(parallel=True) -def _fast_copy(data,dataA,indices,indicesA,starts,ends,k,m): - for i in numba.prange(k): - for _ in range(m): + +@dataclass +class CopyData: + data: np.ndarray + dataA: np.ndarray + indices: np.ndarray + indicesA: np.ndarray + startsA: np.ndarray + endsA: np.ndarray + + +def _fast_copy(copy_data: CopyData, k: int, m: int): + # Access the arrays through copy_data + data = copy_data.data + dataA = copy_data.dataA + indices = copy_data.indices + indicesA = copy_data.indicesA + starts = copy_data.startsA + ends = copy_data.endsA + + def thread_fun(i, m): + for j in range(m): with numba.objmode(): _waitload(i) - length = ends[i]-starts[i] - data[starts[i]:ends[i]] = dataA[i*thread_workload:i*thread_workload+length] - indices[starts[i]:ends[i]] = indicesA[i*thread_workload:i*thread_workload+length] + length = ends[i] - starts[i] + data[starts[i] : ends[i]] = dataA[ + i * thread_workload : i * thread_workload + length + ] + indices[starts[i] : ends[i]] = indicesA[ + i * thread_workload : i * thread_workload + length + ] with numba.objmode(): _signalcopy(i) -def fastload(fname, backed): #, firstn=1): - t0 = time.time() - f = h5py.File(fname,backed) - assert ('X' in f.keys() and 'var' in f.keys() and 'obs' in f.keys()) + threads = [threading.Thread(target=thread_fun, args=(i, m)) for i in range(k)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + +def fastload(fname, backed): + f = h5py.File(fname, backed) + assert "X" in f, "'X' is missing from f" + assert "var" in f, "'var' is missing from f" + assert "obs" in f, "'obs' is missing from f" # get obs dataframe - rows = f['obs'][ list(f['obs'].keys())[0] ].size + rows = f["obs"][list(f["obs"].keys())[0]].size # load index pointers, prepare shared arrays - indptr = f['X']['indptr'][0:rows+1] + indptr = f["X"]["indptr"][0 : rows + 1] datalen = int(indptr[-1]) - if datalen