diff --git a/kernels/fft/c2c_fft/c2c_fft_host.cpp b/kernels/fft/c2c_fft/c2c_fft_host.cpp index bd8d12ce6..d855c554f 100644 --- a/kernels/fft/c2c_fft/c2c_fft_host.cpp +++ b/kernels/fft/c2c_fft/c2c_fft_host.cpp @@ -618,11 +618,6 @@ mluOpStatus_t setFFT1dReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, CNRT_CHECK(cnrtMemcpyAsync(fft_plan->mlu_addrs.idft_matrix, fft_plan->idft_matrix, DFT_TABLE_SIZE, handle->queue, cnrtMemcpyHostToDev)); - CNRT_CHECK(cnrtFreeHost(fft_plan->factors)); - CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles)); - CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix)); - CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles_inv)); - CNRT_CHECK(cnrtFreeHost(fft_plan->idft_matrix)); } return status; } @@ -698,11 +693,6 @@ mluOpStatus_t setFFT2dReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, fft_plan->dft_matrix, DFT_TABLE_SIZE, handle->queue, cnrtMemcpyHostToDev)); - CNRT_CHECK(cnrtFreeHost(fft_plan->factors)); - CNRT_CHECK(cnrtFreeHost(fft_plan->factors_2d)); - CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles)); - CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix)); - if (fft_plan->fft_type == CNFFT_HALF2COMPLEX_HALF || fft_plan->fft_type == CNFFT_FLOAT2COMPLEX_FLOAT) { fft_plan->mlu_addrs.twiddles_2d = @@ -722,8 +712,6 @@ mluOpStatus_t setFFT2dReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, CNRT_CHECK(cnrtMemcpyAsync(fft_plan->mlu_addrs.dft_matrix_2d, fft_plan->dft_matrix_2d, DFT_TABLE_SIZE, handle->queue, cnrtMemcpyHostToDev)); - CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles_2d)); - CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix_2d)); } else if (fft_plan->fft_type == CNFFT_COMPLEX_HALF2HALF || fft_plan->fft_type == CNFFT_COMPLEX_FLOAT2FLOAT) { fft_plan->mlu_addrs.twiddles_inv_2d = @@ -745,8 +733,6 @@ mluOpStatus_t setFFT2dReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, CNRT_CHECK(cnrtMemcpyAsync(fft_plan->mlu_addrs.idft_matrix_2d, fft_plan->idft_matrix_2d, DFT_TABLE_SIZE, handle->queue, cnrtMemcpyHostToDev)); - CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles_inv_2d)); - CNRT_CHECK(cnrtFreeHost(fft_plan->idft_matrix_2d)); } else { fft_plan->mlu_addrs.twiddles_2d = (uint8_t *)fft_plan->reservespace_addr + reservespace_offset; @@ -798,12 +784,6 @@ mluOpStatus_t setFFT2dReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, CNRT_CHECK(cnrtMemcpyAsync(fft_plan->mlu_addrs.idft_matrix, fft_plan->idft_matrix, DFT_TABLE_SIZE, handle->queue, cnrtMemcpyHostToDev)); - CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles_2d)); - CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix_2d)); - CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles_inv_2d)); - CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles_inv)); - CNRT_CHECK(cnrtFreeHost(fft_plan->idft_matrix_2d)); - CNRT_CHECK(cnrtFreeHost(fft_plan->idft_matrix)); } } else if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) { @@ -827,8 +807,6 @@ mluOpStatus_t setFFT2dReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, CNRT_CHECK(cnrtMemcpyAsync( fft_plan->mlu_addrs.dft_matrix_2d, fft_plan->dft_matrix_2d, CPX_TYPE_SIZE * _n0 * _n0, handle->queue, cnrtMemcpyHostToDev)); - CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix)); - CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix_2d)); } break; case CNFFT_COMPLEX_HALF2COMPLEX_HALF: case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT: { @@ -864,10 +842,6 @@ mluOpStatus_t setFFT2dReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, CNRT_CHECK(cnrtMemcpyAsync( fft_plan->mlu_addrs.idft_matrix_2d, fft_plan->idft_matrix_2d, CPX_TYPE_SIZE * _n0 * _n0, handle->queue, cnrtMemcpyHostToDev)); - CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix)); - CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix_2d)); - CNRT_CHECK(cnrtFreeHost(fft_plan->idft_matrix)); - CNRT_CHECK(cnrtFreeHost(fft_plan->idft_matrix_2d)); }; break; case CNFFT_COMPLEX_HALF2HALF: case CNFFT_COMPLEX_FLOAT2FLOAT: { @@ -888,8 +862,6 @@ mluOpStatus_t setFFT2dReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan, CNRT_CHECK(cnrtMemcpyAsync( fft_plan->mlu_addrs.dft_matrix_2d, fft_plan->dft_matrix_2d, CPX_TYPE_SIZE * _n0 * _n0, handle->queue, cnrtMemcpyHostToDev)); - CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix)); - CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix_2d)); }; break; default: { LOG(ERROR) << make_plan_api << ": invalid 2d fft type."; diff --git a/kernels/fft/fft.cpp b/kernels/fft/fft.cpp index 986b98f12..4d4ab9ef1 100644 --- a/kernels/fft/fft.cpp +++ b/kernels/fft/fft.cpp @@ -2700,6 +2700,122 @@ mluOpStatus_t MLUOP_WIN_API mluOpMakeFFTPlanMany( return MLUOP_STATUS_SUCCESS; } +mluOpStatus_t destroyRFFT1dReserveArea(mluOpFFTPlan_t fft_plan, + const std::string api) { + VLOG(5) << "setRFFT1dReserveArea"; + mluOpStatus_t status = MLUOP_STATUS_SUCCESS; + if (!fft_plan->prime) { + CNRT_CHECK(cnrtFreeHost(fft_plan->factors)); + CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles)); + CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix)); + } + return status; +} + +mluOpStatus_t destroyIRFFT1dReserveArea(mluOpFFTPlan_t fft_plan, + const std::string api) { + mluOpStatus_t status = MLUOP_STATUS_SUCCESS; + if (!fft_plan->prime) { + CNRT_CHECK(cnrtFreeHost(fft_plan->factors)); + CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles)); + CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix)); + } + return status; +} + +mluOpStatus_t destroyFFT1dReserveArea(mluOpFFTPlan_t fft_plan, + const std::string api) { + mluOpStatus_t status = MLUOP_STATUS_SUCCESS; + if (!fft_plan->prime) { + CNRT_CHECK(cnrtFreeHost(fft_plan->factors)); + CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles)); + CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix)); + CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles_inv)); + CNRT_CHECK(cnrtFreeHost(fft_plan->idft_matrix)); + } + return status; +} + +mluOpStatus_t destroyFFT2dReserveArea(mluOpFFTPlan_t fft_plan, + const std::string api) { + mluOpStatus_t status = MLUOP_STATUS_SUCCESS; + + const std::string make_plan_api = "[setFFT2dReserveArea]"; + + size_t CPX_TYPE_SIZE = 0; + + switch (fft_plan->fft_type) { + case CNFFT_HALF2COMPLEX_HALF: + case CNFFT_COMPLEX_HALF2HALF: + case CNFFT_COMPLEX_HALF2COMPLEX_HALF: { + CPX_TYPE_SIZE = 2 * 2; + } break; + case CNFFT_FLOAT2COMPLEX_FLOAT: + case CNFFT_COMPLEX_FLOAT2FLOAT: + case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT: { + CPX_TYPE_SIZE = 4 * 2; + }; break; + default: { + LOG(ERROR) << make_plan_api << ": invalid 2d fft type."; + status = MLUOP_STATUS_NOT_SUPPORTED; + return status; + } + } + + if (fft_plan->fft_strategy == CNFFT_FUNC_TWO_LEVEL_STOCKHAM) { + CNRT_CHECK(cnrtFreeHost(fft_plan->factors)); + CNRT_CHECK(cnrtFreeHost(fft_plan->factors_2d)); + CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles)); + CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix)); + if (fft_plan->fft_type == CNFFT_HALF2COMPLEX_HALF || + fft_plan->fft_type == CNFFT_FLOAT2COMPLEX_FLOAT) { + CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles_2d)); + CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix_2d)); + } else if (fft_plan->fft_type == CNFFT_COMPLEX_HALF2HALF || + fft_plan->fft_type == CNFFT_COMPLEX_FLOAT2FLOAT) { + CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles_inv_2d)); + CNRT_CHECK(cnrtFreeHost(fft_plan->idft_matrix_2d)); + } else { + CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles_2d)); + CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix_2d)); + CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles_inv_2d)); + CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles_inv)); + CNRT_CHECK(cnrtFreeHost(fft_plan->idft_matrix_2d)); + CNRT_CHECK(cnrtFreeHost(fft_plan->idft_matrix)); + } + + } else if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) { + switch (fft_plan->fft_type) { + case CNFFT_HALF2COMPLEX_HALF: + case CNFFT_FLOAT2COMPLEX_FLOAT: { + // R2C + CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix)); + CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix_2d)); + } break; + case CNFFT_COMPLEX_HALF2COMPLEX_HALF: + case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT: { + // C2C + CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix)); + CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix_2d)); + CNRT_CHECK(cnrtFreeHost(fft_plan->idft_matrix)); + CNRT_CHECK(cnrtFreeHost(fft_plan->idft_matrix_2d)); + }; break; + case CNFFT_COMPLEX_HALF2HALF: + case CNFFT_COMPLEX_FLOAT2FLOAT: { + // C2R + CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix)); + CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix_2d)); + }; break; + default: { + LOG(ERROR) << make_plan_api << ": invalid 2d fft type."; + status = MLUOP_STATUS_NOT_SUPPORTED; + return status; + } + } + } + return status; +} + mluOpStatus_t MLUOP_WIN_API mluOpDestroyFFTPlan(mluOpFFTPlan_t fft_plan) { const std::string destroy_api = "[mluOpDestroyFFTPlan]"; PARAM_CHECK_NE("[mluOpDestroyFFTPlan]", fft_plan, NULL); @@ -2713,9 +2829,39 @@ mluOpStatus_t MLUOP_WIN_API mluOpDestroyFFTPlan(mluOpFFTPlan_t fft_plan) { mluOpDestroyTensorDescriptor(fft_plan->output_desc) == MLUOP_STATUS_SUCCESS); } + mluOpStatus_t status = MLUOP_STATUS_SUCCESS; + switch (fft_plan->fft_type) { + // r2c + case CNFFT_HALF2COMPLEX_HALF: + case CNFFT_FLOAT2COMPLEX_FLOAT: { + if (fft_plan->rank == 1) { + status = destroyRFFT1dReserveArea(fft_plan, destroy_api); + } else if (fft_plan->rank == 2) { + status = destroyFFT2dReserveArea(fft_plan, destroy_api); + } + }; break; + // c2c + case CNFFT_COMPLEX_HALF2COMPLEX_HALF: + case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT: { + if (fft_plan->rank == 1) { + status = destroyFFT1dReserveArea(fft_plan, destroy_api); + } else if (fft_plan->rank == 2) { + status = destroyFFT2dReserveArea(fft_plan, destroy_api); + } + }; break; + // c2r + case CNFFT_COMPLEX_HALF2HALF: + case CNFFT_COMPLEX_FLOAT2FLOAT: { + if (fft_plan->rank == 1) { + status = destroyIRFFT1dReserveArea(fft_plan, destroy_api); + } else if (fft_plan->rank == 2) { + status = destroyFFT2dReserveArea(fft_plan, destroy_api); + } + }; break; + } delete fft_plan; - return MLUOP_STATUS_SUCCESS; + return status; } mluOpStatus_t MLUOP_WIN_API mluOpSetFFTReserveArea(mluOpHandle_t handle, diff --git a/kernels/fft/irfft/irfft_host.cpp b/kernels/fft/irfft/irfft_host.cpp index 433a3ef61..df5b8fa1b 100644 --- a/kernels/fft/irfft/irfft_host.cpp +++ b/kernels/fft/irfft/irfft_host.cpp @@ -592,9 +592,6 @@ mluOpStatus_t setIRFFT1dReserveArea(mluOpHandle_t handle, CNRT_CHECK(cnrtMemcpyAsync(fft_plan->mlu_addrs.dft_matrix, fft_plan->dft_matrix, DFT_TABLE_SIZE, handle->queue, cnrtMemcpyHostToDev)); - CNRT_CHECK(cnrtFreeHost(fft_plan->factors)); - CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles)); - CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix)); } return status; } diff --git a/kernels/fft/rfft/rfft_host.cpp b/kernels/fft/rfft/rfft_host.cpp index f64354fd8..1bcbbfcb8 100644 --- a/kernels/fft/rfft/rfft_host.cpp +++ b/kernels/fft/rfft/rfft_host.cpp @@ -592,9 +592,6 @@ mluOpStatus_t setRFFT1dReserveArea(mluOpHandle_t handle, CNRT_CHECK(cnrtMemcpyAsync(fft_plan->mlu_addrs.dft_matrix, fft_plan->dft_matrix, DFT_TABLE_SIZE, handle->queue, cnrtMemcpyHostToDev)); - CNRT_CHECK(cnrtFreeHost(fft_plan->factors)); - CNRT_CHECK(cnrtFreeHost(fft_plan->twiddles)); - CNRT_CHECK(cnrtFreeHost(fft_plan->dft_matrix)); } return status; }