From 5445c9def0090738377d72d87cb310bdc949cd48 Mon Sep 17 00:00:00 2001 From: mahxn0 <1262384588@qq.com> Date: Tue, 3 Dec 2024 18:52:12 +0800 Subject: [PATCH] [Fix](mlu-ops): modify common func. (#1167) --- kernels/kernel.h | 2 +- .../ms_deform_attn_backward_fast_union1.mlu | 4 ++-- .../ms_deform_attn_forward/ms_deform_attn_utils.h | 2 +- .../ms_deform_attn_forward/msda_forward_fast_union1.mlu | 4 ++-- kernels/sparse_conv/get_indice_pairs/get_indice_pairs_utils.h | 2 +- kernels/utils/common.h | 2 +- .../pb_gtest/src/internal_kernel/fill_llc/fill_llc_device.mlu | 2 +- 7 files changed, 9 insertions(+), 9 deletions(-) diff --git a/kernels/kernel.h b/kernels/kernel.h index d1a6e96fb..9378f0839 100644 --- a/kernels/kernel.h +++ b/kernels/kernel.h @@ -31,7 +31,7 @@  * Macros for mluop kernels  ******************************************************************************/ // in future, can be "__BANG_ARCH__ == 592 || __BANG_ARCH__ == xxx || ...)" -#define ARCH_SUPPORT_LARGE_TENSOR (__BANG_ARCH__ == 592) +#define ARCH_SUPPORT_LARGE_TENSOR (__BANG_ARCH__ >= 592) #define MAX_WRAM_SIZE (__MLU_WRAM_SIZE__ * 1024) #define WRAM_LT_STRIDE (__MLU_WRAM_SIZE__ * 1024 / 64) diff --git a/kernels/ms_deform_attn/ms_deform_attn_backward/ms_deform_attn_backward_fast_union1.mlu b/kernels/ms_deform_attn/ms_deform_attn_backward/ms_deform_attn_backward_fast_union1.mlu index 21ee0b40d..b72087481 100644 --- a/kernels/ms_deform_attn/ms_deform_attn_backward/ms_deform_attn_backward_fast_union1.mlu +++ b/kernels/ms_deform_attn/ms_deform_attn_backward/ms_deform_attn_backward_fast_union1.mlu @@ -26,7 +26,7 @@ #include "core/logging.h" -#if (__BANG_ARCH__ == 592) +#if (__BANG_ARCH__ >= 592) #define MAX_MEMCPY_SEGNUM (65536) #define NRAM_REMAIN_SIZE (48 * 1024) @@ -454,7 +454,7 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackwardFastKernel( const int32_t channels, const int32_t num_levels, const int32_t num_query, const int32_t num_points, float* grad_value, float* grad_sampling_loc, float* grad_attn_weight) { -#if (__BANG_ARCH__ == 592) +#if (__BANG_ARCH__ >= 592) using T = float; const int32_t num_keys = spatial_size; const int32_t input_stride_4 = diff --git a/kernels/ms_deform_attn/ms_deform_attn_forward/ms_deform_attn_utils.h b/kernels/ms_deform_attn/ms_deform_attn_forward/ms_deform_attn_utils.h index 122d2d35e..0f4b4dd17 100644 --- a/kernels/ms_deform_attn/ms_deform_attn_forward/ms_deform_attn_utils.h +++ b/kernels/ms_deform_attn/ms_deform_attn_forward/ms_deform_attn_utils.h @@ -376,7 +376,7 @@ __mlu_func__ void stageOneLoop( } #endif -#if (__BANG_ARCH__ == 592) +#if (__BANG_ARCH__ >= 592) __mlu_func__ void gatherAsync(void* dst, void* src, unsigned int* offset, void* mask, int transfer_size, mluMemcpyDirection_t dir, int dst_stride, diff --git a/kernels/ms_deform_attn/ms_deform_attn_forward/msda_forward_fast_union1.mlu b/kernels/ms_deform_attn/ms_deform_attn_forward/msda_forward_fast_union1.mlu index b21af0a0e..2d29981e2 100644 --- a/kernels/ms_deform_attn/ms_deform_attn_forward/msda_forward_fast_union1.mlu +++ b/kernels/ms_deform_attn/ms_deform_attn_forward/msda_forward_fast_union1.mlu @@ -906,7 +906,7 @@ __mlu_func__ void MLUKernelMsDeformAttnForwardFastImpl( } } -#if (__BANG_ARCH__ == 592) +#if (__BANG_ARCH__ >= 592) /* The shape of each tensor on nram: @@ -1260,7 +1260,7 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardFast( } #endif -#if (__BANG_ARCH__ == 592) +#if (__BANG_ARCH__ >= 592) MLUKernelMsDeformAttnForwardFastImpl( data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram, data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys, diff --git a/kernels/sparse_conv/get_indice_pairs/get_indice_pairs_utils.h b/kernels/sparse_conv/get_indice_pairs/get_indice_pairs_utils.h index 52a135c7f..259b67e8b 100644 --- a/kernels/sparse_conv/get_indice_pairs/get_indice_pairs_utils.h +++ b/kernels/sparse_conv/get_indice_pairs/get_indice_pairs_utils.h @@ -76,7 +76,7 @@ func: generate stage index from start_index */ __mlu_func__ void stepIndex(int32_t *dst_nram, int32_t start_index, int32_t length) { -#if (__BANG_ARCH__ == 372 || __BANG_ARCH__ == 322 || __BANG_ARCH__ == 592) +#if __BANG_ARCH__ >= 372 int32_t align_num = 128; int32_t repeat = (int32_t)(logf(length / align_num) / logf(2)); int32_t remain = length / align_num - powf(2, repeat); diff --git a/kernels/utils/common.h b/kernels/utils/common.h index bceb8ccd4..c6bd1aead 100644 --- a/kernels/utils/common.h +++ b/kernels/utils/common.h @@ -419,7 +419,7 @@ __mlu_func__ void __mluop_store_str_3D(T *dst, T *src, int size, int seg_num_in, * dst_nram only support nram. * ****************************************************************************/ __mlu_func__ void __mluop_get_stage_indices_tfuse(int *dst_nram, int length) { -#if (__BANG_ARCH__ == 372 || __BANG_ARCH__ == 592) +#if __BANG_ARCH__ >= 372 int align_num = 128; int repeat = (int)(logf(length / align_num) / logf(2)); int remain = length / align_num - powf(2, repeat); diff --git a/test/mlu_op_gtest/pb_gtest/src/internal_kernel/fill_llc/fill_llc_device.mlu b/test/mlu_op_gtest/pb_gtest/src/internal_kernel/fill_llc/fill_llc_device.mlu index c5b9077e7..fbb93ddb1 100644 --- a/test/mlu_op_gtest/pb_gtest/src/internal_kernel/fill_llc/fill_llc_device.mlu +++ b/test/mlu_op_gtest/pb_gtest/src/internal_kernel/fill_llc/fill_llc_device.mlu @@ -24,7 +24,7 @@ #include "kernels/kernel.h" // MAX_NRAM_SIZE __mlu_global__ void flushLLC(void* input, int fill_bytes) { -#if (__BANG_ARCH__ == 592) +#if (__BANG_ARCH__ >= 592) if (coreId != 0) { return; }