Skip to content

Commit

Permalink
[Feature](mlu-ops): adapt scatter,gather (#1168)
Browse files Browse the repository at this point in the history
  • Loading branch information
PetrelYy authored Dec 4, 2024
1 parent 3a674cc commit 59eae84
Show file tree
Hide file tree
Showing 10 changed files with 191 additions and 98 deletions.
28 changes: 15 additions & 13 deletions kernels/box_iou_rotated/box_iou_rotated_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#define KERNELS_BOX_IOU_ROTATED_BOX_IOU_ROTATED_UTILS_H_

#include "kernels/utils/common.h"
#include "kernels/utils/scatter_gather.h"

#define FIILED_ONES (int)0xffffffff
#define HALF_FILLED_ONES (int16_t)0xffff
Expand Down Expand Up @@ -590,21 +591,22 @@ __mlu_func__ void convexHullGraham(
sizeof(T), actual_compute_box_num);

// get the ordered points according to the angle value
__gather(ordered_pts_x + (i + 1) * actual_compute_box_num, intersect_pts_x,
(unsigned int *)temp_offset, sizeof(T), NRAM2NRAM, sizeof(T),
actual_compute_box_num);
__gather(ordered_pts_y + (i + 1) * actual_compute_box_num, intersect_pts_y,
(unsigned int *)temp_offset, sizeof(T), NRAM2NRAM, sizeof(T),
actual_compute_box_num);
__gather(temp_long_1 + (i + 1) * actual_compute_box_num, valid_pts,
(unsigned int *)temp_offset, sizeof(T), NRAM2NRAM, sizeof(T),
actual_compute_box_num);
__mluop_gather<T>(ordered_pts_x + (i + 1) * actual_compute_box_num,
intersect_pts_x, (unsigned int *)temp_offset, NULL,
sizeof(T), NRAM2NRAM, sizeof(T), actual_compute_box_num);
__mluop_gather<T>(ordered_pts_y + (i + 1) * actual_compute_box_num,
intersect_pts_y, (unsigned int *)temp_offset, NULL,
sizeof(T), NRAM2NRAM, sizeof(T), actual_compute_box_num);
__mluop_gather<T>(temp_long_1 + (i + 1) * actual_compute_box_num, valid_pts,
(unsigned int *)temp_offset, NULL, sizeof(T), NRAM2NRAM,
sizeof(T), actual_compute_box_num);

// assign a invalid value to the point which has been get ordered
__scatter(temp_long_2, temp1_ram, (unsigned int *)temp_offset, sizeof(T),
NRAM2NRAM, sizeof(T), actual_compute_box_num);
__scatter(valid_pts, temp2_ram, (unsigned int *)temp_offset, sizeof(T),
NRAM2NRAM, sizeof(T), actual_compute_box_num);
__mluop_scatter<T>(temp_long_2, temp1_ram, (unsigned int *)temp_offset,
NULL, sizeof(T), NRAM2NRAM, sizeof(T),
actual_compute_box_num);
__mluop_scatter<T>(valid_pts, temp2_ram, (unsigned int *)temp_offset, NULL,
sizeof(T), NRAM2NRAM, sizeof(T), actual_compute_box_num);
}
__bang_move(valid_pts, temp_long_1, total_points * sizeof(T));
#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ __mlu_func__ void proposalBoxesDecode(
// gather offset (byte).
__bang_mul_scalar(anchors_index_nram, anchors_index_nram, sizeof(int32_t),
deal_num);
// deal_num <= 5163
__gather(temp_nram, anchors, (unsigned int *)anchors_index_nram,
sizeof(T) * 4, GDRAM2NRAM, sizeof(T) * 4, deal_num);
__bang_transpose(anchors_nram, temp_nram, deal_num, 4);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,14 +233,14 @@ __mlu_func__ void backwardStageTwoLoop(
for (int j = 0; j < 5; j++) {
T* tmp_wp = weight_polation_nram + (j - 1) * nq_nl_np;
if (j < 4) {
gatherAsync(v_ping, zeros_nram, (unsigned int*)offset_zero_nram_stg2,
bit_cond_reverse_nram + j * bit_cond_stride,
channels * sizeof(T), NRAM2NRAM, channels * sizeof(T),
nq_nl_np);
gatherAsync(v_ping, data_value_gdram,
(unsigned int*)offset_nram + j * nq_nl_np,
bit_cond_nram + j * bit_cond_stride, channels * sizeof(T),
GDRAM2NRAM, channels * sizeof(T), nq_nl_np);
gatherAsync<T>(v_ping, zeros_nram, (unsigned int*)offset_zero_nram_stg2,
bit_cond_reverse_nram + j * bit_cond_stride,
channels * sizeof(T), NRAM2NRAM, channels * sizeof(T),
nq_nl_np);
gatherAsync<T>(
v_ping, data_value_gdram, (unsigned int*)offset_nram + j * nq_nl_np,
bit_cond_nram + j * bit_cond_stride, channels * sizeof(T),
GDRAM2NRAM, channels * sizeof(T), nq_nl_np);
}

if (j == 0) {
Expand All @@ -249,10 +249,10 @@ __mlu_func__ void backwardStageTwoLoop(
NRAM2NRAM, channels * sizeof(T), num_levels_points - 1,
num_levels_points * channels * sizeof(T), deal_n - 1, 0,
num_levels_points - 1, channels * sizeof(T), deal_n - 1);
gatherAsync(buffer, zeros_nram, (unsigned int*)offset_zero_nram_stg2,
bit_cond_reverse_nram + 4 * bit_cond_stride,
channels * sizeof(T), NRAM2NRAM, channels * sizeof(T),
nq_nl_np);
gatherAsync<T>(buffer, zeros_nram, (unsigned int*)offset_zero_nram_stg2,
bit_cond_reverse_nram + 4 * bit_cond_stride,
channels * sizeof(T), NRAM2NRAM, channels * sizeof(T),
nq_nl_np);
__bang_write_value(value_wp, nq_nl_np_c, (T)0); // clear value*wp
__sync_move();
// (n, nl, np, c) => (c, n, nl, np)
Expand Down Expand Up @@ -332,7 +332,7 @@ __mlu_func__ void backwardStageTwoLoop(
int32_t* dst_offset = (int32_t*)offset_zero_nram_stg2;
for (int i = 0; i < 4; i++) {
__bang_filter((T*)dst_offset + i * nq_nl_np,
(T*)offset_nram + i * nq_nl_np, cond_all_valid, nq_nl_np);
(T*)offset_nram + i * nq_nl_np, cond_all_valid, nq_nl_np);
}
int32_t* src_offset = (int32_t*)inter_grad;
int32_t* stride_4_2 = dst_offset + 3 * nq_nl_np;
Expand Down Expand Up @@ -368,7 +368,7 @@ __mlu_func__ void backwardStageTwoLoop(
int32_t valid_count = __bang_sum(tmp_cond, nq_nl_np);
if (valid_count > 0) {
__bang_filter((T*)tmp_dst_offset, (T*)tmp_dst_offset, tmp_cond,
nq_nl_np);
nq_nl_np);
__bang_filter((T*)tmp_src_offset, (T*)seq_nram, tmp_cond, nq_nl_np);
__bang_mul_scalar(tmp_src_offset, tmp_src_offset, channels * sizeof(T),
valid_count);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "core/logging.h"
#include "kernels/kernel.h"
#include "kernels/utils/common.h"
#include "kernels/utils/scatter_gather.h"

__nram__ int8_t nram_buffer[MAX_NRAM_SIZE];

Expand Down Expand Up @@ -313,24 +314,25 @@ void __mlu_func__ loadValue(
sizeof(int32_t), 4 * num_deal_grid);
__sync_io_move_compute();

__gather_async((void *)nram_grad_output_tl, (void *)data_value,
(unsigned int *)grad_temp3, deal_num_real * sizeof(float),
GDRAM2NRAM, deal_num_real * sizeof(float), num_deal_grid);

__gather_async((void *)nram_grad_output_tr, (void *)data_value,
(unsigned int *)(grad_temp3 + num_deal_grid),
deal_num_real * sizeof(float), GDRAM2NRAM,
deal_num_real * sizeof(float), num_deal_grid);

__gather_async((void *)nram_grad_output_bl, (void *)data_value,
(unsigned int *)(grad_temp3 + 2 * num_deal_grid),
deal_num_real * sizeof(float), GDRAM2NRAM,
deal_num_real * sizeof(float), num_deal_grid);

__gather_async((void *)nram_grad_output_br, (void *)data_value,
(unsigned int *)(grad_temp3 + 3 * num_deal_grid),
deal_num_real * sizeof(float), GDRAM2NRAM,
deal_num_real * sizeof(float), num_deal_grid);
__mluop_gather<float>((float *)nram_grad_output_tl, (float *)data_value,
(unsigned int *)grad_temp3, NULL,
deal_num_real * sizeof(float), GDRAM2NRAM,
deal_num_real * sizeof(float), num_deal_grid);

__mluop_gather<float>((float *)nram_grad_output_tr, (float *)data_value,
(unsigned int *)(grad_temp3 + num_deal_grid), NULL,
deal_num_real * sizeof(float), GDRAM2NRAM,
deal_num_real * sizeof(float), num_deal_grid);

__mluop_gather<float>((float *)nram_grad_output_bl, (float *)data_value,
(unsigned int *)(grad_temp3 + 2 * num_deal_grid), NULL,
deal_num_real * sizeof(float), GDRAM2NRAM,
deal_num_real * sizeof(float), num_deal_grid);

__mluop_gather<float>((float *)nram_grad_output_br, (float *)data_value,
(unsigned int *)(grad_temp3 + 3 * num_deal_grid), NULL,
deal_num_real * sizeof(float), GDRAM2NRAM,
deal_num_real * sizeof(float), num_deal_grid);
__sync_io_move_compute();

#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include "kernels/kernel.h"
#include "kernels/utils/common.h"
#include "kernels/utils/scatter_gather.h"

#define BIT_COLLECT_PAD (8)
#define BACKWARD_MAX_NQ_NL_NP (1024)
Expand Down Expand Up @@ -377,19 +378,12 @@ __mlu_func__ void stageOneLoop(
#endif

#if (__BANG_ARCH__ >= 592)
template <typename T>
__mlu_func__ void gatherAsync(void* dst, void* src, unsigned int* offset,
void* mask, int transfer_size,
mluMemcpyDirection_t dir, int dst_stride,
int transfer_num) {
__gather_async(dst, src, offset, mask, transfer_size, dir, dst_stride,
transfer_num);
}

__mlu_func__ void gatherSync(void* dst, void* src, unsigned int* offset,
void* mask, int transfer_size,
mluMemcpyDirection_t dir, int dst_stride,
int transfer_num) {
__gather(dst, src, offset, mask, transfer_size, dir, dst_stride,
transfer_num);
__mluop_gather_async<T>((T*)dst, (T*)src, offset, (uint8_t*)mask,
transfer_size, dir, dst_stride, transfer_num);
}
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ __mlu_func__ void getConditionCoordWeight(
}
__bang_mul_scalar(buf_nram, weight_attn_nram, (T)1, total_points);
__bang_filter((float*)weight_attn_nram, (float*)buf_nram,
cond_point_valid_nram, total_points);
cond_point_valid_nram, total_points);
__bang_float2int32((int32_t*)cond_point_polation_nram,
cond_point_polation_nram, total_points * 4, 0);
__bang_mul_scalar((int32_t*)cond_point_polation_nram,
Expand All @@ -300,16 +300,16 @@ __mlu_func__ void getConditionCoordWeight(
(int8_t*)cond_point_polation_nram,
total_points * 4 * sizeof(float));
__bang_filter((float*)weight_polation_nram, (float*)weight_polation_nram_tmp,
cond_point_valid_nram, total_points);
cond_point_valid_nram, total_points);
__bang_filter((float*)weight_polation_nram + total_points,
(float*)weight_polation_nram_tmp + total_points,
cond_point_valid_nram, total_points);
(float*)weight_polation_nram_tmp + total_points,
cond_point_valid_nram, total_points);
__bang_filter((float*)weight_polation_nram + 2 * total_points,
(float*)weight_polation_nram_tmp + 2 * total_points,
cond_point_valid_nram, total_points);
(float*)weight_polation_nram_tmp + 2 * total_points,
cond_point_valid_nram, total_points);
__bang_filter((float*)weight_polation_nram + 3 * total_points,
(float*)weight_polation_nram_tmp + 3 * total_points,
cond_point_valid_nram, total_points);
(float*)weight_polation_nram_tmp + 3 * total_points,
cond_point_valid_nram, total_points);
//================================================================================================
// select cond_point_polation_nram if value_contain_infnan
if (value_contain_infnan) {
Expand All @@ -318,17 +318,17 @@ __mlu_func__ void getConditionCoordWeight(
(int32_t*)cond_point_polation_nram, (int32_t)1,
total_points * 4);
__bang_filter((float*)cond_point_polation_nram,
(float*)cond_point_polation_nram_tmp, cond_point_valid_nram,
total_points);
(float*)cond_point_polation_nram_tmp, cond_point_valid_nram,
total_points);
__bang_filter((float*)cond_point_polation_nram + total_points,
(float*)cond_point_polation_nram_tmp + total_points,
cond_point_valid_nram, total_points);
(float*)cond_point_polation_nram_tmp + total_points,
cond_point_valid_nram, total_points);
__bang_filter((float*)cond_point_polation_nram + 2 * total_points,
(float*)cond_point_polation_nram_tmp + 2 * total_points,
cond_point_valid_nram, total_points);
(float*)cond_point_polation_nram_tmp + 2 * total_points,
cond_point_valid_nram, total_points);
__bang_filter((float*)cond_point_polation_nram + 3 * total_points,
(float*)cond_point_polation_nram_tmp + 3 * total_points,
cond_point_valid_nram, total_points);
(float*)cond_point_polation_nram_tmp + 3 * total_points,
cond_point_valid_nram, total_points);
}
//================================================================================================
// compute and select offset and stride
Expand All @@ -348,11 +348,11 @@ __mlu_func__ void getConditionCoordWeight(
(int32_t*)data_offset_nram_tr_tmp,
(int32_t*)data_offset_nram_tl_tmp, total_points);
__bang_filter((float*)data_offset_nram_tl, (float*)data_offset_nram_tl_tmp,
cond_point_valid_nram, total_points);
cond_point_valid_nram, total_points);
__bang_filter((float*)data_offset_nram_bl, (float*)data_offset_nram_bl_tmp,
cond_point_valid_nram, total_points);
cond_point_valid_nram, total_points);
__bang_filter((float*)data_offset_nram_tr, (float*)data_offset_nram_tr_tmp,
cond_point_valid_nram, total_points);
cond_point_valid_nram, total_points);
}

/*
Expand Down Expand Up @@ -1068,12 +1068,12 @@ __mlu_func__ void forwardStageTwoLoop(
__sync_io_move_compute();

if (i < loop_num) {
gatherAsync(v_load, zeros_nram, (unsigned int*)offset_zero_nram_stg2,
cond_nram_stg2_reverse, channels * sizeof(T), NRAM2NRAM,
channels * sizeof(T), load_point_num);
gatherAsync(v_load, data_value_gdram, (unsigned int*)offset_nram_stg2,
cond_nram_stg2, channels * sizeof(T), GDRAM2NRAM,
channels * sizeof(T), load_point_num);
gatherAsync<T>(v_load, zeros_nram, (unsigned int*)offset_zero_nram_stg2,
cond_nram_stg2_reverse, channels * sizeof(T), NRAM2NRAM,
channels * sizeof(T), load_point_num);
gatherAsync<T>(v_load, data_value_gdram, (unsigned int*)offset_nram_stg2,
cond_nram_stg2, channels * sizeof(T), GDRAM2NRAM,
channels * sizeof(T), load_point_num);
}

if (i > 0) {
Expand Down
10 changes: 4 additions & 6 deletions kernels/roi_align_rotated/roi_align_rotated_forward_vector.mlu
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,12 @@ __mlu_func__ void mluopDivScalar(T *dst, T *src, T value, uint32_t num) {
__asm__ volatile(
"div.scalar.nram.f16 [%[dst]], [%[src0]], "
"%[src1], %[num];\n\t" ::[dst] "r"(dst),
[ src0 ] "r"(src), [ src1 ] "r"(value),
[ num ] "r"(num));
[ src0 ] "r"(src), [ src1 ] "r"(value), [ num ] "r"(num));
} else {
__asm__ volatile(
"div.scalar.nram.f32 [%[dst]], [%[src0]], "
"%[src1], %[num];\n\t" ::[dst] "r"(dst),
[ src0 ] "r"(src), [ src1 ] "r"(value),
[ num ] "r"(num));
[ src0 ] "r"(src), [ src1 ] "r"(value), [ num ] "r"(num));
}
}

Expand Down Expand Up @@ -314,6 +312,7 @@ __mlu_func__ void handleChannels(const T *input, uint32_t deal_channels,
}
uint32_t hwc_num = deal_channels * vec_num;

// vec_num <= 1024
__gather(val, input, pos, deal_channels * sizeof(T), GDRAM2NRAM,
deal_channels * sizeof(T), vec_num);
if (deal_channels != 1) {
Expand Down Expand Up @@ -521,8 +520,7 @@ __mlu_global__ void roiAlignRotatedForward(
}
}
}
mluopDivScalar(output_channels, output_channels, (T)count,
cur_cache_c);
mluopDivScalar(output_channels, output_channels, (T)count, cur_cache_c);
__memcpy(output_dram + bin_i * channels + c_cache_i, output_channels,
cur_cache_c * sizeof(T), NRAM2GDRAM);
}
Expand Down
Loading

0 comments on commit 59eae84

Please sign in to comment.