Skip to content

Commit

Permalink
[Feature](mluOpExecFFT): fix accuracy and performance bug (#1086)
Browse files Browse the repository at this point in the history
Co-authored-by: root <[email protected]>
Co-authored-by: nike-tinghai <[email protected]>
Co-authored-by: niyuming <[email protected]>
Co-authored-by: PetrelYy <[email protected]>
  • Loading branch information
5 people authored Sep 25, 2024
1 parent 8478a98 commit df579dc
Show file tree
Hide file tree
Showing 14 changed files with 600 additions and 204 deletions.
74 changes: 43 additions & 31 deletions kernels/fft/c2c_fft/c2c_fft_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ mluOpStatus_t makeFFT1dPolicy(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan) {
fft_plan->workspace_size += transed_input_size;
// input trans workspace: batch * n * 2 --> 2 * batch * n
const int in_trans_dim_num = 2;
int in_trans_input_dims[in_trans_dim_num] = {padded_input_num, COMPLEX};
int64_t in_trans_input_dims[in_trans_dim_num] = {padded_input_num,
COMPLEX};
int in_trans_permute[in_trans_dim_num] = {1, 0};
size_t in_trans_workspace_size = 0;
status = fftGetTransposeWorkspaceSize(
Expand Down Expand Up @@ -205,8 +206,8 @@ mluOpStatus_t makeFFT1dPolicy(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan) {

// output trans workspace: 2 * batch * n --> batch * n * 2
const int out_trans_dim_num = 2;
int out_trans_input_dims[out_trans_dim_num] = {COMPLEX,
per_matmul_output_num};
int64_t out_trans_input_dims[out_trans_dim_num] = {COMPLEX,
per_matmul_output_num};
int out_trans_permute[out_trans_dim_num] = {1, 0};
size_t out_trans_workspace_size = 0;
status = fftGetTransposeWorkspaceSize(
Expand Down Expand Up @@ -332,8 +333,8 @@ mluOpStatus_t makeFFT1dPolicy(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan) {
// input trans workspace
// 1st transpose: batch * n * 2 --> 2 * batch * n
const int in_trans_1st_dim_num = 2;
int in_trans_1st_input_dims[in_trans_1st_dim_num] = {padded_input_num,
COMPLEX};
int64_t in_trans_1st_input_dims[in_trans_1st_dim_num] = {padded_input_num,
COMPLEX};
int in_trans_1st_permute[in_trans_1st_dim_num] = {1, 0};
size_t in_trans_1st_workspace_size = 0;
status = fftGetTransposeWorkspaceSize(
Expand All @@ -344,8 +345,8 @@ mluOpStatus_t makeFFT1dPolicy(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan) {
in_trans_1st_workspace_size);
// 2nd transpose: 2 * batch * L * 2^m --> 2 * batch * 2^m * L
const int in_trans_2nd_dim_num = 3;
int in_trans_2nd_input_dims[in_trans_2nd_dim_num] = {COMPLEX * batch, L,
m};
int64_t in_trans_2nd_input_dims[in_trans_2nd_dim_num] = {COMPLEX * batch,
L, m};
int in_trans_2nd_permute[in_trans_2nd_dim_num] = {0, 2, 1};
size_t in_trans_2nd_workspace_size = 0;
status = fftGetTransposeWorkspaceSize(
Expand Down Expand Up @@ -375,12 +376,28 @@ mluOpStatus_t makeFFT1dPolicy(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan) {
fft_plan->workspace_size += matmul_times * per_matmul_output_size;
// matmul workspace
size_t matmul_workspace_size = 0;
status = fftGetQuantizeMatMulWorkspaceSize(
handle, matmul_workspace_size, batch * m, L, L, false, true,
in_e_dtype, in_e_dtype, in_r_dtype, api);
fft_plan->matmul_addrs.internal_workspace_size =
std::max(fft_plan->matmul_addrs.internal_workspace_size,
matmul_workspace_size);
if (fft_plan->fft_strategy == CNFFT_FUNC_COOLEY_TUKEY) {
status = fftGetQuantizeMatMulWorkspaceSize(
handle, matmul_workspace_size, batch * m, L, L, false, true,
in_e_dtype, in_e_dtype, in_r_dtype, api);
fft_plan->matmul_addrs.internal_workspace_size =
std::max(fft_plan->matmul_addrs.internal_workspace_size,
matmul_workspace_size);
} else {
status = fftGetBatchMatMulBcastWorkspaceSize(
handle, 2 * L, L, m * 2, batch,
fft_plan->matmul_addrs.dft_im_matrix_addr,
fft_plan->matmul_addrs.dft_pos_addr,
fft_plan->matmul_addrs.dft_scale_addr,
fft_plan->matmul_addrs.input_pad_addr,
fft_plan->matmul_addrs.input_pos_addr,
fft_plan->matmul_addrs.input_scale_addr,
fft_plan->matmul_addrs.matmul_re_mul_re_addr, false, false, 1.0,
0.0, in_e_dtype, in_e_dtype, in_r_dtype,
fft_plan->matmul_addrs.internal_workspace_addr,
fft_plan->matmul_addrs.internal_workspace_size, api);
}

// optensor workspace
size_t optensor_workspace_size = 0;
status =
Expand Down Expand Up @@ -1039,7 +1056,9 @@ static void configureFFT1dWorkspaceAddrs(mluOpHandle_t handle,
fft_plan->mlu_addrs.buffer_buf = (uint8_t *)workspace + offset;
offset += buffer_size * 2;

if (fft_plan->is_input_contiguous || fft_plan->is_batch_contiguous) {
if ((fft_plan->is_input_contiguous &&
fft_plan->inembed[0] <= fft_plan->n[0]) ||
fft_plan->is_batch_contiguous) {
fft_plan->mlu_addrs.input = input;
} else {
fft_plan->mlu_addrs.input = (uint8_t *)workspace + offset;
Expand Down Expand Up @@ -1230,8 +1249,7 @@ static mluOpStatus_t padFFT1dContiguousInput(mluOpHandle_t handle,
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);

const int in_dim_num = 2;
int64_t dims[in_dim_num] = {batch,
std::min(fft_plan->inembed[0], n) * COMPLEX};
int64_t dims[in_dim_num] = {batch, fft_plan->inembed[0] * COMPLEX};
status = mluOpSetTensorDescriptor_v2(input_desc, MLUOP_LAYOUT_ARRAY,
in_r_dtype, in_dim_num, dims);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
Expand All @@ -1242,8 +1260,7 @@ static mluOpStatus_t padFFT1dContiguousInput(mluOpHandle_t handle,
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);

const int pad_dim_num = 4;
int paddings[pad_dim_num] = {
0, 0, 0, std::max((n - fft_plan->inembed[0]), 0) * COMPLEX};
int paddings[pad_dim_num] = {0, 0, 0, (n - fft_plan->inembed[0]) * COMPLEX};
uint64_t padding_value = 0x00000000;

DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle,
Expand All @@ -1255,14 +1272,9 @@ static mluOpStatus_t padFFT1dContiguousInput(mluOpHandle_t handle,
cnnl_padded_input_desc);
CALL_CNNL(cnnlPad(
cnnl_handle, cnnl_input_desc,
// (fft_plan->n[0] > fft_plan->inembed[0]) ? fft_plan->mlu_addrs.input
// : fft_plan->matmul_addrs.input_contiguous_addr,
fft_plan->prime ? fft_plan->matmul_addrs.input_contiguous_addr
: fft_plan->mlu_addrs.input,
paddings, &padding_value, cnnl_padded_input_desc,
// (fft_plan->n[0] > fft_plan->inembed[0]) ?
// fft_plan->mlu_addrs.input_pad_addr
// : fft_plan->matmul_addrs.input_pad_addr
fft_plan->prime ? fft_plan->matmul_addrs.input_pad_addr
: fft_plan->mlu_addrs.input_pad_addr));

Expand Down Expand Up @@ -1396,8 +1408,8 @@ static mluOpStatus_t transposeFFT1dPaddedInput(mluOpHandle_t handle,
VLOG(5) << "launch mluOpTranspose for input CNFFT_FUNC_MATMUL";
int padded_input_num = batch * n;
const int trans_dim_num = 2;
int trans_input_dims[trans_dim_num] = {padded_input_num, COMPLEX};
int trans_output_dims[trans_dim_num] = {COMPLEX, padded_input_num};
int64_t trans_input_dims[trans_dim_num] = {padded_input_num, COMPLEX};
int64_t trans_output_dims[trans_dim_num] = {COMPLEX, padded_input_num};
int trans_permute[trans_dim_num] = {1, 0};

status =
Expand All @@ -1415,8 +1427,8 @@ static mluOpStatus_t transposeFFT1dPaddedInput(mluOpHandle_t handle,
// 1st transpose: batch * n * 2 --> 2 * batch * n
int padded_input_num = batch * n;
const int trans_dim_num = 2;
int trans_input_dims[trans_dim_num] = {padded_input_num, COMPLEX};
int trans_output_dims[trans_dim_num] = {COMPLEX, padded_input_num};
int64_t trans_input_dims[trans_dim_num] = {padded_input_num, COMPLEX};
int64_t trans_output_dims[trans_dim_num] = {COMPLEX, padded_input_num};
int trans_permute[trans_dim_num] = {1, 0};

status =
Expand All @@ -1428,8 +1440,8 @@ static mluOpStatus_t transposeFFT1dPaddedInput(mluOpHandle_t handle,

// 2nd transpose: 2 * batch * L * 2^m --> 2 * batch * 2^m * L
const int trans_2nd_dim_num = 3;
int trans_2nd_input_dims[trans_2nd_dim_num] = {COMPLEX * batch, L, m};
int trans_2nd_output_dims[trans_2nd_dim_num] = {COMPLEX * batch, m, L};
int64_t trans_2nd_input_dims[trans_2nd_dim_num] = {COMPLEX * batch, L, m};
int64_t trans_2nd_output_dims[trans_2nd_dim_num] = {COMPLEX * batch, m, L};
int trans_2nd_permute[trans_2nd_dim_num] = {0, 2, 1};

status = fftTranspose(handle, trans_2nd_dim_num, trans_2nd_input_dims,
Expand Down Expand Up @@ -1739,8 +1751,8 @@ static mluOpStatus_t transposeFFT1dOutput(mluOpHandle_t handle,

int output_num = batch * n;
const int trans_dim_num = 2;
int trans_input_dims[trans_dim_num] = {COMPLEX, output_num};
int trans_output_dims[trans_dim_num] = {output_num, COMPLEX};
int64_t trans_input_dims[trans_dim_num] = {COMPLEX, output_num};
int64_t trans_output_dims[trans_dim_num] = {output_num, COMPLEX};
int trans_permute[trans_dim_num] = {1, 0};

status = fftTranspose(
Expand Down
Loading

0 comments on commit df579dc

Please sign in to comment.