Skip to content

Commit

Permalink
[Feature](bangc-ops): Fix bugs of ms_deform_attn_forward and align na… (
Browse files Browse the repository at this point in the history
#852)

Co-authored-by: dongchengwei <[email protected]>
  • Loading branch information
Unireverse and dongchengwei authored Oct 11, 2023
1 parent 0d3a1be commit 4c2a685
Showing 1 changed file with 54 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ __mlu_func__ void genMask0101(float *mask_ram, int32_t size) {
mask_ram[i] = i % 2;
}
__sync();
// NOTE: when channel is 1, mask_ram may be overwritten, since we
// align size to CEIL_ALIGN(size, align_num)
__memcpy(mask_ram + align_num, mask_ram, NFU_ALIGN_SIZE, NRAM2NRAM,
NFU_ALIGN_SIZE, 0, size / align_num - 2);
NFU_ALIGN_SIZE, 0, (size / align_num + (size % align_num > 0)) - 2);
__sync();
#endif
}
Expand Down Expand Up @@ -72,7 +74,12 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel(
deal_g = PAD_UP(grid_total / skip_n, num_levels * num_points);
size_t id = taskId % skip_n;
offset_g = id * deal_g;
deal_g = id < (skip_n - 1) ? deal_g : grid_total - deal_g * (skip_n - 1);
deal_g = offset_g > grid_total ?
0 : ((id + 1) * deal_g > grid_total ?
deal_g = grid_total - offset_g : deal_g);
}
if (deal_g == 0) {
return;
}
const int32_t float_align = NFU_ALIGN_SIZE / sizeof(float);
int32_t deal_num = 1;
Expand All @@ -85,7 +92,7 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel(
int32_t mult;
while (true) {
deal_num = (MAX_NRAM_SIZE - spatial_size - level_start_index_size) /
(8 * channel + 7) / sizeof(T);
(8 * channel + 8) / sizeof(T);
deal_num = PAD_DOWN(deal_num, float_align);
deal_num = PAD_DOWN(deal_num, num_levels * num_points);
if (deal_num > 0) {
Expand Down Expand Up @@ -118,6 +125,7 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel(
char *point_ram = mask_br + deal_num * sizeof(T);
char *index_tl = point_ram + deal_num * sizeof(T);
char *index_bl = index_tl + deal_num * sizeof(T);
char *valid_mask = index_bl + deal_num * sizeof(T);
// nram space reuse
char *grid_ram = weight_tl;
char *mask_ram = weight_bl;
Expand Down Expand Up @@ -204,6 +212,31 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel(
num_levels * num_points);
__bang_sub_scalar((float *)coord_y, (float *)coord_y, (float)0.5,
deal_num);
// generate valid mask, which means the location is nan/inf or not
// condition coordx > -1 / coordy > -1
__bang_gt_scalar((float *)auxiliary_a, (float *)coord_x, -1.0, deal_num);
__bang_move((char *)valid_mask, (char *)auxiliary_a,
deal_num * sizeof(float));
__bang_gt_scalar((float *)auxiliary_a, (float *)coord_y, -1.0, deal_num);
__bang_add((float *)valid_mask, (float *)valid_mask,
(float *)auxiliary_a, deal_num);

// condition coordx < spatial_x / coordy < spatial_y
__bang_cycle_le((float *)mask_bl, (float *)coord_x,
(float *)spatial_x_float,
deal_num, num_levels * num_points);
__bang_cycle_le((float *)mask_br, (float *)coord_y,
(float *)spatial_y_float,
deal_num, num_levels * num_points);

__bang_add((float *)mask_bl, (float *)mask_bl,
(float *)mask_br, deal_num);
__bang_add((float *)valid_mask, (float *)valid_mask,
(float *)mask_bl, deal_num);
// all condition satisfied, value should be 4.
__bang_eq_scalar((float *)valid_mask, (float *)valid_mask, 4, deal_num);

// get floor value of coord
__bang_floor((float *)coord_x_low, (float *)coord_x, deal_num);
__bang_floor((float *)coord_y_low, (float *)coord_y, deal_num);
// calc index_tl
Expand Down Expand Up @@ -301,6 +334,20 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel(
// 0 <= coord_x_high < spatial_x && 0 <= coord_y_high < spatial_y
__bang_and((float *)mask_br, (float *)mask_br, (float *)auxiliary_a,
deal_num);
// if loc has nan/inf, fill invalid value with 0.
// Note, althrough nan joins the compatution, the comparison returns
// normal value.
__bang_cycle_and((float *)mask_tl, (float *)mask_tl,
(float *)valid_mask, 4 * deal_num, deal_num);

// switch valid_mask to bit-type mask. 1 to 0xffffffff, 0 to 0x00000000
// first we cast float32 to int32. then multiply -1,
// whose hex is 0xffffffff
__bang_float2int32_rd((int32_t *)valid_mask, (float *)valid_mask,
deal_num, 0);
__bang_mul_scalar((int32_t *)valid_mask, (int32_t *)valid_mask,
-1, deal_num);

// calc inner point num
__bang_mul_scalar((float *)weight_tl, (float *)mask_tl, (float)7.0,
deal_num);
Expand Down Expand Up @@ -335,6 +382,10 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel(
(float *)weight_tr, deal_num);
__bang_mul((float *)input_tl + 3 * deal_num, (float *)weight_tl,
(float *)weight_tr, deal_num);
// if loc has nan/inf, fill all invalid potision with 0.
// Note that this operation handles in bit-scale.
__bang_cycle_band((char *)input_tl, (char *)input_tl, (char *)valid_mask,
4 * deal_num * sizeof(float), deal_num * sizeof(float));
__sync();
// extend weight
const int32_t w_rep = channel / ELE_COUNT * ELE_COUNT;
Expand Down

0 comments on commit 4c2a685

Please sign in to comment.