Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
nike-tinghai committed Jul 15, 2024
1 parent db6591e commit 36a9560
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 66 deletions.
43 changes: 23 additions & 20 deletions kernels/fft/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,8 @@ mluOpStatus_t MLUOP_WIN_API fftTwoStepFactor(mluOpFFTPlan_t fft_plan,
status =
fftFactor(r, facbuf, small_factors_offset, factor_type, large_count);
INTERNAL_CHECK("[fftTwoStepFactor]", status == MLUOP_STATUS_SUCCESS);
setMaxParallelNum(fft_plan, cur_facbuf, stage_num, r, is_row_major);
status =setMaxParallelNum(fft_plan, cur_facbuf, stage_num, r, is_row_major);
INTERNAL_CHECK("[fftTwoStepFactor]", status == MLUOP_STATUS_SUCCESS);

out_stride *= r;
large_count++;
Expand Down Expand Up @@ -1250,8 +1251,8 @@ mluOpStatus_t MLUOP_WIN_API calParallelNumLowBound(mluOpFFTPlan_t fft_plan,
para_num = section_num * 2;
align_M = radix;
align_K = K_num * ((radix + K_num - 1) / K_num);
align_N = para_num;
// align_N = 64 * ((para_num + 64 - 1) / 64);
// align_N = para_num;
align_N = 64 * ((para_num + 64 - 1) / 64);
space_need_matmul_tmp =
((align_M * 2 > align_K) ? (align_M * 2) : align_K) * align_N *
2 * TYPE_SIZE;
Expand All @@ -1261,8 +1262,8 @@ mluOpStatus_t MLUOP_WIN_API calParallelNumLowBound(mluOpFFTPlan_t fft_plan,
para_num = butterfly_num * section_num;
align_M = radix;
align_K = K_num * ((radix + K_num - 1) / K_num);
// align_N = 64 * ((para_num + 64 - 1) / 64);
align_N = para_num;
align_N = 64 * ((para_num + 64 - 1) / 64);
//align_N = para_num;

space_need_matmul_tmp = 0;
space_need_matmul_tmp += (align_N * align_K * 2 * TYPE_SIZE);
Expand Down Expand Up @@ -1338,8 +1339,8 @@ mluOpStatus_t MLUOP_WIN_API calParallelNumLowBound(mluOpFFTPlan_t fft_plan,
para_num = section_num * 2;
align_M = radix;
align_K = K_num * ((radix + K_num - 1) / K_num);
align_N = para_num;
// align_N = 64 * ((para_num + 64 - 1) / 64);
//align_N = para_num;
align_N = 64 * ((para_num + 64 - 1) / 64);
space_need_matmul_tmp =
((align_M * 2 > align_K) ? (align_M * 2) : align_K) * align_N *
2 * TYPE_SIZE;
Expand All @@ -1349,8 +1350,8 @@ mluOpStatus_t MLUOP_WIN_API calParallelNumLowBound(mluOpFFTPlan_t fft_plan,
para_num = butterfly_num * section_num;
align_M = radix;
align_K = K_num * ((radix + K_num - 1) / K_num);
// align_N = 64 * ((para_num + 64 - 1) / 64);
align_N = para_num;
align_N = 64 * ((para_num + 64 - 1) / 64);
//align_N = para_num;

space_need_matmul_tmp = 0;
space_need_matmul_tmp += (align_N * align_K * 2 * TYPE_SIZE);
Expand Down Expand Up @@ -1395,14 +1396,14 @@ mluOpStatus_t MLUOP_WIN_API setMaxParallelNum(mluOpFFTPlan_t fft_plan,
const int max_radix = 64;
size_t TYPE_SIZE = 0;
int max_parallel_num = 0;
int nram_space_need = 0;
size_t nram_space_need = 0;
int nram_space_need_tw = 0;
int nram_space_need_dftmtx = (stage == 1)
? max_radix * max_radix * 2 * 2
: max_radix * max_radix * 2; // complex
// int nram_space_need_dftmtx_align = 0;
int space_need_matmul = 0;
int space_need_matmul_tmp = 0;
size_t space_need_matmul = 0;
size_t space_need_matmul_tmp = 0;
int small_stage_num = facbuf[0];
int radix = 0;
int section_num = 0;
Expand Down Expand Up @@ -1643,7 +1644,7 @@ mluOpStatus_t MLUOP_WIN_API setMaxParallelNum(mluOpFFTPlan_t fft_plan,

// nram_space_need += space_need_matmul;
nram_space_need_tw = large_radix * 2 * TYPE_SIZE; // complex
const size_t nram_space_remain =
const int nram_space_remain =
(nram_space_size - nram_space_need_tw - nram_space_need_dftmtx);
max_parallel_num =
nram_space_remain / (nram_space_need + space_need_matmul);
Expand Down Expand Up @@ -1825,18 +1826,20 @@ mluOpStatus_t MLUOP_WIN_API mluOpAllocateRFFT2D(
size_t workspace_size = 0, reservespace_size = 0;

mluOpDataType_t out_c_dtype = fft_plan->output_dtype;
size_t out_c_dtype_size = mluOpDataTypeBytes(out_c_dtype);
mluOpDataType_t in_c_dtype = fft_plan->input_dtype;
size_t complex_dtype_size = (mluOpDataTypeBytes(out_c_dtype)>mluOpDataTypeBytes(in_c_dtype))?
mluOpDataTypeBytes(out_c_dtype):mluOpDataTypeBytes(in_c_dtype);

int batch = fft_plan->batch;
size_t buffer_size = batch * out_c_dtype_size * _n0 * _n1;
size_t buffer_size = batch * complex_dtype_size * _n0 * _n1;

size_t twiddles_size = out_c_dtype_size * _n0;
size_t twiddles_size_2d = out_c_dtype_size * _n1;
size_t twiddles_size =complex_dtype_size * _n0;
size_t twiddles_size_2d = complex_dtype_size * _n1;

if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) {
reservespace_size = out_c_dtype_size * _n0 * _n0 * 2 +
out_c_dtype_size * _n1 * _n1 * 2; /* DFT matrix */
workspace_size = out_c_dtype_size * _n1 * _n0 * batch * 6;
reservespace_size = complex_dtype_size * _n0 * _n0 * 2 +
complex_dtype_size * _n1 * _n1 * 2; /* DFT matrix */
workspace_size = complex_dtype_size * _n1 * _n0 * batch * 6;
} else if (fft_plan->fft_strategy == CNFFT_FUNC_TWO_LEVEL_STOCKHAM) {
reservespace_size = sizeof(int) * (FFT_MAXFACTORS) /* factors */
+ sizeof(int) * (FFT_MAXFACTORS) + twiddles_size * 2 +
Expand Down
5 changes: 0 additions & 5 deletions kernels/fft/fft.h
Original file line number Diff line number Diff line change
Expand Up @@ -379,11 +379,6 @@ mluOpStatus_t MLUOP_WIN_API kernelIRFFT2dButterflyRow(cnrtDim3_t k_dim,
mluOpFFTPlan_t fft_plan,
FFTFlag flag);

// Executes the 2D Butterfly FFT kernel for rows, converting complex to real,
// with the specified dimensions, function type, queue, FFT plan, direction, and flag.
mluOpStatus_t MLUOP_WIN_API kernelFFT2dButterflyRowC2R(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
mluOpFFTPlan_t fft_plan, int direction, FFTFlag flag);

// Executes the complex-to-complex FFT/DFT matrix kernel with the specified dimensions,
// function type, queue, FFT plan, input real data type, and size.
Expand Down
9 changes: 5 additions & 4 deletions kernels/fft/fft_optm_device/fft_c2c_stockham_gdram.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
#include "kernels/fft/fft_optm_device/fft_c2c_stockham_nram.h"

extern __nram__ char
nram_buffer[MAX_NRAM_SIZE + REM_FOR_STACK - 32 * 1024 - FFT_MAXFACTORS * 4];
extern __nram__ int nram_factors[FFT_MAXFACTORS];
nram_buffer[MAX_NRAM_SIZE + REM_FOR_STACK - 32 * 1024];
__mlu_shared__ char sram_buffer[MAX_SRAM_SIZE];
extern __wram__ char wram_buffer[MAX_WRAM_SIZE];

Expand All @@ -40,7 +39,8 @@ __mlu_func__ void computeMutiStageOnchip(DT *input, DT *output, int *factors,
int repeat_num = total_num / taskDim;
int remain_num = total_num % taskDim;

char *nram_buf = nram_buffer;
char *nram_buf = nram_buffer + FFT_MAXFACTORS * sizeof(int);
int *nram_factors = (int *)nram_buffer;

int t_len = repeat_num + ((remain_num > 0 && taskId < remain_num) ? 1 : 0);
int t_start = taskId - remain_num <= 0 ? taskId * (repeat_num + 1)
Expand Down Expand Up @@ -223,7 +223,8 @@ __mlu_func__ void computeMutiStageOnchipColumn(DT *input, DT *output,
int repeat_num = total_num / taskDim;
int remain_num = total_num % taskDim;

char *nram_buf = nram_buffer;
char *nram_buf = nram_buffer + FFT_MAXFACTORS * sizeof(int);
int *nram_factors = (int *)nram_buffer;

int t_len = repeat_num + ((remain_num > 0 && taskId < remain_num) ? 1 : 0);
int t_start = taskId - remain_num <= 0 ? taskId * (repeat_num + 1)
Expand Down
6 changes: 3 additions & 3 deletions kernels/fft/fft_optm_device/fft_c2r_stockham_gdram.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
#include "kernels/fft/fft_optm_device/fft_c2r_stockham_nram.h"

extern __nram__ char
nram_buffer[MAX_NRAM_SIZE + REM_FOR_STACK - 32 * 1024 - FFT_MAXFACTORS * 4];
extern __nram__ int nram_factors[FFT_MAXFACTORS];
nram_buffer[MAX_NRAM_SIZE + REM_FOR_STACK - 32 * 1024];
__mlu_shared__ char sram_buffer[MAX_SRAM_SIZE];
extern __wram__ char wram_buffer[MAX_WRAM_SIZE];

Expand All @@ -40,7 +39,8 @@ __mlu_func__ void computeMutiStageOnchipC2R(DT *input, DT *output, int *factors,
int repeat_num = total_num / taskDim;
int remain_num = total_num % taskDim;

char *nram_buf = nram_buffer;
char *nram_buf = nram_buffer + FFT_MAXFACTORS * sizeof(int);
int *nram_factors = (int *)nram_buffer;

// Each core needs to process "t_len" blocks, "remain_num" is evenly
// assigned to the previous "remian_num" cores.
Expand Down
7 changes: 4 additions & 3 deletions kernels/fft/fft_optm_device/fft_r2c_stockham_gdram.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
#include "kernels/fft/fft_optm_device/fft_r2c_stockham_nram.h"

extern __nram__ char
nram_buffer[MAX_NRAM_SIZE + REM_FOR_STACK - 32 * 1024 - FFT_MAXFACTORS * 4];
extern __nram__ int nram_factors[FFT_MAXFACTORS];
nram_buffer[MAX_NRAM_SIZE + REM_FOR_STACK - 32 * 1024];
__mlu_shared__ char sram_buffer[MAX_SRAM_SIZE];
extern __wram__ char wram_buffer[MAX_WRAM_SIZE];

Expand All @@ -40,7 +39,9 @@ __mlu_func__ void computeMutiStageR2COnchip(DT *input, DT *output, int *factors,
int repeat_num = total_num / taskDim;
int remain_num = total_num % taskDim;

char *nram_buf = nram_buffer;
char *nram_buf = nram_buffer + FFT_MAXFACTORS * sizeof(int);
int *nram_factors = (int *)nram_buffer;

int t_len = repeat_num + ((remain_num > 0 && taskId < remain_num) ? 1 : 0);
int t_start = taskId - remain_num <= 0 ? taskId * (repeat_num + 1)
: (remain_num * (repeat_num + 1) +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@
#include "kernels/fft/fft_optm_device/fft_c2c_stockham_gdram.h"

__nram__ char
nram_buffer[MAX_NRAM_SIZE + REM_FOR_STACK - 32 * 1024 - FFT_MAXFACTORS * 4];
nram_buffer[MAX_NRAM_SIZE + REM_FOR_STACK - 32 * 1024];

__nram__ int nram_factors[FFT_MAXFACTORS];
__wram__ char wram_buffer[MAX_WRAM_SIZE];

// Kernel function for 1D FFT butterfly operations on rows.
Expand Down
26 changes: 1 addition & 25 deletions kernels/fft/fft_optm_device/fft_two-level_network_c2r_device.mlu
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
#include "kernels/fft/fft_optm_device/fft_c2r_stockham_gdram.h"

__nram__ char
nram_buffer[MAX_NRAM_SIZE + REM_FOR_STACK - 32 * 1024 - FFT_MAXFACTORS * 4];
__nram__ int nram_factors[FFT_MAXFACTORS];
nram_buffer[MAX_NRAM_SIZE + REM_FOR_STACK - 32 * 1024];
__wram__ char wram_buffer[MAX_WRAM_SIZE];

__mlu_global__ void MLUKernelFFT1dButterflyRowC2R(
Expand Down Expand Up @@ -59,29 +58,6 @@ __mlu_global__ void MLUKernelFFT1dButterflyRowC2R(
}
}

mluOpStatus_t MLUOP_WIN_API kernelFFT2dButterflyRowC2R(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
mluOpFFTPlan_t fft_plan, int direction, FFTFlag flag) {
VLOG(5) << "Launch Kernel kernelFFT1dButterflyRow <<Union"
<< k_type / CORE_DIM << ", " << k_dim.x << ", " << k_dim.y << ", "
<< k_dim.z << ">>>";
if (direction == FFT_FORWARD) {
KERNEL_CHECK((MLUKernelFFT1dButterflyRowC2R<<<k_dim, k_type, queue>>>(
fft_plan->mlu_addrs.input, fft_plan->mlu_addrs.output,
fft_plan->mlu_addrs.factors, fft_plan->mlu_addrs.twiddles,
fft_plan->mlu_addrs.twiddles_end, fft_plan->mlu_addrs.dft_matrix,
fft_plan->mlu_addrs.buffer_buf, fft_plan->n[0], flag,
fft_plan->output_dtype)));
} else {
KERNEL_CHECK((MLUKernelFFT1dButterflyRowC2R<<<k_dim, k_type, queue>>>(
fft_plan->mlu_addrs.output, fft_plan->mlu_addrs.output,
fft_plan->mlu_addrs.factors, fft_plan->mlu_addrs.twiddles_inv,
fft_plan->mlu_addrs.twiddles_inv_end, fft_plan->mlu_addrs.idft_matrix,
fft_plan->mlu_addrs.buffer_buf, fft_plan->n[0], flag,
fft_plan->output_dtype)));
}
return MLUOP_STATUS_SUCCESS;
}

mluOpStatus_t MLUOP_WIN_API kernelFFT1dButterflyRowC2R(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
#include "kernels/fft/fft_optm_device/fft_r2c_stockham_gdram.h"

__nram__ char
nram_buffer[MAX_NRAM_SIZE + REM_FOR_STACK - 32 * 1024 - FFT_MAXFACTORS * 4];
__nram__ int nram_factors[FFT_MAXFACTORS];
nram_buffer[MAX_NRAM_SIZE + REM_FOR_STACK - 32 * 1024];

__wram__ char wram_buffer[MAX_WRAM_SIZE];

Expand Down
2 changes: 0 additions & 2 deletions kernels/fft/irfft/irfft_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1825,7 +1825,6 @@ mluOpStatus_t execIRFFT2d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan,
fft_plan, FFT_IFFT);

INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);

status = kernelIRFFT2dButterflyRow(k_dim, k_type, handle->queue, fft_plan,
FFT_IFFT);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
Expand All @@ -1840,7 +1839,6 @@ mluOpStatus_t execIRFFT2d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan,
fft_plan->mlu_addrs.output =
(void *)((uint64_t)(fft_plan->mlu_addrs.output) -
fft_plan->batch * odist);

status = makeIRFFT2dContiguousOutput(handle, fft_plan, output,
fft_plan->mlu_addrs.output);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
Expand Down

0 comments on commit 36a9560

Please sign in to comment.