diff --git a/kernels/fft/c2c_fft/c2c_fft_host.cpp b/kernels/fft/c2c_fft/c2c_fft_host.cpp index 29c53d61f..d040d9c11 100644 --- a/kernels/fft/c2c_fft/c2c_fft_host.cpp +++ b/kernels/fft/c2c_fft/c2c_fft_host.cpp @@ -1079,7 +1079,9 @@ static void configureFFT2dWorkspaceAddrs(mluOpHandle_t handle, fft_plan->mlu_addrs.buffer_buf = (uint8_t *)workspace + offset; offset += batch * in_c_dtype_size * _n0 * _n1 * 2; - if (fft_plan->is_input_contiguous) { + if ((fft_plan->is_input_contiguous && + fft_plan->inembed[0] <= fft_plan->n[0] && + fft_plan->inembed[1] <= fft_plan->n[1])) { fft_plan->mlu_addrs.input = input; } else { fft_plan->mlu_addrs.input = (uint8_t *)workspace + offset; @@ -1180,9 +1182,11 @@ static mluOpStatus_t makeFFT2dContiguousInput(mluOpHandle_t handle, int64_t dims[in_dim_num] = {fft_plan->batch, std::min(fft_plan->n[0], fft_plan->inembed[0]), std::min(fft_plan->n[1], fft_plan->inembed[1])}; - int64_t strides[in_dim_num] = {fft_plan->idist, - (fft_plan->istride * fft_plan->inembed[1]), - fft_plan->istride}; + + int64_t strides[3]; // in_dim_num + for (int i = 0; i < in_dim_num; i++) { + strides[i] = fft_plan->in_stride[i]; + } status = mluOpSetTensorDescriptorEx_v2(input_desc, MLUOP_LAYOUT_ARRAY, fft_plan->input_dtype, in_dim_num, dims, strides); @@ -1818,9 +1822,10 @@ static mluOpStatus_t makeFFT2dContiguousOutput(mluOpHandle_t handle, const int out_dim_num = 3; int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->n[0], fft_plan->n[1]}; - int64_t strides[out_dim_num] = {fft_plan->odist, - fft_plan->ostride * fft_plan->onembed[1], - fft_plan->ostride}; + int64_t strides[3]; // out_dim_num + for (int i = 0; i < out_dim_num; i++) { + strides[i] = fft_plan->out_stride[i]; + } status = mluOpSetTensorDescriptor_v2(copy_src_desc, MLUOP_LAYOUT_ARRAY, out_c_dtype, out_dim_num, dims); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); @@ -2053,7 +2058,39 @@ mluOpStatus_t execFFT2d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan, fft_plan->mlu_addrs.input = fft_plan->mlu_addrs.input_pad_addr; } - status = execFFTc2c2d(handle, fft_plan, scale_factor, direction); + if (fft_plan->n[0] == 1 && fft_plan->n[1] == 1) { + mluOpTensorDescriptor_t c_desc = nullptr; + status = mluOpCreateTensorDescriptor(&c_desc); + const int out_dim_num = 3; + int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->n[0], + fft_plan->n[1]}; + status = mluOpSetTensorDescriptor_v2(c_desc, MLUOP_LAYOUT_ARRAY, + fft_plan->output_dtype, 2, dims); + status = mluOpSetTensorDescriptorOnchipDataType(c_desc, + fft_plan->execution_dtype); + + DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, + cnnl_handle); // convert to cnnl_handle + + DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(c_desc, cnnl_output_desc); + + size_t workspace_size = 0; + CALL_CNNL(cnnlGetCopyWorkspaceSize(cnnl_handle, cnnl_output_desc, + cnnl_output_desc, &workspace_size)); + void *workspace = nullptr; + if (workspace_size > 0) { + CNRT_CHECK(cnrtMalloc((void **)&workspace, workspace_size)); + } + + CALL_CNNL(cnnlCopy_v2(cnnl_handle, cnnl_output_desc, + fft_plan->mlu_addrs.input, cnnl_output_desc, + fft_plan->mlu_addrs.output, workspace, + workspace_size)); + DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_output_desc); + DESTROY_CNNL_HANDLE(cnnl_handle); + } else { + status = execFFTc2c2d(handle, fft_plan, scale_factor, direction); + } INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); diff --git a/kernels/fft/common/fft_common_kernels.mlu b/kernels/fft/common/fft_common_kernels.mlu index 8cca3a697..d9e48157d 100644 --- a/kernels/fft/common/fft_common_kernels.mlu +++ b/kernels/fft/common/fft_common_kernels.mlu @@ -109,7 +109,8 @@ __mlu_func__ void selectVec(float *src_addr, int32_t *offset_int_addr, __asm__ volatile( "gather.clean.nram.nram.nram.b32.u32 " "[%[dst]], [%[src]], [%[offset]], %[data_num];\n\t" ::[dst] "r"(dst_addr), - [src] "r"(src_addr), [offset] "r"(offset_int_addr), [data_num] "r"(deal_size)); + [ src ] "r"(src_addr), [ offset ] "r"(offset_int_addr), + [ data_num ] "r"(deal_size)); #else for (auto i = 0; i < deal_size; i++) { dst_addr[i] = src_addr[offset_int_addr[i]]; diff --git a/kernels/fft/fft.cpp b/kernels/fft/fft.cpp index 4d4ab9ef1..9152e3bdf 100644 --- a/kernels/fft/fft.cpp +++ b/kernels/fft/fft.cpp @@ -1657,7 +1657,7 @@ mluOpAllocateC2C1D(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, fft_plan->is_batch_contiguous) ? 0 : buffer_size; - if (fft_plan->n[0] > fft_plan->inembed[0]) { + if (fft_plan->n[0] != fft_plan->inembed[0]) { workspace_size += buffer_size; } size_t twiddles_size = in_c_dtype_size * nfft * 2; @@ -1701,7 +1701,7 @@ mluOpAllocateR2C1D(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, reservespace_size = sizeof(int) * (FFT_MAXFACTORS) /* factors */ + twiddles_size * 2 + DFT_TABLE_SIZE * 2; /* twiddles */ - if (fft_plan->n[0] > fft_plan->inembed[0]) { + if (fft_plan->n[0] != fft_plan->inembed[0]) { workspace_size += buffer_size; // input_pad_addr } fft_plan->workspace_size = workspace_size; @@ -1740,13 +1740,17 @@ mluOpStatus_t MLUOP_WIN_API mluOpAllocateC2C2D( DFT_TABLE_SIZE * 2 + twiddles_size_2d * 2 + DFT_TABLE_SIZE * 2; /* twiddles */ workspace_size = buffer_size * 2; - workspace_size += (fft_plan->is_input_contiguous) ? 0 : buffer_size; + workspace_size += (fft_plan->is_input_contiguous && + fft_plan->inembed[0] <= fft_plan->n[0] && + fft_plan->inembed[1] <= fft_plan->n[1]) + ? 0 + : buffer_size; workspace_size += (fft_plan->is_output_contiguous) ? 0 : buffer_size; } fft_plan->workspace_size = workspace_size; - if (fft_plan->n[0] > fft_plan->inembed[0] || - fft_plan->n[1] > fft_plan->inembed[1]) { + if (fft_plan->n[0] != fft_plan->inembed[0] || + fft_plan->n[1] != fft_plan->inembed[1]) { fft_plan->workspace_size = workspace_size + buffer_size; // input_pad_addr } fft_plan->reservespace_size = reservespace_size; @@ -1783,7 +1787,7 @@ mluOpAllocateC2R1D(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, reservespace_size = sizeof(int) * (FFT_MAXFACTORS) /* factors */ + twiddles_size * 2 + DFT_TABLE_SIZE * 2; /* twiddles */ - if (fft_plan->n[0] > fft_plan->inembed[0]) { + if (fft_plan->n[0] != fft_plan->inembed[0]) { workspace_size += buffer_size; // input_pad_addr } fft_plan->workspace_size = workspace_size; @@ -1791,7 +1795,53 @@ mluOpAllocateC2R1D(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, return MLUOP_STATUS_SUCCESS; } +mluOpStatus_t MLUOP_WIN_API mluOpAllocateIRFFT2D( + mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, + mluOpTensorDescriptor_t input_desc, mluOpTensorDescriptor_t output_desc, + const int _n0, const int _n1) { + const std::string make_plan_api = "[mluOpAllocateIRFFT2D]"; + size_t workspace_size = 0, reservespace_size = 0; + + mluOpDataType_t out_c_dtype = fft_plan->output_dtype; + mluOpDataType_t in_c_dtype = fft_plan->input_dtype; + size_t complex_dtype_size = + (mluOpDataTypeBytes(out_c_dtype) > mluOpDataTypeBytes(in_c_dtype)) + ? mluOpDataTypeBytes(out_c_dtype) + : mluOpDataTypeBytes(in_c_dtype); + + int batch = fft_plan->batch; + size_t buffer_size = batch * complex_dtype_size * _n0 * _n1; + + size_t twiddles_size = complex_dtype_size * _n0; + size_t twiddles_size_2d = complex_dtype_size * _n1; + + if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) { + reservespace_size = complex_dtype_size * _n0 * _n0 * 2 + + complex_dtype_size * _n1 * _n1 * 2; /* DFT matrix */ + workspace_size = complex_dtype_size * _n1 * _n0 * batch * 6; + } else if (fft_plan->fft_strategy == CNFFT_FUNC_TWO_LEVEL_STOCKHAM) { + reservespace_size = sizeof(int) * (FFT_MAXFACTORS) /* factors */ + + sizeof(int) * (FFT_MAXFACTORS) + twiddles_size * 2 + + DFT_TABLE_SIZE * 2 + twiddles_size_2d * 2 + + DFT_TABLE_SIZE * 2; /* twiddles */ + workspace_size = buffer_size * 2; + workspace_size += (fft_plan->is_input_contiguous && + fft_plan->inembed[0] <= fft_plan->n[0] && + fft_plan->inembed[1] <= fft_plan->n[1] / 2 + 1) + ? 0 + : buffer_size; + workspace_size += (fft_plan->is_output_contiguous) ? 0 : buffer_size; + } + if (fft_plan->n[0] != fft_plan->inembed[0] || + fft_plan->n[1] != fft_plan->inembed[1]) { + workspace_size += buffer_size; + } + fft_plan->workspace_size = workspace_size; + fft_plan->reservespace_size = reservespace_size; + + return MLUOP_STATUS_SUCCESS; +} mluOpStatus_t MLUOP_WIN_API mluOpAllocateRFFT2D( mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, mluOpTensorDescriptor_t input_desc, mluOpTensorDescriptor_t output_desc, @@ -1822,12 +1872,16 @@ mluOpStatus_t MLUOP_WIN_API mluOpAllocateRFFT2D( DFT_TABLE_SIZE * 2 + twiddles_size_2d * 2 + DFT_TABLE_SIZE * 2; /* twiddles */ workspace_size = buffer_size * 2; - workspace_size += (fft_plan->is_input_contiguous) ? 0 : buffer_size; + workspace_size += (fft_plan->is_input_contiguous && + fft_plan->inembed[0] <= fft_plan->n[0] && + fft_plan->inembed[1] <= fft_plan->n[1]) + ? 0 + : buffer_size; workspace_size += (fft_plan->is_output_contiguous) ? 0 : buffer_size; } - if (fft_plan->n[0] > fft_plan->inembed[0] || - fft_plan->n[1] > fft_plan->inembed[1]) { + if (fft_plan->n[0] != fft_plan->inembed[0] || + fft_plan->n[1] != fft_plan->inembed[1]) { workspace_size += buffer_size; } fft_plan->workspace_size = workspace_size; @@ -1846,6 +1900,8 @@ mluOpStatus_t MLUOP_WIN_API mluOpMakeFFTPlanC2C1D( const int rank, const int *n) { fft_plan->is_batch_contiguous = (fft_plan->idist == 1 && fft_plan->odist == 1 && + fft_plan->inembed[0] == fft_plan->n[0] && + fft_plan->onembed[0] == fft_plan->n[0] && fft_plan->istride == fft_plan->batch && fft_plan->ostride == fft_plan->batch) && (fft_plan->n[0] == fft_plan->inembed[0]); @@ -2221,7 +2277,7 @@ mluOpStatus_t MLUOP_WIN_API mluOpMakeFFTPlanC2R2D( fft_plan->fft_strategy = CNFFT_FUNC_TWO_LEVEL_STOCKHAM; } - mluOpAllocateRFFT2D(handle, fft_plan, input_desc, output_desc, n[0], n[1]); + mluOpAllocateIRFFT2D(handle, fft_plan, input_desc, output_desc, n[0], n[1]); if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) { switch (fft_plan->fft_type) { @@ -2394,6 +2450,12 @@ mluOpStatus_t MLUOP_WIN_API mluOpMakeFFTPlanMany( fft_plan->inembed[i] = input_desc->dims[fft_plan->idim - rank + i]; fft_plan->onembed[i] = output_desc->dims[fft_plan->odim - rank + i]; } + for (auto i = 0; i < fft_plan->idim; i++) { + fft_plan->in_stride[i] = input_desc->strides[i]; + } + for (auto i = 0; i < fft_plan->odim; i++) { + fft_plan->out_stride[i] = output_desc->strides[i]; + } if (fft_plan->idim == rank + 1) { fft_plan->idist = input_desc->strides[0]; fft_plan->odist = output_desc->strides[0]; diff --git a/kernels/fft/fft.h b/kernels/fft/fft.h index aa7ac0ba6..5eec60789 100644 --- a/kernels/fft/fft.h +++ b/kernels/fft/fft.h @@ -193,24 +193,26 @@ struct mluOpFFTStruct { int inum; // element num of input tensor int istride; // distance between two successive input elements in the // innermost dimension - int idist; // distance between the first element of two consecutive signals - // in a batch of the input data - int odim; // the dimension size of output tensor + int in_stride[FFT_DIM_MAX + 1]; + int idist; // distance between the first element of two consecutive signals + // in a batch of the input data + int odim; // the dimension size of output tensor int onembed[FFT_DIM_MAX]; // Pointer of size rank that indicates the storage // dimensions of the output data in memory int onum; // element num of output tensor int ostride; // distance between two successive output elements in the // innermost dimension - int odist; // distance between the first element of two consecutive signals - // in a batch of the output data - int batch; // batch size for this transform - int L; // n = L * 2^m, L size for this transform - int m; // n = L * 2^m, m size for this transform - int s; // The size that can be put down on NRAM: L * 2^s, only used by - // Cooley-Tukey algorithm - int L_sub; // The size that can be put down on NRAM: L_sub * 2^m, only used - // by Stockham algorithm - int prime; // wether fft1d'size contains a prime number > 64 + int out_stride[FFT_DIM_MAX + 1]; + int odist; // distance between the first element of two consecutive signals + // in a batch of the output data + int batch; // batch size for this transform + int L; // n = L * 2^m, L size for this transform + int m; // n = L * 2^m, m size for this transform + int s; // The size that can be put down on NRAM: L * 2^s, only used by + // Cooley-Tukey algorithm + int L_sub; // The size that can be put down on NRAM: L_sub * 2^m, only used + // by Stockham algorithm + int prime; // wether fft1d'size contains a prime number > 64 bool is_input_contiguous; bool is_output_contiguous; bool is_batch_contiguous; diff --git a/kernels/fft/fft_optm_device/fft_c2c_stockham_nram.h b/kernels/fft/fft_optm_device/fft_c2c_stockham_nram.h index 547174631..07d31dea1 100644 --- a/kernels/fft/fft_optm_device/fft_c2c_stockham_nram.h +++ b/kernels/fft/fft_optm_device/fft_c2c_stockham_nram.h @@ -305,379 +305,6 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpong( } } -// Compute the large butterfly for the subsequent stages of the FFT -template -__mlu_func__ void computeLargeButterflyOtherstages( - DT *output, DT *input, const int large_radix, const DT *cur_large_twiddles, - const DT *_twiddles, const DT *dft_matrix, const int large_section_num, - const int large_butterfly_num, const int large_in_stride, void *nram_buf, - const int *small_factors, const int nfft, const int dir, - const int last_stage) { - const dft_table_entry *dft_table = (const dft_table_entry *)dft_matrix; - const int K_num = 64 / sizeof(DT); - int align_K = 0; - int radix, small_in_stride, small_stage_count, _small_stage_count; - int small_section_num, small_butterfly_num, value_mul; - - const int large_out_stride = large_butterfly_num; - int tw_offset; - - _small_stage_count = small_factors[0]; - tw_offset = small_factors[1]; - - const DT *small_twiddles = _twiddles + tw_offset * 2; - - const int max_para_ldst_num = (4096 + large_radix - 1) / large_radix; - - int nram_buf_offset = 0; - DT *nram_in_r = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * max_para_ldst_num; - - DT *nram_in_i = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * max_para_ldst_num; - - DT *nram_out_r = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * max_para_ldst_num; - - DT *nram_out_i = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * max_para_ldst_num; - - FFT_CPX_T
nram_para_load_in_ping = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; - nram_buf_offset += large_radix * max_para_ldst_num * 2; - - FFT_CPX_T
nram_para_load_in_pong = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; - nram_buf_offset += large_radix * max_para_ldst_num * 2; - - FFT_CPX_T
nram_para_load_tw_ping = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; - nram_buf_offset += large_radix * max_para_ldst_num * 2; - - FFT_CPX_T
nram_para_load_tw_pong = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; - nram_buf_offset += large_radix * max_para_ldst_num * 2; - - FFT_CPX_T
nram_para_store_ping = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; - nram_buf_offset += large_radix * max_para_ldst_num * 2; - - FFT_CPX_T
nram_para_store_pong = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; - nram_buf_offset += large_radix * max_para_ldst_num * 2; - - FFT_CPX_T
nram_transpose_temp; - nram_transpose_temp = { - (DT *)nram_in_r, - (DT *)nram_in_r + large_radix * ((int)last_stage) + - large_radix * (1 - (int)last_stage) * max_para_ldst_num}; - - DT *_nram_tw = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * 2; - - int ld_dft_radix = -1; - const int max_radix = 64; - DT *nram_dftmtx = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += max_radix * max_radix * 2; - - DT *nram_scratch = (DT *)nram_buf + nram_buf_offset; - - DT *CPX_MUL_RR = nram_scratch; - DT *CPX_MUL_RI = &CPX_MUL_RR[large_radix * max_para_ldst_num]; - DT *CPX_MUL_IR = &CPX_MUL_RI[large_radix * max_para_ldst_num]; - DT *CPX_MUL_II = &CPX_MUL_IR[large_radix * max_para_ldst_num]; - - nram_buf_offset += large_radix * max_para_ldst_num * 4; - - int Fin_stride = 0, Fout_stride = 0; - int sec_count; - int repeat_num = - (large_butterfly_num + max_para_ldst_num - 1) / max_para_ldst_num; - for (sec_count = 0; sec_count < large_section_num; ++sec_count) { - for (int repeat_id = 0; repeat_id < repeat_num + 2; ++repeat_id) { - if (repeat_id < repeat_num) { - int i = max_para_ldst_num * repeat_id; - FFT_CPX_T
nram_para_load_in = (repeat_id % 2 == 0) - ? nram_para_load_in_ping - : nram_para_load_in_pong; - - FFT_CPX_T
nram_para_load_tw = (repeat_id % 2 == 0) - ? nram_para_load_tw_ping - : nram_para_load_tw_pong; - - int para_load_num = (max_para_ldst_num > (large_butterfly_num - i)) - ? (large_butterfly_num - i) - : max_para_ldst_num; - - __memcpy_async(nram_para_load_in.r, input + Fin_stride + i, - sizeof(DT) * para_load_num, GDRAM2NRAM, - sizeof(DT) * para_load_num, large_in_stride * sizeof(DT), - large_radix - 1); - __memcpy_async(nram_para_load_in.i, input + nfft + Fin_stride + i, - sizeof(DT) * para_load_num, GDRAM2NRAM, - sizeof(DT) * para_load_num, large_in_stride * sizeof(DT), - large_radix - 1); - __memcpy_async(nram_para_load_tw.r, cur_large_twiddles + i, - sizeof(DT) * para_load_num, SRAM2NRAM, - sizeof(DT) * para_load_num, - large_out_stride * sizeof(DT), large_radix - 2); - __memcpy_async( - nram_para_load_tw.i, - cur_large_twiddles + large_butterfly_num * (large_radix - 1) + i, - sizeof(DT) * para_load_num, SRAM2NRAM, sizeof(DT) * para_load_num, - large_out_stride * sizeof(DT), large_radix - 2); - } - - if (repeat_id >= 2) { - int i = max_para_ldst_num * (repeat_id - 2); - - int para_store_num = (max_para_ldst_num > (large_butterfly_num - i)) - ? (large_butterfly_num - i) - : max_para_ldst_num; - - FFT_CPX_T
nram_para_store = - (repeat_id % 2 == 0) ? nram_para_store_ping : nram_para_store_pong; - - if (last_stage) { - __memcpy_async(output + (Fout_stride + i) * 2, nram_para_store.r, - sizeof(DT) * 2 * para_store_num, NRAM2GDRAM, - large_out_stride * 2 * sizeof(DT), - sizeof(DT) * 2 * para_store_num, large_radix - 1); - } else { - __memcpy_async(output + Fout_stride + i, nram_para_store.r, - para_store_num * sizeof(DT), NRAM2GDRAM, - large_out_stride * sizeof(DT), - sizeof(DT) * para_store_num, large_radix - 1); - __memcpy_async(output + Fout_stride + i + nfft, nram_para_store.i, - para_store_num * sizeof(DT), NRAM2GDRAM, - large_out_stride * sizeof(DT), - sizeof(DT) * para_store_num, large_radix - 1); - } - } - - if (repeat_id >= 1 && repeat_id < repeat_num + 1) { - int i = max_para_ldst_num * (repeat_id - 1); - - FFT_CPX_T
nram_para_load_in = (repeat_id % 2 != 0) - ? nram_para_load_in_ping - : nram_para_load_in_pong; - - FFT_CPX_T
nram_para_load_tw = (repeat_id % 2 != 0) - ? nram_para_load_tw_ping - : nram_para_load_tw_pong; - - FFT_CPX_T
nram_para_store = - (repeat_id % 2 != 0) ? nram_para_store_ping : nram_para_store_pong; - - int para_ldst_num = (max_para_ldst_num > (large_butterfly_num - i)) - ? (large_butterfly_num - i) - : max_para_ldst_num; - - __bang_mul(CPX_MUL_RR, nram_para_load_in.r + para_ldst_num, - nram_para_load_tw.r, para_ldst_num * (large_radix - 1)); - __bang_mul(CPX_MUL_II, nram_para_load_in.i + para_ldst_num, - nram_para_load_tw.i, para_ldst_num * (large_radix - 1)); - __bang_mul(CPX_MUL_RI, nram_para_load_in.r + para_ldst_num, - nram_para_load_tw.i, para_ldst_num * (large_radix - 1)); - __bang_mul(CPX_MUL_IR, nram_para_load_in.i + para_ldst_num, - nram_para_load_tw.r, para_ldst_num * (large_radix - 1)); - - __bang_sub(nram_para_load_in.r + para_ldst_num, CPX_MUL_RR, CPX_MUL_II, - para_ldst_num * (large_radix - 1)); - __bang_add(nram_para_load_in.i + para_ldst_num, CPX_MUL_RI, CPX_MUL_IR, - para_ldst_num * (large_radix - 1)); - - { - radix = small_factors[4]; - small_section_num = small_factors[5]; - small_in_stride = small_factors[7]; - small_stage_count = _small_stage_count; - - if (ld_dft_radix != radix) { - ld_dft_radix = radix; - for (int entry = 0;; entry++) { - if (dft_table[entry].radix == ld_dft_radix) { - align_K = K_num * ((radix + K_num - 1) / K_num); - __memcpy_async( - nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2], - sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM); - __sync_move(); - break; - } - - if (dft_table[entry].radix == -1) { - break; - } - } - } - - computeGenericButterflyFirststageMat( - nram_out_r, nram_out_i, nram_para_load_in.r, nram_para_load_in.i, - nram_scratch, nram_dftmtx, small_section_num * para_ldst_num, - small_section_num * para_ldst_num, 1, dir, radix); - - small_stage_count--; - if (small_stage_count == 0) { - if (last_stage) { - __memcpy_async(nram_transpose_temp.r, nram_out_r, - sizeof(DT) * large_radix, NRAM2NRAM, - sizeof(DT) * large_radix * 2, - sizeof(DT) * large_radix, para_ldst_num - 1); - - __memcpy_async(nram_transpose_temp.i, nram_out_i, - sizeof(DT) * large_radix, NRAM2NRAM, - sizeof(DT) * large_radix * 2, - sizeof(DT) * large_radix, para_ldst_num - 1); - __sync_move(); - - __bang_transpose(nram_para_store.r, nram_transpose_temp.r, - para_ldst_num * 2, large_radix); - } else { - __bang_transpose(nram_para_store.r, nram_out_r, para_ldst_num, - large_radix); - __bang_transpose(nram_para_store.i, nram_out_i, para_ldst_num, - large_radix); - } - - } else { - FFT_SWAP_PTR(nram_out_r, nram_in_r); - FFT_SWAP_PTR(nram_out_i, nram_in_i); - TRANSPOSE_XYZ2YXZ_PAIR(nram_out_r, nram_out_i, nram_in_r, nram_in_i, - small_section_num, para_ldst_num, radix, DT) - DT *nram_tw = _nram_tw; - value_mul = 8; - - for (; small_stage_count > 1; small_stage_count--) { - FFT_SWAP_PTR(nram_out_r, nram_in_r); - FFT_SWAP_PTR(nram_out_i, nram_in_i); - - radix = small_factors[value_mul++]; - small_section_num = small_factors[value_mul++]; - small_butterfly_num = small_factors[value_mul++]; - small_in_stride = small_factors[value_mul++]; - - if (ld_dft_radix != radix) { - ld_dft_radix = radix; - for (int entry = 0;; entry++) { - if (dft_table[entry].radix == ld_dft_radix) { - align_K = K_num * ((radix + K_num - 1) / K_num); - __memcpy_async( - nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2], - sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM); - __sync_move(); - break; - } - - if (dft_table[entry].radix == -1) { - break; - } - } - } - - if (sec_count == 0 && repeat_id == 1) { - __memcpy(nram_tw, small_twiddles, - small_butterfly_num * (radix - 1) * sizeof(DT) * 2, - SRAM2NRAM); - small_twiddles += small_butterfly_num * (radix - 1) * 2; - } - - computeGenericButterflyOtherstagesMat( - nram_out_r, nram_out_i, nram_in_r, nram_in_i, nram_scratch, - nram_dftmtx, nram_tw, small_section_num, small_butterfly_num, - para_ldst_num, small_in_stride, dir, radix); - - nram_tw += small_butterfly_num * (radix - 1) * 2; - } - - { - FFT_SWAP_PTR(nram_out_r, nram_in_r); - FFT_SWAP_PTR(nram_out_i, nram_in_i); - - radix = small_factors[value_mul++]; - small_section_num = small_factors[value_mul++]; - small_butterfly_num = small_factors[value_mul++]; - small_in_stride = small_factors[value_mul]; - - if (sec_count == 0 && repeat_id == 1) { - __memcpy_async( - nram_tw, small_twiddles, - small_butterfly_num * (radix - 1) * sizeof(DT) * 2, - SRAM2NRAM); - __sync_move(); - } - - if (ld_dft_radix != radix) { - ld_dft_radix = radix; - for (int entry = 0;; entry++) { - if (dft_table[entry].radix == ld_dft_radix) { - align_K = K_num * ((radix + K_num - 1) / K_num); - __memcpy_async( - nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2], - sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM); - __sync_move(); - break; - } - - if (dft_table[entry].radix == -1) { - break; - } - } - } - computeGenericButterflyLaststageMat( - nram_out_r, nram_out_i, nram_in_r, nram_in_i, nram_scratch, - nram_dftmtx, nram_tw, small_section_num, small_butterfly_num, - para_ldst_num, small_in_stride, dir, radix); - - if (last_stage) { - __memcpy_async(nram_transpose_temp.r, nram_out_r, - sizeof(DT) * large_radix, NRAM2NRAM, - sizeof(DT) * large_radix * 2, - sizeof(DT) * large_radix, para_ldst_num - 1); - - __memcpy_async(nram_transpose_temp.i, nram_out_i, - sizeof(DT) * large_radix, NRAM2NRAM, - sizeof(DT) * large_radix * 2, - sizeof(DT) * large_radix, para_ldst_num - 1); - __sync_move(); - - __bang_transpose(nram_para_store.r, nram_transpose_temp.r, - para_ldst_num * 2, large_radix); - } else { - __bang_transpose(nram_para_store.r, nram_out_r, para_ldst_num, - large_radix); - __bang_transpose(nram_para_store.i, nram_out_i, para_ldst_num, - large_radix); - } - } - } - } - } - - __sync(); - } - Fin_stride += large_butterfly_num; - Fout_stride += large_radix * large_butterfly_num; - } -} - -template -__mlu_func__ void computeLargeButterflyLaststage( - DT *output, DT *input, const int large_radix, const DT *cur_large_twiddles, - const DT *_twiddles, const DT *dft_matrix, const int large_section_num, - const int large_butterfly_num, const int large_in_stride, void *nram_buf, - const int *small_factors, const int nfft, const int dir) { - computeLargeButterflyOtherstages( - output, input, large_radix, cur_large_twiddles, _twiddles, dft_matrix, - large_section_num, large_butterfly_num, large_in_stride, nram_buf, - small_factors, nfft, dir, 1); -} - // Compute the large butterfly for the last stage of the FFT template __mlu_func__ void computeLargeButterflyOtherstagesBatchPingpong( diff --git a/kernels/fft/fft_optm_device/fft_stockham_u1_device.mlu b/kernels/fft/fft_optm_device/fft_stockham_u1_device.mlu index 9ca8f2a50..99ef014cd 100644 --- a/kernels/fft/fft_optm_device/fft_stockham_u1_device.mlu +++ b/kernels/fft/fft_optm_device/fft_stockham_u1_device.mlu @@ -481,9 +481,9 @@ __mlu_func__ void store(DT* output, DT* y_in_r, DT* x_out1_r, if (part == 0) { int dst_one_point_offset = b * out_n * 2 + n; int src_one_point_offset = pow_2_m * L_sub; - __memcpy_async((DT *)output + dst_one_point_offset, - (DT *)out_nram + src_one_point_offset, sizeof(DT) * 2, - NRAM2GDRAM); + *(output + dst_one_point_offset) = *(out_nram + src_one_point_offset); + *(output + dst_one_point_offset + 1) = + *(out_nram + src_one_point_offset + 1); } } else if (fft_flag == IRFFT) { int dst_offset = part * L_sub; diff --git a/kernels/fft/fft_optm_device/fft_two-level_network_c2c_device.mlu b/kernels/fft/fft_optm_device/fft_two-level_network_c2c_device.mlu index 55b3a37b6..808c91df1 100644 --- a/kernels/fft/fft_optm_device/fft_two-level_network_c2c_device.mlu +++ b/kernels/fft/fft_optm_device/fft_two-level_network_c2c_device.mlu @@ -35,26 +35,10 @@ __mlu_global__ void MLUKernelFFT1dButterflyRow( void *input, void *output, int *factors, void *twiddles, void *twiddles_end, void *dft_matrix, void *buffer, const int batch, const int fft_flag, const int direction, const int dtype_size) { - switch (dtype_size) { - case (MLUOP_DTYPE_COMPLEX_FLOAT): - case (MLUOP_DTYPE_FLOAT): { - computeMutiStageOnchip((float *)input, (float *)output, factors, - (float *)twiddles, (float *)twiddles_end, - (float *)dft_matrix, (float *)buffer, batch, - fft_flag, direction); - }; break; - case (MLUOP_DTYPE_COMPLEX_HALF): - case (MLUOP_DTYPE_HALF): { - computeMutiStageOnchip((half *)input, (half *)output, factors, - (half *)twiddles, (half *)twiddles_end, - (half *)dft_matrix, (half *)buffer, batch, - fft_flag, direction); - }; break; - - default: { - MLULOG("mluOpFFT Not Implemented."); - } - } + computeMutiStageOnchip((float *)input, (float *)output, factors, + (float *)twiddles, (float *)twiddles_end, + (float *)dft_matrix, (float *)buffer, batch, + fft_flag, direction); } // Kernel function for 1D FFT butterfly operations on columns. @@ -62,26 +46,10 @@ __mlu_global__ void MLUKernelFFT1dButterflyColumn( void *input, void *output, int *factors, void *twiddles, void *twiddles_end, void *dft_matrix, void *buffer, const int batch, const int fft_flag, const int direction, const int dtype_size, const int nb) { - switch (dtype_size) { - case (MLUOP_DTYPE_COMPLEX_FLOAT): - case (MLUOP_DTYPE_FLOAT): { - computeMutiStageOnchipColumn( - (float *)input, (float *)output, factors, (float *)twiddles, - (float *)twiddles_end, (float *)dft_matrix, (float *)buffer, batch, - fft_flag, direction, nb); - }; break; - case (MLUOP_DTYPE_COMPLEX_HALF): - case (MLUOP_DTYPE_HALF): { - computeMutiStageOnchipColumn((half *)input, (half *)output, factors, - (half *)twiddles, (half *)twiddles_end, - (half *)dft_matrix, (half *)buffer, - batch, fft_flag, direction, nb); - }; break; - - default: { - MLULOG("mluOpFFT Not Implemented."); - } - } + computeMutiStageOnchipColumn((float *)input, (float *)output, factors, + (float *)twiddles, (float *)twiddles_end, + (float *)dft_matrix, (float *)buffer, + batch, fft_flag, direction, nb); } // Launches a kernel for 2D FFT butterfly operations on columns. diff --git a/kernels/fft/fft_optm_device/fft_two-level_network_c2r_device.mlu b/kernels/fft/fft_optm_device/fft_two-level_network_c2r_device.mlu index 31b3c3908..a76078f62 100644 --- a/kernels/fft/fft_optm_device/fft_two-level_network_c2r_device.mlu +++ b/kernels/fft/fft_optm_device/fft_two-level_network_c2r_device.mlu @@ -33,26 +33,10 @@ __mlu_global__ void MLUKernelFFT1dButterflyRowC2R( void *input, void *output, int *factors, void *twiddles, void *twiddles_end, void *dft_matrix, void *buffer, int batch, int fft_flag, int dtype_size) { - switch (dtype_size) { - case (MLUOP_DTYPE_COMPLEX_FLOAT): - case (MLUOP_DTYPE_FLOAT): { - computeMutiStageOnchipC2R((float *)input, (float *)output, factors, - (float *)twiddles, (float *)twiddles_end, - (float *)dft_matrix, (float *)buffer, - batch, fft_flag); - }; break; - case (MLUOP_DTYPE_COMPLEX_HALF): - case (MLUOP_DTYPE_HALF): { - computeMutiStageOnchipC2R((half *)input, (half *)output, factors, - (half *)twiddles, (half *)twiddles_end, - (half *)dft_matrix, (half *)buffer, batch, - fft_flag); - }; break; - - default: { - MLULOG("mluOpFFT Not Implemented."); - } - } + computeMutiStageOnchipC2R((float *)input, (float *)output, factors, + (float *)twiddles, (float *)twiddles_end, + (float *)dft_matrix, (float *)buffer, batch, + fft_flag); } mluOpStatus_t MLUOP_WIN_API kernelFFT1dButterflyRowC2R( diff --git a/kernels/fft/fft_optm_device/fft_two-level_network_r2c_device.mlu b/kernels/fft/fft_optm_device/fft_two-level_network_r2c_device.mlu index 5dd5f9e8d..3e36946b7 100644 --- a/kernels/fft/fft_optm_device/fft_two-level_network_r2c_device.mlu +++ b/kernels/fft/fft_optm_device/fft_two-level_network_r2c_device.mlu @@ -33,28 +33,10 @@ __mlu_global__ void MLUKernelFFT1dButterflyR2C( void *input, void *output, int *factors, void *twiddles, void *twiddles_end, void *dft_matrix, void *buffer, int batch, int fft_flag, int dtype_size) { - switch (dtype_size) { - case (MLUOP_DTYPE_COMPLEX_FLOAT): - case (MLUOP_DTYPE_FLOAT): { - MLULOG("MLUOP_DTYPE_COMPLEX_FLOAT: MLUOP_DTYPE_FLOAT\n"); - computeMutiStageR2COnchip((float *)input, (float *)output, factors, - (float *)twiddles, (float *)twiddles_end, - (float *)dft_matrix, (float *)buffer, - batch, fft_flag); - }; break; - case (MLUOP_DTYPE_COMPLEX_HALF): - case (MLUOP_DTYPE_HALF): { - MLULOG("MLUOP_DTYPE_COMPLEX_HALF: MLUOP_DTYPE_HALF\n"); - computeMutiStageR2COnchip((half *)input, (half *)output, factors, - (half *)twiddles, (half *)twiddles_end, - (half *)dft_matrix, (half *)buffer, batch, - fft_flag); - }; break; - - default: { - MLULOG("mluOpFFT Not Implemented."); - } - } + computeMutiStageR2COnchip((float *)input, (float *)output, factors, + (float *)twiddles, (float *)twiddles_end, + (float *)dft_matrix, (float *)buffer, batch, + fft_flag); } mluOpStatus_t MLUOP_WIN_API kernelFFT1dButterflyR2C(cnrtDim3_t k_dim, diff --git a/kernels/fft/irfft/irfft_host.cpp b/kernels/fft/irfft/irfft_host.cpp index b065028e1..90e095e5c 100644 --- a/kernels/fft/irfft/irfft_host.cpp +++ b/kernels/fft/irfft/irfft_host.cpp @@ -1627,7 +1627,10 @@ static void configureIRFFT2dWorkspaceAddrs(mluOpHandle_t handle, fft_plan->mlu_addrs.buffer_buf = (uint8_t *)workspace + offset; offset += batch * in_c_dtype_size * _n0 * _n1 * 2; - if (fft_plan->is_input_contiguous) { + if (fft_plan->is_input_contiguous && + fft_plan->inembed[0] <= fft_plan->n[0] && + fft_plan->inembed[1] <= fft_plan->n[1] / 2 + 1 || + fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) { fft_plan->mlu_addrs.input = input; } else { fft_plan->mlu_addrs.input = (uint8_t *)workspace + offset; @@ -1643,7 +1646,7 @@ static void configureIRFFT2dWorkspaceAddrs(mluOpHandle_t handle, } if (fft_plan->n[0] > fft_plan->inembed[0] || - fft_plan->n[1] > fft_plan->inembed[1]) { + fft_plan->n[1] / 2 + 1 > fft_plan->inembed[1]) { fft_plan->mlu_addrs.input_pad_addr = (uint8_t *)workspace + offset; } } @@ -1866,8 +1869,7 @@ static mluOpStatus_t makeIRFFT2dContiguousInput(mluOpHandle_t handle, auto status = MLUOP_STATUS_SUCCESS; if ((!fft_plan->is_input_contiguous || (fft_plan->inembed[0] > fft_plan->n[0] || - fft_plan->inembed[1] > fft_plan->n[1] / 2 + 1) && - !fft_plan->prime) && + fft_plan->inembed[1] > fft_plan->n[1] / 2 + 1)) && fft_plan->fft_strategy != CNFFT_FUNC_MANY_DIST1_2D) { VLOG(5) << "launch mluOpContiguous for irfft2d input"; mluOpTensorDescriptor_t input_desc; @@ -1878,9 +1880,10 @@ static mluOpStatus_t makeIRFFT2dContiguousInput(mluOpHandle_t handle, int64_t dims[in_dim_num] = { fft_plan->batch, std::min(fft_plan->inembed[0], fft_plan->n[0]), std::min(FFT_HALF(fft_plan->n[1]), fft_plan->inembed[1])}; - int64_t strides[in_dim_num] = {fft_plan->idist, - (fft_plan->istride * fft_plan->inembed[1]), - fft_plan->istride}; + int64_t strides[3]; // in_dim_num + for (int i = 0; i < in_dim_num; i++) { + strides[i] = fft_plan->in_stride[i]; + } status = mluOpSetTensorDescriptorEx_v2(input_desc, MLUOP_LAYOUT_ARRAY, fft_plan->input_dtype, in_dim_num, dims, strides); @@ -1916,9 +1919,10 @@ static mluOpStatus_t makeIRFFT2dContiguousOutput(mluOpHandle_t handle, const int out_dim_num = 3; int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->n[0], fft_plan->n[1]}; - int64_t strides[out_dim_num] = {fft_plan->odist, - fft_plan->ostride * fft_plan->onembed[1], - fft_plan->ostride}; + int64_t strides[3]; // out_dim_num + for (int i = 0; i < out_dim_num; i++) { + strides[i] = fft_plan->out_stride[i]; + } status = mluOpSetTensorDescriptor_v2(copy_src_desc, MLUOP_LAYOUT_ARRAY, out_c_dtype, out_dim_num, dims); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); @@ -1987,63 +1991,87 @@ mluOpStatus_t execIRFFT2d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan, fft_plan->mlu_addrs.input = fft_plan->mlu_addrs.input_pad_addr; } - for (int batch_id = 0; batch_id < fft_plan->batch; batch_id++) { - status = kernelIRFFT2dButterflyColumn(k_dim, k_type, handle->queue, - fft_plan, FFT_IFFT); + if (fft_plan->n[0] == 1 && fft_plan->n[1] == 1) { + mluOpTensorDescriptor_t input_desc; + status = mluOpCreateTensorDescriptor(&input_desc); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + const int in_dim_num = 2; + int64_t dims[in_dim_num] = { + fft_plan->batch * fft_plan->n[0] * fft_plan->n[1], 1}; + int64_t strides[in_dim_num] = {2, 1}; + status = mluOpSetTensorDescriptorEx_v2(input_desc, MLUOP_LAYOUT_ARRAY, + MLUOP_DTYPE_FLOAT, in_dim_num, + dims, strides); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); - status = kernelIRFFT2dButterflyRow(k_dim, k_type, handle->queue, fft_plan, - FFT_IFFT); + + status = mluOpContiguous(handle, input_desc, fft_plan->mlu_addrs.input, + fft_plan->mlu_addrs.output); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + status = mluOpDestroyTensorDescriptor(input_desc); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + } else { + for (int batch_id = 0; batch_id < fft_plan->batch; batch_id++) { + status = kernelIRFFT2dButterflyColumn(k_dim, k_type, handle->queue, + fft_plan, FFT_IFFT); + + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + status = kernelIRFFT2dButterflyRow(k_dim, k_type, handle->queue, + fft_plan, FFT_IFFT); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + + fft_plan->mlu_addrs.input = + (void *)((uint64_t)(fft_plan->mlu_addrs.input) + idist); + fft_plan->mlu_addrs.output = + (void *)((uint64_t)(fft_plan->mlu_addrs.output) + odist); + } fft_plan->mlu_addrs.input = - (void *)((uint64_t)(fft_plan->mlu_addrs.input) + idist); + (void *)((uint64_t)(fft_plan->mlu_addrs.input) - + fft_plan->batch * idist); fft_plan->mlu_addrs.output = - (void *)((uint64_t)(fft_plan->mlu_addrs.output) + odist); + (void *)((uint64_t)(fft_plan->mlu_addrs.output) - + fft_plan->batch * odist); } - fft_plan->mlu_addrs.input = (void *)((uint64_t)(fft_plan->mlu_addrs.input) - - fft_plan->batch * idist); - fft_plan->mlu_addrs.output = - (void *)((uint64_t)(fft_plan->mlu_addrs.output) - - fft_plan->batch * odist); + } else if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) { + status = computeFFT2dMatMulColumnC2R(handle, fft_plan, scale_factor); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); - if (scale_factor != 1.0) { - const float alpha[2] = {scale_factor, 0.0}; - const float beta[2] = {0.0, 0.0}; - mluOpTensorDescriptor_t c_desc = nullptr; - status = mluOpCreateTensorDescriptor(&c_desc); - const int out_dim_num = 3; - int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->n[0], - fft_plan->n[1]}; - status = mluOpSetTensorDescriptor_v2(c_desc, MLUOP_LAYOUT_ARRAY, - fft_plan->output_dtype, 3, dims); - status = mluOpSetTensorDescriptorOnchipDataType( - c_desc, fft_plan->execution_dtype); + status = computeFFT2dMatMulRowC2R(handle, fft_plan, scale_factor); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + } - DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, - cnnl_handle); // convert to cnnl_handle + if (scale_factor != 1.0) { + const float alpha[2] = {scale_factor, 0.0}; + const float beta[2] = {0.0, 0.0}; + mluOpTensorDescriptor_t c_desc = nullptr; + status = mluOpCreateTensorDescriptor(&c_desc); + const int out_dim_num = 3; + int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->n[0], + fft_plan->n[1]}; + status = mluOpSetTensorDescriptor_v2(c_desc, MLUOP_LAYOUT_ARRAY, + fft_plan->output_dtype, 3, dims); + status = mluOpSetTensorDescriptorOnchipDataType( + c_desc, fft_plan->execution_dtype); - DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(c_desc, cnnl_output_desc); + DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, + cnnl_handle); // convert to cnnl_handle - CALL_CNNL(cnnlTransform_v2(cnnl_handle, CNNL_POINTER_MODE_HOST, &alpha, - cnnl_output_desc, fft_plan->mlu_addrs.output, - &beta, cnnl_output_desc, - fft_plan->mlu_addrs.output)); - DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_output_desc); - DESTROY_CNNL_HANDLE(cnnl_handle); - } - INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(c_desc, cnnl_output_desc); + CALL_CNNL(cnnlTransform_v2(cnnl_handle, CNNL_POINTER_MODE_HOST, &alpha, + cnnl_output_desc, fft_plan->mlu_addrs.output, + &beta, cnnl_output_desc, + fft_plan->mlu_addrs.output)); + DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_output_desc); + DESTROY_CNNL_HANDLE(cnnl_handle); + } + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + + if (fft_plan->fft_strategy == CNFFT_FUNC_TWO_LEVEL_STOCKHAM) { status = makeIRFFT2dContiguousOutput(handle, fft_plan, output, fft_plan->mlu_addrs.output); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); - - } else if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) { - status = computeFFT2dMatMulColumnC2R(handle, fft_plan, scale_factor); - INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); - - status = computeFFT2dMatMulRowC2R(handle, fft_plan, scale_factor); - INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); } return status; } diff --git a/kernels/fft/rfft/rfft_host.cpp b/kernels/fft/rfft/rfft_host.cpp index d0755e8be..d0649669f 100644 --- a/kernels/fft/rfft/rfft_host.cpp +++ b/kernels/fft/rfft/rfft_host.cpp @@ -453,7 +453,9 @@ static void configureRFFT2dWorkspaceAddrs(mluOpHandle_t handle, fft_plan->mlu_addrs.buffer_buf = (uint8_t *)workspace + offset; offset += batch * out_c_dtype_size * _n0 * _n1 * 2; - if (fft_plan->is_input_contiguous) { + if ((fft_plan->is_input_contiguous && + fft_plan->inembed[0] <= fft_plan->n[0] && + fft_plan->inembed[1] <= fft_plan->n[1])) { fft_plan->mlu_addrs.input = input; } else { fft_plan->mlu_addrs.input = (uint8_t *)workspace + offset; @@ -1138,8 +1140,7 @@ static mluOpStatus_t makeRFFT2dContiguousInput(mluOpHandle_t handle, auto status = MLUOP_STATUS_SUCCESS; if ((!fft_plan->is_input_contiguous || (fft_plan->inembed[0] > fft_plan->n[0] || - fft_plan->inembed[1] > fft_plan->n[1]) && - !fft_plan->prime) && + fft_plan->inembed[1] > fft_plan->n[1])) && fft_plan->fft_strategy != CNFFT_FUNC_MANY_DIST1_2D) { VLOG(5) << "launch mluOpContiguous for rfft2d input"; mluOpTensorDescriptor_t input_desc; @@ -1153,9 +1154,10 @@ static mluOpStatus_t makeRFFT2dContiguousInput(mluOpHandle_t handle, : fft_plan->n[0], fft_plan->n[1] > fft_plan->inembed[1] ? fft_plan->inembed[1] : fft_plan->n[1]}; - int64_t strides[in_dim_num] = {fft_plan->idist, - (fft_plan->istride * fft_plan->inembed[1]), - fft_plan->istride}; + int64_t strides[3]; // in_dim_num + for (int i = 0; i < in_dim_num; i++) { + strides[i] = fft_plan->in_stride[i]; + } status = mluOpSetTensorDescriptorEx_v2(input_desc, MLUOP_LAYOUT_ARRAY, fft_plan->input_dtype, in_dim_num, dims, strides); @@ -1191,9 +1193,10 @@ static mluOpStatus_t makeRFFT2dContiguousOutput(mluOpHandle_t handle, const int out_dim_num = 3; int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->n[0], fft_plan->n[1] / 2 + 1}; - int64_t strides[out_dim_num] = {fft_plan->odist, - fft_plan->ostride * fft_plan->onembed[1], - fft_plan->ostride}; + int64_t strides[3]; // out_dim_num + for (int i = 0; i < out_dim_num; i++) { + strides[i] = fft_plan->out_stride[i]; + } status = mluOpSetTensorDescriptor_v2(copy_src_desc, MLUOP_LAYOUT_ARRAY, out_c_dtype, out_dim_num, dims); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); @@ -1561,70 +1564,104 @@ mluOpStatus_t execRFFT2d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan, status = makeRFFT2dContiguousInput(handle, fft_plan, input); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); - if (fft_plan->n[0] > fft_plan->inembed[0] || - fft_plan->n[1] > fft_plan->inembed[1]) { - status = padRFFT2dContiguousInput(handle, fft_plan); + if (fft_plan->n[0] == 1 && fft_plan->n[1] == 1) { + mluOpTensorDescriptor_t input_desc, padded_output_desc; + status = mluOpCreateTensorDescriptor(&input_desc); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); - - fft_plan->mlu_addrs.input = fft_plan->mlu_addrs.input_pad_addr; - } - - for (int batch_id = 0; batch_id < fft_plan->batch; batch_id++) { - status = kernelRFFT2dButterflyRow(k_dim, k_type, handle->queue, fft_plan, - RFFT); - + status = mluOpCreateTensorDescriptor(&padded_output_desc); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); - status = kernelRFFT2dButterflyColumn(k_dim, k_type, handle->queue, - fft_plan, FFT_IFFT); + const int in_dim_num = 2; + int64_t dims[in_dim_num] = {fft_plan->batch, + fft_plan->n[0] * fft_plan->n[1]}; + status = mluOpSetTensorDescriptor_v2(input_desc, MLUOP_LAYOUT_ARRAY, + MLUOP_DTYPE_FLOAT, in_dim_num, dims); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); - fft_plan->mlu_addrs.input = - (void *)((uint64_t)(fft_plan->mlu_addrs.input) + idist); - fft_plan->mlu_addrs.output = - (void *)((uint64_t)(fft_plan->mlu_addrs.output) + odist); - } - fft_plan->mlu_addrs.input = (void *)((uint64_t)(fft_plan->mlu_addrs.input) - - fft_plan->batch * idist); - fft_plan->mlu_addrs.output = - (void *)((uint64_t)(fft_plan->mlu_addrs.output) - - fft_plan->batch * odist); - - if (scale_factor != 1.0) { - const float alpha[2] = {scale_factor, 0.0}; - const float beta[2] = {0.0, 0.0}; - mluOpTensorDescriptor_t c_desc = nullptr; - status = mluOpCreateTensorDescriptor(&c_desc); - const int out_dim_num = 3; - int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->n[0], - fft_plan->n[1] / 2 + 1}; - status = mluOpSetTensorDescriptor_v2(c_desc, MLUOP_LAYOUT_ARRAY, - fft_plan->output_dtype, 3, dims); - status = mluOpSetTensorDescriptorOnchipDataType( - c_desc, fft_plan->execution_dtype); + int64_t padded_dims[in_dim_num] = {fft_plan->batch, + fft_plan->n[0] * fft_plan->n[1] * 2}; + status = mluOpSetTensorDescriptor_v2( + padded_output_desc, MLUOP_LAYOUT_ARRAY, MLUOP_DTYPE_FLOAT, in_dim_num, + padded_dims); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + const int pad_dim_num = 4; + int paddings[pad_dim_num] = {0, 0, 0, 1}; + uint64_t padding_value = 0x00000000; DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, cnnl_handle); // convert to cnnl_handle - DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(c_desc, cnnl_output_desc); + DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(input_desc, cnnl_input_desc); + DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(padded_output_desc, + cnnl_padded_output_desc); + CALL_CNNL(cnnlPad(cnnl_handle, cnnl_input_desc, fft_plan->mlu_addrs.input, + paddings, &padding_value, cnnl_padded_output_desc, + fft_plan->mlu_addrs.output)); + + DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_input_desc); + DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_padded_output_desc); - CALL_CNNL(cnnlTransform_v2(cnnl_handle, CNNL_POINTER_MODE_HOST, &alpha, - cnnl_output_desc, fft_plan->mlu_addrs.output, - &beta, cnnl_output_desc, - fft_plan->mlu_addrs.output)); - DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_output_desc); DESTROY_CNNL_HANDLE(cnnl_handle); - } - INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + } else { + for (int batch_id = 0; batch_id < fft_plan->batch; batch_id++) { + status = kernelRFFT2dButterflyRow(k_dim, k_type, handle->queue, + fft_plan, RFFT); - status = makeRFFT2dContiguousOutput(handle, fft_plan, output); - INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + + status = kernelRFFT2dButterflyColumn(k_dim, k_type, handle->queue, + fft_plan, FFT_IFFT); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + fft_plan->mlu_addrs.input = + (void *)((uint64_t)(fft_plan->mlu_addrs.input) + idist); + fft_plan->mlu_addrs.output = + (void *)((uint64_t)(fft_plan->mlu_addrs.output) + odist); + } + fft_plan->mlu_addrs.input = + (void *)((uint64_t)(fft_plan->mlu_addrs.input) - + fft_plan->batch * idist); + fft_plan->mlu_addrs.output = + (void *)((uint64_t)(fft_plan->mlu_addrs.output) - + fft_plan->batch * odist); + } } else if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) { status = computeFFT2dMatMulRowR2C(handle, fft_plan, scale_factor); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); status = computeFFT2dMatMulColumnR2C(handle, fft_plan, scale_factor); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); } + + if (scale_factor != 1.0) { + const float alpha[2] = {scale_factor, 0.0}; + const float beta[2] = {0.0, 0.0}; + mluOpTensorDescriptor_t c_desc = nullptr; + status = mluOpCreateTensorDescriptor(&c_desc); + const int out_dim_num = 3; + int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->n[0], + fft_plan->n[1] / 2 + 1}; + status = mluOpSetTensorDescriptor_v2(c_desc, MLUOP_LAYOUT_ARRAY, + fft_plan->output_dtype, 3, dims); + status = mluOpSetTensorDescriptorOnchipDataType( + c_desc, fft_plan->execution_dtype); + + DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, + cnnl_handle); // convert to cnnl_handle + + DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(c_desc, cnnl_output_desc); + + CALL_CNNL(cnnlTransform_v2(cnnl_handle, CNNL_POINTER_MODE_HOST, &alpha, + cnnl_output_desc, fft_plan->mlu_addrs.output, + &beta, cnnl_output_desc, + fft_plan->mlu_addrs.output)); + DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_output_desc); + DESTROY_CNNL_HANDLE(cnnl_handle); + } + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + + if (fft_plan->fft_strategy == CNFFT_FUNC_TWO_LEVEL_STOCKHAM) { + status = makeRFFT2dContiguousOutput(handle, fft_plan, output); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + } return status; }