Skip to content

Commit

Permalink
[Fix](mlu-ops): modify common func. (#1169)
Browse files Browse the repository at this point in the history
  • Loading branch information
mahxn0 authored Dec 3, 2024
1 parent ba823ed commit 2540652
Show file tree
Hide file tree
Showing 7 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion kernels/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
 * Macros for mluop kernels
 ******************************************************************************/
// in future, can be "__BANG_ARCH__ == 592 || __BANG_ARCH__ == xxx || ...)"
#define ARCH_SUPPORT_LARGE_TENSOR (__BANG_ARCH__ == 592)
#define ARCH_SUPPORT_LARGE_TENSOR (__BANG_ARCH__ >= 592)

#define MAX_WRAM_SIZE (__MLU_WRAM_SIZE__ * 1024)
#define WRAM_LT_STRIDE (__MLU_WRAM_SIZE__ * 1024 / 64)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

#include "core/logging.h"

#if (__BANG_ARCH__ == 592)
#if (__BANG_ARCH__ >= 592)

#define MAX_MEMCPY_SEGNUM (65536)
#define NRAM_REMAIN_SIZE (48 * 1024)
Expand Down Expand Up @@ -454,7 +454,7 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackwardFastKernel(
const int32_t channels, const int32_t num_levels, const int32_t num_query,
const int32_t num_points, float* grad_value, float* grad_sampling_loc,
float* grad_attn_weight) {
#if (__BANG_ARCH__ == 592)
#if (__BANG_ARCH__ >= 592)
using T = float;
const int32_t num_keys = spatial_size;
const int32_t input_stride_4 =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ __mlu_func__ void stageOneLoop(
}
#endif

#if (__BANG_ARCH__ == 592)
#if (__BANG_ARCH__ >= 592)
__mlu_func__ void gatherAsync(void* dst, void* src, unsigned int* offset,
void* mask, int transfer_size,
mluMemcpyDirection_t dir, int dst_stride,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@ __mlu_func__ void MLUKernelMsDeformAttnForwardFastImpl(
}
}

#if (__BANG_ARCH__ == 592)
#if (__BANG_ARCH__ >= 592)

/*
The shape of each tensor on nram:
Expand Down Expand Up @@ -1260,7 +1260,7 @@ __mlu_global__ void MLUKernelMsDeformAttnForwardFast(
}
#endif

#if (__BANG_ARCH__ == 592)
#if (__BANG_ARCH__ >= 592)
MLUKernelMsDeformAttnForwardFastImpl<float>(
data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram,
data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func: generate stage index from start_index
*/
__mlu_func__ void stepIndex(int32_t *dst_nram, int32_t start_index,
int32_t length) {
#if (__BANG_ARCH__ == 372 || __BANG_ARCH__ == 322 || __BANG_ARCH__ == 592)
#if __BANG_ARCH__ >= 372
int32_t align_num = 128;
int32_t repeat = (int32_t)(logf(length / align_num) / logf(2));
int32_t remain = length / align_num - powf(2, repeat);
Expand Down
2 changes: 1 addition & 1 deletion kernels/utils/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ __mlu_func__ void __mluop_store_str_3D(T *dst, T *src, int size, int seg_num_in,
* dst_nram only support nram.
* ****************************************************************************/
__mlu_func__ void __mluop_get_stage_indices_tfuse(int *dst_nram, int length) {
#if (__BANG_ARCH__ == 372 || __BANG_ARCH__ == 592)
#if __BANG_ARCH__ >= 372
int align_num = 128;
int repeat = (int)(logf(length / align_num) / logf(2));
int remain = length / align_num - powf(2, repeat);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include "kernels/kernel.h" // MAX_NRAM_SIZE

__mlu_global__ void flushLLC(void* input, int fill_bytes) {
#if (__BANG_ARCH__ == 592)
#if (__BANG_ARCH__ >= 592)
if (coreId != 0) {
return;
}
Expand Down

0 comments on commit 2540652

Please sign in to comment.