-
Notifications
You must be signed in to change notification settings - Fork 89
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add slicing CUDA kernels (#3140)
* feat add awkward_ListArray_getitem_jagged_apply kernel * fix: remove print statements * feat: add awkward_ListArray_getitem_jagged_shrink kernel * test: cuda integration tests * test: more slicing integration tests * fix: ndarray error for cupy array shape * fix: remove unused variable
- Loading branch information
1 parent
0b9f6f4
commit 81085be
Showing
9 changed files
with
2,433 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
120 changes: 120 additions & 0 deletions
120
src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_jagged_apply.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
// BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE | ||
|
||
// BEGIN PYTHON | ||
// def f(grid, block, args): | ||
// (tooffsets, tocarry, slicestarts, slicestops, sliceouterlen, sliceindex, sliceinnerlen, fromstarts, fromstops, contentlen, invocation_index, err_code) = args | ||
// scan_in_array = cupy.zeros(sliceouterlen + 1, dtype=cupy.int64) | ||
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_getitem_jagged_apply_a", tooffsets.dtype, tocarry.dtype, slicestarts.dtype, slicestops.dtype, sliceindex.dtype, fromstarts.dtype, fromstops.dtype]))(grid, block, (tooffsets, tocarry, slicestarts, slicestops, sliceouterlen, sliceindex, sliceinnerlen, fromstarts, fromstops, contentlen, scan_in_array, invocation_index, err_code)) | ||
// scan_in_array = cupy.cumsum(scan_in_array) | ||
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_getitem_jagged_apply_b", tooffsets.dtype, tocarry.dtype, slicestarts.dtype, slicestops.dtype, sliceindex.dtype, fromstarts.dtype, fromstops.dtype]))(grid, block, (tooffsets, tocarry, slicestarts, slicestops, sliceouterlen, sliceindex, sliceinnerlen, fromstarts, fromstops, contentlen, scan_in_array, invocation_index, err_code)) | ||
// out["awkward_ListArray_getitem_jagged_apply_a", {dtype_specializations}] = None | ||
// out["awkward_ListArray_getitem_jagged_apply_b", {dtype_specializations}] = None | ||
// END PYTHON | ||
|
||
enum class LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS { | ||
JAG_STOP_LT_START, // message: "jagged slice's stops[i] < starts[i]" | ||
OFF_GET_CON, // message: "jagged slice's offsets extend beyond its content" | ||
STOP_LT_START, // message: "stops[i] < starts[i]" | ||
STOP_GET_LEN, // message: "stops[i] > len(content)" | ||
IND_OUT_OF_RANGE, // message: "index out of range" | ||
}; | ||
|
||
template <typename T, typename C, typename U, typename V, typename W, typename X, typename Y> | ||
__global__ void | ||
awkward_ListArray_getitem_jagged_apply_a( | ||
T* tooffsets, | ||
C* tocarry, | ||
const U* slicestarts, | ||
const V* slicestops, | ||
int64_t sliceouterlen, | ||
const W* sliceindex, | ||
int64_t sliceinnerlen, | ||
const X* fromstarts, | ||
const Y* fromstops, | ||
int64_t contentlen, | ||
int64_t* scan_in_array, | ||
uint64_t invocation_index, | ||
uint64_t* err_code) { | ||
if (err_code[0] == NO_ERROR) { | ||
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; | ||
scan_in_array[0] = 0; | ||
|
||
if (thread_id < sliceouterlen) { | ||
U slicestart = slicestarts[thread_id]; | ||
V slicestop = slicestops[thread_id]; | ||
|
||
if (slicestart != slicestop) { | ||
if (slicestop < slicestart) { | ||
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::JAG_STOP_LT_START) | ||
} | ||
if (slicestop > sliceinnerlen) { | ||
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::OFF_GET_CON) | ||
} | ||
int64_t start = (int64_t)fromstarts[thread_id]; | ||
int64_t stop = (int64_t)fromstops[thread_id]; | ||
if (stop < start) { | ||
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::STOP_LT_START) | ||
} | ||
if (start != stop && stop > contentlen) { | ||
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::STOP_GET_LEN) | ||
} | ||
scan_in_array[thread_id + 1] = slicestop - slicestart; | ||
} | ||
} | ||
} | ||
} | ||
|
||
template <typename T, typename C, typename U, typename V, typename W, typename X, typename Y> | ||
__global__ void | ||
awkward_ListArray_getitem_jagged_apply_b( | ||
T* tooffsets, | ||
C* tocarry, | ||
const U* slicestarts, | ||
const V* slicestops, | ||
int64_t sliceouterlen, | ||
const W* sliceindex, | ||
int64_t sliceinnerlen, | ||
const X* fromstarts, | ||
const Y* fromstops, | ||
int64_t contentlen, | ||
int64_t* scan_in_array, | ||
uint64_t invocation_index, | ||
uint64_t* err_code) { | ||
if (err_code[0] == NO_ERROR) { | ||
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; | ||
|
||
if (thread_id < sliceouterlen) { | ||
U slicestart = slicestarts[thread_id]; | ||
V slicestop = slicestops[thread_id]; | ||
tooffsets[thread_id] = (T)(scan_in_array[thread_id]); | ||
if (slicestart != slicestop) { | ||
if (slicestop < slicestart) { | ||
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::JAG_STOP_LT_START) | ||
} | ||
if (slicestop > sliceinnerlen) { | ||
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::OFF_GET_CON) | ||
} | ||
int64_t start = (int64_t)fromstarts[thread_id]; | ||
int64_t stop = (int64_t)fromstops[thread_id]; | ||
if (stop < start) { | ||
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::STOP_LT_START) | ||
} | ||
if (start != stop && stop > contentlen) { | ||
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::STOP_GET_LEN) | ||
} | ||
int64_t count = stop - start; | ||
for (int64_t j = slicestart + threadIdx.y; j < slicestop; j += blockDim.y) { | ||
int64_t index = (int64_t) sliceindex[j]; | ||
if (index < -count || index > count) { | ||
RAISE_ERROR(LISTARRAY_GETITEM_JAGGED_APPLY_ERRORS::IND_OUT_OF_RANGE) | ||
} | ||
if (index < 0) { | ||
index += count; | ||
} | ||
tocarry[scan_in_array[thread_id] + j - slicestart] = start + index; | ||
} | ||
} | ||
} | ||
tooffsets[sliceouterlen] = scan_in_array[sliceouterlen]; | ||
} | ||
} |
105 changes: 105 additions & 0 deletions
105
src/awkward/_connect/cuda/cuda_kernels/awkward_ListArray_getitem_jagged_shrink.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
// BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE | ||
|
||
// BEGIN PYTHON | ||
// def f(grid, block, args): | ||
// (tocarry, tosmalloffsets, tolargeoffsets, slicestarts, slicestops, length, missing, invocation_index, err_code) = args | ||
// if length > 0 and length < int(slicestops[length - 1]): | ||
// len_array = int(slicestops[length - 1]) | ||
// else: | ||
// len_array = length | ||
// scan_in_array_k = cupy.zeros(len_array, dtype=cupy.int64) | ||
// scan_in_array_tosmalloffsets = cupy.zeros(length + 1, dtype=cupy.int64) | ||
// scan_in_array_tolargeoffsets = cupy.zeros(length + 1, dtype=cupy.int64) | ||
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_getitem_jagged_shrink_a", tocarry.dtype, tosmalloffsets.dtype, tolargeoffsets.dtype, slicestarts.dtype, slicestops.dtype, missing.dtype]))(grid, block, (tocarry, tosmalloffsets, tolargeoffsets, slicestarts, slicestops, length, missing, scan_in_array_k, scan_in_array_tosmalloffsets, scan_in_array_tolargeoffsets, invocation_index, err_code)) | ||
// scan_in_array_k = cupy.cumsum(scan_in_array_k) | ||
// scan_in_array_tosmalloffsets = cupy.cumsum(scan_in_array_tosmalloffsets) | ||
// scan_in_array_tolargeoffsets = cupy.cumsum(scan_in_array_tolargeoffsets) | ||
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_getitem_jagged_shrink_b", tocarry.dtype, tosmalloffsets.dtype, tolargeoffsets.dtype, slicestarts.dtype, slicestops.dtype, missing.dtype]))(grid, block, (tocarry, tosmalloffsets, tolargeoffsets, slicestarts, slicestops, length, missing, scan_in_array_k, scan_in_array_tosmalloffsets, scan_in_array_tolargeoffsets, invocation_index, err_code)) | ||
// out["awkward_ListArray_getitem_jagged_shrink_a", {dtype_specializations}] = None | ||
// out["awkward_ListArray_getitem_jagged_shrink_b", {dtype_specializations}] = None | ||
// END PYTHON | ||
|
||
template <typename T, typename C, typename U, typename V, typename W, typename X> | ||
__global__ void | ||
awkward_ListArray_getitem_jagged_shrink_a( | ||
T* tocarry, | ||
C* tosmalloffsets, | ||
U* tolargeoffsets, | ||
const V* slicestarts, | ||
const W* slicestops, | ||
int64_t length, | ||
const X* missing, | ||
int64_t* scan_in_array_k, | ||
int64_t* scan_in_array_tosmalloffsets, | ||
int64_t* scan_in_array_tolargeoffsets, | ||
uint64_t invocation_index, | ||
uint64_t* err_code) { | ||
if (err_code[0] == NO_ERROR) { | ||
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (thread_id < length) { | ||
if (thread_id == 0) { | ||
scan_in_array_tosmalloffsets[0] = slicestarts[0]; | ||
scan_in_array_tolargeoffsets[0] = slicestarts[0]; | ||
} | ||
V slicestart = slicestarts[thread_id]; | ||
W slicestop = slicestops[thread_id]; | ||
if (slicestart != slicestop) { | ||
C smallcount = 0; | ||
for (int64_t j = slicestart + threadIdx.y; j < slicestop; j += blockDim.y) { | ||
if (missing[j] >= 0) { | ||
smallcount++; | ||
} | ||
} | ||
scan_in_array_k[thread_id + 1] = smallcount; | ||
scan_in_array_tosmalloffsets[thread_id + 1] = smallcount; | ||
} | ||
scan_in_array_tolargeoffsets[thread_id + 1] = slicestop - slicestart; | ||
} | ||
} | ||
} | ||
|
||
template <typename T, typename C, typename U, typename V, typename W, typename X> | ||
__global__ void | ||
awkward_ListArray_getitem_jagged_shrink_b( | ||
T* tocarry, | ||
C* tosmalloffsets, | ||
U* tolargeoffsets, | ||
const V* slicestarts, | ||
const W* slicestops, | ||
int64_t length, | ||
const X* missing, | ||
int64_t* scan_in_array_k, | ||
int64_t* scan_in_array_tosmalloffsets, | ||
int64_t* scan_in_array_tolargeoffsets, | ||
uint64_t invocation_index, | ||
uint64_t* err_code) { | ||
if (err_code[0] == NO_ERROR) { | ||
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (length == 0) { | ||
tosmalloffsets[0] = 0; | ||
tolargeoffsets[0] = 0; | ||
} | ||
else { | ||
tosmalloffsets[0] = slicestarts[0]; | ||
tolargeoffsets[0] = slicestarts[0]; | ||
} | ||
if (thread_id < length) { | ||
V slicestart = slicestarts[thread_id]; | ||
W slicestop = slicestops[thread_id]; | ||
int64_t k = scan_in_array_k[thread_id] - scan_in_array_k[0]; | ||
if (slicestart != slicestop) { | ||
for (int64_t j = slicestart + threadIdx.y; j < slicestop; j += blockDim.y) { | ||
if (missing[j] >= 0) { | ||
tocarry[k] = j; | ||
k++; | ||
} | ||
} | ||
tosmalloffsets[thread_id + 1] = scan_in_array_tosmalloffsets[thread_id + 1]; | ||
} | ||
else { | ||
tosmalloffsets[thread_id + 1] = scan_in_array_tosmalloffsets[thread_id]; | ||
} | ||
tolargeoffsets[thread_id + 1] = scan_in_array_tolargeoffsets[thread_id + 1]; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.