diff --git a/kernels/fft/c2c_fft/c2c_fft_host.cpp b/kernels/fft/c2c_fft/c2c_fft_host.cpp index 2148696e2..bd8d12ce6 100644 --- a/kernels/fft/c2c_fft/c2c_fft_host.cpp +++ b/kernels/fft/c2c_fft/c2c_fft_host.cpp @@ -157,7 +157,8 @@ mluOpStatus_t makeFFT1dPolicy(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan) { fft_plan->workspace_size += transed_input_size; // input trans workspace: batch * n * 2 --> 2 * batch * n const int in_trans_dim_num = 2; - int in_trans_input_dims[in_trans_dim_num] = {padded_input_num, COMPLEX}; + int64_t in_trans_input_dims[in_trans_dim_num] = {padded_input_num, + COMPLEX}; int in_trans_permute[in_trans_dim_num] = {1, 0}; size_t in_trans_workspace_size = 0; status = fftGetTransposeWorkspaceSize( @@ -205,8 +206,8 @@ mluOpStatus_t makeFFT1dPolicy(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan) { // output trans workspace: 2 * batch * n --> batch * n * 2 const int out_trans_dim_num = 2; - int out_trans_input_dims[out_trans_dim_num] = {COMPLEX, - per_matmul_output_num}; + int64_t out_trans_input_dims[out_trans_dim_num] = {COMPLEX, + per_matmul_output_num}; int out_trans_permute[out_trans_dim_num] = {1, 0}; size_t out_trans_workspace_size = 0; status = fftGetTransposeWorkspaceSize( @@ -332,8 +333,8 @@ mluOpStatus_t makeFFT1dPolicy(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan) { // input trans workspace // 1st transpose: batch * n * 2 --> 2 * batch * n const int in_trans_1st_dim_num = 2; - int in_trans_1st_input_dims[in_trans_1st_dim_num] = {padded_input_num, - COMPLEX}; + int64_t in_trans_1st_input_dims[in_trans_1st_dim_num] = {padded_input_num, + COMPLEX}; int in_trans_1st_permute[in_trans_1st_dim_num] = {1, 0}; size_t in_trans_1st_workspace_size = 0; status = fftGetTransposeWorkspaceSize( @@ -344,8 +345,8 @@ mluOpStatus_t makeFFT1dPolicy(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan) { in_trans_1st_workspace_size); // 2nd transpose: 2 * batch * L * 2^m --> 2 * batch * 2^m * L const int in_trans_2nd_dim_num = 3; - int in_trans_2nd_input_dims[in_trans_2nd_dim_num] = {COMPLEX * batch, L, - m}; + int64_t in_trans_2nd_input_dims[in_trans_2nd_dim_num] = {COMPLEX * batch, + L, m}; int in_trans_2nd_permute[in_trans_2nd_dim_num] = {0, 2, 1}; size_t in_trans_2nd_workspace_size = 0; status = fftGetTransposeWorkspaceSize( @@ -375,12 +376,28 @@ mluOpStatus_t makeFFT1dPolicy(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan) { fft_plan->workspace_size += matmul_times * per_matmul_output_size; // matmul workspace size_t matmul_workspace_size = 0; - status = fftGetQuantizeMatMulWorkspaceSize( - handle, matmul_workspace_size, batch * m, L, L, false, true, - in_e_dtype, in_e_dtype, in_r_dtype, api); - fft_plan->matmul_addrs.internal_workspace_size = - std::max(fft_plan->matmul_addrs.internal_workspace_size, - matmul_workspace_size); + if (fft_plan->fft_strategy == CNFFT_FUNC_COOLEY_TUKEY) { + status = fftGetQuantizeMatMulWorkspaceSize( + handle, matmul_workspace_size, batch * m, L, L, false, true, + in_e_dtype, in_e_dtype, in_r_dtype, api); + fft_plan->matmul_addrs.internal_workspace_size = + std::max(fft_plan->matmul_addrs.internal_workspace_size, + matmul_workspace_size); + } else { + status = fftGetBatchMatMulBcastWorkspaceSize( + handle, 2 * L, L, m * 2, batch, + fft_plan->matmul_addrs.dft_im_matrix_addr, + fft_plan->matmul_addrs.dft_pos_addr, + fft_plan->matmul_addrs.dft_scale_addr, + fft_plan->matmul_addrs.input_pad_addr, + fft_plan->matmul_addrs.input_pos_addr, + fft_plan->matmul_addrs.input_scale_addr, + fft_plan->matmul_addrs.matmul_re_mul_re_addr, false, false, 1.0, + 0.0, in_e_dtype, in_e_dtype, in_r_dtype, + fft_plan->matmul_addrs.internal_workspace_addr, + fft_plan->matmul_addrs.internal_workspace_size, api); + } + // optensor workspace size_t optensor_workspace_size = 0; status = @@ -1039,7 +1056,9 @@ static void configureFFT1dWorkspaceAddrs(mluOpHandle_t handle, fft_plan->mlu_addrs.buffer_buf = (uint8_t *)workspace + offset; offset += buffer_size * 2; - if (fft_plan->is_input_contiguous || fft_plan->is_batch_contiguous) { + if ((fft_plan->is_input_contiguous && + fft_plan->inembed[0] <= fft_plan->n[0]) || + fft_plan->is_batch_contiguous) { fft_plan->mlu_addrs.input = input; } else { fft_plan->mlu_addrs.input = (uint8_t *)workspace + offset; @@ -1230,8 +1249,7 @@ static mluOpStatus_t padFFT1dContiguousInput(mluOpHandle_t handle, INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); const int in_dim_num = 2; - int64_t dims[in_dim_num] = {batch, - std::min(fft_plan->inembed[0], n) * COMPLEX}; + int64_t dims[in_dim_num] = {batch, fft_plan->inembed[0] * COMPLEX}; status = mluOpSetTensorDescriptor_v2(input_desc, MLUOP_LAYOUT_ARRAY, in_r_dtype, in_dim_num, dims); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); @@ -1242,8 +1260,7 @@ static mluOpStatus_t padFFT1dContiguousInput(mluOpHandle_t handle, INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); const int pad_dim_num = 4; - int paddings[pad_dim_num] = { - 0, 0, 0, std::max((n - fft_plan->inembed[0]), 0) * COMPLEX}; + int paddings[pad_dim_num] = {0, 0, 0, (n - fft_plan->inembed[0]) * COMPLEX}; uint64_t padding_value = 0x00000000; DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, @@ -1255,14 +1272,9 @@ static mluOpStatus_t padFFT1dContiguousInput(mluOpHandle_t handle, cnnl_padded_input_desc); CALL_CNNL(cnnlPad( cnnl_handle, cnnl_input_desc, - // (fft_plan->n[0] > fft_plan->inembed[0]) ? fft_plan->mlu_addrs.input - // : fft_plan->matmul_addrs.input_contiguous_addr, fft_plan->prime ? fft_plan->matmul_addrs.input_contiguous_addr : fft_plan->mlu_addrs.input, paddings, &padding_value, cnnl_padded_input_desc, - // (fft_plan->n[0] > fft_plan->inembed[0]) ? - // fft_plan->mlu_addrs.input_pad_addr - // : fft_plan->matmul_addrs.input_pad_addr fft_plan->prime ? fft_plan->matmul_addrs.input_pad_addr : fft_plan->mlu_addrs.input_pad_addr)); @@ -1396,8 +1408,8 @@ static mluOpStatus_t transposeFFT1dPaddedInput(mluOpHandle_t handle, VLOG(5) << "launch mluOpTranspose for input CNFFT_FUNC_MATMUL"; int padded_input_num = batch * n; const int trans_dim_num = 2; - int trans_input_dims[trans_dim_num] = {padded_input_num, COMPLEX}; - int trans_output_dims[trans_dim_num] = {COMPLEX, padded_input_num}; + int64_t trans_input_dims[trans_dim_num] = {padded_input_num, COMPLEX}; + int64_t trans_output_dims[trans_dim_num] = {COMPLEX, padded_input_num}; int trans_permute[trans_dim_num] = {1, 0}; status = @@ -1415,8 +1427,8 @@ static mluOpStatus_t transposeFFT1dPaddedInput(mluOpHandle_t handle, // 1st transpose: batch * n * 2 --> 2 * batch * n int padded_input_num = batch * n; const int trans_dim_num = 2; - int trans_input_dims[trans_dim_num] = {padded_input_num, COMPLEX}; - int trans_output_dims[trans_dim_num] = {COMPLEX, padded_input_num}; + int64_t trans_input_dims[trans_dim_num] = {padded_input_num, COMPLEX}; + int64_t trans_output_dims[trans_dim_num] = {COMPLEX, padded_input_num}; int trans_permute[trans_dim_num] = {1, 0}; status = @@ -1428,8 +1440,8 @@ static mluOpStatus_t transposeFFT1dPaddedInput(mluOpHandle_t handle, // 2nd transpose: 2 * batch * L * 2^m --> 2 * batch * 2^m * L const int trans_2nd_dim_num = 3; - int trans_2nd_input_dims[trans_2nd_dim_num] = {COMPLEX * batch, L, m}; - int trans_2nd_output_dims[trans_2nd_dim_num] = {COMPLEX * batch, m, L}; + int64_t trans_2nd_input_dims[trans_2nd_dim_num] = {COMPLEX * batch, L, m}; + int64_t trans_2nd_output_dims[trans_2nd_dim_num] = {COMPLEX * batch, m, L}; int trans_2nd_permute[trans_2nd_dim_num] = {0, 2, 1}; status = fftTranspose(handle, trans_2nd_dim_num, trans_2nd_input_dims, @@ -1739,8 +1751,8 @@ static mluOpStatus_t transposeFFT1dOutput(mluOpHandle_t handle, int output_num = batch * n; const int trans_dim_num = 2; - int trans_input_dims[trans_dim_num] = {COMPLEX, output_num}; - int trans_output_dims[trans_dim_num] = {output_num, COMPLEX}; + int64_t trans_input_dims[trans_dim_num] = {COMPLEX, output_num}; + int64_t trans_output_dims[trans_dim_num] = {output_num, COMPLEX}; int trans_permute[trans_dim_num] = {1, 0}; status = fftTranspose( diff --git a/kernels/fft/common/fft_basic_ops.cpp b/kernels/fft/common/fft_basic_ops.cpp index d58ab8c9b..24cb115ac 100644 --- a/kernels/fft/common/fft_basic_ops.cpp +++ b/kernels/fft/common/fft_basic_ops.cpp @@ -211,7 +211,35 @@ mluOpStatus_t fftGetQuantizeMatMulWorkspaceSize( CALL_CNNL(cnnlMatMulDescDestroy(matmul_desc)); CALL_CNNL(cnnlMatMulAlgoDestroy(matmul_algo)); } else { - workspace_size = 0; // mluOpMatmul doesn't need workspace. + // workspace_size = 0; // mluOpMatmul doesn't need workspace. + cnnlMatMulDescriptor_t matmul_desc; + cnnlMatMulAlgo_t matmul_algo; + cnnlMatMulHeuristicResult_t heuristic_result; + size_t matmul_ws_size = 0, workspace_size = 0; + bool allow_tf32 = false; + cnnlDataType_t cnnl_compute_type = CNNL_DTYPE_FLOAT; // (TODO) + + DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(c_desc, cnnl_d_desc); + CALL_CNNL(cnnlMatMulDescCreate(&matmul_desc)); + CALL_CNNL(cnnlSetMatMulDescAttr(matmul_desc, CNNL_MATMUL_DESC_TRANSA, + &trans_a_int, sizeof(int32_t))); + CALL_CNNL(cnnlSetMatMulDescAttr(matmul_desc, CNNL_MATMUL_DESC_TRANSB, + &trans_b_int, sizeof(int32_t))); + CALL_CNNL(cnnlSetMatMulDescAttr(matmul_desc, CNNL_MATMUL_ALLOW_TF32, + &allow_tf32, sizeof(int32_t))); + CALL_CNNL(cnnlSetMatMulDescAttr(matmul_desc, CNNL_MATMUL_DESC_COMPUTE_TYPE, + &cnnl_compute_type, + sizeof(cnnl_compute_type))); + CALL_CNNL(cnnlMatMulAlgoCreate(&matmul_algo)); + CALL_CNNL(cnnlCreateMatMulHeuristicResult(&heuristic_result)); + int32_t requested_algo_count = 1, return_algo_count = 0; + + CALL_CNNL(cnnlGetMatMulAlgoHeuristic( + cnnl_handle, matmul_desc, cnnl_a_desc, cnnl_b_desc, cnnl_c_desc, + cnnl_d_desc, nullptr, requested_algo_count, &heuristic_result, + &return_algo_count)); + CALL_CNNL(cnnlGetMatMulHeuristicResult(heuristic_result, matmul_algo, + &workspace_size)); } // destroy cnnl descriptor @@ -265,6 +293,7 @@ mluOpStatus_t fftQuantMatMul(mluOpHandle_t handle, int m, int k, int n, b_dims[0] = k; b_dims[1] = n; } + status = mluOpSetTensorDescriptor_v2(a_desc, MLUOP_LAYOUT_ARRAY, data_type, 2, a_dims); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); @@ -298,6 +327,7 @@ mluOpStatus_t fftQuantMatMul(mluOpHandle_t handle, int m, int k, int n, DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(a_desc, cnnl_a_desc); DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(b_desc, cnnl_b_desc); DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(c_desc, cnnl_c_desc); + DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(c_desc, cnnl_d_desc); // compute matmul result if (fftIsIntDtype(a_compute_type) && fftIsIntDtype(b_compute_type)) { @@ -334,11 +364,137 @@ mluOpStatus_t fftQuantMatMul(mluOpHandle_t handle, int m, int k, int n, CALL_CNNL(cnnlMatMulAlgoDestroy(matmul_algo)); } else { c_desc->onchip_dtype = MLUOP_DTYPE_FLOAT; - CALL_CNNL(cnnlMatMul(cnnl_handle, is_trans_a, is_trans_b, &alpha, - cnnl_a_desc, a_ptr, cnnl_b_desc, b_ptr, &beta, - cnnl_c_desc, c_ptr)); + cnnlMatMulDescriptor_t matmul_desc; + cnnlMatMulAlgo_t matmul_algo; + cnnlMatMulHeuristicResult_t heuristic_result; + size_t matmul_ws_size = 0, workspace_size = 0; + bool allow_tf32 = false; + cnnlDataType_t cnnl_compute_type = CNNL_DTYPE_FLOAT; // (TODO) + + CALL_CNNL(cnnlMatMulDescCreate(&matmul_desc)); + CALL_CNNL(cnnlSetMatMulDescAttr(matmul_desc, CNNL_MATMUL_DESC_TRANSA, + &trans_a_int, sizeof(int32_t))); + CALL_CNNL(cnnlSetMatMulDescAttr(matmul_desc, CNNL_MATMUL_DESC_TRANSB, + &trans_b_int, sizeof(int32_t))); + CALL_CNNL(cnnlSetMatMulDescAttr(matmul_desc, CNNL_MATMUL_ALLOW_TF32, + &allow_tf32, sizeof(int32_t))); + CALL_CNNL(cnnlSetMatMulDescAttr(matmul_desc, CNNL_MATMUL_DESC_COMPUTE_TYPE, + &cnnl_compute_type, + sizeof(cnnl_compute_type))); + CALL_CNNL(cnnlMatMulAlgoCreate(&matmul_algo)); + CALL_CNNL(cnnlCreateMatMulHeuristicResult(&heuristic_result)); + int32_t requested_algo_count = 1, return_algo_count = 0; + + CALL_CNNL(cnnlGetMatMulAlgoHeuristic( + cnnl_handle, matmul_desc, cnnl_a_desc, cnnl_b_desc, cnnl_c_desc, + cnnl_d_desc, nullptr, requested_algo_count, &heuristic_result, + &return_algo_count)); + CALL_CNNL(cnnlGetMatMulHeuristicResult(heuristic_result, matmul_algo, + &workspace_size)); + float *workspace = nullptr; + if (workspace_size > 0) { + CNRT_CHECK(cnrtMalloc((void **)&workspace, workspace_size)); + } + CALL_CNNL(cnnlMatMul_v2(cnnl_handle, matmul_desc, matmul_algo, &alpha, + cnnl_a_desc, a_ptr, cnnl_b_desc, b_ptr, &beta, + cnnl_c_desc, c_ptr, workspace, workspace_size, + cnnl_d_desc, c_ptr)); + } + + // destroy cnnl descriptor + DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_a_desc); + DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_b_desc); + DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_c_desc); + + DESTROY_CNNL_HANDLE(cnnl_handle); + + return status; +} + +mluOpStatus_t fftGetBatchMatMulBcastWorkspaceSize( + mluOpHandle_t handle, + int m, // 2 * L = 750 + int k, // L = 375 + int n, // 2^m = 128 + int batch, void *a_ptr, void *a_pos, void *a_scale, void *b_ptr, + void *b_pos, void *b_scale, void *c_ptr, bool is_trans_a, bool is_trans_b, + float alpha, float beta, mluOpDataType_t a_compute_type, + mluOpDataType_t b_compute_type, mluOpDataType_t data_type, void *workspace, + size_t workspace_size, const std::string api) { + mluOpStatus_t status = MLUOP_STATUS_SUCCESS; + int trans_a_int = (int)is_trans_a; + int trans_b_int = (int)is_trans_b; + + // create descriptor + mluOpTensorDescriptor_t a_desc = nullptr; + mluOpTensorDescriptor_t b_desc = nullptr; + mluOpTensorDescriptor_t c_desc = nullptr; + status = mluOpCreateTensorDescriptor(&a_desc); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + status = mluOpCreateTensorDescriptor(&b_desc); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + status = mluOpCreateTensorDescriptor(&c_desc); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + + // set descriptor + int64_t a_dims[2]; + int64_t b_dims[3] = {batch, k, n}; + int64_t c_dims[3] = {batch, m, n}; + if (is_trans_a) { + a_dims[0] = k; + a_dims[1] = m; + } else { + a_dims[0] = m; + a_dims[1] = k; + } + if (is_trans_b) { + b_dims[1] = n; + b_dims[2] = k; + } else { + b_dims[1] = k; + b_dims[2] = n; } + status = mluOpSetTensorDescriptor_v2(a_desc, MLUOP_LAYOUT_ARRAY, data_type, 2, + a_dims); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + status = mluOpSetTensorDescriptorOnchipDataType(a_desc, a_compute_type); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + status = mluOpSetTensorDescriptor_v2(b_desc, MLUOP_LAYOUT_ARRAY, data_type, 3, + b_dims); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + status = mluOpSetTensorDescriptorOnchipDataType(b_desc, b_compute_type); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + status = mluOpSetTensorDescriptor_v2(c_desc, MLUOP_LAYOUT_ARRAY, data_type, 3, + c_dims); + INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); + c_desc->onchip_dtype = MLUOP_DTYPE_FLOAT; + + DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, + cnnl_handle); // convert to cnnl_handle + // convert to cnnl_tensor_descriptor + DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(a_desc, cnnl_a_desc); + DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(b_desc, cnnl_b_desc); + DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(c_desc, cnnl_c_desc); + cnnlMatMulAlgo_t algo; + CALL_CNNL(cnnlMatMulAlgoCreate(&algo)); + cnnlMatMulDescriptor_t bmm_bcast_desc; + bool use_stride = false; + auto cast_mode = CNNL_MATMUL_BYPASS_QUANTIZE; + CALL_CNNL(cnnlMatMulDescCreate(&bmm_bcast_desc)); + CALL_CNNL(cnnlSetMatMulDescAttr(bmm_bcast_desc, CNNL_MATMUL_DESC_TRANSA, + &trans_a_int, sizeof(int32_t))); + CALL_CNNL(cnnlSetMatMulDescAttr(bmm_bcast_desc, CNNL_MATMUL_DESC_TRANSB, + &trans_b_int, sizeof(int32_t))); + + cnnlMatMulHeuristicResult_t heuristic_result; + CALL_CNNL(cnnlCreateMatMulHeuristicResult(&heuristic_result)); + int requested_algo_count = 1, return_algo_count = 0; + cnnlGetBatchMatMulAlgoHeuristic( + cnnl_handle, bmm_bcast_desc, cnnl_a_desc, cnnl_b_desc, cnnl_c_desc, NULL, + requested_algo_count, &heuristic_result, &return_algo_count); + cnnlGetBatchMatMulHeuristicResult(heuristic_result, algo, &workspace_size); + // destroy descriptor // destroy cnnl descriptor DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_a_desc); DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_b_desc); @@ -375,9 +531,9 @@ mluOpStatus_t fftBatchMatMulBcast( INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); // set descriptor - int a_dims[2]; - int b_dims[3] = {batch, k, n}; - int c_dims[3] = {batch, m, n}; + int64_t a_dims[2]; + int64_t b_dims[3] = {batch, k, n}; + int64_t c_dims[3] = {batch, m, n}; if (is_trans_a) { a_dims[0] = k; a_dims[1] = m; @@ -392,18 +548,18 @@ mluOpStatus_t fftBatchMatMulBcast( b_dims[1] = k; b_dims[2] = n; } - status = mluOpSetTensorDescriptor(a_desc, MLUOP_LAYOUT_ARRAY, data_type, 2, - a_dims); + status = mluOpSetTensorDescriptor_v2(a_desc, MLUOP_LAYOUT_ARRAY, data_type, 2, + a_dims); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); status = mluOpSetTensorDescriptorOnchipDataType(a_desc, a_compute_type); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); - status = mluOpSetTensorDescriptor(b_desc, MLUOP_LAYOUT_ARRAY, data_type, 3, - b_dims); + status = mluOpSetTensorDescriptor_v2(b_desc, MLUOP_LAYOUT_ARRAY, data_type, 3, + b_dims); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); status = mluOpSetTensorDescriptorOnchipDataType(b_desc, b_compute_type); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); - status = mluOpSetTensorDescriptor(c_desc, MLUOP_LAYOUT_ARRAY, data_type, 3, - c_dims); + status = mluOpSetTensorDescriptor_v2(c_desc, MLUOP_LAYOUT_ARRAY, data_type, 3, + c_dims); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); c_desc->onchip_dtype = MLUOP_DTYPE_FLOAT; @@ -415,10 +571,36 @@ mluOpStatus_t fftBatchMatMulBcast( DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(b_desc, cnnl_b_desc); DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(c_desc, cnnl_c_desc); - CALL_CNNL(cnnlBatchMatMulBCast(cnnl_handle, is_trans_a, is_trans_b, - cnnl_a_desc, a_ptr, cnnl_b_desc, b_ptr, NULL, - 0, cnnl_c_desc, c_ptr)); + cnnlMatMulAlgo_t algo; + CALL_CNNL(cnnlMatMulAlgoCreate(&algo)); + cnnlMatMulDescriptor_t bmm_bcast_desc; + bool use_stride = false; + auto cast_mode = CNNL_MATMUL_BYPASS_QUANTIZE; + CALL_CNNL(cnnlMatMulDescCreate(&bmm_bcast_desc)); + CALL_CNNL(cnnlSetMatMulDescAttr(bmm_bcast_desc, CNNL_MATMUL_DESC_TRANSA, + &trans_a_int, sizeof(int32_t))); + CALL_CNNL(cnnlSetMatMulDescAttr(bmm_bcast_desc, CNNL_MATMUL_DESC_TRANSB, + &trans_b_int, sizeof(int32_t))); + + cnnlMatMulHeuristicResult_t heuristic_result; + CALL_CNNL(cnnlCreateMatMulHeuristicResult(&heuristic_result)); + alpha = 1.0; + beta = 0.0; + int requested_algo_count = 1, return_algo_count = 0; + cnnlGetBatchMatMulAlgoHeuristic( + cnnl_handle, bmm_bcast_desc, cnnl_a_desc, cnnl_b_desc, cnnl_c_desc, NULL, + requested_algo_count, &heuristic_result, &return_algo_count); + cnnlGetBatchMatMulHeuristicResult(heuristic_result, algo, &workspace_size); + if (workspace_size > 0) { + CNRT_CHECK(cnrtMalloc((void **)&workspace, workspace_size)); + } else { + CNRT_CHECK(cnrtMalloc((void **)&workspace, m * n * sizeof(float))); + } + CALL_CNNL(cnnlBatchMatMulBCast_v2(cnnl_handle, bmm_bcast_desc, algo, &alpha, + cnnl_a_desc, a_ptr, cnnl_b_desc, b_ptr, + &beta, cnnl_c_desc, c_ptr, + (void *)workspace, workspace_size)); // destroy descriptor // destroy cnnl descriptor DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_a_desc); @@ -432,11 +614,10 @@ mluOpStatus_t fftBatchMatMulBcast( mluOpStatus_t fftGetTransposeWorkspaceSize(mluOpHandle_t handle, size_t &workspace_size, int dim_num, - int ori_dims[], int permute[], + int64_t ori_dims[], int permute[], mluOpDataType_t data_type, const std::string api) { mluOpStatus_t status = MLUOP_STATUS_SUCCESS; - // create descriptor mluOpTensorDescriptor_t input_desc = nullptr; status = mluOpCreateTensorDescriptor(&input_desc); @@ -448,13 +629,11 @@ mluOpStatus_t fftGetTransposeWorkspaceSize(mluOpHandle_t handle, DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, cnnl_handle); // convert to cnnl_handle // set descriptor - status = mluOpSetTensorDescriptor(input_desc, MLUOP_LAYOUT_ARRAY, data_type, - dim_num, ori_dims); + status = mluOpSetTensorDescriptor_v2(input_desc, MLUOP_LAYOUT_ARRAY, + data_type, dim_num, ori_dims); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); - DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(input_desc, cnnl_input_desc); - + DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR_v2(input_desc, cnnl_input_desc); CALL_CNNL(cnnlSetTransposeDescriptor(trans_desc, dim_num, permute)); - // get workspace CALL_CNNL(cnnlGetTransposeWorkspaceSize(cnnl_handle, cnnl_input_desc, trans_desc, &workspace_size)); @@ -463,15 +642,14 @@ mluOpStatus_t fftGetTransposeWorkspaceSize(mluOpHandle_t handle, DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_input_desc); CALL_CNNL(cnnlDestroyTransposeDescriptor(trans_desc)); DESTROY_CNNL_HANDLE(cnnl_handle); - return status; } -mluOpStatus_t fftTranspose(mluOpHandle_t handle, int dim_num, int ori_dims[], - int transed_dims[], int permute[], void *ori_ptr, - void *transed_ptr, mluOpDataType_t data_type, - void *workspace, size_t workspace_size, - const std::string api) { +mluOpStatus_t fftTranspose(mluOpHandle_t handle, int dim_num, + int64_t ori_dims[], int64_t transed_dims[], + int permute[], void *ori_ptr, void *transed_ptr, + mluOpDataType_t data_type, void *workspace, + size_t workspace_size, const std::string api) { mluOpStatus_t status = MLUOP_STATUS_SUCCESS; // create descriptor @@ -483,20 +661,24 @@ mluOpStatus_t fftTranspose(mluOpHandle_t handle, int dim_num, int ori_dims[], INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); // set descriptor - status = mluOpSetTensorDescriptor(input_desc, MLUOP_LAYOUT_ARRAY, data_type, - dim_num, ori_dims); + status = mluOpSetTensorDescriptor_v2(input_desc, MLUOP_LAYOUT_ARRAY, + data_type, dim_num, ori_dims); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); - status = mluOpSetTensorDescriptor(transed_input_desc, MLUOP_LAYOUT_ARRAY, - data_type, dim_num, transed_dims); + status = mluOpSetTensorDescriptor_v2(transed_input_desc, MLUOP_LAYOUT_ARRAY, + data_type, dim_num, transed_dims); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, cnnl_handle); // convert to cnnl_handle - // convert to cnnl_tensor_descriptor - DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(input_desc, cnnl_input_desc); - DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(transed_input_desc, - cnnl_transed_input_desc); + // convert to cnnl_tensor_descriptor + DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR_v2(input_desc, cnnl_input_desc); + // cnnlTensorDescriptor_t cnnl_input_desc = NULL; + // CALL_CNNL(cnnlSetTensorDescriptor_v2(cnnl_input_desc, CNNL_LAYOUT_ARRAY, + // data_type, + // dim_num, ori_dims)); + DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR_v2(transed_input_desc, + cnnl_transed_input_desc); // compute transpose cnnlTransposeDescriptor_t trans_desc = nullptr; CALL_CNNL(cnnlCreateTransposeDescriptor(&trans_desc)); diff --git a/kernels/fft/common/fft_basic_ops.h b/kernels/fft/common/fft_basic_ops.h index 8d28d179b..89b84350c 100644 --- a/kernels/fft/common/fft_basic_ops.h +++ b/kernels/fft/common/fft_basic_ops.h @@ -65,6 +65,14 @@ mluOpStatus_t fftQuantMatMul(mluOpHandle_t handle, int m, int k, int n, mluOpDataType_t data_type, void *workspace, size_t workspace_size, const std::string api); +mluOpStatus_t fftGetBatchMatMulBcastWorkspaceSize( + mluOpHandle_t handle, int m, int k, int n, int batch, void *a_ptr, + void *a_pos, void *a_scale, void *b_ptr, void *b_pos, void *b_scale, + void *c_ptr, bool is_trans_a, bool is_trans_b, float alpha, float beta, + mluOpDataType_t a_compute_type, mluOpDataType_t b_compute_type, + mluOpDataType_t data_type, void *workspace, size_t workspace_size, + const std::string api); + mluOpStatus_t fftBatchMatMulBcast(mluOpHandle_t handle, int m, int k, int n, int batch, void *a_ptr, void *a_pos, void *a_scale, void *b_ptr, void *b_pos, @@ -77,15 +85,15 @@ mluOpStatus_t fftBatchMatMulBcast(mluOpHandle_t handle, int m, int k, int n, mluOpStatus_t fftGetTransposeWorkspaceSize(mluOpHandle_t handle, size_t &workspace_size, int dim_num, - int ori_dims[], int permute[], + int64_t ori_dims[], int permute[], mluOpDataType_t data_type, const std::string api); -mluOpStatus_t fftTranspose(mluOpHandle_t handle, int dim_num, int ori_dims[], - int transed_dims[], int permute[], void *ori_ptr, - void *transed_ptr, mluOpDataType_t data_type, - void *workspace, size_t workspace_size, - const std::string api); +mluOpStatus_t fftTranspose(mluOpHandle_t handle, int dim_num, + int64_t ori_dims[], int64_t transed_dims[], + int permute[], void *ori_ptr, void *transed_ptr, + mluOpDataType_t data_type, void *workspace, + size_t workspace_size, const std::string api); mluOpStatus_t fftGetOptensorWorkspaceSize(mluOpHandle_t handle, size_t &workspace_size, int elem_num, diff --git a/kernels/fft/fft.cpp b/kernels/fft/fft.cpp index 0629ff2fa..986b98f12 100644 --- a/kernels/fft/fft.cpp +++ b/kernels/fft/fft.cpp @@ -1847,7 +1847,8 @@ mluOpStatus_t MLUOP_WIN_API mluOpMakeFFTPlanC2C1D( fft_plan->is_batch_contiguous = (fft_plan->idist == 1 && fft_plan->odist == 1 && fft_plan->istride == fft_plan->batch && - fft_plan->ostride == fft_plan->batch); + fft_plan->ostride == fft_plan->batch) && + (fft_plan->n[0] == fft_plan->inembed[0]); mluOpAllocateC2C1D(handle, fft_plan, input_desc, output_desc, n[0]); int is_row_major = !fft_plan->is_batch_contiguous; fftTwoStepFactor(handle, fft_plan, n[0], fft_plan->factors, is_row_major, @@ -2625,7 +2626,10 @@ mluOpStatus_t MLUOP_WIN_API mluOpMakeFFTPlanMany( fft_plan->fft_type == CNFFT_COMPLEX_HALF2COMPLEX_HALF || n[0] == 1) { fft_plan->prime = 1; } - fft_plan->prime = fft_plan->prime || (n[0] <= 2 && rank == 1); + fft_plan->prime = + fft_plan->prime || + ((n[0] <= 2 || n[0] == 400 || n[0] == 512 || n[0] == 48000) && rank == 1); + /* * decision part */ @@ -2812,6 +2816,17 @@ mluOpStatus_t MLUOP_WIN_API mluOpExecFFT(mluOpHandle_t handle, bool is_in_place = (input == output); VLOG(5) << exec_api << ": in place ? " << is_in_place; + + if (fft_plan->rank == 2 && + (mluop::strideCaseWithNotConsistentDense(1, fft_plan->input_desc) || + mluop::strideCaseWithNotConsistentDense(1, fft_plan->output_desc))) { + LOG(ERROR) + << exec_api + << ": 2d stride case with not consistent dense is not supported now."; + status = MLUOP_STATUS_BAD_PARAM; + GEN_CASE_END(); + return status; + } switch (fft_plan->fft_type) { // r2c case CNFFT_HALF2COMPLEX_HALF: @@ -2829,8 +2844,15 @@ mluOpStatus_t MLUOP_WIN_API mluOpExecFFT(mluOpHandle_t handle, status = execRFFT1d(handle, fft_plan, input, scale_factor, workspace, output); } else if (fft_plan->rank == 2) { - status = execRFFT2d(handle, fft_plan, input, scale_factor, workspace, - output); + if (fft_plan->inembed[1] > fft_plan->n[1]) { + LOG(ERROR) << exec_api + << ": inembed[1] > fft_plan->n[1] is not supported now"; + status = MLUOP_STATUS_BAD_PARAM; + + } else { + status = execRFFT2d(handle, fft_plan, input, scale_factor, workspace, + output); + } } else if (fft_plan->rank == 3) { // TODO(who) status = MLUOP_STATUS_NOT_SUPPORTED; @@ -2852,8 +2874,15 @@ mluOpStatus_t MLUOP_WIN_API mluOpExecFFT(mluOpHandle_t handle, status = execFFT1d(handle, fft_plan, input, scale_factor, workspace, output, direction); } else if (fft_plan->rank == 2) { - status = execFFT2d(handle, fft_plan, input, scale_factor, workspace, - output, direction); + if (fft_plan->inembed[1] > fft_plan->n[1]) { + LOG(ERROR) << exec_api + << ": inembed[1] > fft_plan->n[1] is not supported now"; + status = MLUOP_STATUS_BAD_PARAM; + + } else { + status = execFFT2d(handle, fft_plan, input, scale_factor, workspace, + output, direction); + } } else if (fft_plan->rank == 3) { // TODO(who) status = MLUOP_STATUS_NOT_SUPPORTED; @@ -2862,7 +2891,7 @@ mluOpStatus_t MLUOP_WIN_API mluOpExecFFT(mluOpHandle_t handle, // c2r case CNFFT_COMPLEX_HALF2HALF: case CNFFT_COMPLEX_FLOAT2FLOAT: { - if (((fft_plan->idist * 2) < fft_plan->odist) && is_in_place) { + if (((fft_plan->idist * 2) > fft_plan->odist) && is_in_place) { LOG(ERROR) << exec_api << ": output overwritten may occur during an in-place " @@ -2875,8 +2904,15 @@ mluOpStatus_t MLUOP_WIN_API mluOpExecFFT(mluOpHandle_t handle, status = execIRFFT1d(handle, fft_plan, input, scale_factor, workspace, output); } else if (fft_plan->rank == 2) { - status = execIRFFT2d(handle, fft_plan, input, scale_factor, workspace, - output); + if (fft_plan->inembed[1] > (fft_plan->n[1] / 2 + 1)) { + LOG(ERROR) << exec_api + << ": inembed[1] > fft_plan->n[1] is not supported now"; + status = MLUOP_STATUS_BAD_PARAM; + + } else { + status = execIRFFT2d(handle, fft_plan, input, scale_factor, workspace, + output); + } } else if (fft_plan->rank == 3) { // TODO(who) status = MLUOP_STATUS_NOT_SUPPORTED; diff --git a/kernels/fft/fft_optm_device/fft_two-level_network_c2c_device.mlu b/kernels/fft/fft_optm_device/fft_two-level_network_c2c_device.mlu index 2d6e72659..55b3a37b6 100644 --- a/kernels/fft/fft_optm_device/fft_two-level_network_c2c_device.mlu +++ b/kernels/fft/fft_optm_device/fft_two-level_network_c2c_device.mlu @@ -116,7 +116,7 @@ mluOpStatus_t MLUOP_WIN_API kernelFFT2dButterflyColumn( mluOpStatus_t MLUOP_WIN_API kernelFFT1dButterflyColumn( cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, mluOpFFTPlan_t fft_plan, const int direction, FFTFlag flag) { - VLOG(5) << "Launch Kernel kernelFFT1dButterflyRow <>>"; if (direction == FFT_FORWARD) { diff --git a/kernels/fft/irfft/irfft_host.cpp b/kernels/fft/irfft/irfft_host.cpp index 590d46f1b..c68680f27 100644 --- a/kernels/fft/irfft/irfft_host.cpp +++ b/kernels/fft/irfft/irfft_host.cpp @@ -155,7 +155,7 @@ mluOpStatus_t makeIRFFT1dPolicy(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan) { // input trans workspace: batch * (n / 2 + 1) * 2 --> 2 * batch * (n / 2 + // 1) const int trans_dim_num = 2; - int trans_input_dims[trans_dim_num] = {padded_input_num, COMPLEX}; + int64_t trans_input_dims[trans_dim_num] = {padded_input_num, COMPLEX}; int trans_permute[trans_dim_num] = {1, 0}; size_t trans_workspace_size = 0; status = fftGetTransposeWorkspaceSize(handle, trans_workspace_size, @@ -332,7 +332,8 @@ mluOpStatus_t makeIRFFT1dPolicy(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan) { // transpose workspace: batch * (n / 2 + 1) * 2 --> 2 * batch * (n / 2 + // 1) concat workspace: concat do not need workspace now const int trans_1st_dim_num = 2; - int trans_1st_input_dims[trans_1st_dim_num] = {padded_input_num, COMPLEX}; + int64_t trans_1st_input_dims[trans_1st_dim_num] = {padded_input_num, + COMPLEX}; int trans_1st_permute[trans_1st_dim_num] = {1, 0}; size_t trans_1st_workspace_size = 0; status = fftGetTransposeWorkspaceSize( @@ -348,7 +349,7 @@ mluOpStatus_t makeIRFFT1dPolicy(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan) { fft_plan->workspace_size += transed_input_size; // input trans workspace: 2 * batch * L * 2^m --> 2 * batch * 2^m * L const int trans_2nd_dim_num = 3; - int trans_2nd_input_dims[trans_2nd_dim_num] = {COMPLEX * batch, L, m}; + int64_t trans_2nd_input_dims[trans_2nd_dim_num] = {COMPLEX * batch, L, m}; int trans_2nd_permute[trans_2nd_dim_num] = {0, 2, 1}; size_t trans_2nd_workspace_size = 0; status = fftGetTransposeWorkspaceSize( @@ -379,12 +380,27 @@ mluOpStatus_t makeIRFFT1dPolicy(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan) { fft_plan->workspace_size += matmul_output_size; // matmul workspace size_t matmul_workspace_size = 0; - status = fftGetQuantizeMatMulWorkspaceSize( - handle, matmul_workspace_size, batch * m, L, L, false, true, - in_e_dtype, in_e_dtype, in_r_dtype, api); - fft_plan->matmul_addrs.internal_workspace_size = - std::max(fft_plan->matmul_addrs.internal_workspace_size, - matmul_workspace_size); + if (fft_plan->fft_strategy == CNFFT_FUNC_COOLEY_TUKEY) { + status = fftGetQuantizeMatMulWorkspaceSize( + handle, matmul_workspace_size, batch * m, L, L, false, true, + in_e_dtype, in_e_dtype, in_r_dtype, api); + fft_plan->matmul_addrs.internal_workspace_size = + std::max(fft_plan->matmul_addrs.internal_workspace_size, + matmul_workspace_size); + } else { + status = fftGetBatchMatMulBcastWorkspaceSize( + handle, 2 * L, L, m, batch * 2, + fft_plan->matmul_addrs.dft_re_matrix_addr, + fft_plan->matmul_addrs.dft_pos_addr, + fft_plan->matmul_addrs.dft_scale_addr, + fft_plan->matmul_addrs.input_merged_addr, + fft_plan->matmul_addrs.input_pos_addr, + fft_plan->matmul_addrs.input_scale_addr, + fft_plan->matmul_addrs.matmul_re_mul_re_addr, false, false, 1.0, + 0.0, in_e_dtype, in_e_dtype, in_r_dtype, + fft_plan->matmul_addrs.internal_workspace_addr, + fft_plan->matmul_addrs.internal_workspace_size, api); + } // optensor workspace size_t optensor_workspace_size = 0; status = @@ -747,7 +763,8 @@ static void configureIRFFT1dWorkspaceAddrs(mluOpHandle_t handle, fft_plan->mlu_addrs.buffer_buf = (uint8_t *)workspace + offset; offset += buffer_size * 2; - if (fft_plan->is_input_contiguous) { + if (fft_plan->is_input_contiguous && + fft_plan->inembed[0] <= fft_plan->n[0] / 2 + 1) { fft_plan->mlu_addrs.input = input; } else { fft_plan->mlu_addrs.input = (uint8_t *)workspace + offset; @@ -810,7 +827,6 @@ static mluOpStatus_t padIRFFT1dContiguousInput(mluOpHandle_t handle, std::string api = "[mluOpExecFFT]"; VLOG(5) << "into padIRFFT1dContiguousInput"; mluOpStatus_t status = MLUOP_STATUS_SUCCESS; - mluOpDataType_t in_c_dtype = fft_plan->input_dtype; mluOpDataType_t in_r_dtype = (in_c_dtype == MLUOP_DTYPE_COMPLEX_HALF) ? MLUOP_DTYPE_HALF @@ -827,8 +843,7 @@ static mluOpStatus_t padIRFFT1dContiguousInput(mluOpHandle_t handle, INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); const int in_dim_num = 2; - int64_t dims[in_dim_num] = { - batch, std::min(FFT_HALF(n), fft_plan->inembed[0]) * COMPLEX}; + int64_t dims[in_dim_num] = {batch, fft_plan->inembed[0] * COMPLEX}; status = mluOpSetTensorDescriptor_v2(input_desc, MLUOP_LAYOUT_ARRAY, in_r_dtype, in_dim_num, dims); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); @@ -839,8 +854,9 @@ static mluOpStatus_t padIRFFT1dContiguousInput(mluOpHandle_t handle, INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); const int pad_dim_num = 4; + int paddings[pad_dim_num] = { - 0, 0, 0, std::max((FFT_HALF(n) - fft_plan->inembed[0]), 0) * COMPLEX}; + 0, 0, 0, (FFT_HALF(n) - fft_plan->inembed[0]) * COMPLEX}; uint64_t padding_value = 0x00000000; DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, @@ -989,8 +1005,8 @@ static mluOpStatus_t mergeIRFFT1dInput(mluOpHandle_t handle, VLOG(5) << "launch mluOpTranspose for input"; int padded_input_num = batch * FFT_HALF(n); const int trans_dim_num = 2; - int trans_input_dims[trans_dim_num] = {padded_input_num, COMPLEX}; - int trans_output_dims[trans_dim_num] = {COMPLEX, padded_input_num}; + int64_t trans_input_dims[trans_dim_num] = {padded_input_num, COMPLEX}; + int64_t trans_output_dims[trans_dim_num] = {COMPLEX, padded_input_num}; int trans_permute[trans_dim_num] = {1, 0}; status = @@ -1022,9 +1038,9 @@ static mluOpStatus_t mergeIRFFT1dInput(mluOpHandle_t handle, int dim1_begin = (n % 2) ? -1 : -2; int dim1_end = -FFT_HALF(n); - int begin[ss_dim_num] = {0, dim1_begin}; - int end[ss_dim_num] = {COMPLEX * batch, dim1_end}; - int stride[ss_dim_num] = {1, -1}; + int64_t begin[ss_dim_num] = {0, dim1_begin}; + int64_t end[ss_dim_num] = {COMPLEX * batch, dim1_end}; + int64_t stride[ss_dim_num] = {1, -1}; void *ss_input_addr = fft_plan->matmul_addrs.input_transed_addr; void *ss_output_addr = fft_plan->matmul_addrs.input_reversed_addr; @@ -1037,9 +1053,9 @@ static mluOpStatus_t mergeIRFFT1dInput(mluOpHandle_t handle, cnnl_ss_input_desc); DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(ss_output_desc, cnnl_ss_output_desc); - CALL_CNNL(cnnlStridedSlice(cnnl_handle, cnnl_ss_input_desc, ss_input_addr, - begin, end, stride, cnnl_ss_output_desc, - ss_output_addr)); + CALL_CNNL(cnnlStridedSlice_v2(cnnl_handle, cnnl_ss_input_desc, + ss_input_addr, begin, end, stride, + cnnl_ss_output_desc, ss_output_addr)); // reversed input imag part mul -1 int reversed_input_num = batch * (n - FFT_HALF(n)); @@ -1139,8 +1155,8 @@ static mluOpStatus_t transposeIRFFT1dPaddedInput(mluOpHandle_t handle, VLOG(5) << "launch mluOpTranspose for input MATMUL"; int padded_input_num = batch * FFT_HALF(n); const int trans_dim_num = 2; - int trans_input_dims[trans_dim_num] = {padded_input_num, COMPLEX}; - int trans_output_dims[trans_dim_num] = {COMPLEX, padded_input_num}; + int64_t trans_input_dims[trans_dim_num] = {padded_input_num, COMPLEX}; + int64_t trans_output_dims[trans_dim_num] = {COMPLEX, padded_input_num}; int trans_permute[trans_dim_num] = {1, 0}; status = @@ -1157,8 +1173,8 @@ static mluOpStatus_t transposeIRFFT1dPaddedInput(mluOpHandle_t handle, // 2nd transpose: 2 * batch * L * 2^m --> 2 * batch * 2^m * L const int trans_dim_num = 3; - int trans_input_dims[trans_dim_num] = {COMPLEX * batch, L, m}; - int trans_output_dims[trans_dim_num] = {COMPLEX * batch, m, L}; + int64_t trans_input_dims[trans_dim_num] = {COMPLEX * batch, L, m}; + int64_t trans_output_dims[trans_dim_num] = {COMPLEX * batch, m, L}; int trans_permute[trans_dim_num] = {0, 2, 1}; status = diff --git a/kernels/fft/rfft/rfft_host.cpp b/kernels/fft/rfft/rfft_host.cpp index a8ee8c81a..f64354fd8 100644 --- a/kernels/fft/rfft/rfft_host.cpp +++ b/kernels/fft/rfft/rfft_host.cpp @@ -254,7 +254,7 @@ mluOpStatus_t makeRFFT1dPolicy(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan) { fft_plan->workspace_size += transed_input_size; // input trans workspace: batch * L * 2^m --> batch * 2^m * L const int trans_dim_num = 3; - int trans_input_dims[trans_dim_num] = {batch, L, m}; + int64_t trans_input_dims[trans_dim_num] = {batch, L, m}; int trans_permute[trans_dim_num] = {0, 2, 1}; size_t trans_workspace_size = 0; status = fftGetTransposeWorkspaceSize(handle, trans_workspace_size, @@ -283,12 +283,29 @@ mluOpStatus_t makeRFFT1dPolicy(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan) { fft_plan->workspace_size += matmul_output_size; // matmul workspace size_t matmul_workspace_size = 0; - status = fftGetQuantizeMatMulWorkspaceSize( - handle, matmul_workspace_size, batch * m, L, L, false, true, - in_e_dtype, in_e_dtype, in_r_dtype, api); - fft_plan->matmul_addrs.internal_workspace_size = - std::max(fft_plan->matmul_addrs.internal_workspace_size, - matmul_workspace_size); + if (fft_plan->fft_strategy == CNFFT_FUNC_COOLEY_TUKEY) { + status = fftGetQuantizeMatMulWorkspaceSize( + handle, matmul_workspace_size, batch * m, L, L, false, true, + in_e_dtype, in_e_dtype, in_r_dtype, api); + fft_plan->matmul_addrs.internal_workspace_size = + std::max(fft_plan->matmul_addrs.internal_workspace_size, + matmul_workspace_size); + } else { + status = fftGetBatchMatMulBcastWorkspaceSize( + handle, + L <= fft_plan->L_sub ? (2 * L) + : (2 * (PAD_UP(L / 2, fft_plan->L_sub) + 1)), + L, m, batch, fft_plan->matmul_addrs.dft_re_matrix_addr, + fft_plan->matmul_addrs.dft_pos_addr, + fft_plan->matmul_addrs.dft_scale_addr, + fft_plan->matmul_addrs.input_pad_addr, + fft_plan->matmul_addrs.input_pos_addr, + fft_plan->matmul_addrs.input_scale_addr, + fft_plan->matmul_addrs.matmul_re_mul_re_addr, false, false, 1.0, + 0.0, in_e_dtype, in_e_dtype, in_r_dtype, + fft_plan->matmul_addrs.internal_workspace_addr, + fft_plan->matmul_addrs.internal_workspace_size, api); + } // output merge workspace size_t merge_workspace_size = matmul_output_size; @@ -389,7 +406,7 @@ static void configureRFFT1dWorkspaceAddrs(mluOpHandle_t handle, fft_plan->mlu_addrs.buffer_buf = (uint8_t *)workspace + offset; offset += buffer_size * 2; - if (fft_plan->is_input_contiguous) { + if (fft_plan->is_input_contiguous && fft_plan->inembed[0] <= fft_plan->n[0]) { fft_plan->mlu_addrs.input = input; } else { fft_plan->mlu_addrs.input = (uint8_t *)workspace + offset; @@ -863,8 +880,8 @@ static mluOpStatus_t transposeRFFT1dPaddedInput(mluOpHandle_t handle, int m = (1 << fft_plan->m); const int trans_dim_num = 3; - int trans_input_dims[trans_dim_num] = {batch, L, m}; - int trans_output_dims[trans_dim_num] = {batch, m, L}; + int64_t trans_input_dims[trans_dim_num] = {batch, L, m}; + int64_t trans_output_dims[trans_dim_num] = {batch, m, L}; int trans_permute[trans_dim_num] = {0, 2, 1}; status = @@ -1221,7 +1238,6 @@ mluOpStatus_t execRFFT1d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan, void *workspace, void *output) { mluOpStatus_t status = MLUOP_STATUS_SUCCESS; std::string api = "[mluOpExecFFT]"; - if (fft_plan->prime) { configureRFFT1dMatmulWorkspaceAddrs(handle, fft_plan, (void *)input, workspace, output); diff --git a/kernels/utils/cnnl_helper.cpp b/kernels/utils/cnnl_helper.cpp index f814c1823..9e5787e7e 100644 --- a/kernels/utils/cnnl_helper.cpp +++ b/kernels/utils/cnnl_helper.cpp @@ -105,6 +105,71 @@ cnnlStatus_t mluOpConvertDescriptor(mluOpTensorDescriptor_t desc, return CNNL_STATUS_SUCCESS; } +cnnlStatus_t mluOpConvertDescriptor_v2(mluOpTensorDescriptor_t desc, + cnnlTensorDescriptor_t _desc) { + if (desc == NULL) { + return CNNL_STATUS_SUCCESS; + } + mluOpDataType_t dtype, onchip_dtype; + mluOpTensorLayout_t layout; + int tensor_dim; + CHECK_FUNC_RETURN( + mluOpGetTensorDescriptor(desc, &layout, &dtype, &tensor_dim, NULL), + MLUOP_STATUS_SUCCESS, "MLUOPS get tensor descriptor failed.", + CNNL_STATUS_INTERNAL_ERROR); + CHECK_FUNC_RETURN(mluOpGetTensorDescriptorOnchipDataType(desc, &onchip_dtype), + MLUOP_STATUS_SUCCESS, + "MLUOPS get tensor descriptor onchip type failed.", + CNNL_STATUS_INTERNAL_ERROR); + int64_t *dims = new int64_t[tensor_dim]; + int64_t *strides = new int64_t[tensor_dim]; + CHECK_FUNC_RETURN(mluOpGetTensorDescriptorEx_v2(desc, &layout, &dtype, + &tensor_dim, dims, strides), + MLUOP_STATUS_SUCCESS, + "MLUOPS get tensor descriptor Ex failed.", + CNNL_STATUS_INTERNAL_ERROR); + CHECK_FUNC_RETURN( + cnnlSetTensorDescriptor_v2( + _desc, + mluOpConvertEnum(layout), + mluOpConvertEnum(dtype), tensor_dim, + dims), + CNNL_STATUS_SUCCESS, "Internal set tensor descriptor failed.", + CNNL_STATUS_INTERNAL_ERROR); + CHECK_FUNC_RETURN( + cnnlSetTensorDescriptorEx_v2( + _desc, + mluOpConvertEnum(layout), + mluOpConvertEnum(dtype), tensor_dim, + dims, strides), + CNNL_STATUS_SUCCESS, "Internal set tensor descriptor Ex failed.", + CNNL_STATUS_INTERNAL_ERROR); + CHECK_FUNC_RETURN( + cnnlSetTensorDescriptorOnchipDataType( + _desc, + mluOpConvertEnum(onchip_dtype)), + CNNL_STATUS_SUCCESS, "Internal set tensor descriptor Ex failed.", + CNNL_STATUS_INTERNAL_ERROR); + int position; + float scale; + int offset; + CHECK_FUNC_RETURN( + mluOpGetTensorDescriptorPositionScaleAndOffset(desc, &position, &scale, + &offset), + MLUOP_STATUS_SUCCESS, + "MLUOPS get tensor descriptor position scale and offset failed.", + CNNL_STATUS_INTERNAL_ERROR); + CHECK_FUNC_RETURN( + cnnlSetTensorDescriptorPositionScaleAndOffset(_desc, position, scale, + offset), + CNNL_STATUS_SUCCESS, + "Internal set tensor descriptor position scale and offset failed.", + CNNL_STATUS_INTERNAL_ERROR); + delete[] dims; + delete[] strides; + return CNNL_STATUS_SUCCESS; +} + cnnlStatus_t mluOpConvertHandle(mluOpHandle_t handle, cnnlHandle_t _handle) { cnrtQueue_t queue; CHECK_FUNC_RETURN(mluOpGetQueue(handle, &queue), MLUOP_STATUS_SUCCESS, diff --git a/kernels/utils/cnnl_helper.h b/kernels/utils/cnnl_helper.h index f9c43f994..32ca006ab 100644 --- a/kernels/utils/cnnl_helper.h +++ b/kernels/utils/cnnl_helper.h @@ -60,6 +60,9 @@ void mluOpCnnlCheck(mluOpStatus_t result, char const *const func, cnnlStatus_t mluOpConvertDescriptor(mluOpTensorDescriptor_t desc, cnnlTensorDescriptor_t _desc); +cnnlStatus_t mluOpConvertDescriptor_v2(mluOpTensorDescriptor_t desc, + cnnlTensorDescriptor_t _desc); + cnnlStatus_t mluOpConvertHandle(mluOpHandle_t handle, cnnlHandle_t _handle); // Pointer type force convert @@ -87,6 +90,27 @@ DTYPE mluOpPointerForceConvert(STYPE ptr); } \ } +// TensorDescriptor +#define DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR_v2(desc, _desc) \ + cnnlTensorDescriptor_t _desc; \ + { \ + if (desc != NULL) { \ + cnnlStatus_t ret = cnnlCreateTensorDescriptor(&_desc); \ + if (ret != CNNL_STATUS_SUCCESS) { \ + LOG(ERROR) << "CNNL_HELPER: CNNL creates tensor descriptor failed."; \ + return MLUOP_STATUS_INTERNAL_ERROR; \ + } \ + ret = mluOpConvertDescriptor_v2(desc, _desc); \ + if (ret != CNNL_STATUS_SUCCESS) { \ + LOG(ERROR) \ + << "CNNL_HELPER: Internal convert tensor descriptor failed."; \ + return MLUOP_STATUS_INTERNAL_ERROR; \ + } \ + } else { \ + _desc = NULL; \ + } \ + } + #define CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(desc, _desc) \ { \ cnnlStatus_t ret = cnnlCreateTensorDescriptor(&_desc); \ diff --git a/mlu_op.h b/mlu_op.h index a39beaf49..ea7e162ce 100644 --- a/mlu_op.h +++ b/mlu_op.h @@ -14148,12 +14148,8 @@ typedef struct mluOpFFTStruct *mluOpFFTPlan_t; * Otherwise, the memory leak may occur. * * @par Note - * - This function only supports 1D FFT currently. 2D FFT and 3D FFT + * - This function only supports 1D and 2D FFT currently. 3D FFT * will be supported in the future. - * - When the data type of input is float or complex_float, the 1D FFT length should be equal to: - * length = \f$base * 2^ {m}\f$, and the base should be less than or equal to 4096. - * - When the data type of input is half or complex_half, the 1D FFT length should be equal to: - * length = \f$2^{m}\f$. * * @par Example. * - None. @@ -14374,6 +14370,10 @@ mluOpSetFFTReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, void *rese * when planning to use FFT with half-precision floating-point data, as it limits the flexibility compared to float data * types. * + * - For FFT 2D: + * - real-to-complex FFT: Output numbers / 2 + 1 should not be less than input numbers. + * - complex-to-complex FFT: Output numbers should not be less than input numbers. + * - complex-to-real FFT: Output numbers should not be less than input numbers / 2 + 1. * * @par API Dependency * - Before calling this function, you need to call the ::mluOpCreateFFTPlan diff --git a/test/mlu_op_gtest/api_gtest/include/api_test_tools.h b/test/mlu_op_gtest/api_gtest/include/api_test_tools.h index f15c074eb..91c3ba5d0 100644 --- a/test/mlu_op_gtest/api_gtest/include/api_test_tools.h +++ b/test/mlu_op_gtest/api_gtest/include/api_test_tools.h @@ -59,7 +59,39 @@ class MLUOpTensorParam { mluOpDataType_t dtype_; int dim_nb_; std::vector dim_size_; + std::vector dim_size_int64_; std::vector dim_stride_; + std::vector dim_stride_int64_; + mluOpDataType_t onchip_dtype_; +}; + +class MLUOpTensorParamInt64 { + public: + MLUOpTensorParamInt64(mluOpTensorLayout_t layout, mluOpDataType_t dtype, + int64_t dim_nb, std::vector dim_size, + std::vector dim_stride = {}, + mluOpDataType_t onchip_dtype = MLUOP_DTYPE_INVALID) { + layout_ = layout; + dtype_ = dtype; + dim_nb_ = dim_nb; + dim_size_ = dim_size; + dim_stride_ = dim_stride; + onchip_dtype_ = onchip_dtype; + } + + mluOpTensorLayout_t get_layout() { return layout_; } + mluOpDataType_t get_dtype() { return dtype_; } + int64_t get_dim_nb() { return dim_nb_; } + std::vector get_dim_size() { return dim_size_; } + std::vector get_dim_stride() { return dim_stride_; } + mluOpDataType_t get_onchip_dtype() { return onchip_dtype_; } + + private: + mluOpTensorLayout_t layout_; + mluOpDataType_t dtype_; + int64_t dim_nb_; + std::vector dim_size_; + std::vector dim_stride_; mluOpDataType_t onchip_dtype_; }; } // namespace mluopapitest diff --git a/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_ExecFFT.cpp b/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_ExecFFT.cpp index 3833221d5..c2fd648a0 100644 --- a/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_ExecFFT.cpp +++ b/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_ExecFFT.cpp @@ -68,20 +68,21 @@ class fft_ExecFFT : public testing::Test { const int batch = 2000; const int n[rank] = {400}; const int ndim = rank + 1; - const int input_dim_size[ndim] = {batch, n[0] / 2 + 1}; - const int input_dim_stride[ndim] = {n[0] / 2 + 1, 1}; + const int64_t input_dim_size[ndim] = {batch, n[0] / 2 + 1}; + const int64_t input_dim_stride[ndim] = {n[0] / 2 + 1, 1}; - const int output_dim_size[ndim] = {batch, n[0] / 2 + 1}; - const int output_dim_stride[ndim] = {n[0] / 2 + 1, 1}; + const int64_t output_dim_size[ndim] = {batch, n[0] / 2 + 1}; + const int64_t output_dim_stride[ndim] = {n[0] / 2 + 1, 1}; mluOpCreateTensorDescriptor(&input_desc_); mluOpCreateTensorDescriptor(&output_desc_); - mluOpSetTensorDescriptorEx(input_desc_, MLUOP_LAYOUT_ARRAY, input_data_type, - ndim, input_dim_size, input_dim_stride); + mluOpSetTensorDescriptorEx_v2(input_desc_, MLUOP_LAYOUT_ARRAY, + input_data_type, ndim, input_dim_size, + input_dim_stride); mluOpSetTensorDescriptorOnchipDataType(input_desc_, execution_dtype); - mluOpSetTensorDescriptorEx(output_desc_, MLUOP_LAYOUT_ARRAY, - output_data_type, ndim, output_dim_size, - output_dim_stride); + mluOpSetTensorDescriptorEx_v2(output_desc_, MLUOP_LAYOUT_ARRAY, + output_data_type, ndim, output_dim_size, + output_dim_stride); size_t reservespaceSizeInBytes_ = 64; size_t workspaceSizeInBytes_ = 64; size_t *reservespace_size = &reservespaceSizeInBytes_; diff --git a/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_MakeFFTPlanMany.cpp b/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_MakeFFTPlanMany.cpp index 479097664..83eee3d15 100644 --- a/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_MakeFFTPlanMany.cpp +++ b/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_MakeFFTPlanMany.cpp @@ -46,9 +46,9 @@ class fft_MakeFFTPlanMany : public testing::Test { if (input_desc) { MLUOP_CHECK(mluOpCreateTensorDescriptor(&input_desc_)); - std::vector input_dims{1, 400}; - const int input_dim_stride[2] = {400, 1}; - MLUOP_CHECK(mluOpSetTensorDescriptorEx( + std::vector input_dims{1, 400}; + const int64_t input_dim_stride[2] = {400, 1}; + MLUOP_CHECK(mluOpSetTensorDescriptorEx_v2( input_desc_, MLUOP_LAYOUT_ARRAY, MLUOP_DTYPE_FLOAT, input_dims.size(), input_dims.data(), input_dim_stride)); MLUOP_CHECK(mluOpSetTensorDescriptorOnchipDataType(input_desc_, @@ -57,9 +57,9 @@ class fft_MakeFFTPlanMany : public testing::Test { if (output_desc) { MLUOP_CHECK(mluOpCreateTensorDescriptor(&output_desc_)); - std::vector output_dims{1, 201}; - const int output_dim_stride[2] = {201, 1}; - MLUOP_CHECK(mluOpSetTensorDescriptorEx( + std::vector output_dims{1, 201}; + const int64_t output_dim_stride[2] = {201, 1}; + MLUOP_CHECK(mluOpSetTensorDescriptorEx_v2( output_desc_, MLUOP_LAYOUT_ARRAY, MLUOP_DTYPE_COMPLEX_FLOAT, output_dims.size(), output_dims.data(), output_dim_stride)); } diff --git a/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_general.cpp b/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_general.cpp index 0b2ac99cc..c47f15fcb 100644 --- a/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_general.cpp +++ b/test/mlu_op_gtest/api_gtest/src/gtest/fft/fft_general.cpp @@ -32,8 +32,8 @@ #include "api_test_tools.h" namespace mluopapitest { -typedef std::tuple +typedef std::tuple FFTParams; class fft_general : public testing::TestWithParam { @@ -44,9 +44,9 @@ class fft_general : public testing::TestWithParam { MLUOP_CHECK(mluOpCreate(&handle_)); MLUOP_CHECK(mluOpCreateFFTPlan(&fft_plan_)); - MLUOpTensorParam input_params = std::get<0>(GetParam()); + MLUOpTensorParamInt64 input_params = std::get<0>(GetParam()); MLUOP_CHECK(mluOpCreateTensorDescriptor(&input_desc_)); - MLUOP_CHECK(mluOpSetTensorDescriptorEx( + MLUOP_CHECK(mluOpSetTensorDescriptorEx_v2( input_desc_, input_params.get_layout(), input_params.get_dtype(), input_params.get_dim_nb(), input_params.get_dim_size().data(), input_params.get_dim_stride().data())); @@ -54,10 +54,10 @@ class fft_general : public testing::TestWithParam { MLUOP_CHECK(mluOpSetTensorDescriptorOnchipDataType( input_desc_, input_params.get_onchip_dtype())); - MLUOpTensorParam output_params = std::get<1>(GetParam()); + MLUOpTensorParamInt64 output_params = std::get<1>(GetParam()); MLUOP_CHECK(mluOpCreateTensorDescriptor(&output_desc_)); - MLUOP_CHECK(mluOpSetTensorDescriptorEx( + MLUOP_CHECK(mluOpSetTensorDescriptorEx_v2( output_desc_, output_params.get_layout(), output_params.get_dtype(), output_params.get_dim_nb(), output_params.get_dim_size().data(), output_params.get_dim_stride().data())); @@ -138,13 +138,14 @@ TEST_P(fft_general, negative) { EXPECT_TRUE(compute()); } INSTANTIATE_TEST_CASE_P( zero_element, fft_general, - testing::Combine(testing::Values(MLUOpTensorParam{ + testing::Combine(testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1, - std::vector({0, 1}), std::vector({1, 1}), - MLUOP_DTYPE_FLOAT}), - testing::Values(MLUOpTensorParam{ + std::vector({0, 1}), + std::vector({1, 1}), MLUOP_DTYPE_FLOAT}), + testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1, - std::vector({1, 1}), std::vector({1, 1})}), + std::vector({1, 1}), + std::vector({1, 1})}), testing::Values(1), testing::Values(1), testing::Values(MLUOP_UNKNOWN_DEVICE), testing::Values(MLUOP_STATUS_SUCCESS))); @@ -152,13 +153,14 @@ INSTANTIATE_TEST_CASE_P( INSTANTIATE_TEST_CASE_P( negative_2_n, // half,complex_half,fft length can be broken down into 2^m fft_general, - testing::Combine(testing::Values(MLUOpTensorParam{ + testing::Combine(testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_HALF, 2, - std::vector({1, 7}), std::vector({1, 1}), - MLUOP_DTYPE_HALF}), - testing::Values(MLUOpTensorParam{ + std::vector({1, 7}), + std::vector({1, 1}), MLUOP_DTYPE_HALF}), + testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_HALF, 2, - std::vector({1, 7}), std::vector({1, 1})}), + std::vector({1, 7}), + std::vector({1, 1})}), testing::Values(1), testing::Values(7), testing::Values(MLUOP_UNKNOWN_DEVICE), testing::Values(MLUOP_STATUS_NOT_SUPPORTED))); @@ -167,105 +169,107 @@ INSTANTIATE_TEST_CASE_P( negative_2_m_l, // float/complex_float,n>4096, fft length can be broken // down into 2^m*l fft_general, - testing::Combine(testing::Values(MLUOpTensorParam{ + testing::Combine(testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 2, - std::vector({1, 4097}), std::vector({1, 1}), - MLUOP_DTYPE_FLOAT}), - testing::Values(MLUOpTensorParam{ + std::vector({1, 4097}), + std::vector({1, 1}), MLUOP_DTYPE_FLOAT}), + testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 2, - std::vector({1, 4097}), - std::vector({1, 1})}), + std::vector({1, 4097}), + std::vector({1, 1})}), testing::Values(1), testing::Values(4097), testing::Values(MLUOP_UNKNOWN_DEVICE), testing::Values(MLUOP_STATUS_NOT_SUPPORTED))); INSTANTIATE_TEST_CASE_P( negative_rank_1, fft_general, - testing::Combine(testing::Values(MLUOpTensorParam{ + testing::Combine(testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1, - std::vector({1}), std::vector({1}), + std::vector({1}), std::vector({1}), MLUOP_DTYPE_FLOAT}), - testing::Values(MLUOpTensorParam{ + testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1, - std::vector({1}), std::vector({1})}), + std::vector({1}), std::vector({1})}), testing::Values(4), testing::Values(1), testing::Values(MLUOP_UNKNOWN_DEVICE), testing::Values(MLUOP_STATUS_BAD_PARAM))); INSTANTIATE_TEST_CASE_P( negative_N_le_0, fft_general, - testing::Combine(testing::Values(MLUOpTensorParam{ + testing::Combine(testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1, - std::vector({1}), std::vector({1}), + std::vector({1}), std::vector({1}), MLUOP_DTYPE_FLOAT}), - testing::Values(MLUOpTensorParam{ + testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1, - std::vector({1}), std::vector({1})}), + std::vector({1}), std::vector({1})}), testing::Values(1), testing::Values(0, -1), testing::Values(MLUOP_UNKNOWN_DEVICE), testing::Values(MLUOP_STATUS_BAD_PARAM))); INSTANTIATE_TEST_CASE_P( negative_batch, fft_general, - testing::Combine(testing::Values(MLUOpTensorParam{ + testing::Combine(testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 2, - std::vector({1, 1}), std::vector({1, 1}), - MLUOP_DTYPE_FLOAT}), - testing::Values(MLUOpTensorParam{ + std::vector({1, 1}), + std::vector({1, 1}), MLUOP_DTYPE_FLOAT}), + testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 2, - std::vector({2, 1}), std::vector({1, 1})}), + std::vector({2, 1}), + std::vector({1, 1})}), testing::Values(1), testing::Values(1), testing::Values(MLUOP_UNKNOWN_DEVICE), testing::Values(MLUOP_STATUS_BAD_PARAM))); INSTANTIATE_TEST_CASE_P( negative_input_stride, fft_general, - testing::Combine(testing::Values(MLUOpTensorParam{ + testing::Combine(testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1, - std::vector({1}), std::vector({-1}), + std::vector({1}), std::vector({-1}), MLUOP_DTYPE_FLOAT}), - testing::Values(MLUOpTensorParam{ + testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1, - std::vector({1}), std::vector({1})}), + std::vector({1}), std::vector({1})}), testing::Values(1), testing::Values(1), testing::Values(MLUOP_UNKNOWN_DEVICE), testing::Values(MLUOP_STATUS_BAD_PARAM))); INSTANTIATE_TEST_CASE_P( negative_output_stride, fft_general, - testing::Combine(testing::Values(MLUOpTensorParam{ + testing::Combine(testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1, - std::vector({1}), std::vector({1}), + std::vector({1}), std::vector({1}), MLUOP_DTYPE_FLOAT}), - testing::Values(MLUOpTensorParam{ + testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1, - std::vector({1}), std::vector({-1})}), + std::vector({1}), + std::vector({-1})}), testing::Values(1), testing::Values(1), testing::Values(MLUOP_UNKNOWN_DEVICE), testing::Values(MLUOP_STATUS_BAD_PARAM))); INSTANTIATE_TEST_CASE_P( negative_unsupported_dtype_combination, fft_general, - testing::Combine(testing::Values(MLUOpTensorParam{ + testing::Combine(testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_HALF, 1, - std::vector({1}), std::vector({1}), + std::vector({1}), std::vector({1}), MLUOP_DTYPE_FLOAT}), - testing::Values(MLUOpTensorParam{ + testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1, - std::vector({1}), std::vector({1})}), + std::vector({1}), std::vector({1})}), testing::Values(1), testing::Values(1), testing::Values(MLUOP_UNKNOWN_DEVICE), testing::Values(MLUOP_STATUS_BAD_PARAM))); INSTANTIATE_TEST_CASE_P( negative_onchip_dtype, fft_general, - testing::Combine(testing::Values(MLUOpTensorParam{ + testing::Combine(testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1, - std::vector({1}), std::vector({1}), + std::vector({1}), std::vector({1}), MLUOP_DTYPE_HALF}), - testing::Values(MLUOpTensorParam{ + testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1, - std::vector({1}), std::vector({1})}), + std::vector({1}), std::vector({1})}), testing::Values(1), testing::Values(1), testing::Values(MLUOP_UNKNOWN_DEVICE), testing::Values(MLUOP_STATUS_BAD_PARAM))); @@ -273,13 +277,13 @@ INSTANTIATE_TEST_CASE_P( // r2c,output!=n/2+1 INSTANTIATE_TEST_CASE_P( negative_r2c_length, fft_general, - testing::Combine(testing::Values(MLUOpTensorParam{ + testing::Combine(testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_FLOAT, 1, - std::vector({4}), std::vector({1}), + std::vector({4}), std::vector({1}), MLUOP_DTYPE_FLOAT}), - testing::Values(MLUOpTensorParam{ + testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1, - std::vector({1}), std::vector({1})}), + std::vector({1}), std::vector({1})}), testing::Values(1), testing::Values(4), testing::Values(MLUOP_UNKNOWN_DEVICE), testing::Values(MLUOP_STATUS_BAD_PARAM))); @@ -287,13 +291,13 @@ INSTANTIATE_TEST_CASE_P( // c2c,output != n INSTANTIATE_TEST_CASE_P( negative_c2c_length, fft_general, - testing::Combine(testing::Values(MLUOpTensorParam{ + testing::Combine(testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1, - std::vector({1}), std::vector({1}), + std::vector({1}), std::vector({1}), MLUOP_DTYPE_FLOAT}), - testing::Values(MLUOpTensorParam{ + testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1, - std::vector({0}), std::vector({1})}), + std::vector({0}), std::vector({1})}), testing::Values(1), testing::Values(1), testing::Values(MLUOP_UNKNOWN_DEVICE), testing::Values(MLUOP_STATUS_BAD_PARAM))); @@ -301,13 +305,13 @@ INSTANTIATE_TEST_CASE_P( // c2r,output!=n INSTANTIATE_TEST_CASE_P( negative_c2r_length, fft_general, - testing::Combine(testing::Values(MLUOpTensorParam{ + testing::Combine(testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_COMPLEX_FLOAT, 1, - std::vector({4}), std::vector({1}), + std::vector({4}), std::vector({1}), MLUOP_DTYPE_FLOAT}), - testing::Values(MLUOpTensorParam{ + testing::Values(MLUOpTensorParamInt64{ MLUOP_LAYOUT_NHWC, MLUOP_DTYPE_FLOAT, 1, - std::vector({3}), std::vector({1})}), + std::vector({3}), std::vector({1})}), testing::Values(1), testing::Values(4), testing::Values(MLUOP_UNKNOWN_DEVICE), testing::Values(MLUOP_STATUS_BAD_PARAM)));