Skip to content

Commit

Permalink
[Fix](mluOpExecFFT): fix fft bug.
Browse files Browse the repository at this point in the history
  • Loading branch information
niyuming committed Oct 28, 2024
1 parent 9ff2f87 commit 6a938f0
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 59 deletions.
61 changes: 31 additions & 30 deletions kernels/fft/irfft/irfft_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2033,44 +2033,45 @@ mluOpStatus_t execIRFFT2d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan,
(void *)((uint64_t)(fft_plan->mlu_addrs.output) -
fft_plan->batch * odist);
}
} else if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) {
status = computeFFT2dMatMulColumnC2R(handle, fft_plan, scale_factor);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);

if (scale_factor != 1.0) {
const float alpha[2] = {scale_factor, 0.0};
const float beta[2] = {0.0, 0.0};
mluOpTensorDescriptor_t c_desc = nullptr;
status = mluOpCreateTensorDescriptor(&c_desc);
const int out_dim_num = 3;
int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->n[0],
fft_plan->n[1]};
status = mluOpSetTensorDescriptor_v2(c_desc, MLUOP_LAYOUT_ARRAY,
fft_plan->output_dtype, 3, dims);
status = mluOpSetTensorDescriptorOnchipDataType(
c_desc, fft_plan->execution_dtype);
status = computeFFT2dMatMulRowC2R(handle, fft_plan, scale_factor);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
}

DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle,
cnnl_handle); // convert to cnnl_handle
if (scale_factor != 1.0) {
const float alpha[2] = {scale_factor, 0.0};
const float beta[2] = {0.0, 0.0};
mluOpTensorDescriptor_t c_desc = nullptr;
status = mluOpCreateTensorDescriptor(&c_desc);
const int out_dim_num = 3;
int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->n[0],
fft_plan->n[1]};
status = mluOpSetTensorDescriptor_v2(c_desc, MLUOP_LAYOUT_ARRAY,
fft_plan->output_dtype, 3, dims);
status = mluOpSetTensorDescriptorOnchipDataType(
c_desc, fft_plan->execution_dtype);

DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(c_desc, cnnl_output_desc);
DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle,
cnnl_handle); // convert to cnnl_handle

CALL_CNNL(cnnlTransform_v2(cnnl_handle, CNNL_POINTER_MODE_HOST, &alpha,
cnnl_output_desc, fft_plan->mlu_addrs.output,
&beta, cnnl_output_desc,
fft_plan->mlu_addrs.output));
DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_output_desc);
DESTROY_CNNL_HANDLE(cnnl_handle);
}
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(c_desc, cnnl_output_desc);

CALL_CNNL(cnnlTransform_v2(cnnl_handle, CNNL_POINTER_MODE_HOST, &alpha,
cnnl_output_desc, fft_plan->mlu_addrs.output,
&beta, cnnl_output_desc,
fft_plan->mlu_addrs.output));
DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_output_desc);
DESTROY_CNNL_HANDLE(cnnl_handle);
}
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);

if (fft_plan->fft_strategy == CNFFT_FUNC_TWO_LEVEL_STOCKHAM) {
status = makeIRFFT2dContiguousOutput(handle, fft_plan, output,
fft_plan->mlu_addrs.output);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);

} else if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) {
status = computeFFT2dMatMulColumnC2R(handle, fft_plan, scale_factor);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);

status = computeFFT2dMatMulRowC2R(handle, fft_plan, scale_factor);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
}
return status;
}
59 changes: 30 additions & 29 deletions kernels/fft/rfft/rfft_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1625,42 +1625,43 @@ mluOpStatus_t execRFFT2d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan,
(void *)((uint64_t)(fft_plan->mlu_addrs.output) -
fft_plan->batch * odist);
}
} else if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) {
status = computeFFT2dMatMulRowR2C(handle, fft_plan, scale_factor);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
status = computeFFT2dMatMulColumnR2C(handle, fft_plan, scale_factor);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
}

if (scale_factor != 1.0) {
const float alpha[2] = {scale_factor, 0.0};
const float beta[2] = {0.0, 0.0};
mluOpTensorDescriptor_t c_desc = nullptr;
status = mluOpCreateTensorDescriptor(&c_desc);
const int out_dim_num = 3;
int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->n[0],
fft_plan->n[1] / 2 + 1};
status = mluOpSetTensorDescriptor_v2(c_desc, MLUOP_LAYOUT_ARRAY,
fft_plan->output_dtype, 3, dims);
status = mluOpSetTensorDescriptorOnchipDataType(
c_desc, fft_plan->execution_dtype);
if (scale_factor != 1.0) {
const float alpha[2] = {scale_factor, 0.0};
const float beta[2] = {0.0, 0.0};
mluOpTensorDescriptor_t c_desc = nullptr;
status = mluOpCreateTensorDescriptor(&c_desc);
const int out_dim_num = 3;
int64_t dims[out_dim_num] = {fft_plan->batch, fft_plan->n[0],
fft_plan->n[1] / 2 + 1};
status = mluOpSetTensorDescriptor_v2(c_desc, MLUOP_LAYOUT_ARRAY,
fft_plan->output_dtype, 3, dims);
status = mluOpSetTensorDescriptorOnchipDataType(
c_desc, fft_plan->execution_dtype);

DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle,
cnnl_handle); // convert to cnnl_handle
DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle,
cnnl_handle); // convert to cnnl_handle

DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(c_desc, cnnl_output_desc);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(c_desc, cnnl_output_desc);

CALL_CNNL(cnnlTransform_v2(cnnl_handle, CNNL_POINTER_MODE_HOST, &alpha,
cnnl_output_desc, fft_plan->mlu_addrs.output,
&beta, cnnl_output_desc,
fft_plan->mlu_addrs.output));
DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_output_desc);
DESTROY_CNNL_HANDLE(cnnl_handle);
}
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
CALL_CNNL(cnnlTransform_v2(cnnl_handle, CNNL_POINTER_MODE_HOST, &alpha,
cnnl_output_desc, fft_plan->mlu_addrs.output,
&beta, cnnl_output_desc,
fft_plan->mlu_addrs.output));
DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_output_desc);
DESTROY_CNNL_HANDLE(cnnl_handle);
}
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);

if (fft_plan->fft_strategy == CNFFT_FUNC_TWO_LEVEL_STOCKHAM) {
status = makeRFFT2dContiguousOutput(handle, fft_plan, output);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);

} else if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) {
status = computeFFT2dMatMulRowR2C(handle, fft_plan, scale_factor);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
status = computeFFT2dMatMulColumnR2C(handle, fft_plan, scale_factor);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
}
return status;
}

0 comments on commit 6a938f0

Please sign in to comment.