Skip to content

Commit

Permalink
[Fix](mluOpSetFFTReserveArea): fix fft bug. (#1100)
Browse files Browse the repository at this point in the history
Co-authored-by: niyuming <[email protected]>
  • Loading branch information
aokbok and niyuming authored Oct 10, 2024
1 parent ef37890 commit 5e50714
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 35 deletions.
28 changes: 0 additions & 28 deletions kernels/fft/c2c_fft/c2c_fft_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,11 +618,6 @@ mluOpStatus_t setFFT1dReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan,
CNRT_CHECK(cnrtMemcpyAsync(fft_plan->mlu_addrs.idft_matrix,
fft_plan->idft_matrix, DFT_TABLE_SIZE,
handle->queue, cnrtMemcpyHostToDev));
CNRT_CHECK(cnrtFreeHost(fft_plan->factors));
CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles));
CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix));
CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles_inv));
CNRT_CHECK(cnrtFreeHost(fft_plan->idft_matrix));
}
return status;
}
Expand Down Expand Up @@ -698,11 +693,6 @@ mluOpStatus_t setFFT2dReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan,
fft_plan->dft_matrix, DFT_TABLE_SIZE,
handle->queue, cnrtMemcpyHostToDev));

CNRT_CHECK(cnrtFreeHost(fft_plan->factors));
CNRT_CHECK(cnrtFreeHost(fft_plan->factors_2d));
CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles));
CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix));

if (fft_plan->fft_type == CNFFT_HALF2COMPLEX_HALF ||
fft_plan->fft_type == CNFFT_FLOAT2COMPLEX_FLOAT) {
fft_plan->mlu_addrs.twiddles_2d =
Expand All @@ -722,8 +712,6 @@ mluOpStatus_t setFFT2dReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan,
CNRT_CHECK(cnrtMemcpyAsync(fft_plan->mlu_addrs.dft_matrix_2d,
fft_plan->dft_matrix_2d, DFT_TABLE_SIZE,
handle->queue, cnrtMemcpyHostToDev));
CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles_2d));
CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix_2d));
} else if (fft_plan->fft_type == CNFFT_COMPLEX_HALF2HALF ||
fft_plan->fft_type == CNFFT_COMPLEX_FLOAT2FLOAT) {
fft_plan->mlu_addrs.twiddles_inv_2d =
Expand All @@ -745,8 +733,6 @@ mluOpStatus_t setFFT2dReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan,
CNRT_CHECK(cnrtMemcpyAsync(fft_plan->mlu_addrs.idft_matrix_2d,
fft_plan->idft_matrix_2d, DFT_TABLE_SIZE,
handle->queue, cnrtMemcpyHostToDev));
CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles_inv_2d));
CNRT_CHECK(cnrtFreeHost(fft_plan->idft_matrix_2d));
} else {
fft_plan->mlu_addrs.twiddles_2d =
(uint8_t *)fft_plan->reservespace_addr + reservespace_offset;
Expand Down Expand Up @@ -798,12 +784,6 @@ mluOpStatus_t setFFT2dReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan,
CNRT_CHECK(cnrtMemcpyAsync(fft_plan->mlu_addrs.idft_matrix,
fft_plan->idft_matrix, DFT_TABLE_SIZE,
handle->queue, cnrtMemcpyHostToDev));
CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles_2d));
CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix_2d));
CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles_inv_2d));
CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles_inv));
CNRT_CHECK(cnrtFreeHost(fft_plan->idft_matrix_2d));
CNRT_CHECK(cnrtFreeHost(fft_plan->idft_matrix));
}

} else if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) {
Expand All @@ -827,8 +807,6 @@ mluOpStatus_t setFFT2dReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan,
CNRT_CHECK(cnrtMemcpyAsync(
fft_plan->mlu_addrs.dft_matrix_2d, fft_plan->dft_matrix_2d,
CPX_TYPE_SIZE * _n0 * _n0, handle->queue, cnrtMemcpyHostToDev));
CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix));
CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix_2d));
} break;
case CNFFT_COMPLEX_HALF2COMPLEX_HALF:
case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT: {
Expand Down Expand Up @@ -864,10 +842,6 @@ mluOpStatus_t setFFT2dReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan,
CNRT_CHECK(cnrtMemcpyAsync(
fft_plan->mlu_addrs.idft_matrix_2d, fft_plan->idft_matrix_2d,
CPX_TYPE_SIZE * _n0 * _n0, handle->queue, cnrtMemcpyHostToDev));
CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix));
CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix_2d));
CNRT_CHECK(cnrtFreeHost(fft_plan->idft_matrix));
CNRT_CHECK(cnrtFreeHost(fft_plan->idft_matrix_2d));
}; break;
case CNFFT_COMPLEX_HALF2HALF:
case CNFFT_COMPLEX_FLOAT2FLOAT: {
Expand All @@ -888,8 +862,6 @@ mluOpStatus_t setFFT2dReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan,
CNRT_CHECK(cnrtMemcpyAsync(
fft_plan->mlu_addrs.dft_matrix_2d, fft_plan->dft_matrix_2d,
CPX_TYPE_SIZE * _n0 * _n0, handle->queue, cnrtMemcpyHostToDev));
CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix));
CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix_2d));
}; break;
default: {
LOG(ERROR) << make_plan_api << ": invalid 2d fft type.";
Expand Down
148 changes: 147 additions & 1 deletion kernels/fft/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2700,6 +2700,122 @@ mluOpStatus_t MLUOP_WIN_API mluOpMakeFFTPlanMany(
return MLUOP_STATUS_SUCCESS;
}

mluOpStatus_t destroyRFFT1dReserveArea(mluOpFFTPlan_t fft_plan,
const std::string api) {
VLOG(5) << "setRFFT1dReserveArea";
mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
if (!fft_plan->prime) {
CNRT_CHECK(cnrtFreeHost(fft_plan->factors));
CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles));
CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix));
}
return status;
}

mluOpStatus_t destroyIRFFT1dReserveArea(mluOpFFTPlan_t fft_plan,
const std::string api) {
mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
if (!fft_plan->prime) {
CNRT_CHECK(cnrtFreeHost(fft_plan->factors));
CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles));
CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix));
}
return status;
}

mluOpStatus_t destroyFFT1dReserveArea(mluOpFFTPlan_t fft_plan,
const std::string api) {
mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
if (!fft_plan->prime) {
CNRT_CHECK(cnrtFreeHost(fft_plan->factors));
CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles));
CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix));
CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles_inv));
CNRT_CHECK(cnrtFreeHost(fft_plan->idft_matrix));
}
return status;
}

mluOpStatus_t destroyFFT2dReserveArea(mluOpFFTPlan_t fft_plan,
const std::string api) {
mluOpStatus_t status = MLUOP_STATUS_SUCCESS;

const std::string make_plan_api = "[setFFT2dReserveArea]";

size_t CPX_TYPE_SIZE = 0;

switch (fft_plan->fft_type) {
case CNFFT_HALF2COMPLEX_HALF:
case CNFFT_COMPLEX_HALF2HALF:
case CNFFT_COMPLEX_HALF2COMPLEX_HALF: {
CPX_TYPE_SIZE = 2 * 2;
} break;
case CNFFT_FLOAT2COMPLEX_FLOAT:
case CNFFT_COMPLEX_FLOAT2FLOAT:
case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT: {
CPX_TYPE_SIZE = 4 * 2;
}; break;
default: {
LOG(ERROR) << make_plan_api << ": invalid 2d fft type.";
status = MLUOP_STATUS_NOT_SUPPORTED;
return status;
}
}

if (fft_plan->fft_strategy == CNFFT_FUNC_TWO_LEVEL_STOCKHAM) {
CNRT_CHECK(cnrtFreeHost(fft_plan->factors));
CNRT_CHECK(cnrtFreeHost(fft_plan->factors_2d));
CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles));
CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix));
if (fft_plan->fft_type == CNFFT_HALF2COMPLEX_HALF ||
fft_plan->fft_type == CNFFT_FLOAT2COMPLEX_FLOAT) {
CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles_2d));
CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix_2d));
} else if (fft_plan->fft_type == CNFFT_COMPLEX_HALF2HALF ||
fft_plan->fft_type == CNFFT_COMPLEX_FLOAT2FLOAT) {
CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles_inv_2d));
CNRT_CHECK(cnrtFreeHost(fft_plan->idft_matrix_2d));
} else {
CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles_2d));
CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix_2d));
CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles_inv_2d));
CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles_inv));
CNRT_CHECK(cnrtFreeHost(fft_plan->idft_matrix_2d));
CNRT_CHECK(cnrtFreeHost(fft_plan->idft_matrix));
}

} else if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) {
switch (fft_plan->fft_type) {
case CNFFT_HALF2COMPLEX_HALF:
case CNFFT_FLOAT2COMPLEX_FLOAT: {
// R2C
CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix));
CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix_2d));
} break;
case CNFFT_COMPLEX_HALF2COMPLEX_HALF:
case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT: {
// C2C
CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix));
CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix_2d));
CNRT_CHECK(cnrtFreeHost(fft_plan->idft_matrix));
CNRT_CHECK(cnrtFreeHost(fft_plan->idft_matrix_2d));
}; break;
case CNFFT_COMPLEX_HALF2HALF:
case CNFFT_COMPLEX_FLOAT2FLOAT: {
// C2R
CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix));
CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix_2d));
}; break;
default: {
LOG(ERROR) << make_plan_api << ": invalid 2d fft type.";
status = MLUOP_STATUS_NOT_SUPPORTED;
return status;
}
}
}
return status;
}

mluOpStatus_t MLUOP_WIN_API mluOpDestroyFFTPlan(mluOpFFTPlan_t fft_plan) {
const std::string destroy_api = "[mluOpDestroyFFTPlan]";
PARAM_CHECK_NE("[mluOpDestroyFFTPlan]", fft_plan, NULL);
Expand All @@ -2713,9 +2829,39 @@ mluOpStatus_t MLUOP_WIN_API mluOpDestroyFFTPlan(mluOpFFTPlan_t fft_plan) {
mluOpDestroyTensorDescriptor(fft_plan->output_desc) ==
MLUOP_STATUS_SUCCESS);
}
mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
switch (fft_plan->fft_type) {
// r2c
case CNFFT_HALF2COMPLEX_HALF:
case CNFFT_FLOAT2COMPLEX_FLOAT: {
if (fft_plan->rank == 1) {
status = destroyRFFT1dReserveArea(fft_plan, destroy_api);
} else if (fft_plan->rank == 2) {
status = destroyFFT2dReserveArea(fft_plan, destroy_api);
}
}; break;
// c2c
case CNFFT_COMPLEX_HALF2COMPLEX_HALF:
case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT: {
if (fft_plan->rank == 1) {
status = destroyFFT1dReserveArea(fft_plan, destroy_api);
} else if (fft_plan->rank == 2) {
status = destroyFFT2dReserveArea(fft_plan, destroy_api);
}
}; break;
// c2r
case CNFFT_COMPLEX_HALF2HALF:
case CNFFT_COMPLEX_FLOAT2FLOAT: {
if (fft_plan->rank == 1) {
status = destroyIRFFT1dReserveArea(fft_plan, destroy_api);
} else if (fft_plan->rank == 2) {
status = destroyFFT2dReserveArea(fft_plan, destroy_api);
}
}; break;
}

delete fft_plan;
return MLUOP_STATUS_SUCCESS;
return status;
}

mluOpStatus_t MLUOP_WIN_API mluOpSetFFTReserveArea(mluOpHandle_t handle,
Expand Down
3 changes: 0 additions & 3 deletions kernels/fft/irfft/irfft_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,9 +592,6 @@ mluOpStatus_t setIRFFT1dReserveArea(mluOpHandle_t handle,
CNRT_CHECK(cnrtMemcpyAsync(fft_plan->mlu_addrs.dft_matrix,
fft_plan->dft_matrix, DFT_TABLE_SIZE,
handle->queue, cnrtMemcpyHostToDev));
CNRT_CHECK(cnrtFreeHost(fft_plan->factors));
CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles));
CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix));
}
return status;
}
Expand Down
3 changes: 0 additions & 3 deletions kernels/fft/rfft/rfft_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,9 +592,6 @@ mluOpStatus_t setRFFT1dReserveArea(mluOpHandle_t handle,
CNRT_CHECK(cnrtMemcpyAsync(fft_plan->mlu_addrs.dft_matrix,
fft_plan->dft_matrix, DFT_TABLE_SIZE,
handle->queue, cnrtMemcpyHostToDev));
CNRT_CHECK(cnrtFreeHost(fft_plan->factors));
CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles));
CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix));
}
return status;
}
Expand Down

0 comments on commit 5e50714

Please sign in to comment.