Skip to content

Commit

Permalink
[Fix](mluOpSetFFTReserveArea): fix fft bug
Browse files Browse the repository at this point in the history
  • Loading branch information
niyuming committed Oct 21, 2024
1 parent 95ce7c8 commit 121b4c7
Show file tree
Hide file tree
Showing 3 changed files with 499 additions and 97 deletions.
22 changes: 14 additions & 8 deletions kernels/fft/common/fft_common_kernels.mlu
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,8 @@ __mlu_func__ void genSinCosVec(float *src_addr, float *sin_addr,
*/
__mlu_func__ void genSelectOffsetVec(float *offset_addr,
int32_t *offset_int_addr, int deal_size) {
for (int i = 0; i < deal_size; i++) {
offset_int_addr[i] = (int)(offset_addr[i]);
}
__bang_mul_scalar(offset_addr, offset_addr, (float)sizeof(float), deal_size);
__bang_float2int32((int32_t *)offset_int_addr, offset_addr, deal_size, 0);
}

/*
Expand All @@ -106,9 +105,16 @@ __mlu_func__ void genSelectOffsetVec(float *offset_addr,
*/
__mlu_func__ void selectVec(float *src_addr, int32_t *offset_int_addr,
float *dst_addr, int deal_size) {
#if __BANG_ARCH__ >= 372 && __BANG_ARCH__ != 520
__asm__ volatile(
"gather.clean.nram.nram.nram.b32.u32 "
"[%[dst]], [%[src]], [%[offset]], %[data_num];\n\t" ::[dst] "r"(dst_addr),
[src] "r"(src_addr), [offset] "r"(offset_int_addr), [data_num] "r"(deal_size));
#else
for (auto i = 0; i < deal_size; i++) {
dst_addr[i] = src_addr[offset_int_addr[i]];
}
#endif
}

/*
Expand Down Expand Up @@ -143,7 +149,7 @@ __mlu_func__ void generateRFFTHalfDFTMatrixImpl(int n, void *output) {
float *row_addr = temp_addr;

// generate 0 to n indices
__mluop_get_indices(inc_addr, (float)0.0, deal_size);
__mlu_op_gen_stage_index(inc_addr, deal_size, 0.0f, 1.0f);

// generate sin and cos vectors
const float scale = -2.0 * M_PI / n;
Expand Down Expand Up @@ -227,7 +233,7 @@ __mlu_func__ void generateRFFTFullDFTMatrixImpl(int row, int n, void *output) {
float *row_addr = temp_addr;

// generate 0 to n indices
__mluop_get_indices(inc_addr, (float)0.0, deal_size);
__mlu_op_gen_stage_index(inc_addr, deal_size, 0.0f, 1.0f);

// generate sin and cos vectors
const float scale = -2.0 * M_PI / n;
Expand Down Expand Up @@ -316,7 +322,7 @@ __mlu_func__ void generateIRFFTHalfDFTMatrixImpl(int n, void *output) {
float *row_addr = temp_addr;

// generate 0 to n indices
__mluop_get_indices(inc_addr, (float)0.0, deal_size);
__mlu_op_gen_stage_index(inc_addr, deal_size, 0.0f, 1.0f);

// generate sin and cos coefficient vectors
__bang_write_value((float *)cos_coeff_addr, deal_size, (float)2.0);
Expand Down Expand Up @@ -411,7 +417,7 @@ __mlu_func__ void generateIRFFTFullDFTMatrixImpl(int n, void *output) {
float *row_addr = temp_addr;

// generate 0 to n indices
__mluop_get_indices(inc_addr, (float)0.0, deal_size);
__mlu_op_gen_stage_index(inc_addr, deal_size, 0.0f, 1.0f);

// generate sin and cos vectors
const float scale = 2.0 * M_PI / n;
Expand Down Expand Up @@ -507,7 +513,7 @@ __mlu_func__ void generateC2CFFTDFTMatrixImpl(int n, void *output) {
float *row_addr = temp_addr;

// generate 0 to n indices
__mluop_get_indices(inc_addr, (float)0.0, deal_size);
__mlu_op_gen_stage_index(inc_addr, deal_size, 0.0f, 1.0f);

// generate sin and cos vectors
const float forward_scale = -2.0 * M_PI / n;
Expand Down
Loading

0 comments on commit 121b4c7

Please sign in to comment.