From ecdecc8f945247e3ffc6772375a6a76d4b70bdab Mon Sep 17 00:00:00 2001 From: niyuming Date: Fri, 29 Nov 2024 10:23:04 +0800 Subject: [PATCH] [Fix](mluOpExecFFT): fix core dump, scale factor and one point compute error --- kernels/fft/c2c_fft/c2c_fft_host.cpp | 151 ++++--- kernels/fft/common/fft_basic_ops.cpp | 16 +- kernels/fft/common/fft_common_kernels.mlu | 3 +- kernels/fft/fft.cpp | 108 +++-- kernels/fft/fft.h | 30 +- .../fft_optm_device/fft_c2c_stockham_nram.h | 373 ------------------ .../fft_two-level_network_c2c_device.mlu | 48 +-- .../fft_two-level_network_c2r_device.mlu | 24 +- .../fft_two-level_network_r2c_device.mlu | 26 +- kernels/fft/irfft/irfft_host.cpp | 188 +++++---- kernels/fft/rfft/rfft_host.cpp | 205 +++++----- .../tensor_stride_process_host.cpp | 3 +- 12 files changed, 438 insertions(+), 737 deletions(-) diff --git a/kernels/fft/c2c_fft/c2c_fft_host.cpp b/kernels/fft/c2c_fft/c2c_fft_host.cpp index 29c53d61f..5b8bcf98f 100644 --- a/kernels/fft/c2c_fft/c2c_fft_host.cpp +++ b/kernels/fft/c2c_fft/c2c_fft_host.cpp @@ -648,13 +648,13 @@ mluOpStatus_t setFFT2dReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, } } - int _n0 = fft_plan->n[0]; - int _n1 = fft_plan->n[1]; + int n0_ori = fft_plan->n[0]; + int n1_ori = fft_plan->n[1]; if (fft_plan->fft_strategy == CNFFT_FUNC_TWO_LEVEL_STOCKHAM) { size_t factors_size = FFT_MAXFACTORS * sizeof(int); // bytes - size_t twiddles_size = CPX_TYPE_SIZE * _n1; - size_t twiddles_size_2d = CPX_TYPE_SIZE * _n0; + size_t twiddles_size = CPX_TYPE_SIZE * n1_ori; + size_t twiddles_size_2d = CPX_TYPE_SIZE * n0_ori; size_t reservespace_offset = 0; @@ -794,19 +794,19 @@ mluOpStatus_t setFFT2dReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, size_t reservespace_offset = 0; fft_plan->mlu_addrs.dft_matrix = (uint8_t *)fft_plan->reservespace_addr + reservespace_offset; - reservespace_offset += CPX_TYPE_SIZE * (_n1 / 2 + 1) * _n1; + reservespace_offset += CPX_TYPE_SIZE * (n1_ori / 2 + 1) * n1_ori; fft_plan->mlu_addrs.dft_matrix_2d = (uint8_t *)fft_plan->reservespace_addr + reservespace_offset; - reservespace_offset += CPX_TYPE_SIZE * _n0 * _n0; + reservespace_offset += CPX_TYPE_SIZE * n0_ori * n0_ori; CNRT_CHECK(cnrtMemcpyAsync(fft_plan->mlu_addrs.dft_matrix, fft_plan->dft_matrix, - CPX_TYPE_SIZE * (_n1 / 2 + 1) * _n1, + CPX_TYPE_SIZE * (n1_ori / 2 + 1) * n1_ori, handle->queue, cnrtMemcpyHostToDev)); CNRT_CHECK(cnrtMemcpyAsync( fft_plan->mlu_addrs.dft_matrix_2d, fft_plan->dft_matrix_2d, - CPX_TYPE_SIZE * _n0 * _n0, handle->queue, cnrtMemcpyHostToDev)); + CPX_TYPE_SIZE * n0_ori * n0_ori, handle->queue, cnrtMemcpyHostToDev)); } break; case CNFFT_COMPLEX_HALF2COMPLEX_HALF: case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT: { @@ -814,34 +814,34 @@ mluOpStatus_t setFFT2dReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, size_t reservespace_offset = 0; fft_plan->mlu_addrs.dft_matrix = (uint8_t *)fft_plan->reservespace_addr + reservespace_offset; - reservespace_offset += CPX_TYPE_SIZE * _n1 * _n1; + reservespace_offset += CPX_TYPE_SIZE * n1_ori * n1_ori; fft_plan->mlu_addrs.dft_matrix_2d = (uint8_t *)fft_plan->reservespace_addr + reservespace_offset; - reservespace_offset += CPX_TYPE_SIZE * _n0 * _n0; + reservespace_offset += CPX_TYPE_SIZE * n0_ori * n0_ori; fft_plan->mlu_addrs.idft_matrix = (uint8_t *)fft_plan->reservespace_addr + reservespace_offset; - reservespace_offset += CPX_TYPE_SIZE * _n1 * _n1; + reservespace_offset += CPX_TYPE_SIZE * n1_ori * n1_ori; fft_plan->mlu_addrs.idft_matrix_2d = (uint8_t *)fft_plan->reservespace_addr + reservespace_offset; - reservespace_offset += CPX_TYPE_SIZE * _n0 * _n0; + reservespace_offset += CPX_TYPE_SIZE * n0_ori * n0_ori; CNRT_CHECK(cnrtMemcpyAsync( fft_plan->mlu_addrs.dft_matrix, fft_plan->dft_matrix, - CPX_TYPE_SIZE * _n1 * _n1, handle->queue, cnrtMemcpyHostToDev)); + CPX_TYPE_SIZE * n1_ori * n1_ori, handle->queue, cnrtMemcpyHostToDev)); CNRT_CHECK(cnrtMemcpyAsync( fft_plan->mlu_addrs.dft_matrix_2d, fft_plan->dft_matrix_2d, - CPX_TYPE_SIZE * _n0 * _n0, handle->queue, cnrtMemcpyHostToDev)); + CPX_TYPE_SIZE * n0_ori * n0_ori, handle->queue, cnrtMemcpyHostToDev)); CNRT_CHECK(cnrtMemcpyAsync( fft_plan->mlu_addrs.idft_matrix, fft_plan->idft_matrix, - CPX_TYPE_SIZE * _n1 * _n1, handle->queue, cnrtMemcpyHostToDev)); + CPX_TYPE_SIZE * n1_ori * n1_ori, handle->queue, cnrtMemcpyHostToDev)); CNRT_CHECK(cnrtMemcpyAsync( fft_plan->mlu_addrs.idft_matrix_2d, fft_plan->idft_matrix_2d, - CPX_TYPE_SIZE * _n0 * _n0, handle->queue, cnrtMemcpyHostToDev)); + CPX_TYPE_SIZE * n0_ori * n0_ori, handle->queue, cnrtMemcpyHostToDev)); }; break; case CNFFT_COMPLEX_HALF2HALF: case CNFFT_COMPLEX_FLOAT2FLOAT: { @@ -849,19 +849,19 @@ mluOpStatus_t setFFT2dReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, size_t reservespace_offset = 0; fft_plan->mlu_addrs.dft_matrix = (uint8_t *)fft_plan->reservespace_addr + reservespace_offset; - reservespace_offset += CPX_TYPE_SIZE * (_n1 / 2 + 1) * _n1; + reservespace_offset += CPX_TYPE_SIZE * (n1_ori / 2 + 1) * n1_ori; fft_plan->mlu_addrs.dft_matrix_2d = (uint8_t *)fft_plan->reservespace_addr + reservespace_offset; - reservespace_offset += CPX_TYPE_SIZE * _n0 * _n0; + reservespace_offset += CPX_TYPE_SIZE * n0_ori * n0_ori; CNRT_CHECK(cnrtMemcpyAsync(fft_plan->mlu_addrs.dft_matrix, fft_plan->dft_matrix, - CPX_TYPE_SIZE * (_n1 / 2 + 1) * _n1, + CPX_TYPE_SIZE * (n1_ori / 2 + 1) * n1_ori, handle->queue, cnrtMemcpyHostToDev)); CNRT_CHECK(cnrtMemcpyAsync( fft_plan->mlu_addrs.dft_matrix_2d, fft_plan->dft_matrix_2d, - CPX_TYPE_SIZE * _n0 * _n0, handle->queue, cnrtMemcpyHostToDev)); + CPX_TYPE_SIZE * n0_ori * n0_ori, handle->queue, cnrtMemcpyHostToDev)); }; break; default: { LOG(ERROR) << make_plan_api << ": invalid 2d fft type."; @@ -1060,13 +1060,13 @@ static void configureFFT2dWorkspaceAddrs(mluOpHandle_t handle, size_t out_c_dtype_size = mluOpDataTypeBytes(out_c_dtype); int batch = fft_plan->batch; - int _n0 = fft_plan->n[0]; - int _n1 = fft_plan->n[1]; + int n0_ori = fft_plan->n[0]; + int n1_ori = fft_plan->n[1]; size_t offset = 0; if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) { // rr ri ir ii - size_t buffer_size = batch * in_c_dtype_size * _n0 * _n1 * 2; + size_t buffer_size = batch * in_c_dtype_size * n0_ori * n1_ori * 2; fft_plan->mlu_addrs.input = input; fft_plan->mlu_addrs.output = output; fft_plan->mlu_addrs.buffer_in = (uint8_t *)workspace + offset; @@ -1077,27 +1077,29 @@ static void configureFFT2dWorkspaceAddrs(mluOpHandle_t handle, if (fft_plan->fft_strategy == CNFFT_FUNC_TWO_LEVEL_STOCKHAM) { fft_plan->mlu_addrs.buffer_buf = (uint8_t *)workspace + offset; - offset += batch * in_c_dtype_size * _n0 * _n1 * 2; + offset += batch * in_c_dtype_size * n0_ori * n1_ori * 2; - if (fft_plan->is_input_contiguous) { + if ((fft_plan->is_input_contiguous && + fft_plan->inembed[0] <= fft_plan->n[0] && + fft_plan->inembed[1] <= fft_plan->n[1])) { fft_plan->mlu_addrs.input = input; } else { fft_plan->mlu_addrs.input = (uint8_t *)workspace + offset; - offset += batch * in_c_dtype_size * _n0 * _n1; + offset += batch * in_c_dtype_size * n0_ori * n1_ori; } if (fft_plan->is_output_contiguous) { fft_plan->mlu_addrs.output = output; } else { fft_plan->mlu_addrs.output = (uint8_t *)workspace + offset; - offset += batch * in_c_dtype_size * _n0 * _n1; + offset += batch * in_c_dtype_size * n0_ori * n1_ori; } } if (fft_plan->n[0] > fft_plan->inembed[0] || fft_plan->n[1] > fft_plan->inembed[1]) { fft_plan->mlu_addrs.input_pad_addr = (uint8_t *)workspace + - offset; // batch * in_c_dtype_size * _n0 * _n1 * 2; // buffer_size; + offset; // batch * in_c_dtype_size * n0_ori * n1_ori * 2; // buffer_size; } } // input : in input @@ -1180,9 +1182,11 @@ static mluOpStatus_t makeFFT2dContiguousInput(mluOpHandle_t handle, int64_t dims[in_dim_num] = {fft_plan->batch, std::min(fft_plan->n[0], fft_plan->inembed[0]), std::min(fft_plan->n[1], fft_plan->inembed[1])}; - int64_t strides[in_dim_num] = {fft_plan->idist, - (fft_plan->istride * fft_plan->inembed[1]), - fft_plan->istride}; + + int64_t strides[in_dim_num]; // in_dim_num + for (int i = 0; i < in_dim_num; i++) { + strides[i] = fft_plan->in_stride[i]; + } status = mluOpSetTensorDescriptorEx_v2(input_desc, MLUOP_LAYOUT_ARRAY, fft_plan->input_dtype, in_dim_num, dims, strides); @@ -1779,17 +1783,8 @@ static mluOpStatus_t makeFFT1dContiguousOutput(mluOpHandle_t handle, cnnl_copy_src_desc); DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(copy_dst_desc, cnnl_copy_dst_desc); - size_t workspace_size = 0; - CALL_CNNL(cnnlGetCopyWorkspaceSize(cnnl_handle, cnnl_copy_src_desc, - cnnl_copy_dst_desc, &workspace_size)); - - void *workspace = nullptr; - if (workspace_size > 0) { - CNRT_CHECK(cnrtMalloc((void **)&workspace, workspace_size)); - } CALL_CNNL(cnnlCopy_v2(cnnl_handle, cnnl_copy_src_desc, copy_src_addr, - cnnl_copy_dst_desc, output, workspace, - workspace_size)); + cnnl_copy_dst_desc, output, NULL, 0)); DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_copy_src_desc); DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_copy_dst_desc); DESTROY_CNNL_HANDLE(cnnl_handle); @@ -1818,9 +1813,10 @@ static mluOpStatus_t makeFFT2dContiguousOutput(mluOpHandle_t handle, const int out_dim_num = 3; int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->n[0], fft_plan->n[1]}; - int64_t strides[out_dim_num] = {fft_plan->odist, - fft_plan->ostride * fft_plan->onembed[1], - fft_plan->ostride}; + int64_t strides[out_dim_num]; // out_dim_num + for (int i = 0; i < out_dim_num; i++) { + strides[i] = fft_plan->out_stride[i]; + } status = mluOpSetTensorDescriptor_v2(copy_src_desc, MLUOP_LAYOUT_ARRAY, out_c_dtype, out_dim_num, dims); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); @@ -1838,18 +1834,8 @@ static mluOpStatus_t makeFFT2dContiguousOutput(mluOpHandle_t handle, cnnl_copy_src_desc); DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(copy_dst_desc, cnnl_copy_dst_desc); - - size_t workspace_size = 0; - CALL_CNNL(cnnlGetCopyWorkspaceSize(cnnl_handle, cnnl_copy_src_desc, - cnnl_copy_dst_desc, &workspace_size)); - - void *workspace = nullptr; - if (workspace_size > 0) { - CNRT_CHECK(cnrtMalloc((void **)&workspace, workspace_size)); - } CALL_CNNL(cnnlCopy_v2(cnnl_handle, cnnl_copy_src_desc, copy_src_addr, - cnnl_copy_dst_desc, output, workspace, - workspace_size)); + cnnl_copy_dst_desc, output, NULL, 0)); DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_copy_src_desc); DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_copy_dst_desc); @@ -2003,12 +1989,15 @@ mluOpStatus_t execFFT1d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan, const float beta[2] = {0.0, 0.0}; mluOpTensorDescriptor_t c_desc = nullptr; status = mluOpCreateTensorDescriptor(&c_desc); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); const int out_dim_num = 2; int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->n[0]}; status = mluOpSetTensorDescriptor_v2(c_desc, MLUOP_LAYOUT_ARRAY, - fft_plan->output_dtype, 2, dims); + fft_plan->output_dtype, out_dim_num, dims); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); status = mluOpSetTensorDescriptorOnchipDataType( c_desc, fft_plan->execution_dtype); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, cnnl_handle); // convert to cnnl_handle @@ -2019,6 +2008,8 @@ mluOpStatus_t execFFT1d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan, cnnl_output_desc, fft_plan->mlu_addrs.output, &beta, cnnl_output_desc, fft_plan->mlu_addrs.output)); + status = mluOpDestroyTensorDescriptor(c_desc); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_output_desc); DESTROY_CNNL_HANDLE(cnnl_handle); } @@ -2053,7 +2044,34 @@ mluOpStatus_t execFFT2d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan, fft_plan->mlu_addrs.input = fft_plan->mlu_addrs.input_pad_addr; } - status = execFFTc2c2d(handle, fft_plan, scale_factor, direction); + if (fft_plan->n[0] == 1 && fft_plan->n[1] == 1) { + mluOpTensorDescriptor_t c_desc = nullptr; + status = mluOpCreateTensorDescriptor(&c_desc); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + const int out_dim_num = 3; + int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->n[0], + fft_plan->n[1]}; + status = mluOpSetTensorDescriptor_v2(c_desc, MLUOP_LAYOUT_ARRAY, + fft_plan->output_dtype, out_dim_num, dims); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + status = mluOpSetTensorDescriptorOnchipDataType(c_desc, + fft_plan->execution_dtype); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + + DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, + cnnl_handle); // convert to cnnl_handle + + DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(c_desc, cnnl_output_desc); + CALL_CNNL(cnnlCopy_v2(cnnl_handle, cnnl_output_desc, + fft_plan->mlu_addrs.input, cnnl_output_desc, + fft_plan->mlu_addrs.output, NULL, 0)); + status = mluOpDestroyTensorDescriptor(c_desc); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_output_desc); + DESTROY_CNNL_HANDLE(cnnl_handle); + } else { + status = execFFTc2c2d(handle, fft_plan, scale_factor, direction); + } INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); @@ -2062,13 +2080,16 @@ mluOpStatus_t execFFT2d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan, const float beta[2] = {0.0, 0.0}; mluOpTensorDescriptor_t c_desc = nullptr; status = mluOpCreateTensorDescriptor(&c_desc); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); const int out_dim_num = 3; int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->n[0], fft_plan->n[1]}; status = mluOpSetTensorDescriptor_v2(c_desc, MLUOP_LAYOUT_ARRAY, - fft_plan->output_dtype, 3, dims); + fft_plan->output_dtype, out_dim_num, dims); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); status = mluOpSetTensorDescriptorOnchipDataType(c_desc, fft_plan->execution_dtype); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, cnnl_handle); // convert to cnnl_handle @@ -2079,6 +2100,8 @@ mluOpStatus_t execFFT2d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan, cnnl_output_desc, fft_plan->mlu_addrs.output, &beta, cnnl_output_desc, fft_plan->mlu_addrs.output)); + status = mluOpDestroyTensorDescriptor(c_desc); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_output_desc); DESTROY_CNNL_HANDLE(cnnl_handle); } @@ -2296,11 +2319,11 @@ mluOpStatus_t computeFFT2dMatMulRow(mluOpHandle_t handle, int requested_algo_count = 1, return_algo_count = 0; float *workspace; size_t workspace_size; - cnnlGetBatchMatMulAlgoHeuristic( + cnnlGetBatchMatMulExAlgoHeuristic( cnnl_handle, bmm_bcast_desc, cnnl_a_desc, cnnl_b_desc, cnnl_c_desc, NULL, requested_algo_count, &heuristic_result, &return_algo_count); - cnnlGetBatchMatMulHeuristicResult(heuristic_result, algo, &workspace_size); + cnnlGetBatchMatMulExHeuristicResult(heuristic_result, algo, &workspace_size); if (workspace_size > 0) { CNRT_CHECK(cnrtMalloc((void **)&workspace, workspace_size)); @@ -2308,10 +2331,10 @@ mluOpStatus_t computeFFT2dMatMulRow(mluOpHandle_t handle, CNRT_CHECK(cnrtMalloc((void **)&workspace, m * n * sizeof(float))); } - CALL_CNNL(cnnlBatchMatMulBCast_v2(cnnl_handle, bmm_bcast_desc, algo, &alpha, - cnnl_a_desc, dft_matrix_addr, cnnl_b_desc, - in_addr, &beta, cnnl_c_desc, out_addr, - (void *)workspace, workspace_size)); + CALL_CNNL(cnnlBatchMatMulEx(cnnl_handle, bmm_bcast_desc, algo, &alpha, + cnnl_a_desc, dft_matrix_addr, cnnl_b_desc, + in_addr, &beta, cnnl_c_desc, out_addr, + (void *)workspace, workspace_size)); // destroy cnnl descriptor DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_a_desc); DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_b_desc); diff --git a/kernels/fft/common/fft_basic_ops.cpp b/kernels/fft/common/fft_basic_ops.cpp index fd56557d3..a39a94e3e 100644 --- a/kernels/fft/common/fft_basic_ops.cpp +++ b/kernels/fft/common/fft_basic_ops.cpp @@ -488,10 +488,10 @@ mluOpStatus_t fftGetBatchMatMulBcastWorkspaceSize( cnnlMatMulHeuristicResult_t heuristic_result; CALL_CNNL(cnnlCreateMatMulHeuristicResult(&heuristic_result)); int requested_algo_count = 1, return_algo_count = 0; - cnnlGetBatchMatMulAlgoHeuristic( + cnnlGetBatchMatMulExAlgoHeuristic( cnnl_handle, bmm_bcast_desc, cnnl_a_desc, cnnl_b_desc, cnnl_c_desc, NULL, requested_algo_count, &heuristic_result, &return_algo_count); - cnnlGetBatchMatMulHeuristicResult(heuristic_result, algo, &workspace_size); + cnnlGetBatchMatMulExHeuristicResult(heuristic_result, algo, &workspace_size); // destroy descriptor // destroy cnnl descriptor DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_a_desc); @@ -585,20 +585,20 @@ mluOpStatus_t fftBatchMatMulBcast( alpha = 1.0; beta = 0.0; int requested_algo_count = 1, return_algo_count = 0; - cnnlGetBatchMatMulAlgoHeuristic( + cnnlGetBatchMatMulExAlgoHeuristic( cnnl_handle, bmm_bcast_desc, cnnl_a_desc, cnnl_b_desc, cnnl_c_desc, NULL, requested_algo_count, &heuristic_result, &return_algo_count); - cnnlGetBatchMatMulHeuristicResult(heuristic_result, algo, &workspace_size); + cnnlGetBatchMatMulExHeuristicResult(heuristic_result, algo, &workspace_size); if (workspace_size > 0) { CNRT_CHECK(cnrtMalloc((void **)&workspace, workspace_size)); } else { CNRT_CHECK(cnrtMalloc((void **)&workspace, m * n * sizeof(float))); } - CALL_CNNL(cnnlBatchMatMulBCast_v2(cnnl_handle, bmm_bcast_desc, algo, &alpha, - cnnl_a_desc, a_ptr, cnnl_b_desc, b_ptr, - &beta, cnnl_c_desc, c_ptr, - (void *)workspace, workspace_size)); + CALL_CNNL(cnnlBatchMatMulEx(cnnl_handle, bmm_bcast_desc, algo, &alpha, + cnnl_a_desc, a_ptr, cnnl_b_desc, b_ptr, + &beta, cnnl_c_desc, c_ptr, + (void *)workspace, workspace_size)); // destroy descriptor // destroy cnnl descriptor DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_a_desc); diff --git a/kernels/fft/common/fft_common_kernels.mlu b/kernels/fft/common/fft_common_kernels.mlu index 8cca3a697..d9e48157d 100644 --- a/kernels/fft/common/fft_common_kernels.mlu +++ b/kernels/fft/common/fft_common_kernels.mlu @@ -109,7 +109,8 @@ __mlu_func__ void selectVec(float *src_addr, int32_t *offset_int_addr, __asm__ volatile( "gather.clean.nram.nram.nram.b32.u32 " "[%[dst]], [%[src]], [%[offset]], %[data_num];\n\t" ::[dst] "r"(dst_addr), - [src] "r"(src_addr), [offset] "r"(offset_int_addr), [data_num] "r"(deal_size)); + [ src ] "r"(src_addr), [ offset ] "r"(offset_int_addr), + [ data_num ] "r"(deal_size)); #else for (auto i = 0; i < deal_size; i++) { dst_addr[i] = src_addr[offset_int_addr[i]]; diff --git a/kernels/fft/fft.cpp b/kernels/fft/fft.cpp index 4d4ab9ef1..77da76794 100644 --- a/kernels/fft/fft.cpp +++ b/kernels/fft/fft.cpp @@ -1657,7 +1657,7 @@ mluOpAllocateC2C1D(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, fft_plan->is_batch_contiguous) ? 0 : buffer_size; - if (fft_plan->n[0] > fft_plan->inembed[0]) { + if (fft_plan->n[0] != fft_plan->inembed[0]) { workspace_size += buffer_size; } size_t twiddles_size = in_c_dtype_size * nfft * 2; @@ -1701,7 +1701,7 @@ mluOpAllocateR2C1D(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, reservespace_size = sizeof(int) * (FFT_MAXFACTORS) /* factors */ + twiddles_size * 2 + DFT_TABLE_SIZE * 2; /* twiddles */ - if (fft_plan->n[0] > fft_plan->inembed[0]) { + if (fft_plan->n[0] != fft_plan->inembed[0]) { workspace_size += buffer_size; // input_pad_addr } fft_plan->workspace_size = workspace_size; @@ -1721,17 +1721,17 @@ mluOpStatus_t MLUOP_WIN_API mluOpAllocateC2C2D( size_t in_c_dtype_size = mluOpDataTypeBytes(in_c_dtype); int batch = fft_plan->batch; - const int _n0 = fft_plan->n[0]; - const int _n1 = fft_plan->n[1]; + const int n0_ori = fft_plan->n[0]; + const int n1_ori = fft_plan->n[1]; - size_t buffer_size = batch * in_c_dtype_size * _n0 * _n1; + size_t buffer_size = batch * in_c_dtype_size * n0_ori * n1_ori; - size_t twiddles_size = in_c_dtype_size * _n0; - size_t twiddles_size_2d = in_c_dtype_size * _n1; + size_t twiddles_size = in_c_dtype_size * n0_ori; + size_t twiddles_size_2d = in_c_dtype_size * n1_ori; if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) { reservespace_size = - (in_c_dtype_size * _n0 * _n0 + in_c_dtype_size * _n1 * _n1) * + (in_c_dtype_size * n0_ori * n0_ori + in_c_dtype_size * n1_ori * n1_ori) * 2; /* DFT matrix */ workspace_size = buffer_size * 6; } else if (fft_plan->fft_strategy == CNFFT_FUNC_TWO_LEVEL_STOCKHAM) { @@ -1740,13 +1740,17 @@ mluOpStatus_t MLUOP_WIN_API mluOpAllocateC2C2D( DFT_TABLE_SIZE * 2 + twiddles_size_2d * 2 + DFT_TABLE_SIZE * 2; /* twiddles */ workspace_size = buffer_size * 2; - workspace_size += (fft_plan->is_input_contiguous) ? 0 : buffer_size; + workspace_size += (fft_plan->is_input_contiguous && + fft_plan->inembed[0] <= fft_plan->n[0] && + fft_plan->inembed[1] <= fft_plan->n[1]) + ? 0 + : buffer_size; workspace_size += (fft_plan->is_output_contiguous) ? 0 : buffer_size; } fft_plan->workspace_size = workspace_size; - if (fft_plan->n[0] > fft_plan->inembed[0] || - fft_plan->n[1] > fft_plan->inembed[1]) { + if (fft_plan->n[0] != fft_plan->inembed[0] || + fft_plan->n[1] != fft_plan->inembed[1]) { fft_plan->workspace_size = workspace_size + buffer_size; // input_pad_addr } fft_plan->reservespace_size = reservespace_size; @@ -1783,7 +1787,7 @@ mluOpAllocateC2R1D(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, reservespace_size = sizeof(int) * (FFT_MAXFACTORS) /* factors */ + twiddles_size * 2 + DFT_TABLE_SIZE * 2; /* twiddles */ - if (fft_plan->n[0] > fft_plan->inembed[0]) { + if (fft_plan->n[0] != fft_plan->inembed[0]) { workspace_size += buffer_size; // input_pad_addr } fft_plan->workspace_size = workspace_size; @@ -1791,11 +1795,57 @@ mluOpAllocateC2R1D(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, return MLUOP_STATUS_SUCCESS; } +mluOpStatus_t MLUOP_WIN_API mluOpAllocateIRFFT2D( + mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, + mluOpTensorDescriptor_t input_desc, mluOpTensorDescriptor_t output_desc, + const int n0_ori, const int n1_ori) { + const std::string make_plan_api = "[mluOpAllocateIRFFT2D]"; + size_t workspace_size = 0, reservespace_size = 0; + mluOpDataType_t out_c_dtype = fft_plan->output_dtype; + mluOpDataType_t in_c_dtype = fft_plan->input_dtype; + size_t complex_dtype_size = + (mluOpDataTypeBytes(out_c_dtype) > mluOpDataTypeBytes(in_c_dtype)) + ? mluOpDataTypeBytes(out_c_dtype) + : mluOpDataTypeBytes(in_c_dtype); + + int batch = fft_plan->batch; + size_t buffer_size = batch * complex_dtype_size * n0_ori * n1_ori; + + size_t twiddles_size = complex_dtype_size * n0_ori; + size_t twiddles_size_2d = complex_dtype_size * n1_ori; + + if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) { + reservespace_size = complex_dtype_size * n0_ori * n0_ori * 2 + + complex_dtype_size * n1_ori * n1_ori * 2; /* DFT matrix */ + workspace_size = complex_dtype_size * n1_ori * n0_ori * batch * 6; + } else if (fft_plan->fft_strategy == CNFFT_FUNC_TWO_LEVEL_STOCKHAM) { + reservespace_size = sizeof(int) * (FFT_MAXFACTORS) /* factors */ + + sizeof(int) * (FFT_MAXFACTORS) + twiddles_size * 2 + + DFT_TABLE_SIZE * 2 + twiddles_size_2d * 2 + + DFT_TABLE_SIZE * 2; /* twiddles */ + workspace_size = buffer_size * 2; + workspace_size += (fft_plan->is_input_contiguous && + fft_plan->inembed[0] <= fft_plan->n[0] && + fft_plan->inembed[1] <= fft_plan->n[1] / 2 + 1) + ? 0 + : buffer_size; + workspace_size += (fft_plan->is_output_contiguous) ? 0 : buffer_size; + } + + if (fft_plan->n[0] != fft_plan->inembed[0] || + fft_plan->n[1] != fft_plan->inembed[1]) { + workspace_size += buffer_size; + } + fft_plan->workspace_size = workspace_size; + fft_plan->reservespace_size = reservespace_size; + + return MLUOP_STATUS_SUCCESS; +} mluOpStatus_t MLUOP_WIN_API mluOpAllocateRFFT2D( mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, mluOpTensorDescriptor_t input_desc, mluOpTensorDescriptor_t output_desc, - const int _n0, const int _n1) { + const int n0_ori, const int n1_ori) { const std::string make_plan_api = "[mluOpAllocateRFFT2D]"; size_t workspace_size = 0, reservespace_size = 0; @@ -1807,27 +1857,31 @@ mluOpStatus_t MLUOP_WIN_API mluOpAllocateRFFT2D( : mluOpDataTypeBytes(in_c_dtype); int batch = fft_plan->batch; - size_t buffer_size = batch * complex_dtype_size * _n0 * _n1; + size_t buffer_size = batch * complex_dtype_size * n0_ori * n1_ori; - size_t twiddles_size = complex_dtype_size * _n0; - size_t twiddles_size_2d = complex_dtype_size * _n1; + size_t twiddles_size = complex_dtype_size * n0_ori; + size_t twiddles_size_2d = complex_dtype_size * n1_ori; if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) { - reservespace_size = complex_dtype_size * _n0 * _n0 * 2 + - complex_dtype_size * _n1 * _n1 * 2; /* DFT matrix */ - workspace_size = complex_dtype_size * _n1 * _n0 * batch * 6; + reservespace_size = complex_dtype_size * n0_ori * n0_ori * 2 + + complex_dtype_size * n1_ori * n1_ori * 2; /* DFT matrix */ + workspace_size = complex_dtype_size * n1_ori * n0_ori * batch * 6; } else if (fft_plan->fft_strategy == CNFFT_FUNC_TWO_LEVEL_STOCKHAM) { reservespace_size = sizeof(int) * (FFT_MAXFACTORS) /* factors */ + sizeof(int) * (FFT_MAXFACTORS) + twiddles_size * 2 + DFT_TABLE_SIZE * 2 + twiddles_size_2d * 2 + DFT_TABLE_SIZE * 2; /* twiddles */ workspace_size = buffer_size * 2; - workspace_size += (fft_plan->is_input_contiguous) ? 0 : buffer_size; + workspace_size += (fft_plan->is_input_contiguous && + fft_plan->inembed[0] <= fft_plan->n[0] && + fft_plan->inembed[1] <= fft_plan->n[1]) + ? 0 + : buffer_size; workspace_size += (fft_plan->is_output_contiguous) ? 0 : buffer_size; } - if (fft_plan->n[0] > fft_plan->inembed[0] || - fft_plan->n[1] > fft_plan->inembed[1]) { + if (fft_plan->n[0] != fft_plan->inembed[0] || + fft_plan->n[1] != fft_plan->inembed[1]) { workspace_size += buffer_size; } fft_plan->workspace_size = workspace_size; @@ -1846,6 +1900,8 @@ mluOpStatus_t MLUOP_WIN_API mluOpMakeFFTPlanC2C1D( const int rank, const int *n) { fft_plan->is_batch_contiguous = (fft_plan->idist == 1 && fft_plan->odist == 1 && + fft_plan->inembed[0] == fft_plan->n[0] && + fft_plan->onembed[0] == fft_plan->n[0] && fft_plan->istride == fft_plan->batch && fft_plan->ostride == fft_plan->batch) && (fft_plan->n[0] == fft_plan->inembed[0]); @@ -2221,7 +2277,7 @@ mluOpStatus_t MLUOP_WIN_API mluOpMakeFFTPlanC2R2D( fft_plan->fft_strategy = CNFFT_FUNC_TWO_LEVEL_STOCKHAM; } - mluOpAllocateRFFT2D(handle, fft_plan, input_desc, output_desc, n[0], n[1]); + mluOpAllocateIRFFT2D(handle, fft_plan, input_desc, output_desc, n[0], n[1]); if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) { switch (fft_plan->fft_type) { @@ -2394,6 +2450,12 @@ mluOpStatus_t MLUOP_WIN_API mluOpMakeFFTPlanMany( fft_plan->inembed[i] = input_desc->dims[fft_plan->idim - rank + i]; fft_plan->onembed[i] = output_desc->dims[fft_plan->odim - rank + i]; } + for (auto i = 0; i < fft_plan->idim; i++) { + fft_plan->in_stride[i] = input_desc->strides[i]; + } + for (auto i = 0; i < fft_plan->odim; i++) { + fft_plan->out_stride[i] = output_desc->strides[i]; + } if (fft_plan->idim == rank + 1) { fft_plan->idist = input_desc->strides[0]; fft_plan->odist = output_desc->strides[0]; diff --git a/kernels/fft/fft.h b/kernels/fft/fft.h index aa7ac0ba6..6f31a7751 100644 --- a/kernels/fft/fft.h +++ b/kernels/fft/fft.h @@ -180,6 +180,8 @@ struct cnfftButterflyAddrs { int *factors; int *factors_2d; void *input_pad_addr; + void *input_copy_workspace_addr; + void *output_copy_workspace_addr; }; struct mluOpFFTStruct { int rank; // rank of FFT @@ -193,24 +195,26 @@ struct mluOpFFTStruct { int inum; // element num of input tensor int istride; // distance between two successive input elements in the // innermost dimension - int idist; // distance between the first element of two consecutive signals - // in a batch of the input data - int odim; // the dimension size of output tensor + int in_stride[FFT_DIM_MAX + 1]; + int idist; // distance between the first element of two consecutive signals + // in a batch of the input data + int odim; // the dimension size of output tensor int onembed[FFT_DIM_MAX]; // Pointer of size rank that indicates the storage // dimensions of the output data in memory int onum; // element num of output tensor int ostride; // distance between two successive output elements in the // innermost dimension - int odist; // distance between the first element of two consecutive signals - // in a batch of the output data - int batch; // batch size for this transform - int L; // n = L * 2^m, L size for this transform - int m; // n = L * 2^m, m size for this transform - int s; // The size that can be put down on NRAM: L * 2^s, only used by - // Cooley-Tukey algorithm - int L_sub; // The size that can be put down on NRAM: L_sub * 2^m, only used - // by Stockham algorithm - int prime; // wether fft1d'size contains a prime number > 64 + int out_stride[FFT_DIM_MAX + 1]; + int odist; // distance between the first element of two consecutive signals + // in a batch of the output data + int batch; // batch size for this transform + int L; // n = L * 2^m, L size for this transform + int m; // n = L * 2^m, m size for this transform + int s; // The size that can be put down on NRAM: L * 2^s, only used by + // Cooley-Tukey algorithm + int L_sub; // The size that can be put down on NRAM: L_sub * 2^m, only used + // by Stockham algorithm + int prime; // wether fft1d'size contains a prime number > 64 bool is_input_contiguous; bool is_output_contiguous; bool is_batch_contiguous; diff --git a/kernels/fft/fft_optm_device/fft_c2c_stockham_nram.h b/kernels/fft/fft_optm_device/fft_c2c_stockham_nram.h index 547174631..07d31dea1 100644 --- a/kernels/fft/fft_optm_device/fft_c2c_stockham_nram.h +++ b/kernels/fft/fft_optm_device/fft_c2c_stockham_nram.h @@ -305,379 +305,6 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpong( } } -// Compute the large butterfly for the subsequent stages of the FFT -template -__mlu_func__ void computeLargeButterflyOtherstages( - DT *output, DT *input, const int large_radix, const DT *cur_large_twiddles, - const DT *_twiddles, const DT *dft_matrix, const int large_section_num, - const int large_butterfly_num, const int large_in_stride, void *nram_buf, - const int *small_factors, const int nfft, const int dir, - const int last_stage) { - const dft_table_entry *dft_table = (const dft_table_entry *)dft_matrix; - const int K_num = 64 / sizeof(DT); - int align_K = 0; - int radix, small_in_stride, small_stage_count, _small_stage_count; - int small_section_num, small_butterfly_num, value_mul; - - const int large_out_stride = large_butterfly_num; - int tw_offset; - - _small_stage_count = small_factors[0]; - tw_offset = small_factors[1]; - - const DT *small_twiddles = _twiddles + tw_offset * 2; - - const int max_para_ldst_num = (4096 + large_radix - 1) / large_radix; - - int nram_buf_offset = 0; - DT *nram_in_r = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * max_para_ldst_num; - - DT *nram_in_i = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * max_para_ldst_num; - - DT *nram_out_r = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * max_para_ldst_num; - - DT *nram_out_i = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * max_para_ldst_num; - - FFT_CPX_T
nram_para_load_in_ping = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; - nram_buf_offset += large_radix * max_para_ldst_num * 2; - - FFT_CPX_T
nram_para_load_in_pong = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; - nram_buf_offset += large_radix * max_para_ldst_num * 2; - - FFT_CPX_T
nram_para_load_tw_ping = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; - nram_buf_offset += large_radix * max_para_ldst_num * 2; - - FFT_CPX_T
nram_para_load_tw_pong = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; - nram_buf_offset += large_radix * max_para_ldst_num * 2; - - FFT_CPX_T
nram_para_store_ping = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; - nram_buf_offset += large_radix * max_para_ldst_num * 2; - - FFT_CPX_T
nram_para_store_pong = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; - nram_buf_offset += large_radix * max_para_ldst_num * 2; - - FFT_CPX_T
nram_transpose_temp; - nram_transpose_temp = { - (DT *)nram_in_r, - (DT *)nram_in_r + large_radix * ((int)last_stage) + - large_radix * (1 - (int)last_stage) * max_para_ldst_num}; - - DT *_nram_tw = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * 2; - - int ld_dft_radix = -1; - const int max_radix = 64; - DT *nram_dftmtx = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += max_radix * max_radix * 2; - - DT *nram_scratch = (DT *)nram_buf + nram_buf_offset; - - DT *CPX_MUL_RR = nram_scratch; - DT *CPX_MUL_RI = &CPX_MUL_RR[large_radix * max_para_ldst_num]; - DT *CPX_MUL_IR = &CPX_MUL_RI[large_radix * max_para_ldst_num]; - DT *CPX_MUL_II = &CPX_MUL_IR[large_radix * max_para_ldst_num]; - - nram_buf_offset += large_radix * max_para_ldst_num * 4; - - int Fin_stride = 0, Fout_stride = 0; - int sec_count; - int repeat_num = - (large_butterfly_num + max_para_ldst_num - 1) / max_para_ldst_num; - for (sec_count = 0; sec_count < large_section_num; ++sec_count) { - for (int repeat_id = 0; repeat_id < repeat_num + 2; ++repeat_id) { - if (repeat_id < repeat_num) { - int i = max_para_ldst_num * repeat_id; - FFT_CPX_T
nram_para_load_in = (repeat_id % 2 == 0) - ? nram_para_load_in_ping - : nram_para_load_in_pong; - - FFT_CPX_T
nram_para_load_tw = (repeat_id % 2 == 0) - ? nram_para_load_tw_ping - : nram_para_load_tw_pong; - - int para_load_num = (max_para_ldst_num > (large_butterfly_num - i)) - ? (large_butterfly_num - i) - : max_para_ldst_num; - - __memcpy_async(nram_para_load_in.r, input + Fin_stride + i, - sizeof(DT) * para_load_num, GDRAM2NRAM, - sizeof(DT) * para_load_num, large_in_stride * sizeof(DT), - large_radix - 1); - __memcpy_async(nram_para_load_in.i, input + nfft + Fin_stride + i, - sizeof(DT) * para_load_num, GDRAM2NRAM, - sizeof(DT) * para_load_num, large_in_stride * sizeof(DT), - large_radix - 1); - __memcpy_async(nram_para_load_tw.r, cur_large_twiddles + i, - sizeof(DT) * para_load_num, SRAM2NRAM, - sizeof(DT) * para_load_num, - large_out_stride * sizeof(DT), large_radix - 2); - __memcpy_async( - nram_para_load_tw.i, - cur_large_twiddles + large_butterfly_num * (large_radix - 1) + i, - sizeof(DT) * para_load_num, SRAM2NRAM, sizeof(DT) * para_load_num, - large_out_stride * sizeof(DT), large_radix - 2); - } - - if (repeat_id >= 2) { - int i = max_para_ldst_num * (repeat_id - 2); - - int para_store_num = (max_para_ldst_num > (large_butterfly_num - i)) - ? (large_butterfly_num - i) - : max_para_ldst_num; - - FFT_CPX_T
nram_para_store = - (repeat_id % 2 == 0) ? nram_para_store_ping : nram_para_store_pong; - - if (last_stage) { - __memcpy_async(output + (Fout_stride + i) * 2, nram_para_store.r, - sizeof(DT) * 2 * para_store_num, NRAM2GDRAM, - large_out_stride * 2 * sizeof(DT), - sizeof(DT) * 2 * para_store_num, large_radix - 1); - } else { - __memcpy_async(output + Fout_stride + i, nram_para_store.r, - para_store_num * sizeof(DT), NRAM2GDRAM, - large_out_stride * sizeof(DT), - sizeof(DT) * para_store_num, large_radix - 1); - __memcpy_async(output + Fout_stride + i + nfft, nram_para_store.i, - para_store_num * sizeof(DT), NRAM2GDRAM, - large_out_stride * sizeof(DT), - sizeof(DT) * para_store_num, large_radix - 1); - } - } - - if (repeat_id >= 1 && repeat_id < repeat_num + 1) { - int i = max_para_ldst_num * (repeat_id - 1); - - FFT_CPX_T
nram_para_load_in = (repeat_id % 2 != 0) - ? nram_para_load_in_ping - : nram_para_load_in_pong; - - FFT_CPX_T
nram_para_load_tw = (repeat_id % 2 != 0) - ? nram_para_load_tw_ping - : nram_para_load_tw_pong; - - FFT_CPX_T
nram_para_store = - (repeat_id % 2 != 0) ? nram_para_store_ping : nram_para_store_pong; - - int para_ldst_num = (max_para_ldst_num > (large_butterfly_num - i)) - ? (large_butterfly_num - i) - : max_para_ldst_num; - - __bang_mul(CPX_MUL_RR, nram_para_load_in.r + para_ldst_num, - nram_para_load_tw.r, para_ldst_num * (large_radix - 1)); - __bang_mul(CPX_MUL_II, nram_para_load_in.i + para_ldst_num, - nram_para_load_tw.i, para_ldst_num * (large_radix - 1)); - __bang_mul(CPX_MUL_RI, nram_para_load_in.r + para_ldst_num, - nram_para_load_tw.i, para_ldst_num * (large_radix - 1)); - __bang_mul(CPX_MUL_IR, nram_para_load_in.i + para_ldst_num, - nram_para_load_tw.r, para_ldst_num * (large_radix - 1)); - - __bang_sub(nram_para_load_in.r + para_ldst_num, CPX_MUL_RR, CPX_MUL_II, - para_ldst_num * (large_radix - 1)); - __bang_add(nram_para_load_in.i + para_ldst_num, CPX_MUL_RI, CPX_MUL_IR, - para_ldst_num * (large_radix - 1)); - - { - radix = small_factors[4]; - small_section_num = small_factors[5]; - small_in_stride = small_factors[7]; - small_stage_count = _small_stage_count; - - if (ld_dft_radix != radix) { - ld_dft_radix = radix; - for (int entry = 0;; entry++) { - if (dft_table[entry].radix == ld_dft_radix) { - align_K = K_num * ((radix + K_num - 1) / K_num); - __memcpy_async( - nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2], - sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM); - __sync_move(); - break; - } - - if (dft_table[entry].radix == -1) { - break; - } - } - } - - computeGenericButterflyFirststageMat( - nram_out_r, nram_out_i, nram_para_load_in.r, nram_para_load_in.i, - nram_scratch, nram_dftmtx, small_section_num * para_ldst_num, - small_section_num * para_ldst_num, 1, dir, radix); - - small_stage_count--; - if (small_stage_count == 0) { - if (last_stage) { - __memcpy_async(nram_transpose_temp.r, nram_out_r, - sizeof(DT) * large_radix, NRAM2NRAM, - sizeof(DT) * large_radix * 2, - sizeof(DT) * large_radix, para_ldst_num - 1); - - __memcpy_async(nram_transpose_temp.i, nram_out_i, - sizeof(DT) * large_radix, NRAM2NRAM, - sizeof(DT) * large_radix * 2, - sizeof(DT) * large_radix, para_ldst_num - 1); - __sync_move(); - - __bang_transpose(nram_para_store.r, nram_transpose_temp.r, - para_ldst_num * 2, large_radix); - } else { - __bang_transpose(nram_para_store.r, nram_out_r, para_ldst_num, - large_radix); - __bang_transpose(nram_para_store.i, nram_out_i, para_ldst_num, - large_radix); - } - - } else { - FFT_SWAP_PTR(nram_out_r, nram_in_r); - FFT_SWAP_PTR(nram_out_i, nram_in_i); - TRANSPOSE_XYZ2YXZ_PAIR(nram_out_r, nram_out_i, nram_in_r, nram_in_i, - small_section_num, para_ldst_num, radix, DT) - DT *nram_tw = _nram_tw; - value_mul = 8; - - for (; small_stage_count > 1; small_stage_count--) { - FFT_SWAP_PTR(nram_out_r, nram_in_r); - FFT_SWAP_PTR(nram_out_i, nram_in_i); - - radix = small_factors[value_mul++]; - small_section_num = small_factors[value_mul++]; - small_butterfly_num = small_factors[value_mul++]; - small_in_stride = small_factors[value_mul++]; - - if (ld_dft_radix != radix) { - ld_dft_radix = radix; - for (int entry = 0;; entry++) { - if (dft_table[entry].radix == ld_dft_radix) { - align_K = K_num * ((radix + K_num - 1) / K_num); - __memcpy_async( - nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2], - sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM); - __sync_move(); - break; - } - - if (dft_table[entry].radix == -1) { - break; - } - } - } - - if (sec_count == 0 && repeat_id == 1) { - __memcpy(nram_tw, small_twiddles, - small_butterfly_num * (radix - 1) * sizeof(DT) * 2, - SRAM2NRAM); - small_twiddles += small_butterfly_num * (radix - 1) * 2; - } - - computeGenericButterflyOtherstagesMat( - nram_out_r, nram_out_i, nram_in_r, nram_in_i, nram_scratch, - nram_dftmtx, nram_tw, small_section_num, small_butterfly_num, - para_ldst_num, small_in_stride, dir, radix); - - nram_tw += small_butterfly_num * (radix - 1) * 2; - } - - { - FFT_SWAP_PTR(nram_out_r, nram_in_r); - FFT_SWAP_PTR(nram_out_i, nram_in_i); - - radix = small_factors[value_mul++]; - small_section_num = small_factors[value_mul++]; - small_butterfly_num = small_factors[value_mul++]; - small_in_stride = small_factors[value_mul]; - - if (sec_count == 0 && repeat_id == 1) { - __memcpy_async( - nram_tw, small_twiddles, - small_butterfly_num * (radix - 1) * sizeof(DT) * 2, - SRAM2NRAM); - __sync_move(); - } - - if (ld_dft_radix != radix) { - ld_dft_radix = radix; - for (int entry = 0;; entry++) { - if (dft_table[entry].radix == ld_dft_radix) { - align_K = K_num * ((radix + K_num - 1) / K_num); - __memcpy_async( - nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2], - sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM); - __sync_move(); - break; - } - - if (dft_table[entry].radix == -1) { - break; - } - } - } - computeGenericButterflyLaststageMat( - nram_out_r, nram_out_i, nram_in_r, nram_in_i, nram_scratch, - nram_dftmtx, nram_tw, small_section_num, small_butterfly_num, - para_ldst_num, small_in_stride, dir, radix); - - if (last_stage) { - __memcpy_async(nram_transpose_temp.r, nram_out_r, - sizeof(DT) * large_radix, NRAM2NRAM, - sizeof(DT) * large_radix * 2, - sizeof(DT) * large_radix, para_ldst_num - 1); - - __memcpy_async(nram_transpose_temp.i, nram_out_i, - sizeof(DT) * large_radix, NRAM2NRAM, - sizeof(DT) * large_radix * 2, - sizeof(DT) * large_radix, para_ldst_num - 1); - __sync_move(); - - __bang_transpose(nram_para_store.r, nram_transpose_temp.r, - para_ldst_num * 2, large_radix); - } else { - __bang_transpose(nram_para_store.r, nram_out_r, para_ldst_num, - large_radix); - __bang_transpose(nram_para_store.i, nram_out_i, para_ldst_num, - large_radix); - } - } - } - } - } - - __sync(); - } - Fin_stride += large_butterfly_num; - Fout_stride += large_radix * large_butterfly_num; - } -} - -template -__mlu_func__ void computeLargeButterflyLaststage( - DT *output, DT *input, const int large_radix, const DT *cur_large_twiddles, - const DT *_twiddles, const DT *dft_matrix, const int large_section_num, - const int large_butterfly_num, const int large_in_stride, void *nram_buf, - const int *small_factors, const int nfft, const int dir) { - computeLargeButterflyOtherstages( - output, input, large_radix, cur_large_twiddles, _twiddles, dft_matrix, - large_section_num, large_butterfly_num, large_in_stride, nram_buf, - small_factors, nfft, dir, 1); -} - // Compute the large butterfly for the last stage of the FFT template __mlu_func__ void computeLargeButterflyOtherstagesBatchPingpong( diff --git a/kernels/fft/fft_optm_device/fft_two-level_network_c2c_device.mlu b/kernels/fft/fft_optm_device/fft_two-level_network_c2c_device.mlu index 55b3a37b6..808c91df1 100644 --- a/kernels/fft/fft_optm_device/fft_two-level_network_c2c_device.mlu +++ b/kernels/fft/fft_optm_device/fft_two-level_network_c2c_device.mlu @@ -35,26 +35,10 @@ __mlu_global__ void MLUKernelFFT1dButterflyRow( void *input, void *output, int *factors, void *twiddles, void *twiddles_end, void *dft_matrix, void *buffer, const int batch, const int fft_flag, const int direction, const int dtype_size) { - switch (dtype_size) { - case (MLUOP_DTYPE_COMPLEX_FLOAT): - case (MLUOP_DTYPE_FLOAT): { - computeMutiStageOnchip((float *)input, (float *)output, factors, - (float *)twiddles, (float *)twiddles_end, - (float *)dft_matrix, (float *)buffer, batch, - fft_flag, direction); - }; break; - case (MLUOP_DTYPE_COMPLEX_HALF): - case (MLUOP_DTYPE_HALF): { - computeMutiStageOnchip((half *)input, (half *)output, factors, - (half *)twiddles, (half *)twiddles_end, - (half *)dft_matrix, (half *)buffer, batch, - fft_flag, direction); - }; break; - - default: { - MLULOG("mluOpFFT Not Implemented."); - } - } + computeMutiStageOnchip((float *)input, (float *)output, factors, + (float *)twiddles, (float *)twiddles_end, + (float *)dft_matrix, (float *)buffer, batch, + fft_flag, direction); } // Kernel function for 1D FFT butterfly operations on columns. @@ -62,26 +46,10 @@ __mlu_global__ void MLUKernelFFT1dButterflyColumn( void *input, void *output, int *factors, void *twiddles, void *twiddles_end, void *dft_matrix, void *buffer, const int batch, const int fft_flag, const int direction, const int dtype_size, const int nb) { - switch (dtype_size) { - case (MLUOP_DTYPE_COMPLEX_FLOAT): - case (MLUOP_DTYPE_FLOAT): { - computeMutiStageOnchipColumn( - (float *)input, (float *)output, factors, (float *)twiddles, - (float *)twiddles_end, (float *)dft_matrix, (float *)buffer, batch, - fft_flag, direction, nb); - }; break; - case (MLUOP_DTYPE_COMPLEX_HALF): - case (MLUOP_DTYPE_HALF): { - computeMutiStageOnchipColumn((half *)input, (half *)output, factors, - (half *)twiddles, (half *)twiddles_end, - (half *)dft_matrix, (half *)buffer, - batch, fft_flag, direction, nb); - }; break; - - default: { - MLULOG("mluOpFFT Not Implemented."); - } - } + computeMutiStageOnchipColumn((float *)input, (float *)output, factors, + (float *)twiddles, (float *)twiddles_end, + (float *)dft_matrix, (float *)buffer, + batch, fft_flag, direction, nb); } // Launches a kernel for 2D FFT butterfly operations on columns. diff --git a/kernels/fft/fft_optm_device/fft_two-level_network_c2r_device.mlu b/kernels/fft/fft_optm_device/fft_two-level_network_c2r_device.mlu index 31b3c3908..a76078f62 100644 --- a/kernels/fft/fft_optm_device/fft_two-level_network_c2r_device.mlu +++ b/kernels/fft/fft_optm_device/fft_two-level_network_c2r_device.mlu @@ -33,26 +33,10 @@ __mlu_global__ void MLUKernelFFT1dButterflyRowC2R( void *input, void *output, int *factors, void *twiddles, void *twiddles_end, void *dft_matrix, void *buffer, int batch, int fft_flag, int dtype_size) { - switch (dtype_size) { - case (MLUOP_DTYPE_COMPLEX_FLOAT): - case (MLUOP_DTYPE_FLOAT): { - computeMutiStageOnchipC2R((float *)input, (float *)output, factors, - (float *)twiddles, (float *)twiddles_end, - (float *)dft_matrix, (float *)buffer, - batch, fft_flag); - }; break; - case (MLUOP_DTYPE_COMPLEX_HALF): - case (MLUOP_DTYPE_HALF): { - computeMutiStageOnchipC2R((half *)input, (half *)output, factors, - (half *)twiddles, (half *)twiddles_end, - (half *)dft_matrix, (half *)buffer, batch, - fft_flag); - }; break; - - default: { - MLULOG("mluOpFFT Not Implemented."); - } - } + computeMutiStageOnchipC2R((float *)input, (float *)output, factors, + (float *)twiddles, (float *)twiddles_end, + (float *)dft_matrix, (float *)buffer, batch, + fft_flag); } mluOpStatus_t MLUOP_WIN_API kernelFFT1dButterflyRowC2R( diff --git a/kernels/fft/fft_optm_device/fft_two-level_network_r2c_device.mlu b/kernels/fft/fft_optm_device/fft_two-level_network_r2c_device.mlu index 5dd5f9e8d..3e36946b7 100644 --- a/kernels/fft/fft_optm_device/fft_two-level_network_r2c_device.mlu +++ b/kernels/fft/fft_optm_device/fft_two-level_network_r2c_device.mlu @@ -33,28 +33,10 @@ __mlu_global__ void MLUKernelFFT1dButterflyR2C( void *input, void *output, int *factors, void *twiddles, void *twiddles_end, void *dft_matrix, void *buffer, int batch, int fft_flag, int dtype_size) { - switch (dtype_size) { - case (MLUOP_DTYPE_COMPLEX_FLOAT): - case (MLUOP_DTYPE_FLOAT): { - MLULOG("MLUOP_DTYPE_COMPLEX_FLOAT: MLUOP_DTYPE_FLOAT\n"); - computeMutiStageR2COnchip((float *)input, (float *)output, factors, - (float *)twiddles, (float *)twiddles_end, - (float *)dft_matrix, (float *)buffer, - batch, fft_flag); - }; break; - case (MLUOP_DTYPE_COMPLEX_HALF): - case (MLUOP_DTYPE_HALF): { - MLULOG("MLUOP_DTYPE_COMPLEX_HALF: MLUOP_DTYPE_HALF\n"); - computeMutiStageR2COnchip((half *)input, (half *)output, factors, - (half *)twiddles, (half *)twiddles_end, - (half *)dft_matrix, (half *)buffer, batch, - fft_flag); - }; break; - - default: { - MLULOG("mluOpFFT Not Implemented."); - } - } + computeMutiStageR2COnchip((float *)input, (float *)output, factors, + (float *)twiddles, (float *)twiddles_end, + (float *)dft_matrix, (float *)buffer, batch, + fft_flag); } mluOpStatus_t MLUOP_WIN_API kernelFFT1dButterflyR2C(cnrtDim3_t k_dim, diff --git a/kernels/fft/irfft/irfft_host.cpp b/kernels/fft/irfft/irfft_host.cpp index b065028e1..690ab378b 100644 --- a/kernels/fft/irfft/irfft_host.cpp +++ b/kernels/fft/irfft/irfft_host.cpp @@ -1486,17 +1486,8 @@ static mluOpStatus_t makeIRFFT1dContiguousOutput(mluOpHandle_t handle, DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(copy_dst_desc, cnnl_copy_dst_desc); - size_t workspace_size = 0; - CALL_CNNL(cnnlGetCopyWorkspaceSize(cnnl_handle, cnnl_copy_src_desc, - cnnl_copy_dst_desc, &workspace_size)); - - void *workspace = nullptr; - if (workspace_size > 0) { - CNRT_CHECK(cnrtMalloc((void **)&workspace, workspace_size)); - } CALL_CNNL(cnnlCopy_v2(cnnl_handle, cnnl_copy_src_desc, copy_src_addr, - cnnl_copy_dst_desc, output, workspace, - workspace_size)); + cnnl_copy_dst_desc, output, NULL, 0)); DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_copy_src_desc); DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_copy_dst_desc); @@ -1567,12 +1558,15 @@ mluOpStatus_t execIRFFT1d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan, const float beta = 0.0; mluOpTensorDescriptor_t c_desc = nullptr; status = mluOpCreateTensorDescriptor(&c_desc); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); const int out_dim_num = 2; int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->n[0]}; status = mluOpSetTensorDescriptor_v2(c_desc, MLUOP_LAYOUT_ARRAY, - fft_plan->output_dtype, 2, dims); + fft_plan->output_dtype, out_dim_num, dims); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); status = mluOpSetTensorDescriptorOnchipDataType( c_desc, fft_plan->execution_dtype); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, cnnl_handle); // convert to cnnl_handle @@ -1583,6 +1577,8 @@ mluOpStatus_t execIRFFT1d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan, cnnl_output_desc, fft_plan->mlu_addrs.output, &beta, cnnl_output_desc, fft_plan->mlu_addrs.output)); + status = mluOpDestroyTensorDescriptor(c_desc); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_output_desc); DESTROY_CNNL_HANDLE(cnnl_handle); } @@ -1606,13 +1602,13 @@ static void configureIRFFT2dWorkspaceAddrs(mluOpHandle_t handle, size_t out_c_dtype_size = mluOpDataTypeBytes(out_c_dtype); int batch = fft_plan->batch; - int _n0 = fft_plan->n[0]; - int _n1 = fft_plan->n[1]; + int n0_ori = fft_plan->n[0]; + int n1_ori = fft_plan->n[1]; size_t offset = 0; if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) { // rr ri ir ii - size_t buffer_size = batch * in_c_dtype_size * _n0 * _n1 * 2; + size_t buffer_size = batch * in_c_dtype_size * n0_ori * n1_ori * 2; offset = 0; fft_plan->mlu_addrs.input = input; fft_plan->mlu_addrs.output = output; @@ -1625,25 +1621,28 @@ static void configureIRFFT2dWorkspaceAddrs(mluOpHandle_t handle, if (fft_plan->fft_strategy == CNFFT_FUNC_TWO_LEVEL_STOCKHAM) { offset = 0; fft_plan->mlu_addrs.buffer_buf = (uint8_t *)workspace + offset; - offset += batch * in_c_dtype_size * _n0 * _n1 * 2; + offset += batch * in_c_dtype_size * n0_ori * n1_ori * 2; - if (fft_plan->is_input_contiguous) { + if (fft_plan->is_input_contiguous && + fft_plan->inembed[0] <= fft_plan->n[0] && + fft_plan->inembed[1] <= fft_plan->n[1] / 2 + 1 || + fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) { fft_plan->mlu_addrs.input = input; } else { fft_plan->mlu_addrs.input = (uint8_t *)workspace + offset; - offset += batch * in_c_dtype_size * _n0 * _n1; + offset += batch * in_c_dtype_size * n0_ori * n1_ori; } if (fft_plan->is_output_contiguous) { fft_plan->mlu_addrs.output = output; } else { fft_plan->mlu_addrs.output = (uint8_t *)workspace + offset; - offset += batch * in_c_dtype_size * _n0 * _n1; + offset += batch * in_c_dtype_size * n0_ori * n1_ori; } } if (fft_plan->n[0] > fft_plan->inembed[0] || - fft_plan->n[1] > fft_plan->inembed[1]) { + fft_plan->n[1] / 2 + 1 > fft_plan->inembed[1]) { fft_plan->mlu_addrs.input_pad_addr = (uint8_t *)workspace + offset; } } @@ -1828,11 +1827,11 @@ mluOpStatus_t computeFFT2dMatMulRowC2R(mluOpHandle_t handle, int requested_algo_count = 1, return_algo_count = 0; float *workspace; size_t workspace_size; - cnnlGetBatchMatMulAlgoHeuristic( + cnnlGetBatchMatMulExAlgoHeuristic( cnnl_handle, bmm_bcast_desc, cnnl_a_desc, cnnl_b_desc, cnnl_c_desc, NULL, requested_algo_count, &heuristic_result, &return_algo_count); - cnnlGetBatchMatMulHeuristicResult(heuristic_result, algo, &workspace_size); + cnnlGetBatchMatMulExHeuristicResult(heuristic_result, algo, &workspace_size); if (workspace_size > 0) { CNRT_CHECK(cnrtMalloc((void **)&workspace, workspace_size)); @@ -1840,10 +1839,10 @@ mluOpStatus_t computeFFT2dMatMulRowC2R(mluOpHandle_t handle, CNRT_CHECK(cnrtMalloc((void **)&workspace, m * n * sizeof(float))); } - CALL_CNNL(cnnlBatchMatMulBCast_v2(cnnl_handle, bmm_bcast_desc, algo, &alpha, - cnnl_a_desc, dft_matrix_addr, cnnl_b_desc, - in_addr, &beta, cnnl_c_desc, out_addr, - (void *)workspace, workspace_size)); + CALL_CNNL(cnnlBatchMatMulEx(cnnl_handle, bmm_bcast_desc, algo, &alpha, + cnnl_a_desc, dft_matrix_addr, cnnl_b_desc, + in_addr, &beta, cnnl_c_desc, out_addr, + (void *)workspace, workspace_size)); // destroy cnnl descriptor DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_a_desc); DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_b_desc); @@ -1866,8 +1865,7 @@ static mluOpStatus_t makeIRFFT2dContiguousInput(mluOpHandle_t handle, auto status = MLUOP_STATUS_SUCCESS; if ((!fft_plan->is_input_contiguous || (fft_plan->inembed[0] > fft_plan->n[0] || - fft_plan->inembed[1] > fft_plan->n[1] / 2 + 1) && - !fft_plan->prime) && + fft_plan->inembed[1] > fft_plan->n[1] / 2 + 1)) && fft_plan->fft_strategy != CNFFT_FUNC_MANY_DIST1_2D) { VLOG(5) << "launch mluOpContiguous for irfft2d input"; mluOpTensorDescriptor_t input_desc; @@ -1878,9 +1876,10 @@ static mluOpStatus_t makeIRFFT2dContiguousInput(mluOpHandle_t handle, int64_t dims[in_dim_num] = { fft_plan->batch, std::min(fft_plan->inembed[0], fft_plan->n[0]), std::min(FFT_HALF(fft_plan->n[1]), fft_plan->inembed[1])}; - int64_t strides[in_dim_num] = {fft_plan->idist, - (fft_plan->istride * fft_plan->inembed[1]), - fft_plan->istride}; + int64_t strides[in_dim_num]; // in_dim_num + for (int i = 0; i < in_dim_num; i++) { + strides[i] = fft_plan->in_stride[i]; + } status = mluOpSetTensorDescriptorEx_v2(input_desc, MLUOP_LAYOUT_ARRAY, fft_plan->input_dtype, in_dim_num, dims, strides); @@ -1916,9 +1915,10 @@ static mluOpStatus_t makeIRFFT2dContiguousOutput(mluOpHandle_t handle, const int out_dim_num = 3; int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->n[0], fft_plan->n[1]}; - int64_t strides[out_dim_num] = {fft_plan->odist, - fft_plan->ostride * fft_plan->onembed[1], - fft_plan->ostride}; + int64_t strides[out_dim_num]; // out_dim_num + for (int i = 0; i < out_dim_num; i++) { + strides[i] = fft_plan->out_stride[i]; + } status = mluOpSetTensorDescriptor_v2(copy_src_desc, MLUOP_LAYOUT_ARRAY, out_c_dtype, out_dim_num, dims); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); @@ -1937,17 +1937,8 @@ static mluOpStatus_t makeIRFFT2dContiguousOutput(mluOpHandle_t handle, DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(copy_dst_desc, cnnl_copy_dst_desc); - size_t workspace_size = 0; - CALL_CNNL(cnnlGetCopyWorkspaceSize(cnnl_handle, cnnl_copy_src_desc, - cnnl_copy_dst_desc, &workspace_size)); - - void *workspace = nullptr; - if (workspace_size > 0) { - CNRT_CHECK(cnrtMalloc((void **)&workspace, workspace_size)); - } CALL_CNNL(cnnlCopy_v2(cnnl_handle, cnnl_copy_src_desc, copy_src_addr, - cnnl_copy_dst_desc, output, workspace, - workspace_size)); + cnnl_copy_dst_desc, output, NULL, 0)); DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_copy_src_desc); DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_copy_dst_desc); @@ -1987,63 +1978,92 @@ mluOpStatus_t execIRFFT2d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan, fft_plan->mlu_addrs.input = fft_plan->mlu_addrs.input_pad_addr; } - for (int batch_id = 0; batch_id < fft_plan->batch; batch_id++) { - status = kernelIRFFT2dButterflyColumn(k_dim, k_type, handle->queue, - fft_plan, FFT_IFFT); + if (fft_plan->n[0] == 1 && fft_plan->n[1] == 1) { + mluOpTensorDescriptor_t input_desc; + status = mluOpCreateTensorDescriptor(&input_desc); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + const int in_dim_num = 2; + int64_t dims[in_dim_num] = { + fft_plan->batch * fft_plan->n[0] * fft_plan->n[1], 1}; + int64_t strides[in_dim_num] = {2, 1}; + status = mluOpSetTensorDescriptorEx_v2(input_desc, MLUOP_LAYOUT_ARRAY, + MLUOP_DTYPE_FLOAT, in_dim_num, + dims, strides); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); - status = kernelIRFFT2dButterflyRow(k_dim, k_type, handle->queue, fft_plan, - FFT_IFFT); + + status = mluOpContiguous(handle, input_desc, fft_plan->mlu_addrs.input, + fft_plan->mlu_addrs.output); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + status = mluOpDestroyTensorDescriptor(input_desc); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + } else { + for (int batch_id = 0; batch_id < fft_plan->batch; batch_id++) { + status = kernelIRFFT2dButterflyColumn(k_dim, k_type, handle->queue, + fft_plan, FFT_IFFT); + + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + status = kernelIRFFT2dButterflyRow(k_dim, k_type, handle->queue, + fft_plan, FFT_IFFT); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + + fft_plan->mlu_addrs.input = + (void *)((uint64_t)(fft_plan->mlu_addrs.input) + idist); + fft_plan->mlu_addrs.output = + (void *)((uint64_t)(fft_plan->mlu_addrs.output) + odist); + } fft_plan->mlu_addrs.input = - (void *)((uint64_t)(fft_plan->mlu_addrs.input) + idist); + (void *)((uint64_t)(fft_plan->mlu_addrs.input) - + fft_plan->batch * idist); fft_plan->mlu_addrs.output = - (void *)((uint64_t)(fft_plan->mlu_addrs.output) + odist); + (void *)((uint64_t)(fft_plan->mlu_addrs.output) - + fft_plan->batch * odist); } - fft_plan->mlu_addrs.input = (void *)((uint64_t)(fft_plan->mlu_addrs.input) - - fft_plan->batch * idist); - fft_plan->mlu_addrs.output = - (void *)((uint64_t)(fft_plan->mlu_addrs.output) - - fft_plan->batch * odist); + } else if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) { + status = computeFFT2dMatMulColumnC2R(handle, fft_plan, scale_factor); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); - if (scale_factor != 1.0) { - const float alpha[2] = {scale_factor, 0.0}; - const float beta[2] = {0.0, 0.0}; - mluOpTensorDescriptor_t c_desc = nullptr; - status = mluOpCreateTensorDescriptor(&c_desc); - const int out_dim_num = 3; - int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->n[0], - fft_plan->n[1]}; - status = mluOpSetTensorDescriptor_v2(c_desc, MLUOP_LAYOUT_ARRAY, - fft_plan->output_dtype, 3, dims); - status = mluOpSetTensorDescriptorOnchipDataType( - c_desc, fft_plan->execution_dtype); + status = computeFFT2dMatMulRowC2R(handle, fft_plan, scale_factor); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + } - DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, - cnnl_handle); // convert to cnnl_handle + if (scale_factor != 1.0) { + const float alpha[2] = {scale_factor, 0.0}; + const float beta[2] = {0.0, 0.0}; + mluOpTensorDescriptor_t c_desc = nullptr; + status = mluOpCreateTensorDescriptor(&c_desc); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + const int out_dim_num = 3; + int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->n[0], + fft_plan->n[1]}; + status = mluOpSetTensorDescriptor_v2(c_desc, MLUOP_LAYOUT_ARRAY, + fft_plan->output_dtype, out_dim_num, dims); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + status = mluOpSetTensorDescriptorOnchipDataType( + c_desc, fft_plan->execution_dtype); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); - DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(c_desc, cnnl_output_desc); + DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, + cnnl_handle); // convert to cnnl_handle - CALL_CNNL(cnnlTransform_v2(cnnl_handle, CNNL_POINTER_MODE_HOST, &alpha, - cnnl_output_desc, fft_plan->mlu_addrs.output, - &beta, cnnl_output_desc, - fft_plan->mlu_addrs.output)); - DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_output_desc); - DESTROY_CNNL_HANDLE(cnnl_handle); - } + DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(c_desc, cnnl_output_desc); + + CALL_CNNL(cnnlTransform_v2(cnnl_handle, CNNL_POINTER_MODE_HOST, &alpha, + cnnl_output_desc, fft_plan->mlu_addrs.output, + &beta, cnnl_output_desc, + fft_plan->mlu_addrs.output)); + status = mluOpDestroyTensorDescriptor(c_desc); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_output_desc); + DESTROY_CNNL_HANDLE(cnnl_handle); + } + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + if (fft_plan->fft_strategy == CNFFT_FUNC_TWO_LEVEL_STOCKHAM) { status = makeIRFFT2dContiguousOutput(handle, fft_plan, output, fft_plan->mlu_addrs.output); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); - - } else if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) { - status = computeFFT2dMatMulColumnC2R(handle, fft_plan, scale_factor); - INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); - - status = computeFFT2dMatMulRowC2R(handle, fft_plan, scale_factor); - INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); } return status; } diff --git a/kernels/fft/rfft/rfft_host.cpp b/kernels/fft/rfft/rfft_host.cpp index d0755e8be..109ca0fa0 100644 --- a/kernels/fft/rfft/rfft_host.cpp +++ b/kernels/fft/rfft/rfft_host.cpp @@ -434,13 +434,13 @@ static void configureRFFT2dWorkspaceAddrs(mluOpHandle_t handle, size_t out_c_dtype_size = mluOpDataTypeBytes(out_c_dtype); int batch = fft_plan->batch; - int _n0 = fft_plan->n[0]; - int _n1 = fft_plan->n[1]; + int n0_ori = fft_plan->n[0]; + int n1_ori = fft_plan->n[1]; size_t offset = 0; if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) { // rr ri ir ii - size_t buffer_size = batch * out_c_dtype_size * _n0 * _n1 * 2; + size_t buffer_size = batch * out_c_dtype_size * n0_ori * n1_ori * 2; fft_plan->mlu_addrs.input = input; fft_plan->mlu_addrs.output = output; fft_plan->mlu_addrs.buffer_in = (uint8_t *)workspace + offset; @@ -451,20 +451,22 @@ static void configureRFFT2dWorkspaceAddrs(mluOpHandle_t handle, if (fft_plan->fft_strategy == CNFFT_FUNC_TWO_LEVEL_STOCKHAM) { fft_plan->mlu_addrs.buffer_buf = (uint8_t *)workspace + offset; - offset += batch * out_c_dtype_size * _n0 * _n1 * 2; + offset += batch * out_c_dtype_size * n0_ori * n1_ori * 2; - if (fft_plan->is_input_contiguous) { + if ((fft_plan->is_input_contiguous && + fft_plan->inembed[0] <= fft_plan->n[0] && + fft_plan->inembed[1] <= fft_plan->n[1])) { fft_plan->mlu_addrs.input = input; } else { fft_plan->mlu_addrs.input = (uint8_t *)workspace + offset; - offset += batch * out_c_dtype_size * _n0 * _n1; + offset += batch * out_c_dtype_size * n0_ori * n1_ori; } if (fft_plan->is_output_contiguous) { fft_plan->mlu_addrs.output = output; } else { fft_plan->mlu_addrs.output = (uint8_t *)workspace + offset; - offset += batch * out_c_dtype_size * _n0 * _n1; + offset += batch * out_c_dtype_size * n0_ori * n1_ori; } } @@ -1109,17 +1111,8 @@ static mluOpStatus_t makeRFFT1dContiguousOutput(mluOpHandle_t handle, DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(copy_dst_desc, cnnl_copy_dst_desc); - size_t workspace_size = 0; - CALL_CNNL(cnnlGetCopyWorkspaceSize(cnnl_handle, cnnl_copy_src_desc, - cnnl_copy_dst_desc, &workspace_size)); - - void *workspace = nullptr; - if (workspace_size > 0) { - CNRT_CHECK(cnrtMalloc((void **)&workspace, workspace_size)); - } CALL_CNNL(cnnlCopy_v2(cnnl_handle, cnnl_copy_src_desc, copy_src_addr, - cnnl_copy_dst_desc, output, workspace, - workspace_size)); + cnnl_copy_dst_desc, output, NULL, 0)); DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_copy_src_desc); DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_copy_dst_desc); @@ -1138,8 +1131,7 @@ static mluOpStatus_t makeRFFT2dContiguousInput(mluOpHandle_t handle, auto status = MLUOP_STATUS_SUCCESS; if ((!fft_plan->is_input_contiguous || (fft_plan->inembed[0] > fft_plan->n[0] || - fft_plan->inembed[1] > fft_plan->n[1]) && - !fft_plan->prime) && + fft_plan->inembed[1] > fft_plan->n[1])) && fft_plan->fft_strategy != CNFFT_FUNC_MANY_DIST1_2D) { VLOG(5) << "launch mluOpContiguous for rfft2d input"; mluOpTensorDescriptor_t input_desc; @@ -1153,9 +1145,10 @@ static mluOpStatus_t makeRFFT2dContiguousInput(mluOpHandle_t handle, : fft_plan->n[0], fft_plan->n[1] > fft_plan->inembed[1] ? fft_plan->inembed[1] : fft_plan->n[1]}; - int64_t strides[in_dim_num] = {fft_plan->idist, - (fft_plan->istride * fft_plan->inembed[1]), - fft_plan->istride}; + int64_t strides[in_dim_num]; // in_dim_num + for (int i = 0; i < in_dim_num; i++) { + strides[i] = fft_plan->in_stride[i]; + } status = mluOpSetTensorDescriptorEx_v2(input_desc, MLUOP_LAYOUT_ARRAY, fft_plan->input_dtype, in_dim_num, dims, strides); @@ -1191,9 +1184,10 @@ static mluOpStatus_t makeRFFT2dContiguousOutput(mluOpHandle_t handle, const int out_dim_num = 3; int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->n[0], fft_plan->n[1] / 2 + 1}; - int64_t strides[out_dim_num] = {fft_plan->odist, - fft_plan->ostride * fft_plan->onembed[1], - fft_plan->ostride}; + int64_t strides[out_dim_num]; // out_dim_num + for (int i = 0; i < out_dim_num; i++) { + strides[i] = fft_plan->out_stride[i]; + } status = mluOpSetTensorDescriptor_v2(copy_src_desc, MLUOP_LAYOUT_ARRAY, out_c_dtype, out_dim_num, dims); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); @@ -1211,17 +1205,8 @@ static mluOpStatus_t makeRFFT2dContiguousOutput(mluOpHandle_t handle, DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(copy_dst_desc, cnnl_copy_dst_desc); - size_t workspace_size = 0; - CALL_CNNL(cnnlGetCopyWorkspaceSize(cnnl_handle, cnnl_copy_src_desc, - cnnl_copy_dst_desc, &workspace_size)); - - void *workspace = nullptr; - if (workspace_size > 0) { - CNRT_CHECK(cnrtMalloc((void **)&workspace, workspace_size)); - } CALL_CNNL(cnnlCopy_v2(cnnl_handle, cnnl_copy_src_desc, copy_src_addr, - cnnl_copy_dst_desc, output, workspace, - workspace_size)); + cnnl_copy_dst_desc, output, NULL, 0)); DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_copy_src_desc); DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_copy_dst_desc); @@ -1287,12 +1272,15 @@ mluOpStatus_t execRFFT1d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan, const float beta[2] = {0.0, 0.0}; mluOpTensorDescriptor_t c_desc = nullptr; status = mluOpCreateTensorDescriptor(&c_desc); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); const int out_dim_num = 2; int64_t dims[out_dim_num] = {fft_plan->batch, (fft_plan->n[0] / 2 + 1)}; status = mluOpSetTensorDescriptor_v2(c_desc, MLUOP_LAYOUT_ARRAY, - fft_plan->output_dtype, 2, dims); + fft_plan->output_dtype, out_dim_num, dims); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); status = mluOpSetTensorDescriptorOnchipDataType( c_desc, fft_plan->execution_dtype); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, cnnl_handle); // convert to cnnl_handle @@ -1303,6 +1291,8 @@ mluOpStatus_t execRFFT1d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan, cnnl_output_desc, fft_plan->mlu_addrs.output, &beta, cnnl_output_desc, fft_plan->mlu_addrs.output)); + status = mluOpDestroyTensorDescriptor(c_desc); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_output_desc); DESTROY_CNNL_HANDLE(cnnl_handle); } @@ -1511,21 +1501,21 @@ mluOpStatus_t computeFFT2dMatMulRowR2C(mluOpHandle_t handle, int requested_algo_count = 1, return_algo_count = 0; float *workspace; size_t workspace_size; - cnnlGetBatchMatMulAlgoHeuristic( + cnnlGetBatchMatMulExAlgoHeuristic( cnnl_handle, bmm_bcast_desc, cnnl_a_desc, cnnl_b_desc, cnnl_c_desc, NULL, requested_algo_count, &heuristic_result, &return_algo_count); - cnnlGetBatchMatMulHeuristicResult(heuristic_result, algo, &workspace_size); + cnnlGetBatchMatMulExHeuristicResult(heuristic_result, algo, &workspace_size); if (workspace_size > 0) { CNRT_CHECK(cnrtMalloc((void **)&workspace, workspace_size)); } else { CNRT_CHECK(cnrtMalloc((void **)&workspace, m * n * sizeof(float))); } - CALL_CNNL(cnnlBatchMatMulBCast_v2(cnnl_handle, bmm_bcast_desc, algo, &alpha, - cnnl_a_desc, dft_matrix_addr, cnnl_b_desc, - in_addr, &beta, cnnl_c_desc, out_addr, - (void *)workspace, workspace_size)); + CALL_CNNL(cnnlBatchMatMulEx(cnnl_handle, bmm_bcast_desc, algo, &alpha, + cnnl_a_desc, dft_matrix_addr, cnnl_b_desc, + in_addr, &beta, cnnl_c_desc, out_addr, + (void *)workspace, workspace_size)); // destroy cnnl descriptor DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_a_desc); DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_b_desc); @@ -1561,70 +1551,109 @@ mluOpStatus_t execRFFT2d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan, status = makeRFFT2dContiguousInput(handle, fft_plan, input); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); - if (fft_plan->n[0] > fft_plan->inembed[0] || - fft_plan->n[1] > fft_plan->inembed[1]) { - status = padRFFT2dContiguousInput(handle, fft_plan); + if (fft_plan->n[0] == 1 && fft_plan->n[1] == 1) { + mluOpTensorDescriptor_t input_desc, padded_output_desc; + status = mluOpCreateTensorDescriptor(&input_desc); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); - - fft_plan->mlu_addrs.input = fft_plan->mlu_addrs.input_pad_addr; - } - - for (int batch_id = 0; batch_id < fft_plan->batch; batch_id++) { - status = kernelRFFT2dButterflyRow(k_dim, k_type, handle->queue, fft_plan, - RFFT); - + status = mluOpCreateTensorDescriptor(&padded_output_desc); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); - status = kernelRFFT2dButterflyColumn(k_dim, k_type, handle->queue, - fft_plan, FFT_IFFT); + const int in_dim_num = 2; + int64_t dims[in_dim_num] = {fft_plan->batch, + fft_plan->n[0] * fft_plan->n[1]}; + status = mluOpSetTensorDescriptor_v2(input_desc, MLUOP_LAYOUT_ARRAY, + MLUOP_DTYPE_FLOAT, in_dim_num, dims); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); - fft_plan->mlu_addrs.input = - (void *)((uint64_t)(fft_plan->mlu_addrs.input) + idist); - fft_plan->mlu_addrs.output = - (void *)((uint64_t)(fft_plan->mlu_addrs.output) + odist); - } - fft_plan->mlu_addrs.input = (void *)((uint64_t)(fft_plan->mlu_addrs.input) - - fft_plan->batch * idist); - fft_plan->mlu_addrs.output = - (void *)((uint64_t)(fft_plan->mlu_addrs.output) - - fft_plan->batch * odist); - - if (scale_factor != 1.0) { - const float alpha[2] = {scale_factor, 0.0}; - const float beta[2] = {0.0, 0.0}; - mluOpTensorDescriptor_t c_desc = nullptr; - status = mluOpCreateTensorDescriptor(&c_desc); - const int out_dim_num = 3; - int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->n[0], - fft_plan->n[1] / 2 + 1}; - status = mluOpSetTensorDescriptor_v2(c_desc, MLUOP_LAYOUT_ARRAY, - fft_plan->output_dtype, 3, dims); - status = mluOpSetTensorDescriptorOnchipDataType( - c_desc, fft_plan->execution_dtype); + int64_t padded_dims[in_dim_num] = {fft_plan->batch, + fft_plan->n[0] * fft_plan->n[1] * 2}; + status = mluOpSetTensorDescriptor_v2( + padded_output_desc, MLUOP_LAYOUT_ARRAY, MLUOP_DTYPE_FLOAT, in_dim_num, + padded_dims); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + const int pad_dim_num = 4; + int paddings[pad_dim_num] = {0, 0, 0, 1}; + uint64_t padding_value = 0x00000000; DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, cnnl_handle); // convert to cnnl_handle - DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(c_desc, cnnl_output_desc); + DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(input_desc, cnnl_input_desc); + DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(padded_output_desc, + cnnl_padded_output_desc); + CALL_CNNL(cnnlPad(cnnl_handle, cnnl_input_desc, fft_plan->mlu_addrs.input, + paddings, &padding_value, cnnl_padded_output_desc, + fft_plan->mlu_addrs.output)); + + DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_input_desc); + DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_padded_output_desc); - CALL_CNNL(cnnlTransform_v2(cnnl_handle, CNNL_POINTER_MODE_HOST, &alpha, - cnnl_output_desc, fft_plan->mlu_addrs.output, - &beta, cnnl_output_desc, - fft_plan->mlu_addrs.output)); - DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_output_desc); DESTROY_CNNL_HANDLE(cnnl_handle); - } - INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + } else { + for (int batch_id = 0; batch_id < fft_plan->batch; batch_id++) { + status = kernelRFFT2dButterflyRow(k_dim, k_type, handle->queue, + fft_plan, RFFT); - status = makeRFFT2dContiguousOutput(handle, fft_plan, output); - INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + status = kernelRFFT2dButterflyColumn(k_dim, k_type, handle->queue, + fft_plan, FFT_IFFT); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + + fft_plan->mlu_addrs.input = + (void *)((uint64_t)(fft_plan->mlu_addrs.input) + idist); + fft_plan->mlu_addrs.output = + (void *)((uint64_t)(fft_plan->mlu_addrs.output) + odist); + } + fft_plan->mlu_addrs.input = + (void *)((uint64_t)(fft_plan->mlu_addrs.input) - + fft_plan->batch * idist); + fft_plan->mlu_addrs.output = + (void *)((uint64_t)(fft_plan->mlu_addrs.output) - + fft_plan->batch * odist); + } } else if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) { status = computeFFT2dMatMulRowR2C(handle, fft_plan, scale_factor); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); status = computeFFT2dMatMulColumnR2C(handle, fft_plan, scale_factor); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); } + + if (scale_factor != 1.0) { + const float alpha[2] = {scale_factor, 0.0}; + const float beta[2] = {0.0, 0.0}; + mluOpTensorDescriptor_t c_desc = nullptr; + status = mluOpCreateTensorDescriptor(&c_desc); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + const int out_dim_num = 3; + int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->n[0], + fft_plan->n[1] / 2 + 1}; + status = mluOpSetTensorDescriptor_v2(c_desc, MLUOP_LAYOUT_ARRAY, + fft_plan->output_dtype, out_dim_num, dims); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + status = mluOpSetTensorDescriptorOnchipDataType( + c_desc, fft_plan->execution_dtype); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + + DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, + cnnl_handle); // convert to cnnl_handle + + DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(c_desc, cnnl_output_desc); + + CALL_CNNL(cnnlTransform_v2(cnnl_handle, CNNL_POINTER_MODE_HOST, &alpha, + cnnl_output_desc, fft_plan->mlu_addrs.output, + &beta, cnnl_output_desc, + fft_plan->mlu_addrs.output)); + status = mluOpDestroyTensorDescriptor(c_desc); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_output_desc); + DESTROY_CNNL_HANDLE(cnnl_handle); + } + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + + if (fft_plan->fft_strategy == CNFFT_FUNC_TWO_LEVEL_STOCKHAM) { + status = makeRFFT2dContiguousOutput(handle, fft_plan, output); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + } return status; } diff --git a/kernels/tensor_stride_process/tensor_stride_process_host.cpp b/kernels/tensor_stride_process/tensor_stride_process_host.cpp index 410112258..bcb9685a6 100644 --- a/kernels/tensor_stride_process/tensor_stride_process_host.cpp +++ b/kernels/tensor_stride_process/tensor_stride_process_host.cpp @@ -484,7 +484,8 @@ mluOpContiguous(mluOpHandle_t handle, const mluOpTensorDescriptor_t input_desc, DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(input_desc, cnnl_input_desc); DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(temp_desc, cnnl_temp_desc); CALL_CNNL( - cnnlCopy(cnnl_handle, cnnl_input_desc, input, cnnl_temp_desc, output)); + cnnlCopy_v2(cnnl_handle, cnnl_input_desc, input, cnnl_temp_desc, output, + NULL, 0)); DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_input_desc); DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_temp_desc); DESTROY_CNNL_HANDLE(cnnl_handle);