Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature](bangc-ops): Fix bugs of ms_deform_attn_forward and align na… #852

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ __mlu_func__ void genMask0101(float *mask_ram, int32_t size) {
mask_ram[i] = i % 2;
}
__sync();
// NOTE: when channel is 1, mask_ram may be overwritten, since we
// align size to CEIL_ALIGN(size, align_num)
__memcpy(mask_ram + align_num, mask_ram, NFU_ALIGN_SIZE, NRAM2NRAM,
NFU_ALIGN_SIZE, 0, size / align_num - 2);
NFU_ALIGN_SIZE, 0, (size / align_num + (size % align_num > 0)) - 2);
__sync();
#endif
}
Expand Down Expand Up @@ -72,7 +74,12 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel(
deal_g = PAD_UP(grid_total / skip_n, num_levels * num_points);
size_t id = taskId % skip_n;
offset_g = id * deal_g;
deal_g = id < (skip_n - 1) ? deal_g : grid_total - deal_g * (skip_n - 1);
deal_g = offset_g > grid_total ?
0 : ((id + 1) * deal_g > grid_total ?
deal_g = grid_total - offset_g : deal_g);
}
if (deal_g == 0) {
return;
}
const int32_t float_align = NFU_ALIGN_SIZE / sizeof(float);
int32_t deal_num = 1;
Expand All @@ -85,7 +92,7 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel(
int32_t mult;
while (true) {
deal_num = (MAX_NRAM_SIZE - spatial_size - level_start_index_size) /
(8 * channel + 7) / sizeof(T);
(8 * channel + 8) / sizeof(T);
deal_num = PAD_DOWN(deal_num, float_align);
deal_num = PAD_DOWN(deal_num, num_levels * num_points);
if (deal_num > 0) {
Expand Down Expand Up @@ -118,6 +125,7 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel(
char *point_ram = mask_br + deal_num * sizeof(T);
char *index_tl = point_ram + deal_num * sizeof(T);
char *index_bl = index_tl + deal_num * sizeof(T);
char *valid_mask = index_bl + deal_num * sizeof(T);
// nram space reuse
char *grid_ram = weight_tl;
char *mask_ram = weight_bl;
Expand Down Expand Up @@ -204,6 +212,31 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel(
num_levels * num_points);
__bang_sub_scalar((float *)coord_y, (float *)coord_y, (float)0.5,
deal_num);
// generate valid mask, which means the location is nan/inf or not
// condition coordx > -1 / coordy > -1
__bang_gt_scalar((float *)auxiliary_a, (float *)coord_x, -1.0, deal_num);
__bang_move((char *)valid_mask, (char *)auxiliary_a,
deal_num * sizeof(float));
__bang_gt_scalar((float *)auxiliary_a, (float *)coord_y, -1.0, deal_num);
__bang_add((float *)valid_mask, (float *)valid_mask,
(float *)auxiliary_a, deal_num);

// condition coordx < spatial_x / coordy < spatial_y
__bang_cycle_le((float *)mask_bl, (float *)coord_x,
(float *)spatial_x_float,
deal_num, num_levels * num_points);
__bang_cycle_le((float *)mask_br, (float *)coord_y,
(float *)spatial_y_float,
deal_num, num_levels * num_points);

__bang_add((float *)mask_bl, (float *)mask_bl,
(float *)mask_br, deal_num);
__bang_add((float *)valid_mask, (float *)valid_mask,
(float *)mask_bl, deal_num);
// all condition satisfied, value should be 4.
__bang_eq_scalar((float *)valid_mask, (float *)valid_mask, 4, deal_num);

// get floor value of coord
__bang_floor((float *)coord_x_low, (float *)coord_x, deal_num);
__bang_floor((float *)coord_y_low, (float *)coord_y, deal_num);
// calc index_tl
Expand Down Expand Up @@ -301,6 +334,20 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel(
// 0 <= coord_x_high < spatial_x && 0 <= coord_y_high < spatial_y
__bang_and((float *)mask_br, (float *)mask_br, (float *)auxiliary_a,
deal_num);
// if loc has nan/inf, fill invalid value with 0.
// Note, althrough nan joins the compatution, the comparison returns
// normal value.
__bang_cycle_and((float *)mask_tl, (float *)mask_tl,
(float *)valid_mask, 4 * deal_num, deal_num);

// switch valid_mask to bit-type mask. 1 to 0xffffffff, 0 to 0x00000000
// first we cast float32 to int32. then multiply -1,
// whose hex is 0xffffffff
__bang_float2int32_rd((int32_t *)valid_mask, (float *)valid_mask,
deal_num, 0);
__bang_mul_scalar((int32_t *)valid_mask, (int32_t *)valid_mask,
-1, deal_num);

// calc inner point num
__bang_mul_scalar((float *)weight_tl, (float *)mask_tl, (float)7.0,
deal_num);
Expand Down Expand Up @@ -335,6 +382,10 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel(
(float *)weight_tr, deal_num);
__bang_mul((float *)input_tl + 3 * deal_num, (float *)weight_tl,
(float *)weight_tr, deal_num);
// if loc has nan/inf, fill all invalid potision with 0.
// Note that this operation handles in bit-scale.
__bang_cycle_band((char *)input_tl, (char *)input_tl, (char *)valid_mask,
4 * deal_num * sizeof(float), deal_num * sizeof(float));
__sync();
// extend weight
const int32_t w_rep = channel / ELE_COUNT * ELE_COUNT;
Expand Down