From 4c2a685712b84c18489e99d01d4b61166ec2fa3e Mon Sep 17 00:00:00 2001 From: Chengwei Dong Date: Wed, 11 Oct 2023 16:36:45 +0800 Subject: [PATCH] =?UTF-8?q?[Feature](bangc-ops):=20Fix=20bugs=20of=20ms=5F?= =?UTF-8?q?deform=5Fattn=5Fforward=20and=20align=20na=E2=80=A6=20(#852)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: dongchengwei --- .../msda_forward_small_channel_union1.mlu | 57 ++++++++++++++++++- 1 file changed, 54 insertions(+), 3 deletions(-) diff --git a/bangc-ops/kernels/ms_deform_attn_forward/msda_forward_small_channel_union1.mlu b/bangc-ops/kernels/ms_deform_attn_forward/msda_forward_small_channel_union1.mlu index 3185728d4..e30899484 100644 --- a/bangc-ops/kernels/ms_deform_attn_forward/msda_forward_small_channel_union1.mlu +++ b/bangc-ops/kernels/ms_deform_attn_forward/msda_forward_small_channel_union1.mlu @@ -35,8 +35,10 @@ __mlu_func__ void genMask0101(float *mask_ram, int32_t size) { mask_ram[i] = i % 2; } __sync(); + // NOTE: when channel is 1, mask_ram may be overwritten, since we + // align size to CEIL_ALIGN(size, align_num) __memcpy(mask_ram + align_num, mask_ram, NFU_ALIGN_SIZE, NRAM2NRAM, - NFU_ALIGN_SIZE, 0, size / align_num - 2); + NFU_ALIGN_SIZE, 0, (size / align_num + (size % align_num > 0)) - 2); __sync(); #endif } @@ -72,7 +74,12 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel( deal_g = PAD_UP(grid_total / skip_n, num_levels * num_points); size_t id = taskId % skip_n; offset_g = id * deal_g; - deal_g = id < (skip_n - 1) ? deal_g : grid_total - deal_g * (skip_n - 1); + deal_g = offset_g > grid_total ? + 0 : ((id + 1) * deal_g > grid_total ? + deal_g = grid_total - offset_g : deal_g); + } + if (deal_g == 0) { + return; } const int32_t float_align = NFU_ALIGN_SIZE / sizeof(float); int32_t deal_num = 1; @@ -85,7 +92,7 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel( int32_t mult; while (true) { deal_num = (MAX_NRAM_SIZE - spatial_size - level_start_index_size) / - (8 * channel + 7) / sizeof(T); + (8 * channel + 8) / sizeof(T); deal_num = PAD_DOWN(deal_num, float_align); deal_num = PAD_DOWN(deal_num, num_levels * num_points); if (deal_num > 0) { @@ -118,6 +125,7 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel( char *point_ram = mask_br + deal_num * sizeof(T); char *index_tl = point_ram + deal_num * sizeof(T); char *index_bl = index_tl + deal_num * sizeof(T); + char *valid_mask = index_bl + deal_num * sizeof(T); // nram space reuse char *grid_ram = weight_tl; char *mask_ram = weight_bl; @@ -204,6 +212,31 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel( num_levels * num_points); __bang_sub_scalar((float *)coord_y, (float *)coord_y, (float)0.5, deal_num); + // generate valid mask, which means the location is nan/inf or not + // condition coordx > -1 / coordy > -1 + __bang_gt_scalar((float *)auxiliary_a, (float *)coord_x, -1.0, deal_num); + __bang_move((char *)valid_mask, (char *)auxiliary_a, + deal_num * sizeof(float)); + __bang_gt_scalar((float *)auxiliary_a, (float *)coord_y, -1.0, deal_num); + __bang_add((float *)valid_mask, (float *)valid_mask, + (float *)auxiliary_a, deal_num); + + // condition coordx < spatial_x / coordy < spatial_y + __bang_cycle_le((float *)mask_bl, (float *)coord_x, + (float *)spatial_x_float, + deal_num, num_levels * num_points); + __bang_cycle_le((float *)mask_br, (float *)coord_y, + (float *)spatial_y_float, + deal_num, num_levels * num_points); + + __bang_add((float *)mask_bl, (float *)mask_bl, + (float *)mask_br, deal_num); + __bang_add((float *)valid_mask, (float *)valid_mask, + (float *)mask_bl, deal_num); + // all condition satisfied, value should be 4. + __bang_eq_scalar((float *)valid_mask, (float *)valid_mask, 4, deal_num); + + // get floor value of coord __bang_floor((float *)coord_x_low, (float *)coord_x, deal_num); __bang_floor((float *)coord_y_low, (float *)coord_y, deal_num); // calc index_tl @@ -301,6 +334,20 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel( // 0 <= coord_x_high < spatial_x && 0 <= coord_y_high < spatial_y __bang_and((float *)mask_br, (float *)mask_br, (float *)auxiliary_a, deal_num); + // if loc has nan/inf, fill invalid value with 0. + // Note, althrough nan joins the compatution, the comparison returns + // normal value. + __bang_cycle_and((float *)mask_tl, (float *)mask_tl, + (float *)valid_mask, 4 * deal_num, deal_num); + + // switch valid_mask to bit-type mask. 1 to 0xffffffff, 0 to 0x00000000 + // first we cast float32 to int32. then multiply -1, + // whose hex is 0xffffffff + __bang_float2int32_rd((int32_t *)valid_mask, (float *)valid_mask, + deal_num, 0); + __bang_mul_scalar((int32_t *)valid_mask, (int32_t *)valid_mask, + -1, deal_num); + // calc inner point num __bang_mul_scalar((float *)weight_tl, (float *)mask_tl, (float)7.0, deal_num); @@ -335,6 +382,10 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel( (float *)weight_tr, deal_num); __bang_mul((float *)input_tl + 3 * deal_num, (float *)weight_tl, (float *)weight_tr, deal_num); + // if loc has nan/inf, fill all invalid potision with 0. + // Note that this operation handles in bit-scale. + __bang_cycle_band((char *)input_tl, (char *)input_tl, (char *)valid_mask, + 4 * deal_num * sizeof(float), deal_num * sizeof(float)); __sync(); // extend weight const int32_t w_rep = channel / ELE_COUNT * ELE_COUNT;