diff --git a/bangc-ops/kernels/ms_deform_attn_forward/ms_deform_attn_forward.h b/bangc-ops/kernels/ms_deform_attn_forward/ms_deform_attn_forward.h index 942601345a..f447f6b7e6 100644 --- a/bangc-ops/kernels/ms_deform_attn_forward/ms_deform_attn_forward.h +++ b/bangc-ops/kernels/ms_deform_attn_forward/ms_deform_attn_forward.h @@ -27,7 +27,7 @@ #include "mlu_op.h" #define MIN(a, b) (((a) < (b)) ? (a) : (b)) -#define MS_DEFORM_ATTN_FORWARD_HEADVECTOR 1 +#define MS_DEFORM_ATTN_FORWARD_HEADVECTOR 0 template __mlu_global__ void MLUKernelMsDeformAttnForwardDefault( diff --git a/bangc-ops/kernels/ms_deform_attn_forward/msda_forward_small_channel_union1.mlu b/bangc-ops/kernels/ms_deform_attn_forward/msda_forward_small_channel_union1.mlu index 3185728d44..e30899484a 100644 --- a/bangc-ops/kernels/ms_deform_attn_forward/msda_forward_small_channel_union1.mlu +++ b/bangc-ops/kernels/ms_deform_attn_forward/msda_forward_small_channel_union1.mlu @@ -35,8 +35,10 @@ __mlu_func__ void genMask0101(float *mask_ram, int32_t size) { mask_ram[i] = i % 2; } __sync(); + // NOTE: when channel is 1, mask_ram may be overwritten, since we + // align size to CEIL_ALIGN(size, align_num) __memcpy(mask_ram + align_num, mask_ram, NFU_ALIGN_SIZE, NRAM2NRAM, - NFU_ALIGN_SIZE, 0, size / align_num - 2); + NFU_ALIGN_SIZE, 0, (size / align_num + (size % align_num > 0)) - 2); __sync(); #endif } @@ -72,7 +74,12 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel( deal_g = PAD_UP(grid_total / skip_n, num_levels * num_points); size_t id = taskId % skip_n; offset_g = id * deal_g; - deal_g = id < (skip_n - 1) ? deal_g : grid_total - deal_g * (skip_n - 1); + deal_g = offset_g > grid_total ? + 0 : ((id + 1) * deal_g > grid_total ? + deal_g = grid_total - offset_g : deal_g); + } + if (deal_g == 0) { + return; } const int32_t float_align = NFU_ALIGN_SIZE / sizeof(float); int32_t deal_num = 1; @@ -85,7 +92,7 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel( int32_t mult; while (true) { deal_num = (MAX_NRAM_SIZE - spatial_size - level_start_index_size) / - (8 * channel + 7) / sizeof(T); + (8 * channel + 8) / sizeof(T); deal_num = PAD_DOWN(deal_num, float_align); deal_num = PAD_DOWN(deal_num, num_levels * num_points); if (deal_num > 0) { @@ -118,6 +125,7 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel( char *point_ram = mask_br + deal_num * sizeof(T); char *index_tl = point_ram + deal_num * sizeof(T); char *index_bl = index_tl + deal_num * sizeof(T); + char *valid_mask = index_bl + deal_num * sizeof(T); // nram space reuse char *grid_ram = weight_tl; char *mask_ram = weight_bl; @@ -204,6 +212,31 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel( num_levels * num_points); __bang_sub_scalar((float *)coord_y, (float *)coord_y, (float)0.5, deal_num); + // generate valid mask, which means the location is nan/inf or not + // condition coordx > -1 / coordy > -1 + __bang_gt_scalar((float *)auxiliary_a, (float *)coord_x, -1.0, deal_num); + __bang_move((char *)valid_mask, (char *)auxiliary_a, + deal_num * sizeof(float)); + __bang_gt_scalar((float *)auxiliary_a, (float *)coord_y, -1.0, deal_num); + __bang_add((float *)valid_mask, (float *)valid_mask, + (float *)auxiliary_a, deal_num); + + // condition coordx < spatial_x / coordy < spatial_y + __bang_cycle_le((float *)mask_bl, (float *)coord_x, + (float *)spatial_x_float, + deal_num, num_levels * num_points); + __bang_cycle_le((float *)mask_br, (float *)coord_y, + (float *)spatial_y_float, + deal_num, num_levels * num_points); + + __bang_add((float *)mask_bl, (float *)mask_bl, + (float *)mask_br, deal_num); + __bang_add((float *)valid_mask, (float *)valid_mask, + (float *)mask_bl, deal_num); + // all condition satisfied, value should be 4. + __bang_eq_scalar((float *)valid_mask, (float *)valid_mask, 4, deal_num); + + // get floor value of coord __bang_floor((float *)coord_x_low, (float *)coord_x, deal_num); __bang_floor((float *)coord_y_low, (float *)coord_y, deal_num); // calc index_tl @@ -301,6 +334,20 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel( // 0 <= coord_x_high < spatial_x && 0 <= coord_y_high < spatial_y __bang_and((float *)mask_br, (float *)mask_br, (float *)auxiliary_a, deal_num); + // if loc has nan/inf, fill invalid value with 0. + // Note, althrough nan joins the compatution, the comparison returns + // normal value. + __bang_cycle_and((float *)mask_tl, (float *)mask_tl, + (float *)valid_mask, 4 * deal_num, deal_num); + + // switch valid_mask to bit-type mask. 1 to 0xffffffff, 0 to 0x00000000 + // first we cast float32 to int32. then multiply -1, + // whose hex is 0xffffffff + __bang_float2int32_rd((int32_t *)valid_mask, (float *)valid_mask, + deal_num, 0); + __bang_mul_scalar((int32_t *)valid_mask, (int32_t *)valid_mask, + -1, deal_num); + // calc inner point num __bang_mul_scalar((float *)weight_tl, (float *)mask_tl, (float)7.0, deal_num); @@ -335,6 +382,10 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel( (float *)weight_tr, deal_num); __bang_mul((float *)input_tl + 3 * deal_num, (float *)weight_tl, (float *)weight_tr, deal_num); + // if loc has nan/inf, fill all invalid potision with 0. + // Note that this operation handles in bit-scale. + __bang_cycle_band((char *)input_tl, (char *)input_tl, (char *)valid_mask, + 4 * deal_num * sizeof(float), deal_num * sizeof(float)); __sync(); // extend weight const int32_t w_rep = channel / ELE_COUNT * ELE_COUNT;