Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature](mlu-ops): adapt scatter,gather #1170

Merged
merged 1 commit into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions kernels/box_iou_rotated/box_iou_rotated_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#define KERNELS_BOX_IOU_ROTATED_BOX_IOU_ROTATED_UTILS_H_

#include "kernels/utils/common.h"
#include "kernels/utils/scatter_gather.h"

#define FIILED_ONES (int)0xffffffff
#define HALF_FILLED_ONES (int16_t)0xffff
Expand Down Expand Up @@ -590,21 +591,22 @@ __mlu_func__ void convexHullGraham(
sizeof(T), actual_compute_box_num);

// get the ordered points according to the angle value
__gather(ordered_pts_x + (i + 1) * actual_compute_box_num, intersect_pts_x,
(unsigned int *)temp_offset, sizeof(T), NRAM2NRAM, sizeof(T),
actual_compute_box_num);
__gather(ordered_pts_y + (i + 1) * actual_compute_box_num, intersect_pts_y,
(unsigned int *)temp_offset, sizeof(T), NRAM2NRAM, sizeof(T),
actual_compute_box_num);
__gather(temp_long_1 + (i + 1) * actual_compute_box_num, valid_pts,
(unsigned int *)temp_offset, sizeof(T), NRAM2NRAM, sizeof(T),
actual_compute_box_num);
__mluop_gather<T>(ordered_pts_x + (i + 1) * actual_compute_box_num,
intersect_pts_x, (unsigned int *)temp_offset, NULL,
sizeof(T), NRAM2NRAM, sizeof(T), actual_compute_box_num);
__mluop_gather<T>(ordered_pts_y + (i + 1) * actual_compute_box_num,
intersect_pts_y, (unsigned int *)temp_offset, NULL,
sizeof(T), NRAM2NRAM, sizeof(T), actual_compute_box_num);
__mluop_gather<T>(temp_long_1 + (i + 1) * actual_compute_box_num, valid_pts,
(unsigned int *)temp_offset, NULL, sizeof(T), NRAM2NRAM,
sizeof(T), actual_compute_box_num);

// assign a invalid value to the point which has been get ordered
__scatter(temp_long_2, temp1_ram, (unsigned int *)temp_offset, sizeof(T),
NRAM2NRAM, sizeof(T), actual_compute_box_num);
__scatter(valid_pts, temp2_ram, (unsigned int *)temp_offset, sizeof(T),
NRAM2NRAM, sizeof(T), actual_compute_box_num);
__mluop_scatter<T>(temp_long_2, temp1_ram, (unsigned int *)temp_offset,
NULL, sizeof(T), NRAM2NRAM, sizeof(T),
actual_compute_box_num);
__mluop_scatter<T>(valid_pts, temp2_ram, (unsigned int *)temp_offset, NULL,
sizeof(T), NRAM2NRAM, sizeof(T), actual_compute_box_num);
}
__bang_move(valid_pts, temp_long_1, total_points * sizeof(T));
#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ __mlu_func__ void proposalBoxesDecode(
// gather offset (byte).
__bang_mul_scalar(anchors_index_nram, anchors_index_nram, sizeof(int32_t),
deal_num);
// deal_num <= 5163
__gather(temp_nram, anchors, (unsigned int *)anchors_index_nram,
sizeof(T) * 4, GDRAM2NRAM, sizeof(T) * 4, deal_num);
__bang_transpose(anchors_nram, temp_nram, deal_num, 4);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,14 +233,14 @@ __mlu_func__ void backwardStageTwoLoop(
for (int j = 0; j < 5; j++) {
T* tmp_wp = weight_polation_nram + (j - 1) * nq_nl_np;
if (j < 4) {
gatherAsync(v_ping, zeros_nram, (unsigned int*)offset_zero_nram_stg2,
bit_cond_reverse_nram + j * bit_cond_stride,
channels * sizeof(T), NRAM2NRAM, channels * sizeof(T),
nq_nl_np);
gatherAsync(v_ping, data_value_gdram,
(unsigned int*)offset_nram + j * nq_nl_np,
bit_cond_nram + j * bit_cond_stride, channels * sizeof(T),
GDRAM2NRAM, channels * sizeof(T), nq_nl_np);
gatherAsync<T>(v_ping, zeros_nram, (unsigned int*)offset_zero_nram_stg2,
bit_cond_reverse_nram + j * bit_cond_stride,
channels * sizeof(T), NRAM2NRAM, channels * sizeof(T),
nq_nl_np);
gatherAsync<T>(
v_ping, data_value_gdram, (unsigned int*)offset_nram + j * nq_nl_np,
bit_cond_nram + j * bit_cond_stride, channels * sizeof(T),
GDRAM2NRAM, channels * sizeof(T), nq_nl_np);
}

if (j == 0) {
Expand All @@ -249,10 +249,10 @@ __mlu_func__ void backwardStageTwoLoop(
NRAM2NRAM, channels * sizeof(T), num_levels_points - 1,
num_levels_points * channels * sizeof(T), deal_n - 1, 0,
num_levels_points - 1, channels * sizeof(T), deal_n - 1);
gatherAsync(buffer, zeros_nram, (unsigned int*)offset_zero_nram_stg2,
bit_cond_reverse_nram + 4 * bit_cond_stride,
channels * sizeof(T), NRAM2NRAM, channels * sizeof(T),
nq_nl_np);
gatherAsync<T>(buffer, zeros_nram, (unsigned int*)offset_zero_nram_stg2,
bit_cond_reverse_nram + 4 * bit_cond_stride,
channels * sizeof(T), NRAM2NRAM, channels * sizeof(T),
nq_nl_np);
__bang_write_value(value_wp, nq_nl_np_c, (T)0); // clear value*wp
__sync_move();
// (n, nl, np, c) => (c, n, nl, np)
Expand Down Expand Up @@ -332,7 +332,7 @@ __mlu_func__ void backwardStageTwoLoop(
int32_t* dst_offset = (int32_t*)offset_zero_nram_stg2;
for (int i = 0; i < 4; i++) {
__bang_filter((T*)dst_offset + i * nq_nl_np,
(T*)offset_nram + i * nq_nl_np, cond_all_valid, nq_nl_np);
(T*)offset_nram + i * nq_nl_np, cond_all_valid, nq_nl_np);
}
int32_t* src_offset = (int32_t*)inter_grad;
int32_t* stride_4_2 = dst_offset + 3 * nq_nl_np;
Expand Down Expand Up @@ -368,7 +368,7 @@ __mlu_func__ void backwardStageTwoLoop(
int32_t valid_count = __bang_sum(tmp_cond, nq_nl_np);
if (valid_count > 0) {
__bang_filter((T*)tmp_dst_offset, (T*)tmp_dst_offset, tmp_cond,
nq_nl_np);
nq_nl_np);
__bang_filter((T*)tmp_src_offset, (T*)seq_nram, tmp_cond, nq_nl_np);
__bang_mul_scalar(tmp_src_offset, tmp_src_offset, channels * sizeof(T),
valid_count);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "core/logging.h"
#include "kernels/kernel.h"
#include "kernels/utils/common.h"
#include "kernels/utils/scatter_gather.h"

__nram__ int8_t nram_buffer[MAX_NRAM_SIZE];

Expand Down Expand Up @@ -313,24 +314,25 @@ void __mlu_func__ loadValue(
sizeof(int32_t), 4 * num_deal_grid);
__sync_io_move_compute();

__gather_async((void *)nram_grad_output_tl, (void *)data_value,
(unsigned int *)grad_temp3, deal_num_real * sizeof(float),
GDRAM2NRAM, deal_num_real * sizeof(float), num_deal_grid);

__gather_async((void *)nram_grad_output_tr, (void *)data_value,
(unsigned int *)(grad_temp3 + num_deal_grid),
deal_num_real * sizeof(float), GDRAM2NRAM,
deal_num_real * sizeof(float), num_deal_grid);

__gather_async((void *)nram_grad_output_bl, (void *)data_value,
(unsigned int *)(grad_temp3 + 2 * num_deal_grid),
deal_num_real * sizeof(float), GDRAM2NRAM,
deal_num_real * sizeof(float), num_deal_grid);

__gather_async((void *)nram_grad_output_br, (void *)data_value,
(unsigned int *)(grad_temp3 + 3 * num_deal_grid),
deal_num_real * sizeof(float), GDRAM2NRAM,
deal_num_real * sizeof(float), num_deal_grid);
__mluop_gather<float>((float *)nram_grad_output_tl, (float *)data_value,
(unsigned int *)grad_temp3, NULL,
deal_num_real * sizeof(float), GDRAM2NRAM,
deal_num_real * sizeof(float), num_deal_grid);

__mluop_gather<float>((float *)nram_grad_output_tr, (float *)data_value,
(unsigned int *)(grad_temp3 + num_deal_grid), NULL,
deal_num_real * sizeof(float), GDRAM2NRAM,
deal_num_real * sizeof(float), num_deal_grid);

__mluop_gather<float>((float *)nram_grad_output_bl, (float *)data_value,
(unsigned int *)(grad_temp3 + 2 * num_deal_grid), NULL,
deal_num_real * sizeof(float), GDRAM2NRAM,
deal_num_real * sizeof(float), num_deal_grid);

__mluop_gather<float>((float *)nram_grad_output_br, (float *)data_value,
(unsigned int *)(grad_temp3 + 3 * num_deal_grid), NULL,
deal_num_real * sizeof(float), GDRAM2NRAM,
deal_num_real * sizeof(float), num_deal_grid);
__sync_io_move_compute();

#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include "kernels/kernel.h"
#include "kernels/utils/common.h"
#include "kernels/utils/scatter_gather.h"

#define BIT_COLLECT_PAD (8)
#define BACKWARD_MAX_NQ_NL_NP (1024)
Expand Down Expand Up @@ -377,19 +378,12 @@ __mlu_func__ void stageOneLoop(
#endif

#if (__BANG_ARCH__ >= 592)
template <typename T>
__mlu_func__ void gatherAsync(void* dst, void* src, unsigned int* offset,
void* mask, int transfer_size,
mluMemcpyDirection_t dir, int dst_stride,
int transfer_num) {
__gather_async(dst, src, offset, mask, transfer_size, dir, dst_stride,
transfer_num);
}

__mlu_func__ void gatherSync(void* dst, void* src, unsigned int* offset,
void* mask, int transfer_size,
mluMemcpyDirection_t dir, int dst_stride,
int transfer_num) {
__gather(dst, src, offset, mask, transfer_size, dir, dst_stride,
transfer_num);
__mluop_gather_async<T>((T*)dst, (T*)src, offset, (uint8_t*)mask,
transfer_size, dir, dst_stride, transfer_num);
}
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ __mlu_func__ void getConditionCoordWeight(
}
__bang_mul_scalar(buf_nram, weight_attn_nram, (T)1, total_points);
__bang_filter((float*)weight_attn_nram, (float*)buf_nram,
cond_point_valid_nram, total_points);
cond_point_valid_nram, total_points);
__bang_float2int32((int32_t*)cond_point_polation_nram,
cond_point_polation_nram, total_points * 4, 0);
__bang_mul_scalar((int32_t*)cond_point_polation_nram,
Expand All @@ -300,16 +300,16 @@ __mlu_func__ void getConditionCoordWeight(
(int8_t*)cond_point_polation_nram,
total_points * 4 * sizeof(float));
__bang_filter((float*)weight_polation_nram, (float*)weight_polation_nram_tmp,
cond_point_valid_nram, total_points);
cond_point_valid_nram, total_points);
__bang_filter((float*)weight_polation_nram + total_points,
(float*)weight_polation_nram_tmp + total_points,
cond_point_valid_nram, total_points);
(float*)weight_polation_nram_tmp + total_points,
cond_point_valid_nram, total_points);
__bang_filter((float*)weight_polation_nram + 2 * total_points,
(float*)weight_polation_nram_tmp + 2 * total_points,
cond_point_valid_nram, total_points);
(float*)weight_polation_nram_tmp + 2 * total_points,
cond_point_valid_nram, total_points);
__bang_filter((float*)weight_polation_nram + 3 * total_points,
(float*)weight_polation_nram_tmp + 3 * total_points,
cond_point_valid_nram, total_points);
(float*)weight_polation_nram_tmp + 3 * total_points,
cond_point_valid_nram, total_points);
//================================================================================================
// select cond_point_polation_nram if value_contain_infnan
if (value_contain_infnan) {
Expand All @@ -318,17 +318,17 @@ __mlu_func__ void getConditionCoordWeight(
(int32_t*)cond_point_polation_nram, (int32_t)1,
total_points * 4);
__bang_filter((float*)cond_point_polation_nram,
(float*)cond_point_polation_nram_tmp, cond_point_valid_nram,
total_points);
(float*)cond_point_polation_nram_tmp, cond_point_valid_nram,
total_points);
__bang_filter((float*)cond_point_polation_nram + total_points,
(float*)cond_point_polation_nram_tmp + total_points,
cond_point_valid_nram, total_points);
(float*)cond_point_polation_nram_tmp + total_points,
cond_point_valid_nram, total_points);
__bang_filter((float*)cond_point_polation_nram + 2 * total_points,
(float*)cond_point_polation_nram_tmp + 2 * total_points,
cond_point_valid_nram, total_points);
(float*)cond_point_polation_nram_tmp + 2 * total_points,
cond_point_valid_nram, total_points);
__bang_filter((float*)cond_point_polation_nram + 3 * total_points,
(float*)cond_point_polation_nram_tmp + 3 * total_points,
cond_point_valid_nram, total_points);
(float*)cond_point_polation_nram_tmp + 3 * total_points,
cond_point_valid_nram, total_points);
}
//================================================================================================
// compute and select offset and stride
Expand All @@ -348,11 +348,11 @@ __mlu_func__ void getConditionCoordWeight(
(int32_t*)data_offset_nram_tr_tmp,
(int32_t*)data_offset_nram_tl_tmp, total_points);
__bang_filter((float*)data_offset_nram_tl, (float*)data_offset_nram_tl_tmp,
cond_point_valid_nram, total_points);
cond_point_valid_nram, total_points);
__bang_filter((float*)data_offset_nram_bl, (float*)data_offset_nram_bl_tmp,
cond_point_valid_nram, total_points);
cond_point_valid_nram, total_points);
__bang_filter((float*)data_offset_nram_tr, (float*)data_offset_nram_tr_tmp,
cond_point_valid_nram, total_points);
cond_point_valid_nram, total_points);
}

/*
Expand Down Expand Up @@ -1068,12 +1068,12 @@ __mlu_func__ void forwardStageTwoLoop(
__sync_io_move_compute();

if (i < loop_num) {
gatherAsync(v_load, zeros_nram, (unsigned int*)offset_zero_nram_stg2,
cond_nram_stg2_reverse, channels * sizeof(T), NRAM2NRAM,
channels * sizeof(T), load_point_num);
gatherAsync(v_load, data_value_gdram, (unsigned int*)offset_nram_stg2,
cond_nram_stg2, channels * sizeof(T), GDRAM2NRAM,
channels * sizeof(T), load_point_num);
gatherAsync<T>(v_load, zeros_nram, (unsigned int*)offset_zero_nram_stg2,
cond_nram_stg2_reverse, channels * sizeof(T), NRAM2NRAM,
channels * sizeof(T), load_point_num);
gatherAsync<T>(v_load, data_value_gdram, (unsigned int*)offset_nram_stg2,
cond_nram_stg2, channels * sizeof(T), GDRAM2NRAM,
channels * sizeof(T), load_point_num);
}

if (i > 0) {
Expand Down
10 changes: 4 additions & 6 deletions kernels/roi_align_rotated/roi_align_rotated_forward_vector.mlu
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,12 @@ __mlu_func__ void mluopDivScalar(T *dst, T *src, T value, uint32_t num) {
__asm__ volatile(
"div.scalar.nram.f16 [%[dst]], [%[src0]], "
"%[src1], %[num];\n\t" ::[dst] "r"(dst),
[ src0 ] "r"(src), [ src1 ] "r"(value),
[ num ] "r"(num));
[ src0 ] "r"(src), [ src1 ] "r"(value), [ num ] "r"(num));
} else {
__asm__ volatile(
"div.scalar.nram.f32 [%[dst]], [%[src0]], "
"%[src1], %[num];\n\t" ::[dst] "r"(dst),
[ src0 ] "r"(src), [ src1 ] "r"(value),
[ num ] "r"(num));
[ src0 ] "r"(src), [ src1 ] "r"(value), [ num ] "r"(num));
}
}

Expand Down Expand Up @@ -314,6 +312,7 @@ __mlu_func__ void handleChannels(const T *input, uint32_t deal_channels,
}
uint32_t hwc_num = deal_channels * vec_num;

// vec_num <= 1024
__gather(val, input, pos, deal_channels * sizeof(T), GDRAM2NRAM,
deal_channels * sizeof(T), vec_num);
if (deal_channels != 1) {
Expand Down Expand Up @@ -521,8 +520,7 @@ __mlu_global__ void roiAlignRotatedForward(
}
}
}
mluopDivScalar(output_channels, output_channels, (T)count,
cur_cache_c);
mluopDivScalar(output_channels, output_channels, (T)count, cur_cache_c);
__memcpy(output_dram + bin_i * channels + c_cache_i, output_channels,
cur_cache_c * sizeof(T), NRAM2GDRAM);
}
Expand Down
Loading