From 59eae84eade1ed8f9928091e8ba5733c59086a84 Mon Sep 17 00:00:00 2001 From: PetrelYy <92866578+PetrelYy@users.noreply.github.com> Date: Wed, 4 Dec 2024 15:12:40 +0800 Subject: [PATCH] [Feature](mlu-ops): adapt scatter,gather (#1168) --- .../box_iou_rotated/box_iou_rotated_utils.h | 28 +++--- .../generate_proposals_v2_union1_500.mlu | 1 + .../ms_deform_attn_backward_fast_union1.mlu | 28 +++--- ...rm_attn_backward_small_channels_union1.mlu | 38 ++++---- .../ms_deform_attn_utils.h | 14 +-- .../msda_forward_fast_union1.mlu | 50 +++++------ .../roi_align_rotated_forward_vector.mlu | 10 +-- kernels/utils/scatter_gather.h | 90 +++++++++++++++++++ .../voxel_pooling_forward_union1.mlu | 10 ++- kernels/voxelization/voxelization_kernel.mlu | 20 +++-- 10 files changed, 191 insertions(+), 98 deletions(-) create mode 100644 kernels/utils/scatter_gather.h diff --git a/kernels/box_iou_rotated/box_iou_rotated_utils.h b/kernels/box_iou_rotated/box_iou_rotated_utils.h index 7c3e8d270..22aa3e0ec 100644 --- a/kernels/box_iou_rotated/box_iou_rotated_utils.h +++ b/kernels/box_iou_rotated/box_iou_rotated_utils.h @@ -24,6 +24,7 @@ #define KERNELS_BOX_IOU_ROTATED_BOX_IOU_ROTATED_UTILS_H_ #include "kernels/utils/common.h" +#include "kernels/utils/scatter_gather.h" #define FIILED_ONES (int)0xffffffff #define HALF_FILLED_ONES (int16_t)0xffff @@ -590,21 +591,22 @@ __mlu_func__ void convexHullGraham( sizeof(T), actual_compute_box_num); // get the ordered points according to the angle value - __gather(ordered_pts_x + (i + 1) * actual_compute_box_num, intersect_pts_x, - (unsigned int *)temp_offset, sizeof(T), NRAM2NRAM, sizeof(T), - actual_compute_box_num); - __gather(ordered_pts_y + (i + 1) * actual_compute_box_num, intersect_pts_y, - (unsigned int *)temp_offset, sizeof(T), NRAM2NRAM, sizeof(T), - actual_compute_box_num); - __gather(temp_long_1 + (i + 1) * actual_compute_box_num, valid_pts, - (unsigned int *)temp_offset, sizeof(T), NRAM2NRAM, sizeof(T), - actual_compute_box_num); + __mluop_gather(ordered_pts_x + (i + 1) * actual_compute_box_num, + intersect_pts_x, (unsigned int *)temp_offset, NULL, + sizeof(T), NRAM2NRAM, sizeof(T), actual_compute_box_num); + __mluop_gather(ordered_pts_y + (i + 1) * actual_compute_box_num, + intersect_pts_y, (unsigned int *)temp_offset, NULL, + sizeof(T), NRAM2NRAM, sizeof(T), actual_compute_box_num); + __mluop_gather(temp_long_1 + (i + 1) * actual_compute_box_num, valid_pts, + (unsigned int *)temp_offset, NULL, sizeof(T), NRAM2NRAM, + sizeof(T), actual_compute_box_num); // assign a invalid value to the point which has been get ordered - __scatter(temp_long_2, temp1_ram, (unsigned int *)temp_offset, sizeof(T), - NRAM2NRAM, sizeof(T), actual_compute_box_num); - __scatter(valid_pts, temp2_ram, (unsigned int *)temp_offset, sizeof(T), - NRAM2NRAM, sizeof(T), actual_compute_box_num); + __mluop_scatter(temp_long_2, temp1_ram, (unsigned int *)temp_offset, + NULL, sizeof(T), NRAM2NRAM, sizeof(T), + actual_compute_box_num); + __mluop_scatter(valid_pts, temp2_ram, (unsigned int *)temp_offset, NULL, + sizeof(T), NRAM2NRAM, sizeof(T), actual_compute_box_num); } __bang_move(valid_pts, temp_long_1, total_points * sizeof(T)); #else diff --git a/kernels/generate_proposals_v2/generate_proposals_v2_union1_500.mlu b/kernels/generate_proposals_v2/generate_proposals_v2_union1_500.mlu index bf5887b03..59e25153b 100644 --- a/kernels/generate_proposals_v2/generate_proposals_v2_union1_500.mlu +++ b/kernels/generate_proposals_v2/generate_proposals_v2_union1_500.mlu @@ -158,6 +158,7 @@ __mlu_func__ void proposalBoxesDecode( // gather offset (byte). __bang_mul_scalar(anchors_index_nram, anchors_index_nram, sizeof(int32_t), deal_num); + // deal_num <= 5163 __gather(temp_nram, anchors, (unsigned int *)anchors_index_nram, sizeof(T) * 4, GDRAM2NRAM, sizeof(T) * 4, deal_num); __bang_transpose(anchors_nram, temp_nram, deal_num, 4); diff --git a/kernels/ms_deform_attn/ms_deform_attn_backward/ms_deform_attn_backward_fast_union1.mlu b/kernels/ms_deform_attn/ms_deform_attn_backward/ms_deform_attn_backward_fast_union1.mlu index b72087481..d94a2f021 100644 --- a/kernels/ms_deform_attn/ms_deform_attn_backward/ms_deform_attn_backward_fast_union1.mlu +++ b/kernels/ms_deform_attn/ms_deform_attn_backward/ms_deform_attn_backward_fast_union1.mlu @@ -233,14 +233,14 @@ __mlu_func__ void backwardStageTwoLoop( for (int j = 0; j < 5; j++) { T* tmp_wp = weight_polation_nram + (j - 1) * nq_nl_np; if (j < 4) { - gatherAsync(v_ping, zeros_nram, (unsigned int*)offset_zero_nram_stg2, - bit_cond_reverse_nram + j * bit_cond_stride, - channels * sizeof(T), NRAM2NRAM, channels * sizeof(T), - nq_nl_np); - gatherAsync(v_ping, data_value_gdram, - (unsigned int*)offset_nram + j * nq_nl_np, - bit_cond_nram + j * bit_cond_stride, channels * sizeof(T), - GDRAM2NRAM, channels * sizeof(T), nq_nl_np); + gatherAsync(v_ping, zeros_nram, (unsigned int*)offset_zero_nram_stg2, + bit_cond_reverse_nram + j * bit_cond_stride, + channels * sizeof(T), NRAM2NRAM, channels * sizeof(T), + nq_nl_np); + gatherAsync( + v_ping, data_value_gdram, (unsigned int*)offset_nram + j * nq_nl_np, + bit_cond_nram + j * bit_cond_stride, channels * sizeof(T), + GDRAM2NRAM, channels * sizeof(T), nq_nl_np); } if (j == 0) { @@ -249,10 +249,10 @@ __mlu_func__ void backwardStageTwoLoop( NRAM2NRAM, channels * sizeof(T), num_levels_points - 1, num_levels_points * channels * sizeof(T), deal_n - 1, 0, num_levels_points - 1, channels * sizeof(T), deal_n - 1); - gatherAsync(buffer, zeros_nram, (unsigned int*)offset_zero_nram_stg2, - bit_cond_reverse_nram + 4 * bit_cond_stride, - channels * sizeof(T), NRAM2NRAM, channels * sizeof(T), - nq_nl_np); + gatherAsync(buffer, zeros_nram, (unsigned int*)offset_zero_nram_stg2, + bit_cond_reverse_nram + 4 * bit_cond_stride, + channels * sizeof(T), NRAM2NRAM, channels * sizeof(T), + nq_nl_np); __bang_write_value(value_wp, nq_nl_np_c, (T)0); // clear value*wp __sync_move(); // (n, nl, np, c) => (c, n, nl, np) @@ -332,7 +332,7 @@ __mlu_func__ void backwardStageTwoLoop( int32_t* dst_offset = (int32_t*)offset_zero_nram_stg2; for (int i = 0; i < 4; i++) { __bang_filter((T*)dst_offset + i * nq_nl_np, - (T*)offset_nram + i * nq_nl_np, cond_all_valid, nq_nl_np); + (T*)offset_nram + i * nq_nl_np, cond_all_valid, nq_nl_np); } int32_t* src_offset = (int32_t*)inter_grad; int32_t* stride_4_2 = dst_offset + 3 * nq_nl_np; @@ -368,7 +368,7 @@ __mlu_func__ void backwardStageTwoLoop( int32_t valid_count = __bang_sum(tmp_cond, nq_nl_np); if (valid_count > 0) { __bang_filter((T*)tmp_dst_offset, (T*)tmp_dst_offset, tmp_cond, - nq_nl_np); + nq_nl_np); __bang_filter((T*)tmp_src_offset, (T*)seq_nram, tmp_cond, nq_nl_np); __bang_mul_scalar(tmp_src_offset, tmp_src_offset, channels * sizeof(T), valid_count); diff --git a/kernels/ms_deform_attn/ms_deform_attn_backward/ms_deform_attn_backward_small_channels_union1.mlu b/kernels/ms_deform_attn/ms_deform_attn_backward/ms_deform_attn_backward_small_channels_union1.mlu index 517c00a8c..9ff2a72e8 100644 --- a/kernels/ms_deform_attn/ms_deform_attn_backward/ms_deform_attn_backward_small_channels_union1.mlu +++ b/kernels/ms_deform_attn/ms_deform_attn_backward/ms_deform_attn_backward_small_channels_union1.mlu @@ -25,6 +25,7 @@ #include "core/logging.h" #include "kernels/kernel.h" #include "kernels/utils/common.h" +#include "kernels/utils/scatter_gather.h" __nram__ int8_t nram_buffer[MAX_NRAM_SIZE]; @@ -313,24 +314,25 @@ void __mlu_func__ loadValue( sizeof(int32_t), 4 * num_deal_grid); __sync_io_move_compute(); - __gather_async((void *)nram_grad_output_tl, (void *)data_value, - (unsigned int *)grad_temp3, deal_num_real * sizeof(float), - GDRAM2NRAM, deal_num_real * sizeof(float), num_deal_grid); - - __gather_async((void *)nram_grad_output_tr, (void *)data_value, - (unsigned int *)(grad_temp3 + num_deal_grid), - deal_num_real * sizeof(float), GDRAM2NRAM, - deal_num_real * sizeof(float), num_deal_grid); - - __gather_async((void *)nram_grad_output_bl, (void *)data_value, - (unsigned int *)(grad_temp3 + 2 * num_deal_grid), - deal_num_real * sizeof(float), GDRAM2NRAM, - deal_num_real * sizeof(float), num_deal_grid); - - __gather_async((void *)nram_grad_output_br, (void *)data_value, - (unsigned int *)(grad_temp3 + 3 * num_deal_grid), - deal_num_real * sizeof(float), GDRAM2NRAM, - deal_num_real * sizeof(float), num_deal_grid); + __mluop_gather((float *)nram_grad_output_tl, (float *)data_value, + (unsigned int *)grad_temp3, NULL, + deal_num_real * sizeof(float), GDRAM2NRAM, + deal_num_real * sizeof(float), num_deal_grid); + + __mluop_gather((float *)nram_grad_output_tr, (float *)data_value, + (unsigned int *)(grad_temp3 + num_deal_grid), NULL, + deal_num_real * sizeof(float), GDRAM2NRAM, + deal_num_real * sizeof(float), num_deal_grid); + + __mluop_gather((float *)nram_grad_output_bl, (float *)data_value, + (unsigned int *)(grad_temp3 + 2 * num_deal_grid), NULL, + deal_num_real * sizeof(float), GDRAM2NRAM, + deal_num_real * sizeof(float), num_deal_grid); + + __mluop_gather((float *)nram_grad_output_br, (float *)data_value, + (unsigned int *)(grad_temp3 + 3 * num_deal_grid), NULL, + deal_num_real * sizeof(float), GDRAM2NRAM, + deal_num_real * sizeof(float), num_deal_grid); __sync_io_move_compute(); #else diff --git a/kernels/ms_deform_attn/ms_deform_attn_forward/ms_deform_attn_utils.h b/kernels/ms_deform_attn/ms_deform_attn_forward/ms_deform_attn_utils.h index 0f4b4dd17..7ecb0b41f 100644 --- a/kernels/ms_deform_attn/ms_deform_attn_forward/ms_deform_attn_utils.h +++ b/kernels/ms_deform_attn/ms_deform_attn_forward/ms_deform_attn_utils.h @@ -28,6 +28,7 @@ #include "kernels/kernel.h" #include "kernels/utils/common.h" +#include "kernels/utils/scatter_gather.h" #define BIT_COLLECT_PAD (8) #define BACKWARD_MAX_NQ_NL_NP (1024) @@ -377,19 +378,12 @@ __mlu_func__ void stageOneLoop( #endif #if (__BANG_ARCH__ >= 592) +template __mlu_func__ void gatherAsync(void* dst, void* src, unsigned int* offset, void* mask, int transfer_size, mluMemcpyDirection_t dir, int dst_stride, int transfer_num) { - __gather_async(dst, src, offset, mask, transfer_size, dir, dst_stride, - transfer_num); -} - -__mlu_func__ void gatherSync(void* dst, void* src, unsigned int* offset, - void* mask, int transfer_size, - mluMemcpyDirection_t dir, int dst_stride, - int transfer_num) { - __gather(dst, src, offset, mask, transfer_size, dir, dst_stride, - transfer_num); + __mluop_gather_async((T*)dst, (T*)src, offset, (uint8_t*)mask, + transfer_size, dir, dst_stride, transfer_num); } #endif diff --git a/kernels/ms_deform_attn/ms_deform_attn_forward/msda_forward_fast_union1.mlu b/kernels/ms_deform_attn/ms_deform_attn_forward/msda_forward_fast_union1.mlu index 2d29981e2..a4c61a979 100644 --- a/kernels/ms_deform_attn/ms_deform_attn_forward/msda_forward_fast_union1.mlu +++ b/kernels/ms_deform_attn/ms_deform_attn_forward/msda_forward_fast_union1.mlu @@ -290,7 +290,7 @@ __mlu_func__ void getConditionCoordWeight( } __bang_mul_scalar(buf_nram, weight_attn_nram, (T)1, total_points); __bang_filter((float*)weight_attn_nram, (float*)buf_nram, - cond_point_valid_nram, total_points); + cond_point_valid_nram, total_points); __bang_float2int32((int32_t*)cond_point_polation_nram, cond_point_polation_nram, total_points * 4, 0); __bang_mul_scalar((int32_t*)cond_point_polation_nram, @@ -300,16 +300,16 @@ __mlu_func__ void getConditionCoordWeight( (int8_t*)cond_point_polation_nram, total_points * 4 * sizeof(float)); __bang_filter((float*)weight_polation_nram, (float*)weight_polation_nram_tmp, - cond_point_valid_nram, total_points); + cond_point_valid_nram, total_points); __bang_filter((float*)weight_polation_nram + total_points, - (float*)weight_polation_nram_tmp + total_points, - cond_point_valid_nram, total_points); + (float*)weight_polation_nram_tmp + total_points, + cond_point_valid_nram, total_points); __bang_filter((float*)weight_polation_nram + 2 * total_points, - (float*)weight_polation_nram_tmp + 2 * total_points, - cond_point_valid_nram, total_points); + (float*)weight_polation_nram_tmp + 2 * total_points, + cond_point_valid_nram, total_points); __bang_filter((float*)weight_polation_nram + 3 * total_points, - (float*)weight_polation_nram_tmp + 3 * total_points, - cond_point_valid_nram, total_points); + (float*)weight_polation_nram_tmp + 3 * total_points, + cond_point_valid_nram, total_points); //================================================================================================ // select cond_point_polation_nram if value_contain_infnan if (value_contain_infnan) { @@ -318,17 +318,17 @@ __mlu_func__ void getConditionCoordWeight( (int32_t*)cond_point_polation_nram, (int32_t)1, total_points * 4); __bang_filter((float*)cond_point_polation_nram, - (float*)cond_point_polation_nram_tmp, cond_point_valid_nram, - total_points); + (float*)cond_point_polation_nram_tmp, cond_point_valid_nram, + total_points); __bang_filter((float*)cond_point_polation_nram + total_points, - (float*)cond_point_polation_nram_tmp + total_points, - cond_point_valid_nram, total_points); + (float*)cond_point_polation_nram_tmp + total_points, + cond_point_valid_nram, total_points); __bang_filter((float*)cond_point_polation_nram + 2 * total_points, - (float*)cond_point_polation_nram_tmp + 2 * total_points, - cond_point_valid_nram, total_points); + (float*)cond_point_polation_nram_tmp + 2 * total_points, + cond_point_valid_nram, total_points); __bang_filter((float*)cond_point_polation_nram + 3 * total_points, - (float*)cond_point_polation_nram_tmp + 3 * total_points, - cond_point_valid_nram, total_points); + (float*)cond_point_polation_nram_tmp + 3 * total_points, + cond_point_valid_nram, total_points); } //================================================================================================ // compute and select offset and stride @@ -348,11 +348,11 @@ __mlu_func__ void getConditionCoordWeight( (int32_t*)data_offset_nram_tr_tmp, (int32_t*)data_offset_nram_tl_tmp, total_points); __bang_filter((float*)data_offset_nram_tl, (float*)data_offset_nram_tl_tmp, - cond_point_valid_nram, total_points); + cond_point_valid_nram, total_points); __bang_filter((float*)data_offset_nram_bl, (float*)data_offset_nram_bl_tmp, - cond_point_valid_nram, total_points); + cond_point_valid_nram, total_points); __bang_filter((float*)data_offset_nram_tr, (float*)data_offset_nram_tr_tmp, - cond_point_valid_nram, total_points); + cond_point_valid_nram, total_points); } /* @@ -1068,12 +1068,12 @@ __mlu_func__ void forwardStageTwoLoop( __sync_io_move_compute(); if (i < loop_num) { - gatherAsync(v_load, zeros_nram, (unsigned int*)offset_zero_nram_stg2, - cond_nram_stg2_reverse, channels * sizeof(T), NRAM2NRAM, - channels * sizeof(T), load_point_num); - gatherAsync(v_load, data_value_gdram, (unsigned int*)offset_nram_stg2, - cond_nram_stg2, channels * sizeof(T), GDRAM2NRAM, - channels * sizeof(T), load_point_num); + gatherAsync(v_load, zeros_nram, (unsigned int*)offset_zero_nram_stg2, + cond_nram_stg2_reverse, channels * sizeof(T), NRAM2NRAM, + channels * sizeof(T), load_point_num); + gatherAsync(v_load, data_value_gdram, (unsigned int*)offset_nram_stg2, + cond_nram_stg2, channels * sizeof(T), GDRAM2NRAM, + channels * sizeof(T), load_point_num); } if (i > 0) { diff --git a/kernels/roi_align_rotated/roi_align_rotated_forward_vector.mlu b/kernels/roi_align_rotated/roi_align_rotated_forward_vector.mlu index d226df82c..e8c545e04 100644 --- a/kernels/roi_align_rotated/roi_align_rotated_forward_vector.mlu +++ b/kernels/roi_align_rotated/roi_align_rotated_forward_vector.mlu @@ -38,14 +38,12 @@ __mlu_func__ void mluopDivScalar(T *dst, T *src, T value, uint32_t num) { __asm__ volatile( "div.scalar.nram.f16 [%[dst]], [%[src0]], " "%[src1], %[num];\n\t" ::[dst] "r"(dst), - [ src0 ] "r"(src), [ src1 ] "r"(value), - [ num ] "r"(num)); + [ src0 ] "r"(src), [ src1 ] "r"(value), [ num ] "r"(num)); } else { __asm__ volatile( "div.scalar.nram.f32 [%[dst]], [%[src0]], " "%[src1], %[num];\n\t" ::[dst] "r"(dst), - [ src0 ] "r"(src), [ src1 ] "r"(value), - [ num ] "r"(num)); + [ src0 ] "r"(src), [ src1 ] "r"(value), [ num ] "r"(num)); } } @@ -314,6 +312,7 @@ __mlu_func__ void handleChannels(const T *input, uint32_t deal_channels, } uint32_t hwc_num = deal_channels * vec_num; + // vec_num <= 1024 __gather(val, input, pos, deal_channels * sizeof(T), GDRAM2NRAM, deal_channels * sizeof(T), vec_num); if (deal_channels != 1) { @@ -521,8 +520,7 @@ __mlu_global__ void roiAlignRotatedForward( } } } - mluopDivScalar(output_channels, output_channels, (T)count, - cur_cache_c); + mluopDivScalar(output_channels, output_channels, (T)count, cur_cache_c); __memcpy(output_dram + bin_i * channels + c_cache_i, output_channels, cur_cache_c * sizeof(T), NRAM2GDRAM); } diff --git a/kernels/utils/scatter_gather.h b/kernels/utils/scatter_gather.h new file mode 100644 index 000000000..729fd9c9d --- /dev/null +++ b/kernels/utils/scatter_gather.h @@ -0,0 +1,90 @@ +/************************************************************************* + * Copyright (C) [2024] by Cambricon, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "kernels/kernel.h" + +#define SCATTER_GATHER_PARAMS \ + T *dst, const T *src, const uint32_t *offset, const uint8_t *mask, \ + const uint32_t transfer_size, const mluMemcpyDirection_t dir, \ + const uint32_t stride, const uint32_t data_num + +#if __BANG_ARCH__ > 592 +#define MLUOP_SCATTER_GATHER(func, is_scatter) \ + template \ + __mlu_func__ void __mluop_##func(SCATTER_GATHER_PARAMS) { \ + if (data_num <= UINT16_MAX) { \ + if (mask) { \ + __##func(dst, src, offset, (const void *)mask, transfer_size, dir, \ + stride, data_num); \ + } else { \ + __##func(dst, src, offset, transfer_size, dir, stride, data_num); \ + } \ + } else { \ + uint16_t data_num_new = PAD_DOWN(UINT16_MAX, 64); \ + uint32_t remain = data_num % data_num_new; \ + uint32_t repeat = data_num / data_num_new + uint32_t(remain > 0); \ + uint32_t dst_offset = is_scatter ? 0 : data_num_new; \ + uint32_t src_offset = is_scatter ? data_num_new : 0; \ + \ + for (uint32_t i = 0; i <= repeat; ++i) { \ + const uint16_t data_num_loop = i < repeat ? data_num_new : remain; \ + if (mask) { \ + __##func(dst + i * dst_offset, src + i * src_offset, \ + mask + i * (data_num_new / 8), offset + i * data_num_new, \ + transfer_size, dir, stride, data_num_loop); \ + } else { \ + __##func(dst + i * dst_offset, src + i * src_offset, \ + offset + i * data_num_new, transfer_size, dir, stride, \ + data_num_loop); \ + } \ + } \ + } \ + } + +/* __mlu_op_scatter + * __mlu_op_scatter_async + * __mlu_op_gather + * __mlu_op_gather_async + */ +MLUOP_SCATTER_GATHER(gather_async, false) +MLUOP_SCATTER_GATHER(gather, false) +MLUOP_SCATTER_GATHER(scatter_async, true) +MLUOP_SCATTER_GATHER(scatter, true) + +#elif __BANG_ARCH__ == 592 +#define MLUOP_SCATTER_GATHER(func) \ + template \ + __mlu_func__ void __mluop_##func(SCATTER_GATHER_PARAMS) { \ + if (mask) { \ + __##func(dst, src, offset, mask, transfer_size, dir, stride, data_num); \ + } else { \ + __##func(dst, src, offset, transfer_size, dir, stride, data_num); \ + } \ + } + +MLUOP_SCATTER_GATHER(gather_async) +MLUOP_SCATTER_GATHER(gather) +MLUOP_SCATTER_GATHER(scatter_async) +MLUOP_SCATTER_GATHER(scatter) + +#endif // __BANG_ARCH__ > 592 diff --git a/kernels/voxel_pooling_forward/voxel_pooling_forward_union1.mlu b/kernels/voxel_pooling_forward/voxel_pooling_forward_union1.mlu index 90ecc8363..a7ff5fbb8 100644 --- a/kernels/voxel_pooling_forward/voxel_pooling_forward_union1.mlu +++ b/kernels/voxel_pooling_forward/voxel_pooling_forward_union1.mlu @@ -25,6 +25,7 @@ #include "core/logging.h" #include "kernels/kernel.h" #include "kernels/utils/common.h" +#include "kernels/utils/scatter_gather.h" __nram__ int8_t nram_buffer[MAX_NRAM_SIZE]; @@ -392,10 +393,11 @@ __mlu_func__ void MLUKernelVoxelPoolingStageTwoPerfKernel( __bang_ge_bitindex((float *)gather_mask, (float *)nram_geom + point_idx_offset, (float *)nram_geom_x, align_8_deal_num); - __gather((float *)gather_src, (float *)input_features, - (unsigned int *)gather_offset + point_idx_offset, - (void *)gather_mask, num_channels * sizeof(float), GDRAM2NRAM, - num_channels * sizeof(float), actual_load_num); + __mluop_gather((float *)gather_src, (float *)input_features, + (unsigned int *)gather_offset + point_idx_offset, + (uint8_t *)gather_mask, + num_channels * sizeof(float), GDRAM2NRAM, + num_channels * sizeof(float), actual_load_num); for (int index = 0; index < actual_load_num; index++) { int output_features_pt_offset = nram_geom[point_idx_offset + index]; if (output_features_pt_offset >= 0) { diff --git a/kernels/voxelization/voxelization_kernel.mlu b/kernels/voxelization/voxelization_kernel.mlu index 04f5580e7..9832ab4bf 100644 --- a/kernels/voxelization/voxelization_kernel.mlu +++ b/kernels/voxelization/voxelization_kernel.mlu @@ -28,6 +28,7 @@ #include "core/logging.h" #include "kernels/kernel.h" #include "kernels/utils/common.h" +#include "kernels/utils/scatter_gather.h" __nram__ int8_t nram_buffer[MAX_NRAM_SIZE]; @@ -547,9 +548,10 @@ __mlu_global__ void mluCalcPointsPerVoxel( // compute scatter src: voxel_idx __bang_add_scalar(nram_temp_mask, nram_base_offset, voxel_num_temp, deal_num); - __scatter(nram_scatter_output, nram_temp_mask, - (unsigned int *)nram_scatter_offset, nram_mask_bitindex, - sizeof(int32_t), NRAM2NRAM, sizeof(int32_t), reserve_voxels); + __mluop_scatter(nram_scatter_output, nram_temp_mask, + (unsigned int *)nram_scatter_offset, + (uint8_t *)nram_mask_bitindex, sizeof(int32_t), + NRAM2NRAM, sizeof(int32_t), reserve_voxels); __memcpy(num_points_per_voxel + voxel_num_temp, nram_scatter_mask, reserve_voxels * sizeof(int32_t), NRAM2GDRAM); voxel_num_temp += reserve_voxels; @@ -568,8 +570,9 @@ __mlu_global__ void mluCalcPointsPerVoxel( if (count > 0) { __bang_mul_scalar(nram_p2p_idx, nram_p2p_idx, sizeof(int32_t), count); // get repeated point real point_id - __gather(gather_output, coor_to_voxelidx, (unsigned int *)nram_p2p_idx, - sizeof(int32_t), GDRAM2NRAM, sizeof(int32_t), count); + __mluop_gather( + gather_output, coor_to_voxelidx, (unsigned int *)nram_p2p_idx, NULL, + sizeof(int32_t), GDRAM2NRAM, sizeof(int32_t), count); __bang_eq_scalar(nram_scatter_mask, gather_output, -1, count); __bang_not(nram_scatter_mask, nram_scatter_mask, count); __bang_gt_bitindex((float *)nram_mask_bitindex, @@ -582,9 +585,10 @@ __mlu_global__ void mluCalcPointsPerVoxel( gather_mask, deal_num); __bang_mul_scalar(nram_temp_mask, nram_temp_mask, sizeof(int32_t), deal_num); - __scatter(coor_to_voxelidx, gather_output, - (unsigned int *)nram_temp_mask, nram_mask_bitindex, - sizeof(int32_t), NRAM2GDRAM, sizeof(int32_t), count); + __mluop_scatter(coor_to_voxelidx, gather_output, + (unsigned int *)nram_temp_mask, + (uint8_t *)nram_mask_bitindex, sizeof(int32_t), + NRAM2GDRAM, sizeof(int32_t), count); // step4: compute num_points_per_voxel for (int32_t i = 0; i < count; i++) {