diff --git a/kernels/fft/c2c_fft/c2c_fft_host.cpp b/kernels/fft/c2c_fft/c2c_fft_host.cpp index 29e213163..116763023 100644 --- a/kernels/fft/c2c_fft/c2c_fft_host.cpp +++ b/kernels/fft/c2c_fft/c2c_fft_host.cpp @@ -2112,7 +2112,7 @@ mluOpStatus_t execIRFFT1d_v2(mluOpHandle_t handle, // fft_plan->mlu_addrs.input); // INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); // } - + status = execFFTc2r1d(handle, fft_plan, scale_factor); INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS); diff --git a/kernels/fft/fft.cpp b/kernels/fft/fft.cpp index e16ad65b9..16247326f 100644 --- a/kernels/fft/fft.cpp +++ b/kernels/fft/fft.cpp @@ -603,13 +603,13 @@ mluOpStatus_t MLUOP_WIN_API fftFactor(const int _n, int *facbuf, } break; - case (32 * 17): - if (n % 32 == 0) { - r = 32; - } else if ((n % 17) == 0) { - r = 17; - } - break; + case (32 * 17): + if (n % 32 == 0) { + r = 32; + } else if ((n % 17) == 0) { + r = 17; + } + break; case 600: if (n % 30 == 0) { @@ -726,19 +726,19 @@ mluOpStatus_t MLUOP_WIN_API fftTwoStepFactor(mluOpFFTPlan_t fft_plan, if (is_row_major) { switch (_n) { case (32 * 17): - r = 32 * 17; + r = 32 * 17; break; case (200): - r = 200; + r = 200; break; case (600): - r = 600; + r = 600; break; case (256): - r = 256; + r = 256; break; case 1024: @@ -868,7 +868,29 @@ mluOpStatus_t MLUOP_WIN_API fftTwoStepFactor(mluOpFFTPlan_t fft_plan, } } n /= r; - in_stride = _n / r; + + switch (fft_plan->fft_type) { + // r2c + case CNFFT_HALF2COMPLEX_HALF: + case CNFFT_FLOAT2COMPLEX_FLOAT: + case CNFFT_COMPLEX_HALF2HALF: + case CNFFT_COMPLEX_FLOAT2FLOAT: { + if ((n * r) != _n) { + in_stride = (((out_stride / 2) + 1) * section_num) / r; + + } else { + in_stride = _n / r; + } + } + + case CNFFT_COMPLEX_HALF2COMPLEX_HALF: + case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT: { + in_stride = _n / r; + } + default: + break; + } + section_num = n; stage_num++; @@ -991,7 +1013,7 @@ mluOpStatus_t MLUOP_WIN_API calParallelNumLowBound(mluOpFFTPlan_t fft_plan, switch (fft_plan->fft_type) { // r2c case CNFFT_HALF2COMPLEX_HALF: - case CNFFT_FLOAT2COMPLEX_FLOAT: + case CNFFT_FLOAT2COMPLEX_FLOAT: case CNFFT_COMPLEX_HALF2HALF: case CNFFT_COMPLEX_FLOAT2FLOAT: case CNFFT_COMPLEX_HALF2COMPLEX_HALF: @@ -2226,13 +2248,12 @@ mluOpStatus_t MLUOP_WIN_API mluOpSetFFTReserveArea(mluOpHandle_t handle, case CNFFT_COMPLEX_HALF2COMPLEX_HALF: case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT: { if (fft_plan->rank == 1) { - if(fft_plan->prime) { - status = setFFT1dReserveArea(handle, fft_plan, api); - }else{ -status = setFFT1dReserveArea_v2(handle, fft_plan, api); + if (fft_plan->prime) { + status = setFFT1dReserveArea(handle, fft_plan, api); + } else { + status = setFFT1dReserveArea_v2(handle, fft_plan, api); } - } else if (fft_plan->rank == 2) { // status = setFFT1dReserveArea(handle, fft_plan, api); status = setFFT2dReserveArea(handle, fft_plan, api); @@ -2244,10 +2265,10 @@ status = setFFT1dReserveArea_v2(handle, fft_plan, api); case CNFFT_COMPLEX_HALF2HALF: case CNFFT_COMPLEX_FLOAT2FLOAT: { if (fft_plan->rank == 1) { - if(fft_plan->prime) { - status = setIRFFT1dReserveArea(handle, fft_plan, api); - }else{ -status = setIRFFT1dReserveArea_v2(handle, fft_plan, api); + if (fft_plan->prime) { + status = setIRFFT1dReserveArea(handle, fft_plan, api); + } else { + status = setIRFFT1dReserveArea_v2(handle, fft_plan, api); } } else if (fft_plan->rank == 2) { // status = setFFT1dReserveArea(handle, fft_plan, api); diff --git a/kernels/fft/fft_optm_device/fft_c2c_stockham_nram.h b/kernels/fft/fft_optm_device/fft_c2c_stockham_nram.h index 1a1b8a33c..b72ab76a6 100644 --- a/kernels/fft/fft_optm_device/fft_c2c_stockham_nram.h +++ b/kernels/fft/fft_optm_device/fft_c2c_stockham_nram.h @@ -936,7 +936,7 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpong( small_butterfly_num, para_num, small_in_stride, dir, radix); } else { - // TODO(zrg):check + // TODO(zrg): check computeGenericButterflyLaststageMat( nram_para_store_pong, nram_para_store_pong + max_para_ldst_num * large_radix, @@ -1117,25 +1117,24 @@ __mlu_func__ void computeLargeButterflyOtherstages( int para_load_num = (max_para_ldst_num > (large_butterfly_num - i)) ? (large_butterfly_num - i) : max_para_ldst_num; - if (para_load_num != 1) { - __memcpy_async(nram_para_load_in.r, input + Fin_stride + i, - sizeof(DT) * para_load_num, GDRAM2NRAM, - sizeof(DT) * para_load_num, - large_in_stride * sizeof(DT), large_radix - 1); - __memcpy_async(nram_para_load_in.i, input + nfft + Fin_stride + i, - sizeof(DT) * para_load_num, GDRAM2NRAM, - sizeof(DT) * para_load_num, - large_in_stride * sizeof(DT), large_radix - 1); - __memcpy_async(nram_para_load_tw.r, cur_large_twiddles + i, - sizeof(DT) * para_load_num, SRAM2NRAM, - sizeof(DT) * para_load_num, - large_out_stride * sizeof(DT), large_radix - 2); - __memcpy_async( - nram_para_load_tw.i, - cur_large_twiddles + large_butterfly_num * (large_radix - 1) + i, - sizeof(DT) * para_load_num, SRAM2NRAM, sizeof(DT) * para_load_num, - large_out_stride * sizeof(DT), large_radix - 2); - } + + __memcpy_async(nram_para_load_in.r, input + Fin_stride + i, + sizeof(DT) * para_load_num, GDRAM2NRAM, + sizeof(DT) * para_load_num, large_in_stride * sizeof(DT), + large_radix - 1); + __memcpy_async(nram_para_load_in.i, input + nfft + Fin_stride + i, + sizeof(DT) * para_load_num, GDRAM2NRAM, + sizeof(DT) * para_load_num, large_in_stride * sizeof(DT), + large_radix - 1); + __memcpy_async(nram_para_load_tw.r, cur_large_twiddles + i, + sizeof(DT) * para_load_num, SRAM2NRAM, + sizeof(DT) * para_load_num, + large_out_stride * sizeof(DT), large_radix - 2); + __memcpy_async( + nram_para_load_tw.i, + cur_large_twiddles + large_butterfly_num * (large_radix - 1) + i, + sizeof(DT) * para_load_num, SRAM2NRAM, sizeof(DT) * para_load_num, + large_out_stride * sizeof(DT), large_radix - 2); } // pipeline: store-stage @@ -1763,7 +1762,7 @@ __mlu_func__ void computeLargeButterflyOtherstagesBatchPingpong( // pipeline: load-stage if (repeat_id < repeat_num) { - if (para_num != 1) { + if (1) { __memcpy_async( nram_para_load_in_ping.r, input_batch + sec_count * large_butterfly_num + butterfly_id, diff --git a/kernels/fft/fft_optm_device/fft_c2r_stockham_gdram.h b/kernels/fft/fft_optm_device/fft_c2r_stockham_gdram.h index c2bf4d154..d32c6b5b7 100644 --- a/kernels/fft/fft_optm_device/fft_c2r_stockham_gdram.h +++ b/kernels/fft/fft_optm_device/fft_c2r_stockham_gdram.h @@ -56,7 +56,7 @@ __mlu_func__ void computeMutiStageOnchipC2R(DT *input, DT *output, int *factors, taskId, repeat_num, remain_num, t_len, t_start, t_end); int radix, section_num, butterfly_num, in_stride, stage_count, value_mul, - small_factors_offset; + out_stride, small_factors_offset; int *small_factors; int last_stage; @@ -78,17 +78,23 @@ __mlu_func__ void computeMutiStageOnchipC2R(DT *input, DT *output, int *factors, const int nfft = factors[1]; // first stage - radix = factors[5 + 0]; - section_num = factors[5 + 1]; - in_stride = factors[5 + 3]; - small_factors_offset = factors[5 + 4]; + radix = factors[5 * _stage_count + 0]; + section_num = factors[5 * _stage_count + 1]; + butterfly_num = factors[5 * _stage_count + 3]; + out_stride = factors[5 * _stage_count + 3]; + in_stride = butterfly_num; + small_factors_offset = factors[5 * _stage_count + 4]; + + for (int loop_stage = 2; loop_stage < _stage_count; loop_stage++) { + cur_radix = factors[5 * loop_stage]; + butterfly_num_stage = factors[5 * loop_stage + 2]; + twiddles += (cur_radix - 1) * (butterfly_num_stage / 2); + } // small_factors = factors + small_factors_offset; - stage_count = _stage_count; - last_stage = (stage_count == 1); - - + stage_count = 1; + last_stage = (_stage_count == 1); if (__is_mpu()) { __memcpy_async(sram_factors, factors, FFT_MAXFACTORS * sizeof(int), @@ -96,7 +102,6 @@ __mlu_func__ void computeMutiStageOnchipC2R(DT *input, DT *output, int *factors, if (twiddles_size) { __memcpy_async(sram_twiddles, twiddles, twiddles_size * sizeof(DT), GDRAM2SRAM); - } const dft_table_entry *dft_table_gdram = @@ -166,11 +171,11 @@ __mlu_func__ void computeMutiStageOnchipC2R(DT *input, DT *output, int *factors, } // sram_large_tw - stage_count--; - if (stage_count == 0) { - // continue; + + if (stage_count == _stage_count) { return; } + stage_count++; // if (__is_mpu()) { // return; @@ -178,7 +183,7 @@ __mlu_func__ void computeMutiStageOnchipC2R(DT *input, DT *output, int *factors, // sram_large_tw value_mul = 10; - for (; stage_count > 1; stage_count--) { + for (; stage_count < _stage_count; stage_count++) { // fft_swap_ptr
(&buffer, &output); // FFT_SWAP_PTR(buffer, output); FFT_SWAP_PTR(buffer, output); @@ -198,14 +203,17 @@ __mlu_func__ void computeMutiStageOnchipC2R(DT *input, DT *output, int *factors, FFT_SWAP_PTR(odd_extra_buffer, output); } - // value_mul = (_stage_count - stage_count + 1) * 5; + value_mul = (_stage_count - stage_count + 1) * 5; // update parameter - radix = factors[value_mul++]; - section_num = factors[value_mul++]; - butterfly_num = factors[value_mul++]; - in_stride = factors[value_mul++]; - small_factors_offset = factors[value_mul++]; + + radix = factors[value_mul]; + section_num = factors[value_mul + 1]; + butterfly_num = factors[value_mul + 2]; + in_stride = butterfly_num; + out_stride = factors[value_mul + 3]; + small_factors_offset = factors[value_mul + 4]; + twiddles -= (radix - 1) * (butterfly_num / 2 + 1); small_factors = factors + small_factors_offset; @@ -213,28 +221,14 @@ __mlu_func__ void computeMutiStageOnchipC2R(DT *input, DT *output, int *factors, // MLULOG("other stage radix: %d \n", radix); if (repeat_num > 0 || taskId < remain_num) { - if (6000 / radix > repeat_num && 0) { - for (int t = t_start; t < t_end; t++) { - DT *output_batch = output + t * (nfft << 1); - DT *buffer_batch = buffer + t * (nfft << 1); - - computeLargeButterflyOtherstages
( - output_batch, buffer_batch, (DT *)twiddles, _twiddles, - sram_dftmtx, section_num, butterfly_num, in_stride, - (void *)nram_buf, small_factors, nfft, direction, 0); - - // __sync(); - } - } else { - computeLargeButterflyOtherstagesBatchPingpong
( - output, buffer, (DT *)twiddles, _twiddles, sram_dftmtx, - section_num, butterfly_num, in_stride, (void *)nram_buf, - small_factors, nfft, t_start, t_end, direction, 0); - } + computeLargeButterflyOtherstagesBatchPingpongC2R
( + output, buffer, (DT *)twiddles, _twiddles, sram_dftmtx, section_num, + butterfly_num, in_stride, (void *)nram_buf, small_factors, nfft, + t_start, t_end, direction, 0); } } - twiddles += butterfly_num * (radix - 1) * 2; // 2 for complex - } // for (stage_count) + twiddles += (butterfly_num / 2 - 1) * (radix - 1) * 2; // 2 for complex + } // for (stage_count) // __mlu_shared__ DT *sram_tw[2048]; // radix-1024 // __mlu_shared__ DT *sram_tw[64]; // radix-1024 @@ -248,32 +242,21 @@ __mlu_func__ void computeMutiStageOnchipC2R(DT *input, DT *output, int *factors, FFT_SWAP_PTR(buffer, output); // update parameter - radix = factors[value_mul++]; - section_num = factors[value_mul++]; - butterfly_num = factors[value_mul++]; - in_stride = factors[value_mul++]; - small_factors_offset = factors[value_mul]; + radix = factors[5]; + section_num = factors[6]; + butterfly_num = factors[7]; + in_stride = butterfly_num; + out_stride = factors[8]; + small_factors_offset = factors[9]; small_factors = factors + small_factors_offset; if (__is_ipu()) { if (repeat_num > 0 || taskId < remain_num) { - if (0) { - for (int t = t_start; t < t_end; t++) { - DT *output_batch = output + t * (nfft << 1); - DT *buffer_batch = buffer + t * (nfft << 1); - - computeLargeButterflyLaststage
( - output_batch, buffer_batch, (DT *)twiddles, _twiddles, - sram_dftmtx, section_num, butterfly_num, in_stride, - (void *)nram_buf, small_factors, nfft, direction); - } - } else { - computeLargeButterflyLaststageBatchPingpong( - output, buffer, (DT *)twiddles, _twiddles, sram_dftmtx, - section_num, butterfly_num, in_stride, (void *)nram_buf, - small_factors, nfft, t_start, t_end, direction); - } + computeLargeButterflyLaststageBatchPingpongC2R( + output, buffer, (DT *)twiddles, _twiddles, sram_dftmtx, section_num, + butterfly_num, in_stride, (void *)nram_buf, small_factors, nfft, + t_start, t_end, direction); } } } diff --git a/kernels/fft/fft_optm_device/fft_c2r_stockham_nram.h b/kernels/fft/fft_optm_device/fft_c2r_stockham_nram.h index 42568f848..4939a65fe 100644 --- a/kernels/fft/fft_optm_device/fft_c2r_stockham_nram.h +++ b/kernels/fft/fft_optm_device/fft_c2r_stockham_nram.h @@ -24,465 +24,6 @@ #include "kernels/fft/fft_optm_device/fft_generic_butterfly.h" #include "kernels/fft/fft_optm_device/fft_vector_butterfly.h" -template -__mlu_func__ void computeLargeButterflyFirststageC2R( - DT *output, DT *input, int large_in_stride, int section_num, - const DT *twiddles, const DT *dft_matrix, void *nram_buf, - const int *small_factors, int dir, int nfft, int last_stage) { - const dft_table_entry *dft_table = (const dft_table_entry *)dft_matrix; - - // network info - int radix, small_in_stride, small_stage_count, large_radix, - _small_stage_count; - int small_section_num, small_butterfly_num, value_mul; - int tw_offset; - - const int K_num = 64 / sizeof(DT); - int align_K = 0; - - _small_stage_count = small_factors[0]; - large_radix = small_factors[1]; - tw_offset = small_factors[2]; - - // load compute store - // (0) load 0 ping sync() - // (1) compute 0 ping load 1 pong sync() - // (2) store 0 compute 1 pong load 2 ping sync() - // (3) store 1 compute 2 load 3 sync() - - // compute last-large-stage (nram_out_r,nram_out_i) [2, large_radix]-> - // transpose -> [large_radix, 2] - - // complex array -> real array, imag array -> complex array - // first-large-stage complex -> real array, imag array - // other-large-stage none - // last-large-stage real array, imag array -> complex - - // const int max_para_ldst_num = 1; - int max_para_ldst_num = (4096 + large_radix - 1) / large_radix; - max_para_ldst_num = - (section_num < max_para_ldst_num) ? section_num : max_para_ldst_num; - - const DT *small_twiddles = twiddles + tw_offset * 2; // complex - - // assign nram space - int nram_buf_offset = 0; - - DT *nram_in_r = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * max_para_ldst_num; - - DT *nram_in_i = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * max_para_ldst_num; - - DT *nram_out_r = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * max_para_ldst_num; - - DT *nram_out_i = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * max_para_ldst_num; - - DT *nram_para_load_ping = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex - - DT *nram_para_load_pong = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex - - DT *nram_para_store_ping = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex - - DT *nram_para_store_pong = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex - - // transpose space: [radix, 2 * parrallel] -> [parrallel * 2, radix] - // DT *nram_transpose_load = (DT *)nram_buf + nram_buf_offset; - // nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex - - // FFT_CPX_T
nram_transpose_temp; - // temp out-space before transpose - // if last-stage: - // compute_id 0 r - // compute_id 0 i - // compute_id 1 r - // compute_id 1 i - // else: - // compute_id 0 r - // compute_id 1 i - // compute_id 0 r - // compute_id 1 i - - DT *_nram_tw = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * 2; // complex - - int ld_dft_radix = -1; - const int max_radix = 64; - DT *nram_dftmtx = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += max_radix * max_radix * 2; // complex - - // nram space used: - // sizeof(DT) * 2 * large_radix * (max_para_ldst_num * 6 + 1) + sizeof(DT) * 2 - // * (max_radix * max_radix) - // + sizeof(DT) * 2 * large_radix * max_para_ldst_num * 4 - DT *nram_scratch = (DT *)nram_buf + nram_buf_offset; - - __memcpy_async(_nram_tw, small_twiddles, large_radix * sizeof(DT) * 2, - SRAM2NRAM); - - // ceil - int repeat_num = (section_num + max_para_ldst_num - 1) / max_para_ldst_num; - - for (int repeat_id = 0; repeat_id < repeat_num + 2; ++repeat_id) { - // pipeline: load-stage - if (repeat_id < repeat_num) { - // MLULOG("pipeline: load-stage.\n"); - int i = max_para_ldst_num * repeat_id; - DT *nram_para_load = - (repeat_id % 2 == 0) ? nram_para_load_ping : nram_para_load_pong; - - // DT *nram_dftmtx = - // (repeat_id % 2 == 0) ? nram_dftmtx_ping : nram_dftmtx_pong; - int para_load_num = (max_para_ldst_num > (section_num - i)) - ? (section_num - i) - : max_para_ldst_num; - if (section_num == 1) { - __memcpy_async(nram_para_load, input, sizeof(DT) * 2 * large_radix, - GDRAM2NRAM); - } else { - // gather load - // 2d memcpy - // 0 1 2 3 4 ... 1023 - // GDRAM -> NRAM - // 8bytes radix-1024 - // 64bytes - - __memcpy_async(nram_para_load, input + i * 2, - sizeof(DT) * 2 * para_load_num, GDRAM2NRAM, - sizeof(DT) * 2 * para_load_num, - large_in_stride * sizeof(DT) * 2, large_radix - 1); - } - } - - // pipeline: store-stage - if (repeat_id >= 2) { - // MLULOG("pipeline: store-stage.\n"); - int i = max_para_ldst_num * (repeat_id - 2); - - int para_store_num = (max_para_ldst_num > (section_num - i)) - ? (section_num - i) - : max_para_ldst_num; - - DT *nram_para_store = - (repeat_id % 2 == 0) ? nram_para_store_ping : nram_para_store_pong; - - if (last_stage) { - if (section_num == 1) { - __memcpy_async(output, nram_para_store, sizeof(DT) * 2 * large_radix, - NRAM2GDRAM); - } else { - // scatter-store - __memcpy_async(output + i * large_radix * 2, nram_para_store, - sizeof(DT) * 2 * para_store_num * large_radix, - NRAM2GDRAM); - } - } else { - // real - __memcpy_async(output + i * large_radix, nram_para_store, - para_store_num * large_radix * sizeof(DT), NRAM2GDRAM); - // imag - __memcpy_async(output + i * large_radix + nfft, - nram_para_store + max_para_ldst_num * large_radix, - para_store_num * large_radix * sizeof(DT), NRAM2GDRAM); - } - } - - // pipeline: compute-stage - - if (repeat_id >= 1 && repeat_id < repeat_num + 1) { - int i = max_para_ldst_num * (repeat_id - 1); - - DT *nram_para_load = - (repeat_id % 2 != 0) ? nram_para_load_ping : nram_para_load_pong; - DT *nram_para_store = - (repeat_id % 2 != 0) ? nram_para_store_ping : nram_para_store_pong; - - int para_ldst_num = (max_para_ldst_num > (section_num - i)) - ? (section_num - i) - : max_para_ldst_num; - // // [large_radix, para_ldst_num, 2] -> [para_ldst_num, 2, large_radix] - // __bang_transpose(nram_transpose_load, nram_para_load, large_radix, - // 2 * para_ldst_num); - - // [large_radix, para_ldst_num, 2] -> [2, para_ldst_num, large_radix] - // overlap nram_out_r - // DT *nram_transpose_load = nram_out_r; - // __bang_transpose(nram_transpose_load, nram_para_load, - // large_radix * para_ldst_num, 2); - // // [large_radix, para_ldst_num] -> [para_ldst_num, large_radix] - // __bang_transpose(nram_in_r, nram_transpose_load, large_radix, - // para_ldst_num); - // __bang_transpose(nram_in_i, - // nram_transpose_load + large_radix * para_ldst_num, - // large_radix, para_ldst_num); - - // DT *nram_transpose_load = nram_in_r; - __bang_transpose(nram_in_r, nram_para_load, large_radix * para_ldst_num, - 2); - // [large_radix, para_ldst_num] -> [para_ldst_num, large_radix] - // __bang_transpose(nram_in_r, nram_transpose_load, large_radix, - // para_ldst_num); - // __bang_transpose(nram_in_i, - // nram_transpose_load + large_radix * para_ldst_num, - // large_radix, para_ldst_num); - - for (int compute_id = 0; compute_id < para_ldst_num; - compute_id += para_ldst_num) { - // load real & imag - - radix = small_factors[4]; - small_section_num = small_factors[5]; - small_in_stride = small_factors[7]; - small_stage_count = _small_stage_count; - - // __memcpy(nram_in_r, - // nram_transpose_load + compute_id * large_radix * 2, - // large_radix * sizeof(DT) * 2, NRAM2NRAM); - - // first stage - - if (ld_dft_radix != radix) { - ld_dft_radix = radix; - for (int entry = 0;; entry++) { - if (dft_table[entry].radix == ld_dft_radix) { - align_K = K_num * ((radix + K_num - 1) / K_num); - __memcpy(nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2], - sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM); - } - - if (dft_table[entry].radix == -1) { - break; - } - } - } - - switch (radix) { - default: - MLULOG("computeGenericButterflyFirststageMat: %d.\n", radix); - computeGenericButterflyFirststageMat( - nram_out_r, nram_out_i, nram_in_r, - nram_in_r + large_radix * para_ldst_num, nram_scratch, - nram_dftmtx, small_section_num * para_ldst_num, - small_section_num * para_ldst_num, 1, dir, radix); - break; - } - - // [radix, small_section_num, para_ldst_num] -> - // [small_section_num, para_ldst_num, radix] -> [para_ldst_num, - // small_section_num, radix] - - small_stage_count--; - if (small_stage_count == 0) { - // nram to gdram - - if (last_stage) { - // [2, para_ldst_num, large_radix] -> [para_ldst_num, large_radix, - // 2] - // DT* nram_transpose_store = nram_in_r; - - __bang_transpose(nram_para_store, nram_out_r, 2, - max_para_ldst_num * large_radix); - } else { - // [2, para_ldst_num, large_radix] -> [2, para_ldst_num, - // large_radix] - // TODO(zrg): redundant move - __memcpy(nram_para_store, nram_out_r, - para_ldst_num * large_radix * sizeof(DT), NRAM2NRAM); - __memcpy(nram_para_store + max_para_ldst_num * large_radix, - nram_out_i, para_ldst_num * large_radix * sizeof(DT), - NRAM2NRAM); - } - - continue; - } - - // [small_section_num, para_ldst_num, radix] -> [para_ldst_num, - // small_section_num, radix] - - FFT_SWAP_PTR(nram_out_r, nram_in_r); - FFT_SWAP_PTR(nram_out_i, nram_in_i); - - TRANSPOSE_XYZ2YXZ_PAIR(nram_out_r, nram_out_i, nram_in_r, nram_in_i, - small_section_num, para_ldst_num, radix, DT) - - value_mul = 8; - // DT *sram_tw = (DT *)sram_buffer; - DT *nram_tw = _nram_tw; - - for (; small_stage_count > 1; small_stage_count--) { - FFT_SWAP_PTR(nram_out_r, nram_in_r); - FFT_SWAP_PTR(nram_out_i, nram_in_i); - - // value_mul = (_small_stage_count - small_stage_count + 1) << 2; - - // // update parameter - radix = small_factors[value_mul++]; - small_section_num = small_factors[value_mul++]; - small_butterfly_num = small_factors[value_mul++]; - small_in_stride = small_factors[value_mul++]; - // value_mul += 4; - // copy GDRAM2SRAM - - // if (compute_id == 0 && repeat_id == 1 && 0) { - // __memcpy(nram_tw, small_twiddles, - // small_butterfly_num * (radix - 1) * sizeof(DT) * 2, - // GDRAM2NRAM); - // small_twiddles += small_butterfly_num * (radix - 1) * 2; - // } - - if (ld_dft_radix != radix) { - ld_dft_radix = radix; - for (int entry = 0;; entry++) { - if (dft_table[entry].radix == ld_dft_radix) { - align_K = K_num * ((radix + K_num - 1) / K_num); - __memcpy(nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2], - sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM); - break; - } - - if (dft_table[entry].radix == -1) { - break; - } - } - } - - switch (radix) { - case 2: - // computeRadix2ButterflyOtherstages(Fout, Fin, section_num, - // section_num, 1, dir); - break; - case 3: - computeRadix3ButterflyOtherstages( - nram_out_r, nram_out_i, nram_in_r, nram_in_i, nram_scratch, - nram_tw, small_section_num, small_butterfly_num, - small_in_stride, dir); - break; - - default: - // computeGenericButterflyOtherstages(Fout, buffer, twiddles, - // radix, section_num, butterfly_num, in_stride, 0, dir); - MLULOG("computeGenericButterflyOtherstagesMat: %d.\n", radix); - computeGenericButterflyOtherstagesMat( - nram_out_r, nram_out_i, nram_in_r, nram_in_i, nram_scratch, - nram_dftmtx, nram_tw, small_section_num, small_butterfly_num, - para_ldst_num, small_in_stride, dir, radix); - break; - } - - nram_tw += small_butterfly_num * (radix - 1) * 2; - } // for (stage_count) - - // for (int j = 0; j < large_radix; j++) { - // MLULOG("output i: (%f, %f).\n", nram_out_r[j], nram_out_i[j]); - // } - - // MLULOG("butterfly id: %d\n", i); - // last stage - { - FFT_SWAP_PTR(nram_out_r, nram_in_r); - FFT_SWAP_PTR(nram_out_i, nram_in_i); - - // copy GDRAM2SRAM - - // update parameter - // value_mul = _small_stage_count << 2; - radix = small_factors[value_mul++]; - small_section_num = small_factors[value_mul++]; - small_butterfly_num = small_factors[value_mul++]; - small_in_stride = small_factors[value_mul]; - - // if (compute_id == 0 && repeat_id == 1 && 0) { - // __memcpy(nram_tw, small_twiddles, - // small_butterfly_num * (radix - 1) * sizeof(DT) * 2, - // GDRAM2NRAM); - // } - if (ld_dft_radix != radix) { - ld_dft_radix = radix; - for (int entry = 0;; entry++) { - if (dft_table[entry].radix == ld_dft_radix) { - align_K = K_num * ((radix + K_num - 1) / K_num); - __memcpy(nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2], - sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM); - break; - } - - if (dft_table[entry].radix == -1) { - break; - } - } - } - - switch (radix) { - case 2: - break; - - default: - MLULOG("computeGenericButterflyLaststageMat: %d.\n", radix); - computeGenericButterflyLaststageMat( - nram_out_r, nram_out_i, nram_in_r, nram_in_i, nram_scratch, - nram_dftmtx, nram_tw, small_section_num, small_butterfly_num, - para_ldst_num, small_in_stride, dir, radix); - MLULOG("computeGenericButterflyLaststageMat: %d End.\n", radix); - break; - } - - if (last_stage) { - // [2, para_ldst_num, large_radix] -> [para_ldst_num, large_radix, - // 2] - // DT* nram_transpose_store = nram_in_r; - - __bang_transpose(nram_para_store, nram_out_r, 2, - max_para_ldst_num * large_radix); - } else { - // [2, para_ldst_num, large_radix] -> [2, para_ldst_num, - // large_radix] - // TODO(zrg): redundant move - __memcpy(nram_para_store, nram_out_r, - para_ldst_num * large_radix * sizeof(DT), NRAM2NRAM); - __memcpy(nram_para_store + max_para_ldst_num * large_radix, - nram_out_i, para_ldst_num * large_radix * sizeof(DT), - NRAM2NRAM); - } - - // if (last_stage) { - // // MLULOG("last_stage. \n"); - - // // __memcpy(nram_transpose_temp + (compute_id * 2) * large_radix, - // // nram_out_r, large_radix * sizeof(DT), NRAM2NRAM); - // // __memcpy(nram_transpose_temp - // // + (compute_id * 2 + 1) * large_radix, - // // nram_out_i, large_radix * sizeof(DT), NRAM2NRAM); - - // __memcpy(nram_transpose_temp + (compute_id * 2) * large_radix, - // nram_out_r, large_radix * sizeof(DT) * 2, NRAM2NRAM); - // __bang_transpose( - // nram_para_store + (compute_id * 2) * large_radix, - // nram_transpose_temp + (compute_id * 2) * large_radix, 2, - // large_radix); - - // } else { - // // MLULOG("not last_stage. \n"); - // __memcpy(nram_para_store + compute_id * large_radix, nram_out_r, - // large_radix * sizeof(DT), NRAM2NRAM); - // __memcpy(nram_para_store + - // (compute_id + max_para_ldst_num) * large_radix, - // nram_out_i, large_radix * sizeof(DT), NRAM2NRAM); - // } - // MLULOG("last_stage. \n"); - } - } - } - - __sync(); - } -} - template __mlu_func__ void computeLargeButterflyFirststageBatchPingpongC2R( DT *output, DT *input, int large_in_stride, int section_num, @@ -506,8 +47,7 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpongC2R( large_radix = small_factors[1]; tw_offset = small_factors[2]; printf("tw_offset: %d\n\n", tw_offset); - const int half_butterfly_num = section_num/2 + 1; - + const int half_butterfly_num = section_num / 2 + 1; max_para_ldst_num = (half_butterfly_num < small_factors[3]) ? half_butterfly_num @@ -535,7 +75,7 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpongC2R( // max_para_ldst_num = ((7232) / large_radix > 0) ? (7232) / large_radix : 1; const DT *small_twiddles = twiddles + tw_offset * 2; // complex - + // assign nram space int nram_buf_offset = 0; @@ -580,18 +120,8 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpongC2R( (DT *)nram_buf + nram_buf_offset + max_radix * max_radix * 2}; nram_buf_offset += max_radix * max_radix * 4; // complex - // nram space used: - // sizeof(DT) * 2 * large_radix * (max_para_ldst_num * 6 + 1) + sizeof(DT) * 2 - // * (max_radix * max_radix) - // + sizeof(DT) * 2 * large_radix * max_para_ldst_num * 4 DT *nram_scratch = (DT *)nram_buf + nram_buf_offset; - // DT *nram_temp_r = (DT *)nram_buf + nram_buf_offset; - // nram_buf_offset += large_radix * max_para_ldst_num; - - // DT *nram_temp_i = (DT *)nram_buf + nram_buf_offset; - // nram_buf_offset += large_radix * max_para_ldst_num; - __memcpy_async(_nram_tw, small_twiddles, large_radix * sizeof(DT) * 2, SRAM2NRAM); @@ -646,15 +176,12 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpongC2R( } } - // pipeline: store-stage if (repeat_id >= 2) { if (last_stage) { if (half_butterfly_num == 1) { // store only real part - printf(" taskId: %d, store only real part, odis: %ld , size: %ld\n", - taskId, (size_t)output_batch, - size_t(sizeof(DT) * large_radix)); + __memcpy_async(output_batch - 2 * odist, nram_para_store_ping, sizeof(DT) * large_radix, NRAM2GDRAM); @@ -679,10 +206,9 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpongC2R( // pipeline: compute-stage if (repeat_id >= 1 && repeat_id < repeat_num + 1) { - - for (int j = 0; j < 32; j++) { - printf("694 _nram_tw [%d]: (%f).\n", j, _nram_tw[j]); - } + for (int j = 0; j < 32; j++) { + printf("694 _nram_tw [%d]: (%f).\n", j, _nram_tw[j]); + } DT *nram_in_r = nram_para_store_pong; DT *nram_in_i = nram_para_store_pong + large_radix * max_para_ldst_num; @@ -704,14 +230,10 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpongC2R( nram_in_r + upper_radix * para_num + para_num * large_radix, 1, lower_radix * para_num); - __memcpy(nram_in_r + upper_radix * para_num + para_num * large_radix, - nram_out_i, para_num * lower_radix * sizeof(DT), NRAM2NRAM); - - - + __bang_mul_scalar( + nram_in_r + upper_radix * para_num + para_num * large_radix, + nram_out_i, -1, para_num * lower_radix); } else { - - __bang_transpose(nram_in_r, nram_para_load_pong, large_radix, 2); __bang_rotate180(nram_out_r, nram_in_r + 1, 1, @@ -723,10 +245,9 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpongC2R( __bang_rotate180(nram_out_i, nram_in_r + 1 + large_radix, 1, (large_radix - large_radix / 2 - 1)); - __bang_mul_scalar( - nram_in_r + (large_radix / 2 + 1) + large_radix, nram_out_i, -1, - (large_radix - large_radix / 2 - 1)); - + __bang_mul_scalar(nram_in_r + (large_radix / 2 + 1) + large_radix, + nram_out_i, -1, + (large_radix - large_radix / 2 - 1)); } for (int compute_id = 0; compute_id < para_num; @@ -738,7 +259,6 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpongC2R( small_in_stride = small_factors[7]; small_stage_count = _small_stage_count; - // first stage if (ld_dft_radix[0] != radix && ld_dft_radix[1] != radix) { ld_dft_radix[1] = ld_dft_radix[0]; @@ -933,17 +453,13 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpongC2R( FFT_SWAP_PTR(nram_dftmtx[0], nram_dftmtx[1]); } - switch (radix) { default: if (last_stage) { - if (nram_in_r == nram_para_store_pong) { FFT_SWAP_PTR(nram_para_store_pong, nram_para_load_pong); } - - computeGenericButterflyLaststageMat( nram_para_store_pong, nram_out_i, nram_in_r, nram_in_i, nram_scratch, nram_dftmtx[0], nram_tw, small_section_num, @@ -951,7 +467,6 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpongC2R( radix); } else { - computeGenericButterflyLaststageMat( nram_para_store_pong, nram_para_store_pong + max_para_ldst_num * large_radix, @@ -983,29 +498,36 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpongC2R( } template -__mlu_func__ void computeLargeButterflyOtherstages( +__mlu_func__ void computeLargeButterflyOtherstagesBatchPingpongC2R( DT *output, DT *input, const DT *cur_large_twiddles, const DT *_twiddles, const DT *dft_matrix, int large_section_num, int large_butterfly_num, int large_in_stride, void *nram_buf, const int *small_factors, int nfft, - int dir, int last_stage) { + const int t_start, const int t_end, int dir, int last_stage) { // return; const dft_table_entry *dft_table = (const dft_table_entry *)dft_matrix; - const int K_num = 64 / sizeof(DT); - int align_K = 0; + int radix, small_in_stride, small_stage_count, large_radix, _small_stage_count; int small_section_num, small_butterfly_num, value_mul; const int large_out_stride = large_butterfly_num; - int tw_offset; + const int half_butterfly_num = large_butterfly_num / 2 + 1; + int tw_offset; + const int K_num = 64 / sizeof(DT); + int align_K = 0; _small_stage_count = small_factors[0]; large_radix = small_factors[1]; tw_offset = small_factors[2]; - + const int upper_radix = (large_radix + 1) / 2; + const int lower_radix = large_radix - upper_radix; const DT *small_twiddles = _twiddles + tw_offset * 2; // complex - const int max_para_ldst_num = (4096 + large_radix - 1) / large_radix; + // const int max_para_ldst_num = (6144 + large_radix - 1) / large_radix; + // int max_para_ldst_num = (6400 + large_radix - 1) / large_radix; + int max_para_ldst_num = (half_butterfly_num < small_factors[3]) + ? half_butterfly_num + : small_factors[3]; // int para_ldst_num; // TODO(zrg): save nram space. @@ -1017,19 +539,7 @@ __mlu_func__ void computeLargeButterflyOtherstages( // FFT_CPX_T
*nram_out = &nram_in[large_radix]; // FFT_CPX_T
*nram_buf = &nram_in[large_radix * 2]; int nram_buf_offset = 0; - DT *nram_in_r = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * max_para_ldst_num; - - DT *nram_in_i = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * max_para_ldst_num; - - DT *nram_out_r = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * max_para_ldst_num; - - DT *nram_out_i = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * max_para_ldst_num; - // parallel load/store space FFT_CPX_T
nram_para_load_in_ping = { (DT *)nram_buf + nram_buf_offset, (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; @@ -1040,12 +550,7 @@ __mlu_func__ void computeLargeButterflyOtherstages( (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex - FFT_CPX_T
nram_para_load_tw_ping = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; - nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex - - FFT_CPX_T
nram_para_load_tw_pong = { + FFT_CPX_T
nram_para_load_tw = { (DT *)nram_buf + nram_buf_offset, (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex @@ -1060,34 +565,11 @@ __mlu_func__ void computeLargeButterflyOtherstages( (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex - // overlap nram_in - FFT_CPX_T
nram_transpose_temp; - // temp out-space before transpose - // if last-stage: - // compute_id 0 r - // compute_id 0 i - // compute_id 1 r - // compute_id 1 i - // else: - // compute_id 0 r - // compute_id 1 i - // compute_id 0 r - // compute_id 1 i - nram_transpose_temp = { - (DT *)nram_in_r, - (DT *)nram_in_r + large_radix * ((int)last_stage) + - large_radix * (1 - (int)last_stage) * max_para_ldst_num}; // nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex DT *_nram_tw = (DT *)nram_buf + nram_buf_offset; nram_buf_offset += large_radix * 2; // complex - // transpose space: [radix, 2 * parrallel] -> [parrallel * 2, radix] - // FFT_CPX_T
nram_transpose_load = { - // (DT *)nram_buf + nram_buf_offset, - // (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; - // nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex - int ld_dft_radix = -1; const int max_radix = 64; DT *nram_dftmtx = (DT *)nram_buf + nram_buf_offset; @@ -1103,348 +585,151 @@ __mlu_func__ void computeLargeButterflyOtherstages( nram_buf_offset += large_radix * max_para_ldst_num * 4; // complex - // size: (large_radix - 1) * max_para_ldst_num - // DT *scratch_tw_r = &CPX_MUL_II[large_radix * max_para_ldst_num]; - // DT *scratch_tw_i = &scratch_tw_r[(large_radix - 1) * max_para_ldst_num]; - - int Fin_stride = 0, Fout_stride = 0; - int sec_count; - int repeat_num = - (large_butterfly_num + max_para_ldst_num - 1) / max_para_ldst_num; - for (sec_count = 0; sec_count < large_section_num; ++sec_count) { - for (int repeat_id = 0; repeat_id < repeat_num + 2; ++repeat_id) { - // small_twiddles = _small_twiddles; - - // pipeline: load-stage - - if (repeat_id < repeat_num) { - // MLULOG("pipeline: load-stage.\n"); - int i = max_para_ldst_num * repeat_id; - FFT_CPX_T
nram_para_load_in = (repeat_id % 2 == 0) - ? nram_para_load_in_ping - : nram_para_load_in_pong; - - FFT_CPX_T
nram_para_load_tw = (repeat_id % 2 == 0) - ? nram_para_load_tw_ping - : nram_para_load_tw_pong; - - int para_load_num = (max_para_ldst_num > (large_butterfly_num - i)) - ? (large_butterfly_num - i) - : max_para_ldst_num; - if (para_load_num != 1) { - __memcpy_async(nram_para_load_in.r, input + Fin_stride + i, - sizeof(DT) * para_load_num, GDRAM2NRAM, - sizeof(DT) * para_load_num, - large_in_stride * sizeof(DT), large_radix - 1); - __memcpy_async(nram_para_load_in.i, input + nfft + Fin_stride + i, - sizeof(DT) * para_load_num, GDRAM2NRAM, - sizeof(DT) * para_load_num, - large_in_stride * sizeof(DT), large_radix - 1); - __memcpy_async(nram_para_load_tw.r, cur_large_twiddles + i, - sizeof(DT) * para_load_num, SRAM2NRAM, - sizeof(DT) * para_load_num, - large_out_stride * sizeof(DT), large_radix - 2); - __memcpy_async( - nram_para_load_tw.i, - cur_large_twiddles + large_butterfly_num * (large_radix - 1) + i, - sizeof(DT) * para_load_num, SRAM2NRAM, sizeof(DT) * para_load_num, - large_out_stride * sizeof(DT), large_radix - 2); - } - } + // overlap nram_in + FFT_CPX_T
nram_transpose_temp; + // temp out-space before transpose + // if last-stage: + // compute_id 0 r + // compute_id 0 i + // compute_id 1 r + // compute_id 1 i + // else: + // compute_id 0 r + // compute_id 1 i + // compute_id 0 r + // compute_id 1 i + nram_transpose_temp = { + (DT *)nram_buf + nram_buf_offset, + (DT *)nram_buf + nram_buf_offset + large_radix * ((int)last_stage) + + large_radix * (1 - (int)last_stage) * max_para_ldst_num}; - // pipeline: store-stage - if (repeat_id >= 2) { - // MLULOG("pipeline: store-stage.\n"); - int i = max_para_ldst_num * (repeat_id - 2); + // size: (large_radix - 1) * max_para_ldst_num + // DT *scratch_tw_r = &CPX_MUL_II[large_radix * max_para_ldst_num]; + // DT *scratch_tw_i = &scratch_tw_r[(large_radix - 1) * max_para_ldst_num]; - int para_store_num = (max_para_ldst_num > (large_butterfly_num - i)) - ? (large_butterfly_num - i) - : max_para_ldst_num; + // int Fin_stride = 0, Fout_stride = 0; + int sec_count; + // int repeat_num = - FFT_CPX_T
nram_para_store = - (repeat_id % 2 == 0) ? nram_para_store_ping : nram_para_store_pong; + int repeat_num = (t_end - t_start); + const int odist = last_stage ? nfft : nfft << 1; + const int idist = (nfft / 2 + 1) << 1; + input += t_start * idist; + output += t_start * odist; - if (last_stage) { - // __memcpy_async( - // output + (Fout_stride + i * large_radix) * 2, - // nram_para_store.r, - // sizeof(DT) * 2 * para_store_num * large_radix, NRAM2GDRAM); - - __memcpy_async(output + (Fout_stride + i) * 2, nram_para_store.r, - sizeof(DT) * 2 * para_store_num, NRAM2GDRAM, - large_out_stride * 2 * sizeof(DT), - sizeof(DT) * 2 * para_store_num, large_radix - 1); - } else { - // // real - // __memcpy_async(output + Fout_stride + i * large_radix, - // nram_para_store.r, - // para_store_num * large_radix * sizeof(DT), - // NRAM2GDRAM); - // // imag - // __memcpy_async(output + Fout_stride + i * large_radix + nfft, - // nram_para_store.i, - // para_store_num * large_radix * sizeof(DT), - // NRAM2GDRAM); - // real - __memcpy_async(output + Fout_stride + i, nram_para_store.r, - para_store_num * sizeof(DT), NRAM2GDRAM, - large_out_stride * sizeof(DT), - sizeof(DT) * para_store_num, large_radix - 1); - // imag - __memcpy_async(output + Fout_stride + i + nfft, nram_para_store.i, - para_store_num * sizeof(DT), NRAM2GDRAM, - large_out_stride * sizeof(DT), - sizeof(DT) * para_store_num, large_radix - 1); - } - } - // __sync(); - // pipeline: compute-stage + for (int butterfly_id = 0; butterfly_id < half_butterfly_num; + butterfly_id += max_para_ldst_num) { + for (sec_count = 0; sec_count < large_section_num; ++sec_count) { + DT *output_batch = output; + DT *input_batch = input; + int para_num = (max_para_ldst_num > (half_butterfly_num - butterfly_id)) + ? (half_butterfly_num - butterfly_id) + : max_para_ldst_num; - if (repeat_id >= 1 && repeat_id < repeat_num + 1) { - // MLULOG("pipeline: compute-stage.\n"); - int i = max_para_ldst_num * (repeat_id - 1); - - FFT_CPX_T
nram_para_load_in = (repeat_id % 2 != 0) - ? nram_para_load_in_ping - : nram_para_load_in_pong; - - FFT_CPX_T
nram_para_load_tw = (repeat_id % 2 != 0) - ? nram_para_load_tw_ping - : nram_para_load_tw_pong; - - FFT_CPX_T
nram_para_store = - (repeat_id % 2 != 0) ? nram_para_store_ping : nram_para_store_pong; - - int para_ldst_num = (max_para_ldst_num > (large_butterfly_num - i)) - ? (large_butterfly_num - i) - : max_para_ldst_num; - - // __bang_transpose(nram_transpose_load, nram_para_load, large_radix, - // 2 * para_ldst_num); - - // rotation-large - __bang_mul(CPX_MUL_RR, nram_para_load_in.r + para_ldst_num, - nram_para_load_tw.r, para_ldst_num * (large_radix - 1)); - __bang_mul(CPX_MUL_II, nram_para_load_in.i + para_ldst_num, - nram_para_load_tw.i, para_ldst_num * (large_radix - 1)); - __bang_mul(CPX_MUL_RI, nram_para_load_in.r + para_ldst_num, - nram_para_load_tw.i, para_ldst_num * (large_radix - 1)); - __bang_mul(CPX_MUL_IR, nram_para_load_in.i + para_ldst_num, - nram_para_load_tw.r, para_ldst_num * (large_radix - 1)); - - __bang_sub(nram_para_load_in.r + para_ldst_num, CPX_MUL_RR, CPX_MUL_II, - para_ldst_num * (large_radix - 1)); - __bang_add(nram_para_load_in.i + para_ldst_num, CPX_MUL_RI, CPX_MUL_IR, - para_ldst_num * (large_radix - 1)); - - // __bang_transpose(nram_transpose_load.r, nram_para_load_in.r, - // large_radix, para_ldst_num); - // __bang_transpose(nram_transpose_load.i, nram_para_load_in.i, - // large_radix, para_ldst_num); - - // for (int compute_id = 0; compute_id < para_ldst_num; compute_id++) { - for (int compute_id = 0; compute_id < para_ldst_num; - compute_id += para_ldst_num) { - // load real & imag + for (int repeat_id = 0; repeat_id < repeat_num + 2; + ++repeat_id, input_batch += idist, output_batch += odist) { + // small_twiddles = _small_twiddles; - radix = small_factors[4]; - small_section_num = small_factors[5]; - small_in_stride = small_factors[7]; - small_stage_count = _small_stage_count; + // pipeline: load-stage - // __memcpy(nram_in_r, - // nram_transpose_load + compute_id * large_radix * 2, - // large_radix * sizeof(DT) * 2, NRAM2NRAM); + if (repeat_id < repeat_num) { + if (1) { + __memcpy_async( + nram_para_load_in_ping.r, + input_batch + sec_count * half_butterfly_num + butterfly_id, + sizeof(DT) * para_num, GDRAM2NRAM, sizeof(DT) * para_num, + large_in_stride * sizeof(DT), upper_radix - 1); + __memcpy_async(nram_para_load_in_ping.i, + input_batch + nfft + sec_count * half_butterfly_num + + butterfly_id, + sizeof(DT) * para_num, GDRAM2NRAM, + sizeof(DT) * para_num, large_in_stride * sizeof(DT), + upper_radix - 1); - // first stage - // if(0) - if (ld_dft_radix != radix) { - ld_dft_radix = radix; - for (int entry = 0;; entry++) { - if (dft_table[entry].radix == ld_dft_radix) { - align_K = K_num * ((radix + K_num - 1) / K_num); - __memcpy(nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2], - sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM); - break; - } + __memcpy_async( + nram_para_load_in_ping.r + upper_radix * para_num, + input_batch + sec_count * half_butterfly_num + butterfly_id + + (half_butterfly_num - butterfly_id - para_num + 1), + sizeof(DT) * para_num, GDRAM2NRAM, sizeof(DT) * para_num, + large_in_stride * sizeof(DT), lower_radix - 1); + __memcpy_async( + nram_para_load_in_ping.i + upper_radix * para_num, + input_batch + nfft + sec_count * half_butterfly_num + + butterfly_id + + (half_butterfly_num - butterfly_id - para_num + 1), + sizeof(DT) * para_num, GDRAM2NRAM, sizeof(DT) * para_num, + large_in_stride * sizeof(DT), lower_radix - 1); - if (dft_table[entry].radix == -1) { - break; - } + if (repeat_id == 0 && sec_count == 0) { + __memcpy_async( + nram_para_load_tw.r, cur_large_twiddles + butterfly_id, + sizeof(DT) * para_num, SRAM2NRAM, sizeof(DT) * para_num, + large_out_stride * sizeof(DT), large_radix - 2); + __memcpy_async( + nram_para_load_tw.i, + cur_large_twiddles + half_butterfly_num * (large_radix - 1) + + butterfly_id, + sizeof(DT) * para_num, SRAM2NRAM, sizeof(DT) * para_num, + large_out_stride * sizeof(DT), large_radix - 2); } } + } - switch (radix) { - default: - // computeGenericButterflyFirststage(Fout, buffer, twiddles, - // radix, section_num, butterfly_num, in_stride, 0, dir); - MLULOG("computeGenericButterflyFirststageMat: %d.\n", radix); - - // para_ldst_num = 1 - // in: [radix, butterfly_num] - // butterfly: [radix, radix] * [radix, butterfly_num] - // out_butterfly: [radix, butterfly_num] - // out: [butterfly_num, radix] - - // para_ldst_num != 1 - // in: [radix, butterfly_num, para_ldst_num] == [large_radix, - // para_ldst_num] butterfly: [radix, radix] * [radix, - // butterfly_num, para_ldst_num] out_butterfly: [radix, - // butterfly_num, para_ldst_num] == [radix, butterfly_num * - // para_ldst_num] out: [butterfly_num, para_ldst_num, radix] - - computeGenericButterflyFirststageMat( - nram_out_r, nram_out_i, nram_para_load_in.r, - nram_para_load_in.i, nram_scratch, nram_dftmtx, - small_section_num * para_ldst_num, - small_section_num * para_ldst_num, 1, dir, radix); - break; + // pipeline: store-stage + if (repeat_id >= 2) { + if (last_stage) { + __memcpy_async(output_batch - odist * 2 + + (sec_count * large_radix * half_butterfly_num + + butterfly_id) * + 2, + nram_para_store_ping.r, sizeof(DT) * 2 * para_num, + NRAM2GDRAM, large_out_stride * 2 * sizeof(DT), + sizeof(DT) * 2 * para_num, large_radix - 1); + } else { + // real + __memcpy_async(output_batch - odist * 2 + + sec_count * large_radix * half_butterfly_num + + butterfly_id, + nram_para_store_ping.r, para_num * sizeof(DT), + NRAM2GDRAM, large_out_stride * sizeof(DT), + sizeof(DT) * para_num, large_radix - 1); + // imag + __memcpy_async(output_batch - odist * 2 + + sec_count * large_radix * half_butterfly_num + + butterfly_id + nfft, + nram_para_store_ping.i, para_num * sizeof(DT), + NRAM2GDRAM, large_out_stride * sizeof(DT), + sizeof(DT) * para_num, large_radix - 1); } + } + // __sync(); + // pipeline: compute-stage - // for (int j = 0; j < large_radix; j++) { - // MLULOG("output i: (%f, %f).\n", nram_out_r[j], nram_out_i[j]); - // } - - // [radix, small_section_num, para_ldst_num] -> - // [small_section_num, para_ldst_num, radix] -> - // [para_ldst_num, small_section_num, radix] -> - // [small_section_num, radix, para_ldst_num] == - // [large_radix, para_ldst_num] - - small_stage_count--; - if (small_stage_count == 0) { - // nram to gdram - - // if (last_stage) { - // // [2, para_ldst_num, large_radix] -> [para_ldst_num, - // large_radix, - // // 2] - // // DT* nram_transpose_store = nram_in_r; - - // __bang_transpose(nram_in_r, nram_out_r, 2, - // max_para_ldst_num * large_radix); - - // } else { - // // [2, para_ldst_num, large_radix] -> [2, para_ldst_num, - // // large_radix] - // // TODO(zrg): redundant move - // __memcpy(nram_in_r, nram_out_r, - // para_ldst_num * large_radix * sizeof(DT), NRAM2NRAM); - // __memcpy(nram_in_i, - // nram_out_i, para_ldst_num * large_radix * sizeof(DT), - // NRAM2NRAM); - // } - - // [nfft, 2] -> [2, nfft] -> [2, nfft] -> [nfft, 2] - if (last_stage) { - __memcpy(nram_transpose_temp.r + - compute_id * large_radix * (1 + (int)last_stage), - nram_out_r, sizeof(DT) * large_radix, NRAM2NRAM, - sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix, - para_ldst_num - 1); - - __memcpy(nram_transpose_temp.i + - compute_id * large_radix * (1 + (int)last_stage), - nram_out_i, sizeof(DT) * large_radix, NRAM2NRAM, - sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix, - para_ldst_num - 1); - - __bang_transpose(nram_para_store.r, nram_transpose_temp.r, - para_ldst_num * 2, large_radix); - } else { - __bang_transpose(nram_para_store.r, nram_out_r, para_ldst_num, - large_radix); - __bang_transpose(nram_para_store.i, nram_out_i, para_ldst_num, - large_radix); - } - - continue; - } + if (repeat_id >= 1 && repeat_id < repeat_num + 1) { + // MLULOG("pipeline: compute-stage.\n"); + // int i = max_para_ldst_num * (repeat_id - 1); - FFT_SWAP_PTR(nram_out_r, nram_in_r); - FFT_SWAP_PTR(nram_out_i, nram_in_i); - // DT* nram_transpose_store = nram_in_r; - - // for (int para_ldst_id = 0; para_ldst_id < para_ldst_num; - // para_ldst_id++) { - // __memcpy(nram_out_r + para_ldst_id * small_section_num * radix, - // nram_in_r + para_ldst_id * radix, sizeof(DT) * radix, - // NRAM2NRAM, sizeof(DT) * radix, - // para_ldst_num * radix * sizeof(DT), small_section_num - - // 1); - - // __memcpy(nram_out_i + para_ldst_id * small_section_num * radix, - // nram_in_i + para_ldst_id * radix, sizeof(DT) * radix, - // NRAM2NRAM, sizeof(DT) * radix, - // para_ldst_num * radix * sizeof(DT), small_section_num - - // 1); - // } + DT *nram_in_r = nram_para_load_in_pong.r; + DT *nram_in_i = nram_para_load_in_pong.i; - // after first stage: [butterfly_num, para_ldst_num, radix] - // other in: [para_ldst_num, butterfly_num, radix] == [para_ldst_num, - // large_radix] - TRANSPOSE_XYZ2YXZ_PAIR(nram_out_r, nram_out_i, nram_in_r, nram_in_i, - small_section_num, para_ldst_num, radix, DT) - - // TODO(zrg) : add not last-stage - // if (small_stage_count == 0) { - // // if last-stage: stride = large_radix * 2 - // // compute_id 0 r - // // compute_id 0 i - // // compute_id 1 r - // // compute_id 1 i - // // else: stride = large_radix - // // compute_id 0 r - // // compute_id 1 i - // // compute_id 0 r - // // compute_id 1 i - - // // [radix, small_section_num, para_ldst_num] -> - // // [small_section_num, para_ldst_num, radix] -> - // // [para_ldst_num, small_section_num, radix] -> - // // [small_section_num, radix, para_ldst_num] == - // // [large_radix, para_ldst_num] - - // __memcpy(nram_transpose_temp.r + - // compute_id * large_radix * (1 + (int)last_stage), - // nram_out_r, sizeof(DT) * large_radix, NRAM2NRAM, - // sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix, - // para_ldst_num - 1); - - // __memcpy(nram_transpose_temp.i + - // compute_id * large_radix * (1 + (int)last_stage), - // nram_out_i, sizeof(DT) * large_radix, NRAM2NRAM, - // sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix, - // para_ldst_num - 1); - - // // __memcpy(nram_transpose_temp.r + - // // compute_id * large_radix * (1 + (int)last_stage), - // // nram_out_r, large_radix * sizeof(DT), NRAM2NRAM); - // // __memcpy(nram_transpose_temp.i + - // // compute_id * large_radix * (1 + (int)last_stage), - // // nram_out_i, large_radix * sizeof(DT), NRAM2NRAM); - - // // __bang_transpose(nram_transpose_temp.r, nram_transpose_temp.r, - // // max_para_ldst_num * 2, large_radix); - // continue; - // } + DT *nram_out_r = nram_para_store_pong.r; + DT *nram_out_i = nram_para_store_pong.i; - // DT *sram_tw = (DT *)sram_buffer; - DT *nram_tw = _nram_tw; - value_mul = 8; + for (int compute_id = 0; compute_id < para_num; + compute_id += para_num) { + // load real & imag - for (; small_stage_count > 1; small_stage_count--) { - FFT_SWAP_PTR(nram_out_r, nram_in_r); - FFT_SWAP_PTR(nram_out_i, nram_in_i); + radix = small_factors[4]; + small_section_num = small_factors[5]; + small_in_stride = small_factors[7]; + small_stage_count = _small_stage_count; - // value_mul = (_small_stage_count - small_stage_count + 1) * 4; - // // update parameter - radix = small_factors[value_mul++]; - small_section_num = small_factors[value_mul++]; - small_butterfly_num = small_factors[value_mul++]; - small_in_stride = small_factors[value_mul++]; - // copy GDRAM2SRAM + // __memcpy(nram_in_r, + // nram_transpose_load + compute_id * large_radix * 2, + // large_radix * sizeof(DT) * 2, NRAM2NRAM); + // first stage + // if(0) if (ld_dft_radix != radix) { ld_dft_radix = radix; for (int entry = 0;; entry++) { @@ -1462,2040 +747,329 @@ __mlu_func__ void computeLargeButterflyOtherstages( } } - if (sec_count == 0 && compute_id == 0 && repeat_id == 1) { - __memcpy(nram_tw, small_twiddles, - small_butterfly_num * (radix - 1) * sizeof(DT) * 2, - SRAM2NRAM); - small_twiddles += small_butterfly_num * (radix - 1) * 2; - } - switch (radix) { - // case 2: - // // computeRadix2ButterflyOtherstages(Fout, Fin, section_num, - // // section_num, 1, dir); - // break; - // case 3: - // computeRadix3ButterflyOtherstages( - // nram_out_r, nram_out_i, nram_in_r, nram_in_i, - // nram_scratch, nram_tw, small_section_num, - // small_butterfly_num, small_in_stride, dir); - // break; - // case 9: - // computeRadix9ButterflyOtherstages( - // nram_out_r, nram_out_i, nram_in_r, nram_in_i, - // nram_scratch, nram_tw, small_section_num, - // small_butterfly_num, small_in_stride, dir); - // break; default: - // computeGenericButterflyOtherstages(Fout, buffer, twiddles, + // computeGenericButterflyFirststage(Fout, buffer, twiddles, // radix, section_num, butterfly_num, in_stride, 0, dir); - computeGenericButterflyOtherstagesMat( - nram_out_r, nram_out_i, nram_in_r, nram_in_i, nram_scratch, - nram_dftmtx, nram_tw, small_section_num, - small_butterfly_num, para_ldst_num, small_in_stride, dir, - radix); - break; - } - - nram_tw += small_butterfly_num * (radix - 1) * 2; - } // for (stage_count) - - // for (int j = 0; j < large_radix; j++) { - // MLULOG("output i: (%f, %f).\n", nram_out_r[j], nram_out_i[j]); - // } - - // MLULOG("butterfly id: %d\n", i); - // last stage - { - FFT_SWAP_PTR(nram_out_r, nram_in_r); - FFT_SWAP_PTR(nram_out_i, nram_in_i); - // copy GDRAM2SRAM - - // update parameter - radix = small_factors[value_mul++]; - small_section_num = small_factors[value_mul++]; - small_butterfly_num = small_factors[value_mul++]; - small_in_stride = small_factors[value_mul]; + MLULOG("computeGenericButterflyFirststageMat: %d.\n", radix); - if (sec_count == 0 && compute_id == 0 && repeat_id == 1) { - __memcpy(nram_tw, small_twiddles, - small_butterfly_num * (radix - 1) * sizeof(DT) * 2, - SRAM2NRAM); - } + // para_num = 1 + // in: [radix, butterfly_num] + // butterfly: [radix, radix] * [radix, butterfly_num] + // out_butterfly: [radix, butterfly_num] + // out: [butterfly_num, radix] - if (ld_dft_radix != radix) { - ld_dft_radix = radix; - for (int entry = 0;; entry++) { - if (dft_table[entry].radix == ld_dft_radix) { - align_K = K_num * ((radix + K_num - 1) / K_num); - __memcpy(nram_dftmtx, - &dft_matrix[dft_table[entry].offset * 2], - sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM); - break; - } + // para_num != 1 + // in: [radix, butterfly_num, para_num] == [large_radix, + // para_num] butterfly: [radix, radix] * [radix, + // butterfly_num, para_num] out_butterfly: [radix, + // butterfly_num, para_num] == [radix, butterfly_num * + // para_num] out: [butterfly_num, para_num, radix] - if (dft_table[entry].radix == -1) { - break; - } - } - } - switch (radix) { - // case 2: - // break; - // case 3: - // computeRadix3ButterflyLaststage( - // nram_out_r, nram_out_i, nram_in_r, nram_in_i, - // nram_scratch, nram_tw, small_section_num, - // small_butterfly_num, small_in_stride, dir); - // break; - // case 9: - // computeRadix9ButterflyLaststage( - // nram_out_r, nram_out_i, nram_in_r, nram_in_i, - // nram_scratch, nram_tw, small_section_num, - // small_butterfly_num, small_in_stride, dir); - // break; - default: - // computeGenericButterflyLaststage(Fout, buffer, twiddles, - // radix, section_num, butterfly_num, in_stride, 0, dir); - computeGenericButterflyLaststageMat( - nram_out_r, nram_out_i, nram_in_r, nram_in_i, nram_scratch, - nram_dftmtx, nram_tw, small_section_num, - small_butterfly_num, para_ldst_num, small_in_stride, dir, - radix); + computeGenericButterflyFirststageMat( + nram_out_r, nram_out_i, nram_para_load_in_pong.r, + nram_para_load_in_pong.i, nram_scratch, nram_dftmtx, + small_section_num * para_num, small_section_num * para_num, + 1, dir, radix); break; } - // if last-stage: stride = large_radix * 2 - // compute_id 0 r - // compute_id 0 i - // compute_id 1 r - // compute_id 1 i - // else: stride = large_radix - // compute_id 0 r - // compute_id 1 i - // compute_id 0 r - // compute_id 1 i - // __memcpy(nram_transpose_temp.r + - // compute_id * large_radix * (1 + (int)last_stage), - // nram_out_r, large_radix * sizeof(DT), NRAM2NRAM); - // __memcpy(nram_transpose_temp.i + - // compute_id * large_radix * (1 + (int)last_stage), - // nram_out_i, large_radix * sizeof(DT), NRAM2NRAM); + // for (int j = 0; j < large_radix; j++) { + // MLULOG("output i: (%f, %f).\n", nram_out_r[j], nram_out_i[j]); + // } - if (last_stage) { - __memcpy(nram_transpose_temp.r + - compute_id * large_radix * (1 + (int)last_stage), - nram_out_r, sizeof(DT) * large_radix, NRAM2NRAM, - sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix, - para_ldst_num - 1); - - __memcpy(nram_transpose_temp.i + - compute_id * large_radix * (1 + (int)last_stage), - nram_out_i, sizeof(DT) * large_radix, NRAM2NRAM, - sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix, - para_ldst_num - 1); - - __bang_transpose(nram_para_store.r, nram_transpose_temp.r, - para_ldst_num * 2, large_radix); - } else { - __bang_transpose(nram_para_store.r, nram_out_r, para_ldst_num, - large_radix); - __bang_transpose(nram_para_store.i, nram_out_i, para_ldst_num, - large_radix); - } + // [radix, small_section_num, para_num] -> + // [small_section_num, para_num, radix] -> + // [para_num, small_section_num, radix] -> + // [small_section_num, radix, para_num] == + // [large_radix, para_num] - // __bang_transpose(nram_para_store, nram_transpose_temp.r, - // max_para_ldst_num * 2, large_radix); - } - } - } - // __sync(); + small_stage_count--; + if (small_stage_count == 0) { + // nram to gdram - __sync(); - } - Fin_stride += large_butterfly_num; - Fout_stride += large_radix * large_butterfly_num; - } -} - -template -__mlu_func__ void computeLargeButterflyLaststage( - DT *output, DT *input, const DT *cur_large_twiddles, const DT *_twiddles, - const DT *dft_matrix, int large_section_num, int large_butterfly_num, - int large_in_stride, void *nram_buf, const int *small_factors, int nfft, - int dir) { - computeLargeButterflyOtherstages(output, input, cur_large_twiddles, _twiddles, - dft_matrix, large_section_num, - large_butterfly_num, large_in_stride, - nram_buf, small_factors, nfft, dir, 1); -} - -template -__mlu_func__ void computeLargeButterflyOtherstagesBatchPingpong( - DT *output, DT *input, const DT *cur_large_twiddles, const DT *_twiddles, - const DT *dft_matrix, int large_section_num, int large_butterfly_num, - int large_in_stride, void *nram_buf, const int *small_factors, int nfft, - const int t_start, const int t_end, int dir, int last_stage) { - // return; - const dft_table_entry *dft_table = (const dft_table_entry *)dft_matrix; - - int radix, small_in_stride, small_stage_count, large_radix, - _small_stage_count; - int small_section_num, small_butterfly_num, value_mul; - - const int large_out_stride = large_butterfly_num; - int tw_offset; - const int K_num = 64 / sizeof(DT); - int align_K = 0; - _small_stage_count = small_factors[0]; - large_radix = small_factors[1]; - tw_offset = small_factors[2]; - - const DT *small_twiddles = _twiddles + tw_offset * 2; // complex - - // const int max_para_ldst_num = (6144 + large_radix - 1) / large_radix; - // int max_para_ldst_num = (6400 + large_radix - 1) / large_radix; - int max_para_ldst_num = (large_butterfly_num < small_factors[3]) - ? large_butterfly_num - : small_factors[3]; - - // int para_ldst_num; - // TODO(zrg): save nram space. - // __nram__ DT nram_space[MAX_BUTTERFLY_ON_CHIP * 2]; - // 0 1 2 3 4 5 - // 0 1 3 - // DT *nram_buf_end = (DT*)&((uint8_t*)nram_buf)[NRAM_BUFFER_SIZE]; - // FFT_CPX_T
*nram_in = (FFT_CPX_T
*)nram_buffer; - // FFT_CPX_T
*nram_out = &nram_in[large_radix]; - // FFT_CPX_T
*nram_buf = &nram_in[large_radix * 2]; - int nram_buf_offset = 0; - - FFT_CPX_T
nram_para_load_in_ping = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; - nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex - - FFT_CPX_T
nram_para_load_in_pong = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; - nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex - - FFT_CPX_T
nram_para_load_tw = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; - nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex - - // FFT_CPX_T
nram_para_load_tw_ping = { - // (DT *)nram_buf + nram_buf_offset, - // (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; - // nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex - - // FFT_CPX_T
nram_para_load_tw_pong = { - // (DT *)nram_buf + nram_buf_offset, - // (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; - // nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex - - FFT_CPX_T
nram_para_store_ping = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; - nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex - - FFT_CPX_T
nram_para_store_pong = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; - nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex - - // nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex - - DT *_nram_tw = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * 2; // complex - - // transpose space: [radix, 2 * parrallel] -> [parrallel * 2, radix] - // FFT_CPX_T
nram_transpose_load = { - // (DT *)nram_buf + nram_buf_offset, - // (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num}; - // nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex - - int ld_dft_radix = -1; - const int max_radix = 64; - DT *nram_dftmtx = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += max_radix * max_radix * 2; // complex - - DT *nram_scratch = (DT *)nram_buf + nram_buf_offset; - - // temp overlap with "nram_scratch" - DT *CPX_MUL_RR = nram_scratch; - DT *CPX_MUL_RI = &CPX_MUL_RR[large_radix * max_para_ldst_num]; - DT *CPX_MUL_IR = &CPX_MUL_RI[large_radix * max_para_ldst_num]; - DT *CPX_MUL_II = &CPX_MUL_IR[large_radix * max_para_ldst_num]; - - nram_buf_offset += large_radix * max_para_ldst_num * 4; // complex - - // overlap nram_in - FFT_CPX_T
nram_transpose_temp; - // temp out-space before transpose - // if last-stage: - // compute_id 0 r - // compute_id 0 i - // compute_id 1 r - // compute_id 1 i - // else: - // compute_id 0 r - // compute_id 1 i - // compute_id 0 r - // compute_id 1 i - nram_transpose_temp = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + large_radix * ((int)last_stage) + - large_radix * (1 - (int)last_stage) * max_para_ldst_num}; - - // size: (large_radix - 1) * max_para_ldst_num - // DT *scratch_tw_r = &CPX_MUL_II[large_radix * max_para_ldst_num]; - // DT *scratch_tw_i = &scratch_tw_r[(large_radix - 1) * max_para_ldst_num]; - - // int Fin_stride = 0, Fout_stride = 0; - int sec_count; - // int repeat_num = - // (large_butterfly_num + max_para_ldst_num - 1) / max_para_ldst_num; - int repeat_num = (t_end - t_start); - input += t_start * (nfft << 1); - output += t_start * (nfft << 1); - - for (int butterfly_id = 0; butterfly_id < large_butterfly_num; - butterfly_id += max_para_ldst_num) { - for (sec_count = 0; sec_count < large_section_num; ++sec_count) { - DT *output_batch = output; - DT *input_batch = input; - int para_num = (max_para_ldst_num > (large_butterfly_num - butterfly_id)) - ? (large_butterfly_num - butterfly_id) - : max_para_ldst_num; - - for (int repeat_id = 0; repeat_id < repeat_num + 2; ++repeat_id, - input_batch += (nfft << 1), output_batch += (nfft << 1)) { - // small_twiddles = _small_twiddles; - - // pipeline: load-stage - - if (repeat_id < repeat_num) { - if (para_num != 1) { - __memcpy_async( - nram_para_load_in_ping.r, - input_batch + sec_count * large_butterfly_num + butterfly_id, - sizeof(DT) * para_num, GDRAM2NRAM, sizeof(DT) * para_num, - large_in_stride * sizeof(DT), large_radix - 1); - __memcpy_async(nram_para_load_in_ping.i, - input_batch + nfft + - sec_count * large_butterfly_num + butterfly_id, - sizeof(DT) * para_num, GDRAM2NRAM, - sizeof(DT) * para_num, large_in_stride * sizeof(DT), - large_radix - 1); - if (repeat_id == 0 && sec_count == 0) { - __memcpy_async( - nram_para_load_tw.r, cur_large_twiddles + butterfly_id, - sizeof(DT) * para_num, SRAM2NRAM, sizeof(DT) * para_num, - large_out_stride * sizeof(DT), large_radix - 2); - __memcpy_async( - nram_para_load_tw.i, - cur_large_twiddles + large_butterfly_num * (large_radix - 1) + - butterfly_id, - sizeof(DT) * para_num, SRAM2NRAM, sizeof(DT) * para_num, - large_out_stride * sizeof(DT), large_radix - 2); - } - } - } - - // pipeline: store-stage - if (repeat_id >= 2) { - if (last_stage) { - __memcpy_async(output_batch - (nfft << 2) + - (sec_count * large_radix * large_butterfly_num + - butterfly_id) * - 2, - nram_para_store_ping.r, sizeof(DT) * 2 * para_num, - NRAM2GDRAM, large_out_stride * 2 * sizeof(DT), - sizeof(DT) * 2 * para_num, large_radix - 1); - } else { - // real - __memcpy_async(output_batch - (nfft << 2) + - sec_count * large_radix * large_butterfly_num + - butterfly_id, - nram_para_store_ping.r, para_num * sizeof(DT), - NRAM2GDRAM, large_out_stride * sizeof(DT), - sizeof(DT) * para_num, large_radix - 1); - // imag - __memcpy_async(output_batch - (nfft << 2) + - sec_count * large_radix * large_butterfly_num + - butterfly_id + nfft, - nram_para_store_ping.i, para_num * sizeof(DT), - NRAM2GDRAM, large_out_stride * sizeof(DT), - sizeof(DT) * para_num, large_radix - 1); - } - } - // __sync(); - // pipeline: compute-stage - - if (repeat_id >= 1 && repeat_id < repeat_num + 1) { - // MLULOG("pipeline: compute-stage.\n"); - // int i = max_para_ldst_num * (repeat_id - 1); - - DT *nram_in_r = nram_para_load_in_pong.r; - DT *nram_in_i = nram_para_load_in_pong.i; - - DT *nram_out_r = nram_para_store_pong.r; - DT *nram_out_i = nram_para_store_pong.i; - - // rotation-large - __bang_mul(CPX_MUL_RR, nram_para_load_in_pong.r + para_num, - nram_para_load_tw.r, para_num * (large_radix - 1)); - __bang_mul(CPX_MUL_II, nram_para_load_in_pong.i + para_num, - nram_para_load_tw.i, para_num * (large_radix - 1)); - __bang_mul(CPX_MUL_RI, nram_para_load_in_pong.r + para_num, - nram_para_load_tw.i, para_num * (large_radix - 1)); - __bang_mul(CPX_MUL_IR, nram_para_load_in_pong.i + para_num, - nram_para_load_tw.r, para_num * (large_radix - 1)); - - __bang_sub(nram_para_load_in_pong.r + para_num, CPX_MUL_RR, - CPX_MUL_II, para_num * (large_radix - 1)); - __bang_add(nram_para_load_in_pong.i + para_num, CPX_MUL_RI, - CPX_MUL_IR, para_num * (large_radix - 1)); - - // __bang_transpose(nram_transpose_load.r, nram_para_load_in.r, - // large_radix, para_num); - // __bang_transpose(nram_transpose_load.i, nram_para_load_in.i, - // large_radix, para_num); - - // for (int compute_id = 0; compute_id < para_num; compute_id++) - // { - for (int compute_id = 0; compute_id < para_num; - compute_id += para_num) { - // load real & imag - - radix = small_factors[4]; - small_section_num = small_factors[5]; - small_in_stride = small_factors[7]; - small_stage_count = _small_stage_count; - - // __memcpy(nram_in_r, - // nram_transpose_load + compute_id * large_radix * 2, - // large_radix * sizeof(DT) * 2, NRAM2NRAM); - - // first stage - // if(0) - if (ld_dft_radix != radix) { - ld_dft_radix = radix; - for (int entry = 0;; entry++) { - if (dft_table[entry].radix == ld_dft_radix) { - align_K = K_num * ((radix + K_num - 1) / K_num); - __memcpy(nram_dftmtx, - &dft_matrix[dft_table[entry].offset * 2], - sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM); - break; - } - - if (dft_table[entry].radix == -1) { - break; - } - } - } - - switch (radix) { - default: - // computeGenericButterflyFirststage(Fout, buffer, twiddles, - // radix, section_num, butterfly_num, in_stride, 0, dir); - MLULOG("computeGenericButterflyFirststageMat: %d.\n", radix); - - // para_num = 1 - // in: [radix, butterfly_num] - // butterfly: [radix, radix] * [radix, butterfly_num] - // out_butterfly: [radix, butterfly_num] - // out: [butterfly_num, radix] - - // para_num != 1 - // in: [radix, butterfly_num, para_num] == [large_radix, - // para_num] butterfly: [radix, radix] * [radix, - // butterfly_num, para_num] out_butterfly: [radix, - // butterfly_num, para_num] == [radix, butterfly_num * - // para_num] out: [butterfly_num, para_num, radix] - - computeGenericButterflyFirststageMat( - nram_out_r, nram_out_i, nram_para_load_in_pong.r, - nram_para_load_in_pong.i, nram_scratch, nram_dftmtx, - small_section_num * para_num, small_section_num * para_num, - 1, dir, radix); - break; - } - - // for (int j = 0; j < large_radix; j++) { - // MLULOG("output i: (%f, %f).\n", nram_out_r[j], nram_out_i[j]); - // } - - // [radix, small_section_num, para_num] -> - // [small_section_num, para_num, radix] -> - // [para_num, small_section_num, radix] -> - // [small_section_num, radix, para_num] == - // [large_radix, para_num] - - small_stage_count--; - if (small_stage_count == 0) { - // nram to gdram - - // if (last_stage) { - // // [2, para_num, large_radix] -> [para_num, - // large_radix, - // // 2] - // // DT* nram_transpose_store = nram_in_r; - - // __bang_transpose(nram_in_r, nram_out_r, 2, - // max_para_num * large_radix); - - // } else { - // // [2, para_num, large_radix] -> [2, para_num, - // // large_radix] - // // TODO(zrg): redundant move - // __memcpy(nram_in_r, nram_out_r, - // para_num * large_radix * sizeof(DT), - // NRAM2NRAM); - // __memcpy(nram_in_i, - // nram_out_i, para_num * large_radix * - // sizeof(DT), NRAM2NRAM); - // } - - // [nfft, 2] -> [2, nfft] -> [2, nfft] -> [nfft, 2] - if (last_stage) { - __memcpy(nram_transpose_temp.r + - compute_id * large_radix * (1 + (int)last_stage), - nram_out_r, sizeof(DT) * large_radix, NRAM2NRAM, - sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix, - para_num - 1); - - __memcpy(nram_transpose_temp.i + - compute_id * large_radix * (1 + (int)last_stage), - nram_out_i, sizeof(DT) * large_radix, NRAM2NRAM, - sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix, - para_num - 1); - - __bang_transpose(nram_para_store_pong.r, nram_transpose_temp.r, - para_num * 2, large_radix); - } else { - if (nram_out_r == nram_para_store_pong.r) { - FFT_SWAP_PTR(nram_para_load_in_pong.r, nram_para_store_pong.r) - FFT_SWAP_PTR(nram_para_load_in_pong.i, nram_para_store_pong.i) - } - __bang_transpose(nram_para_store_pong.r, nram_out_r, para_num, - large_radix); - __bang_transpose(nram_para_store_pong.i, nram_out_i, para_num, - large_radix); - } - - continue; - } - - FFT_SWAP_PTR(nram_out_r, nram_in_r); - FFT_SWAP_PTR(nram_out_i, nram_in_i); - // DT* nram_transpose_store = nram_in_r; - - // for (int para_ldst_id = 0; para_ldst_id < para_ldst_num; - // para_ldst_id++) { - // __memcpy(nram_out_r + para_ldst_id * small_section_num * radix, - // nram_in_r + para_ldst_id * radix, sizeof(DT) * radix, - // NRAM2NRAM, sizeof(DT) * radix, - // para_ldst_num * radix * sizeof(DT), small_section_num - // - 1); - - // __memcpy(nram_out_i + para_ldst_id * small_section_num * radix, - // nram_in_i + para_ldst_id * radix, sizeof(DT) * radix, - // NRAM2NRAM, sizeof(DT) * radix, - // para_ldst_num * radix * sizeof(DT), small_section_num - // - 1); - // } - - // after first stage: [butterfly_num, para_ldst_num, radix] - // other in: [para_ldst_num, butterfly_num, radix] == - // [para_ldst_num, large_radix] - TRANSPOSE_XYZ2YXZ_PAIR(nram_out_r, nram_out_i, nram_in_r, nram_in_i, - small_section_num, para_num, radix, DT) - - // TODO(zrg) : add not last-stage - // if (small_stage_count == 0) { - // // if last-stage: stride = large_radix * 2 - // // compute_id 0 r - // // compute_id 0 i - // // compute_id 1 r - // // compute_id 1 i - // // else: stride = large_radix - // // compute_id 0 r - // // compute_id 1 i - // // compute_id 0 r - // // compute_id 1 i - - // // [radix, small_section_num, para_ldst_num] -> - // // [small_section_num, para_ldst_num, radix] -> - // // [para_ldst_num, small_section_num, radix] -> - // // [small_section_num, radix, para_ldst_num] == - // // [large_radix, para_ldst_num] - - // __memcpy(nram_transpose_temp.r + - // compute_id * large_radix * (1 + (int)last_stage), - // nram_out_r, sizeof(DT) * large_radix, NRAM2NRAM, - // sizeof(DT) * large_radix * 2, sizeof(DT) * - // large_radix, para_ldst_num - 1); - - // __memcpy(nram_transpose_temp.i + - // compute_id * large_radix * (1 + (int)last_stage), - // nram_out_i, sizeof(DT) * large_radix, NRAM2NRAM, - // sizeof(DT) * large_radix * 2, sizeof(DT) * - // large_radix, para_ldst_num - 1); - - // // __memcpy(nram_transpose_temp.r + - // // compute_id * large_radix * (1 + - // (int)last_stage), - // // nram_out_r, large_radix * sizeof(DT), NRAM2NRAM); - // // __memcpy(nram_transpose_temp.i + - // // compute_id * large_radix * (1 + - // (int)last_stage), - // // nram_out_i, large_radix * sizeof(DT), NRAM2NRAM); - - // // __bang_transpose(nram_transpose_temp.r, - // nram_transpose_temp.r, - // // max_para_ldst_num * 2, large_radix); - // continue; - // } - - // DT *sram_tw = (DT *)sram_buffer; - DT *nram_tw = _nram_tw; - value_mul = 8; - - for (; small_stage_count > 1; small_stage_count--) { - FFT_SWAP_PTR(nram_out_r, nram_in_r); - FFT_SWAP_PTR(nram_out_i, nram_in_i); - - // value_mul = (_small_stage_count - small_stage_count + 1) * 4; - // // update parameter - radix = small_factors[value_mul++]; - small_section_num = small_factors[value_mul++]; - small_butterfly_num = small_factors[value_mul++]; - small_in_stride = small_factors[value_mul++]; - // copy GDRAM2SRAM - - if (ld_dft_radix != radix) { - ld_dft_radix = radix; - for (int entry = 0;; entry++) { - if (dft_table[entry].radix == ld_dft_radix) { - align_K = K_num * ((radix + K_num - 1) / K_num); - __memcpy( - nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2], - sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM); - break; - } - - if (dft_table[entry].radix == -1) { - break; - } - } - } - - if (sec_count == 0 && compute_id == 0 && repeat_id == 1) { - __memcpy(nram_tw, small_twiddles, - small_butterfly_num * (radix - 1) * sizeof(DT) * 2, - SRAM2NRAM); - small_twiddles += small_butterfly_num * (radix - 1) * 2; - } - - switch (radix) { - // case 2: - // // computeRadix2ButterflyOtherstages(Fout, Fin, - // section_num, - // // section_num, 1, dir); - // break; - // case 3: - // computeRadix3ButterflyOtherstages( - // nram_out_r, nram_out_i, nram_in_r, nram_in_i, - // nram_scratch, nram_tw, small_section_num, - // small_butterfly_num, small_in_stride, dir); - // break; - // case 9: - // computeRadix9ButterflyOtherstages( - // nram_out_r, nram_out_i, nram_in_r, nram_in_i, - // nram_scratch, nram_tw, small_section_num, - // small_butterfly_num, small_in_stride, dir); - // break; - default: - // computeGenericButterflyOtherstages(Fout, buffer, twiddles, - // radix, section_num, butterfly_num, in_stride, 0, dir); - computeGenericButterflyOtherstagesMat( - nram_out_r, nram_out_i, nram_in_r, nram_in_i, - nram_scratch, nram_dftmtx, nram_tw, small_section_num, - small_butterfly_num, para_num, small_in_stride, dir, - radix); - break; - } - - nram_tw += small_butterfly_num * (radix - 1) * 2; - } // for (stage_count) - - // for (int j = 0; j < large_radix; j++) { - // MLULOG("output i: (%f, %f).\n", nram_out_r[j], nram_out_i[j]); - // } - - // MLULOG("butterfly id: %d\n", i); - // last stage - { - FFT_SWAP_PTR(nram_out_r, nram_in_r); - FFT_SWAP_PTR(nram_out_i, nram_in_i); - // copy GDRAM2SRAM - - // update parameter - radix = small_factors[value_mul++]; - small_section_num = small_factors[value_mul++]; - small_butterfly_num = small_factors[value_mul++]; - small_in_stride = small_factors[value_mul]; - - if (sec_count == 0 && compute_id == 0 && repeat_id == 1) { - __memcpy(nram_tw, small_twiddles, - small_butterfly_num * (radix - 1) * sizeof(DT) * 2, - SRAM2NRAM); - } - - if (ld_dft_radix != radix) { - ld_dft_radix = radix; - for (int entry = 0;; entry++) { - if (dft_table[entry].radix == ld_dft_radix) { - align_K = K_num * ((radix + K_num - 1) / K_num); - __memcpy( - nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2], - sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM); - break; - } - - if (dft_table[entry].radix == -1) { - break; - } - } - } - switch (radix) { - // case 2: - // break; - // case 3: - // computeRadix3ButterflyLaststage( - // nram_out_r, nram_out_i, nram_in_r, nram_in_i, - // nram_scratch, nram_tw, small_section_num, - // small_butterfly_num, small_in_stride, dir); - // break; - // case 9: - // computeRadix9ButterflyLaststage( - // nram_out_r, nram_out_i, nram_in_r, nram_in_i, - // nram_scratch, nram_tw, small_section_num, - // small_butterfly_num, small_in_stride, dir); - // break; - default: - // computeGenericButterflyLaststage(Fout, buffer, twiddles, - // radix, section_num, butterfly_num, in_stride, 0, dir); - computeGenericButterflyLaststageMat( - nram_out_r, nram_out_i, nram_in_r, nram_in_i, - nram_scratch, nram_dftmtx, nram_tw, small_section_num, - small_butterfly_num, para_num, small_in_stride, dir, - radix); - break; - } - - // if last-stage: stride = large_radix * 2 - // compute_id 0 r - // compute_id 0 i - // compute_id 1 r - // compute_id 1 i - // else: stride = large_radix - // compute_id 0 r - // compute_id 1 i - // compute_id 0 r - // compute_id 1 i - // __memcpy(nram_transpose_temp.r + - // compute_id * large_radix * (1 + (int)last_stage), - // nram_out_r, large_radix * sizeof(DT), NRAM2NRAM); - // __memcpy(nram_transpose_temp.i + - // compute_id * large_radix * (1 + (int)last_stage), - // nram_out_i, large_radix * sizeof(DT), NRAM2NRAM); - - if (last_stage) { - __memcpy(nram_transpose_temp.r + - compute_id * large_radix * (1 + (int)last_stage), - nram_out_r, sizeof(DT) * large_radix, NRAM2NRAM, - sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix, - para_num - 1); - - __memcpy(nram_transpose_temp.i + - compute_id * large_radix * (1 + (int)last_stage), - nram_out_i, sizeof(DT) * large_radix, NRAM2NRAM, - sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix, - para_num - 1); - - __bang_transpose(nram_para_store_pong.r, nram_transpose_temp.r, - para_num * 2, large_radix); - } else { - if (nram_out_r == nram_para_store_pong.r) { - FFT_SWAP_PTR(nram_para_load_in_pong.r, nram_para_store_pong.r) - FFT_SWAP_PTR(nram_para_load_in_pong.i, nram_para_store_pong.i) - } - - __bang_transpose(nram_para_store_pong.r, nram_out_r, para_num, - large_radix); - __bang_transpose(nram_para_store_pong.i, nram_out_i, para_num, - large_radix); - } - - // __bang_transpose(nram_para_store, nram_transpose_temp.r, - // max_para_ldst_num * 2, large_radix); - } - } - } - // __sync(); - - __sync(); - FFT_SWAP_PTR(nram_para_load_in_ping.r, nram_para_load_in_pong.r) - FFT_SWAP_PTR(nram_para_load_in_ping.i, nram_para_load_in_pong.i) - FFT_SWAP_PTR(nram_para_store_ping.r, nram_para_store_pong.r) - FFT_SWAP_PTR(nram_para_store_ping.i, nram_para_store_pong.i) - } - } - // Fin_stride += large_butterfly_num; - // Fout_stride += large_radix * large_butterfly_num; - } -} - -template -__mlu_func__ void computeLargeButterflyLaststageBatchPingpong( - DT *output, DT *input, const DT *cur_large_twiddles, const DT *_twiddles, - const DT *dft_matrix, int large_section_num, int large_butterfly_num, - int large_in_stride, void *nram_buf, const int *small_factors, int nfft, - const int t_start, const int t_end, int dir) { - computeLargeButterflyOtherstagesBatchPingpong( - output, input, cur_large_twiddles, _twiddles, dft_matrix, - large_section_num, large_butterfly_num, large_in_stride, nram_buf, - small_factors, nfft, t_start, t_end, dir, 1); -} - -template -__mlu_func__ void computeLargeButterflyFirststageColumn( - DT *output, DT *input, int large_in_stride, int section_num, - const DT *twiddles, const DT *dft_matrix, void *nram_buf, - const int *small_factors, int dir, int nfft, int last_stage, - const int para_batch, const int nb) { - // constant - // const int para_batch = 3; - const int K_num = 64 / sizeof(DT); - int align_K = 0; - const dft_table_entry *dft_table = (const dft_table_entry *)dft_matrix; - // test - // for(int i =0; i<3; i++){ - // MLULOG("entry: %d, dft_table.radix: %d, dft_table.offset: %d.\n", - // i, dft_table[i].radix, dft_table[i].offset); - // } - // network info - int radix, small_in_stride, small_stage_count, large_radix, - _small_stage_count; - int small_section_num, small_butterfly_num, value_mul; - int tw_offset; - // int max_radix = small_factors[4]; - _small_stage_count = small_factors[0]; - large_radix = small_factors[1]; - tw_offset = small_factors[2]; - - // for (int i=2; i<= _small_stage_count; i++) { - - // max_radix = max(small_factors[i*4], max_radix); - - // } - - // load compute store - // (0) load 0 ping sync() - // (1) compute 0 ping load 1 pong sync() - // (2) store 0 compute 1 pong load 2 ping sync() - // (3) store 1 compute 2 load 3 sync() - - // compute last-large-stage (nram_out_r,nram_out_i) [2, large_radix]-> - // transpose -> [large_radix, 2] - - // complex array -> real array, imag array -> complex array - // first-large-stage complex -> real array, imag array - // other-large-stage none - // last-large-stage real array, imag array -> complex - const DT *small_twiddles = twiddles + tw_offset * 2; // complex - - // assign nram space - int nram_buf_offset = 0; - - // parallel load/store space - // sizeof(DT) * 2 * large_radix * para_batch * 4 - DT *nram_para_load_ping = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * para_batch * 2; // complex - - DT *nram_para_load_pong = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * para_batch * 2; // complex - - DT *nram_para_store_ping = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * para_batch * 2; // complex - - DT *nram_para_store_pong = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * para_batch * 2; // complex - - // transpose space: [radix, 2 * parrallel] -> [parrallel * 2, radix] - // DT *nram_transpose_load = (DT *)nram_buf + nram_buf_offset; - // nram_buf_offset += large_radix * para_batch * 2; // complex - - // FFT_CPX_T
nram_transpose_temp; - // temp out-space before transpose - // if last-stage: - // compute_id 0 r - // compute_id 0 i - // compute_id 1 r - // compute_id 1 i - // else: - // compute_id 0 r - // compute_id 1 i - // compute_id 0 r - // compute_id 1 i - - DT *_nram_tw = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * 2; // complex - - // load dftmtx sample - int ld_dft_radix = -1; - const int max_radix = 64; - DT *nram_dftmtx = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += max_radix * max_radix * 2; // complex - - // const int ld_dft_radix = 16; - // DT *nram_dftmtx8 = (DT *)nram_buf + nram_buf_offset; - // nram_buf_offset += 8 * 8 * 2; // complex - - // for (int entry = 0;; entry++) { - // if (dft_table[entry].radix == 8) { - // __memcpy_async(nram_dftmtx8, &dft_matrix[dft_table[entry].offset * 2], - // sizeof(DT) * 2 * 8 * 8, GDRAM2NRAM); - // break; - // } - // if (dft_table[entry].radix == -1) { - // break; - // } - // } - - // nram space used: - // sizeof(DT) * 2 * large_radix * (para_batch * 6 + 1) + sizeof(DT) * 2 - // * (max_radix * max_radix) - // + sizeof(DT) * 2 * large_radix * para_batch * 4 - DT *nram_scratch = (DT *)nram_buf + nram_buf_offset; - - // DT *nram_transpose_temp = nram_scratch; - // overlap nram_scratch - // nram_transpose_temp = { - // (DT *)nram_scratch, - // (DT *)nram_scratch + large_radix * ((int)last_stage) + - // large_radix * (1 - (int)last_stage) * para_batch}; - // nram_buf_offset += large_radix * para_batch * 2; // complex - - __memcpy_async(_nram_tw, small_twiddles, large_radix * sizeof(DT) * 2, - SRAM2NRAM); - - // return; - // ceil - int repeat_num = section_num; - // MLULOG("repeat_num column: %d\n", repeat_num); - for (int repeat_id = 0; repeat_id < repeat_num + 2; ++repeat_id) { - // pipeline: load-stage - - if (repeat_id < repeat_num) { - // MLULOG("pipeline: load-stage.\n"); - int i = repeat_id; - - // DT *nram_dftmtx = - // (repeat_id % 2 == 0) ? nram_dftmtx_ping : nram_dftmtx_pong; - - // if (section_num == 1) { - // __memcpy_async(nram_para_load, input, sizeof(DT) * 2 * large_radix, - // GDRAM2NRAM); - // } else { - // // gather load - // // 2d memcpy - // // 0 1 2 3 4 ... 1023 - // // GDRAM -> NRAM - // // 8bytes radix-1024 - // // 64bytes - - // __memcpy_async(nram_para_load, input + i * 2, - // sizeof(DT) * 2 * para_batch, GDRAM2NRAM, - // sizeof(DT) * 2 * para_batch, - // large_in_stride * sizeof(DT) * 2, large_radix - 1); - // } - - // if(0) - // __memcpy_async(nram_para_load, input, sizeof(DT) * 2 * large_radix * - // para_batch, - // GDRAM2NRAM); - // if(0) - __memcpy_async(nram_para_load_ping, input + i * 2 * nb, - sizeof(DT) * 2 * para_batch, GDRAM2NRAM, - sizeof(DT) * 2 * para_batch, - nb * large_in_stride * sizeof(DT) * 2, large_radix - 1); - } - - // pipeline: store-stage - - if (repeat_id >= 2) { - // MLULOG("pipeline: store-stage.\n"); - int i = (repeat_id - 2); - - if (last_stage) { - // if(0) - // __memcpy_async(output + i * large_radix * 2 * nb1, - // nram_para_store_ping, - // para_batch * sizeof(DT) * 2, NRAM2GDRAM, - // nb1 * 2 * sizeof(DT), para_batch * sizeof(DT) * 2, - // large_radix - 1); - __memcpy_async(output + i * large_radix * 2 * nb, nram_para_store_ping, - para_batch * sizeof(DT) * 2, NRAM2GDRAM, - nb * 2 * sizeof(DT), para_batch * sizeof(DT) * 2, - large_radix - 1); - // __memcpy_async(output + i * large_radix * 2 * para_batch, - // nram_para_store_ping, - // 2* para_batch * large_radix * sizeof(DT), NRAM2GDRAM); - } else { - // // real - // __memcpy_async(output + i * large_radix * nb1, nram_para_store, - // para_batch * sizeof(DT), NRAM2GDRAM, nb1 * sizeof(DT), - // para_batch * sizeof(DT), large_radix - 1); - // // imag - // __memcpy_async(output + i * large_radix * nb1 + nb0 * nb1, - // nram_para_store + para_batch * large_radix, - // para_batch * sizeof(DT), NRAM2GDRAM, nb1 * sizeof(DT), - // para_batch * sizeof(DT), large_radix - 1); - - // real - __memcpy_async(output + i * large_radix * para_batch, - nram_para_store_ping, - para_batch * large_radix * sizeof(DT), NRAM2GDRAM); - // imag - __memcpy_async( - output + i * large_radix * para_batch + nfft * para_batch, - nram_para_store_ping + para_batch * large_radix, - para_batch * large_radix * sizeof(DT), NRAM2GDRAM); - } - } - - // pipeline: compute-stage - - if (repeat_id >= 1 && repeat_id < repeat_num + 1) { - // int i = (repeat_id - 1); - - // // [large_radix, para_batch, 2] -> [para_batch, 2, large_radix] - // __bang_transpose(nram_transpose_load, nram_para_load, large_radix, - // 2 * para_batch); - - // [large_radix, para_batch, 2] -> [2, para_batch, large_radix] - // overlap nram_out_r - // DT *nram_transpose_load = nram_out_r; - // __bang_transpose(nram_transpose_load, nram_para_load, - // large_radix * para_batch, 2); - // // [large_radix, para_batch] -> [para_batch, large_radix] - // __bang_transpose(nram_in_r, nram_transpose_load, large_radix, - // para_batch); - // __bang_transpose(nram_in_i, - // nram_transpose_load + large_radix * para_batch, - // large_radix, para_batch); - - DT *nram_in_r = nram_para_store_pong; - DT *nram_in_i = nram_para_store_pong + large_radix * para_batch; - - DT *nram_out_r = nram_para_load_pong; - DT *nram_out_i = nram_para_load_pong + large_radix * para_batch; - - // DT *nram_transpose_load = nram_in_r; - __bang_transpose(nram_in_r, nram_para_load_pong, large_radix * para_batch, - 2); - // [large_radix, para_batch] -> [para_batch, large_radix] - // __bang_transpose(nram_in_r, nram_transpose_load, large_radix, - // para_batch); - // __bang_transpose(nram_in_i, - // nram_transpose_load + large_radix * para_batch, - // large_radix, para_batch); - - for (int compute_id = 0; compute_id < para_batch; - compute_id += para_batch) { - // load real & imag - - radix = small_factors[4]; - small_section_num = small_factors[5]; - small_in_stride = small_factors[7]; - small_stage_count = _small_stage_count; - - // __memcpy(nram_in_r, - // nram_transpose_load + compute_id * large_radix * 2, - // large_radix * sizeof(DT) * 2, NRAM2NRAM); - - // first stage - - if (ld_dft_radix != radix) { - ld_dft_radix = radix; - for (int entry = 0;; entry++) { - if (dft_table[entry].radix == ld_dft_radix) { - align_K = K_num * ((radix + K_num - 1) / K_num); - __memcpy(nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2], - sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM); - break; - } - - if (dft_table[entry].radix == -1) { - break; - } - } - } - - switch (radix) { - default: - MLULOG("computeGenericButterflyFirststageMat: %d.\n", radix); - computeGenericButterflyFirststageMat( - nram_out_r, nram_out_i, nram_in_r, nram_in_i, nram_scratch, - nram_dftmtx, small_section_num * para_batch, - small_section_num * para_batch, 1, dir, radix); - break; - } - // for (int j = 0; j < ld_dft_radix * align_K; j++) { - // printf("nram_dftmtx [%d][%d]: (%f, %f).\n", (ld_dft_radix * - // align_K)nram_dftmtx[j], nram_dftmtx[j + ld_dft_radix * align_K]); - // } - - // [radix, small_section_num, para_batch] -> - // [small_section_num, para_batch, radix] -> [para_batch, - // small_section_num, radix] - - // __memcpy(nram_out_r + para_ldst_id * small_section_num * radix, - // nram_in_r + para_ldst_id * radix, sizeof(DT) * radix, - // NRAM2NRAM, sizeof(DT) * radix, - // para_batch * radix * sizeof(DT), small_section_num - 1); - - // __memcpy(nram_out_i + para_ldst_id * small_section_num * radix, - // nram_in_i + para_ldst_id * radix, sizeof(DT) * radix, - // NRAM2NRAM, sizeof(DT) * radix, - // para_batch * radix * sizeof(DT), small_section_num - 1); - // __sync_move(); - - small_stage_count--; - if (small_stage_count == 0) { - // nram to gdram - FFT_SWAP_PTR(nram_out_r, nram_in_r); - FFT_SWAP_PTR(nram_out_i, nram_in_i); - __bang_transpose(nram_out_r, nram_in_r, para_batch, large_radix); - __bang_transpose(nram_out_i, nram_in_i, para_batch, large_radix); - - if (nram_out_r == nram_para_store_pong) { - FFT_SWAP_PTR(nram_para_load_pong, nram_para_store_pong) - } - - if (last_stage) { - // [2, para_batch, large_radix] -> [para_batch, large_radix, - // 2] - // DT* nram_transpose_store = nram_in_r; - - __bang_transpose(nram_para_store_pong, nram_out_r, 2, - para_batch * large_radix); - } else { - // [2, para_batch, large_radix] -> [2, para_batch, - // large_radix] - // TODO(zrg): redundant move - __memcpy(nram_para_store_pong, nram_out_r, - para_batch * large_radix * sizeof(DT), NRAM2NRAM); - __memcpy(nram_para_store_pong + para_batch * large_radix, - nram_out_i, para_batch * large_radix * sizeof(DT), - NRAM2NRAM); - } - - continue; - } - - FFT_SWAP_PTR(nram_out_r, nram_in_r); - FFT_SWAP_PTR(nram_out_i, nram_in_i); - - TRANSPOSE_XYZ2YXZ_PAIR(nram_out_r, nram_out_i, nram_in_r, nram_in_i, - small_section_num, para_batch, radix, DT) - - value_mul = 8; - // DT *sram_tw = (DT *)sram_buffer; - DT *nram_tw = _nram_tw; - - for (; small_stage_count > 1; small_stage_count--) { - FFT_SWAP_PTR(nram_out_r, nram_in_r); - FFT_SWAP_PTR(nram_out_i, nram_in_i); - - // value_mul = (_small_stage_count - small_stage_count + 1) << 2; - - // // update parameter - radix = small_factors[value_mul++]; - small_section_num = small_factors[value_mul++]; - small_butterfly_num = small_factors[value_mul++]; - small_in_stride = small_factors[value_mul++]; - // value_mul += 4; - // copy GDRAM2SRAM - - // if (compute_id == 0 && repeat_id == 1 && 0) { - // __memcpy(nram_tw, small_twiddles, - // small_butterfly_num * (radix - 1) * sizeof(DT) * 2, - // GDRAM2NRAM); - // small_twiddles += small_butterfly_num * (radix - 1) * 2; - // } - - if (ld_dft_radix != radix) { - ld_dft_radix = radix; - for (int entry = 0;; entry++) { - if (dft_table[entry].radix == ld_dft_radix) { - align_K = K_num * ((radix + K_num - 1) / K_num); - __memcpy(nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2], - sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM); - break; - } - - if (dft_table[entry].radix == -1) { - break; - } - } - } - - switch (radix) { - default: - // computeGenericButterflyOtherstages(Fout, buffer, twiddles, - // radix, section_num, butterfly_num, in_stride, 0, dir); - MLULOG("computeGenericButterflyOtherstagesMat: %d.\n", radix); - computeGenericButterflyOtherstagesMat( - nram_out_r, nram_out_i, nram_in_r, nram_in_i, nram_scratch, - nram_dftmtx, nram_tw, small_section_num, small_butterfly_num, - para_batch, small_in_stride, dir, radix); - break; - } + // if (last_stage) { + // // [2, para_num, large_radix] -> [para_num, + // large_radix, + // // 2] + // // DT* nram_transpose_store = nram_in_r; - nram_tw += small_butterfly_num * (radix - 1) * 2; - } // for (stage_count) + // __bang_transpose(nram_in_r, nram_out_r, 2, + // max_para_num * large_radix); - // for (int j = 0; j < large_radix; j++) { - // MLULOG("output i: (%f, %f).\n", nram_out_r[j], nram_out_i[j]); - // } + // } else { + // // [2, para_num, large_radix] -> [2, para_num, + // // large_radix] + // // TODO(zrg): redundant move + // __memcpy(nram_in_r, nram_out_r, + // para_num * large_radix * sizeof(DT), + // NRAM2NRAM); + // __memcpy(nram_in_i, + // nram_out_i, para_num * large_radix * + // sizeof(DT), NRAM2NRAM); + // } - // MLULOG("butterfly id: %d\n", i); - // last stage - { - FFT_SWAP_PTR(nram_out_r, nram_in_r); - FFT_SWAP_PTR(nram_out_i, nram_in_i); + // [nfft, 2] -> [2, nfft] -> [2, nfft] -> [nfft, 2] + if (last_stage) { + __memcpy(nram_transpose_temp.r + + compute_id * large_radix * (1 + (int)last_stage), + nram_out_r, sizeof(DT) * large_radix, NRAM2NRAM, + sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix, + para_num - 1); - // copy GDRAM2SRAM + __memcpy(nram_transpose_temp.i + + compute_id * large_radix * (1 + (int)last_stage), + nram_out_i, sizeof(DT) * large_radix, NRAM2NRAM, + sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix, + para_num - 1); - // update parameter - // value_mul = _small_stage_count << 2; - radix = small_factors[value_mul++]; - small_section_num = small_factors[value_mul++]; - small_butterfly_num = small_factors[value_mul++]; - small_in_stride = small_factors[value_mul]; + __bang_transpose(nram_para_store_pong.r, nram_transpose_temp.r, + para_num * 2, large_radix); + } else { + // rotation-large + __bang_mul(CPX_MUL_RR, nram_out_r + para_num, + nram_para_load_tw.r, para_num * (large_radix - 1)); + __bang_mul(CPX_MUL_II, nram_out_i + para_num, + nram_para_load_tw.i, para_num * (large_radix - 1)); + __bang_mul(CPX_MUL_RI, nram_out_r + para_num, + nram_para_load_tw.i, para_num * (large_radix - 1)); + __bang_mul(CPX_MUL_IR, nram_out_i + para_num, + nram_para_load_tw.r, para_num * (large_radix - 1)); + + __bang_sub(nram_out_r + para_num, CPX_MUL_RR, CPX_MUL_II, + para_num * (large_radix - 1)); + __bang_add(nram_out_i + para_num, CPX_MUL_RI, CPX_MUL_IR, + para_num * (large_radix - 1)); - // if (compute_id == 0 && repeat_id == 1 && 0) { - // __memcpy(nram_tw, small_twiddles, - // small_butterfly_num * (radix - 1) * sizeof(DT) * 2, - // GDRAM2NRAM); - // } - if (ld_dft_radix != radix) { - ld_dft_radix = radix; - for (int entry = 0;; entry++) { - if (dft_table[entry].radix == ld_dft_radix) { - align_K = K_num * ((radix + K_num - 1) / K_num); - __memcpy(nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2], - sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM); - break; + if (nram_out_r == nram_para_store_pong.r) { + FFT_SWAP_PTR(nram_para_load_in_pong.r, nram_para_store_pong.r) + FFT_SWAP_PTR(nram_para_load_in_pong.i, nram_para_store_pong.i) + } + __bang_transpose(nram_para_store_pong.r, nram_out_r, para_num, + large_radix); + __bang_transpose(nram_para_store_pong.i, nram_out_i, para_num, + large_radix); } - if (dft_table[entry].radix == -1) { - break; - } + continue; } - } - - switch (radix) { - case 2: - break; - - default: - MLULOG("computeGenericButterflyLaststageMat: %d.\n", radix); - computeGenericButterflyLaststageMat( - nram_out_r, nram_out_i, nram_in_r, nram_in_i, nram_scratch, - nram_dftmtx, nram_tw, small_section_num, small_butterfly_num, - para_batch, small_in_stride, dir, radix); - MLULOG("computeGenericButterflyLaststageMat: %d End.\n", radix); - break; - } - - // [2, para_batch, large_radix] -> [2, large_radix, para_batch] - - if (last_stage) { - // TRANSPOSE_XYZ2YXZ_PAIR(nram_out_r, nram_out_i, nram_in_r, - // nram_in_i, 2, para_batch, large_radix, DT) - - // [2, para_batch, large_radix] -> [para_batch, large_radix, - // 2] - // DT* nram_transpose_store = nram_in_r; FFT_SWAP_PTR(nram_out_r, nram_in_r); FFT_SWAP_PTR(nram_out_i, nram_in_i); - __bang_transpose(nram_out_r, nram_in_r, para_batch, large_radix); - __bang_transpose(nram_out_i, nram_in_i, para_batch, large_radix); - - if (nram_out_r == nram_para_store_pong) { - FFT_SWAP_PTR(nram_para_load_pong, nram_para_store_pong) - } - - __bang_transpose(nram_para_store_pong, nram_out_r, 2, - para_batch * large_radix); - // [2, para_batch, large_radix] -> [para_batch, 2, large_radix] -> - // [large_radix, para_batch, 2] - // __bang_transpose(nram_para_store, nram_out_r, para_batch * 2, - // large_radix); - } else { - // [2, para_batch, large_radix] -> [2, para_batch, - // large_radix] - // TODO(zrg): test - if (nram_out_r == nram_para_store_pong) { - FFT_SWAP_PTR(nram_para_load_pong, nram_para_store_pong) - } - - __bang_transpose(nram_para_store_pong, nram_out_r, para_batch, - large_radix); - __bang_transpose(nram_para_store_pong + para_batch * large_radix, - nram_out_i, para_batch, large_radix); - - // __memcpy(nram_para_store, nram_out_r, - // para_batch * large_radix * sizeof(DT), NRAM2NRAM); - // __memcpy(nram_para_store + para_batch * large_radix, nram_out_i, - // para_batch * large_radix * sizeof(DT), NRAM2NRAM); - } - } - } - } - - __sync(); - FFT_SWAP_PTR(nram_para_load_ping, nram_para_load_pong) - FFT_SWAP_PTR(nram_para_store_ping, nram_para_store_pong) - } -} - -template -__mlu_func__ void computeLargeButterflyOtherstagesColumn( - DT *output, DT *input, const DT *cur_large_twiddles, const DT *_twiddles, - const DT *dft_matrix, int large_section_num, int large_butterfly_num, - int large_in_stride, void *nram_buf, const int *small_factors, int nfft, - int dir, int last_stage, int para_batch, int nb) { - // return; - const dft_table_entry *dft_table = (const dft_table_entry *)dft_matrix; - - int radix, small_in_stride, small_stage_count, large_radix, - _small_stage_count; - int small_section_num, small_butterfly_num, value_mul; - - const int large_out_stride = large_butterfly_num; - int tw_offset; - - _small_stage_count = small_factors[0]; - large_radix = small_factors[1]; - tw_offset = small_factors[2]; - - const int K_num = 64 / sizeof(DT); - int align_K = 0; - - const DT *small_twiddles = _twiddles + tw_offset * 2; // complex - // const DT *small_twiddles; // complex - - // MLULOG("small_section_num: %d.\n\n\n", small_section_num); - // max num for parallel load/store - // const int max_para_ldst_num = 2187 / large_radix; - // const int max_para_ldst_num = (2187 + large_radix - 1) / large_radix; - // const int max_para_ldst_num = (512 + large_radix - 1) / large_radix; - // const int max_para_ldst_num = 20; - // const int max_para_ldst_num = (4096 + large_radix - 1) / large_radix; - - // int para_ldst_num; - // TODO(zrg): save nram space. - // __nram__ DT nram_space[MAX_BUTTERFLY_ON_CHIP * 2]; - // 0 1 2 3 4 5 - // 0 1 3 - // DT *nram_buf_end = (DT*)&((uint8_t*)nram_buf)[NRAM_BUFFER_SIZE]; - // FFT_CPX_T
*nram_in = (FFT_CPX_T
*)nram_buffer; - // FFT_CPX_T
*nram_out = &nram_in[large_radix]; - // FFT_CPX_T
*nram_buf = &nram_in[large_radix * 2]; - int nram_buf_offset = 0; - DT *nram_in_r = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * para_batch; - - DT *nram_in_i = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * para_batch; - - DT *nram_out_r = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * para_batch; - - DT *nram_out_i = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * para_batch; - - // parallel load/store space - FFT_CPX_T
nram_para_load_in_ping = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + large_radix * para_batch}; - nram_buf_offset += large_radix * para_batch * 2; // complex - - FFT_CPX_T
nram_para_load_in_pong = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + large_radix * para_batch}; - nram_buf_offset += large_radix * para_batch * 2; // complex - - // FFT_CPX_T
nram_para_load_tw_ping = { - // (DT *)nram_buf + nram_buf_offset, - // (DT *)nram_buf + nram_buf_offset + large_radix * para_batch}; - // nram_buf_offset += large_radix * para_batch * 2; // complex - - // FFT_CPX_T
nram_para_load_tw_pong = { - // (DT *)nram_buf + nram_buf_offset, - // (DT *)nram_buf + nram_buf_offset + large_radix * para_batch}; - // nram_buf_offset += large_radix * para_batch * 2; // complex - - FFT_CPX_T
nram_para_load_tw_ping = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + (large_radix - 1)}; - nram_buf_offset += (large_radix - 1) * 2; // complex - - FFT_CPX_T
nram_para_load_tw_pong = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + (large_radix - 1)}; - nram_buf_offset += (large_radix - 1) * 2; // complex - - FFT_CPX_T
nram_para_store_ping = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + large_radix * para_batch}; - nram_buf_offset += large_radix * para_batch * 2; // complex - - FFT_CPX_T
nram_para_store_pong = { - (DT *)nram_buf + nram_buf_offset, - (DT *)nram_buf + nram_buf_offset + large_radix * para_batch}; - nram_buf_offset += large_radix * para_batch * 2; // complex - - // overlap nram_in - FFT_CPX_T
nram_transpose_temp; - // temp out-space before transpose - // if last-stage: - // compute_id 0 r - // compute_id 0 i - // compute_id 1 r - // compute_id 1 i - // else: - // compute_id 0 r - // compute_id 1 i - // compute_id 0 r - // compute_id 1 i - nram_transpose_temp = {(DT *)nram_in_r, - (DT *)nram_in_r + large_radix * ((int)last_stage) + - large_radix * (1 - (int)last_stage) * para_batch}; - // nram_buf_offset += large_radix * para_batch * 2; // complex - - DT *_nram_tw = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += large_radix * 2; // complex - - // transpose space: [radix, 2 * parrallel] -> [parrallel * 2, radix] - // FFT_CPX_T
nram_transpose_load = { - // (DT *)nram_buf + nram_buf_offset, - // (DT *)nram_buf + nram_buf_offset + large_radix * para_batch}; - // nram_buf_offset += large_radix * para_batch * 2; // complex - - // load dftmtx sample - // const int ld_dft_radix = 16; - // DT *nram_dftmtx = (DT *)nram_buf + nram_buf_offset; - // nram_buf_offset += ld_dft_radix * ld_dft_radix * 2; // complex - - // for (int entry = 0;; entry++) { - // if (dft_table[entry].radix == ld_dft_radix) { - // __memcpy(nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2], - // sizeof(DT) * 2 * ld_dft_radix * ld_dft_radix, GDRAM2NRAM); - // break; - // } - - // if (dft_table[entry].radix == -1) { - // break; - // } - // } - int ld_dft_radix = -1; - const int max_radix = 64; - DT *nram_dftmtx = (DT *)nram_buf + nram_buf_offset; - nram_buf_offset += max_radix * max_radix * 2; // complex - - DT *nram_scratch = (DT *)nram_buf + nram_buf_offset; - - // temp overlap with "nram_scratch" - DT *CPX_MUL_RR = nram_scratch; - DT *CPX_MUL_RI = &CPX_MUL_RR[(large_radix - 1) * para_batch]; - DT *CPX_MUL_IR = &CPX_MUL_RI[(large_radix - 1) * para_batch]; - DT *CPX_MUL_II = &CPX_MUL_IR[(large_radix - 1) * para_batch]; - - nram_buf_offset += (large_radix - 1) * para_batch * 4; // complex - - // size: (large_radix - 1) * para_batch - // DT *scratch_tw_r = &CPX_MUL_II[large_radix * para_batch]; - // DT *scratch_tw_i = &scratch_tw_r[(large_radix - 1) * para_batch]; - - // if (nram_buf == NULL) { - // MLULOG("nram_buf: NULL.\n"); - // } - // if (input == NULL) { - // MLULOG("input: NULL.\n"); - // } - // if (output == NULL) { - // MLULOG("output: NULL.\n"); - // } - - // if (cur_large_twiddles == NULL) { - // MLULOG("twiddles: NULL.\n"); - // } - - // __nram__ DT *nram_buf = &nram_space[MAX_BUTTERFLY_ON_CHIP*2]; - - // DT *odd_extra_buffer = buffer + nfft*2; // for in_place temp buffer - - // const int para_num = 1; - - int Fin_stride = 0, Fout_stride = 0; - int sec_count; - int repeat_num = large_butterfly_num; - - for (sec_count = 0; sec_count < large_section_num; ++sec_count) { - for (int repeat_id = 0; repeat_id < repeat_num + 2; ++repeat_id) { - // small_twiddles = _small_twiddles; - - // pipeline: load-stage - - if (repeat_id < repeat_num) { - // MLULOG("pipeline: load-stage.\n"); - int i = repeat_id; - FFT_CPX_T
nram_para_load_in = (repeat_id % 2 == 0) - ? nram_para_load_in_ping - : nram_para_load_in_pong; - - FFT_CPX_T
nram_para_load_tw = (repeat_id % 2 == 0) - ? nram_para_load_tw_ping - : nram_para_load_tw_pong; - - if (para_batch != 1 || 1) { - __memcpy_async( - nram_para_load_in.r, input + (Fin_stride + i) * para_batch, - sizeof(DT) * para_batch, GDRAM2NRAM, sizeof(DT) * para_batch, - para_batch * large_in_stride * sizeof(DT), large_radix - 1); - __memcpy_async( - nram_para_load_in.i, - input + para_batch * nfft + (Fin_stride + i) * para_batch, - sizeof(DT) * para_batch, GDRAM2NRAM, sizeof(DT) * para_batch, - para_batch * large_in_stride * sizeof(DT), large_radix - 1); - - // __memcpy_async(nram_para_load_tw.r, cur_large_twiddles + i, - // sizeof(DT), SRAM2NRAM, sizeof(DT), - // large_out_stride * sizeof(DT), large_radix - 2); - // __memcpy_async( - // nram_para_load_tw.i, - // cur_large_twiddles + large_butterfly_num * (large_radix - 1) + - // i, sizeof(DT), SRAM2NRAM, sizeof(DT), large_out_stride * - // sizeof(DT), large_radix - 2); - __memcpy_async(nram_para_load_tw.r, - cur_large_twiddles + i * (large_radix - 1), - sizeof(DT) * (large_radix - 1) * 2, SRAM2NRAM); - __memcpy_async(nram_para_load_tw.i, - cur_large_twiddles + - large_butterfly_num * (large_radix - 1) + - i * (large_radix - 1), - sizeof(DT) * (large_radix - 1), SRAM2NRAM); - } - } - - // pipeline: store-stage - if (repeat_id >= 2) { - // MLULOG("pipeline: store-stage.\n"); - int i = (repeat_id - 2); - - // int para_store_num = (max_para_ldst_num > (large_butterfly_num - i)) - // ? (large_butterfly_num - i) - // : max_para_ldst_num; - - FFT_CPX_T
nram_para_store = - (repeat_id % 2 == 0) ? nram_para_store_ping : nram_para_store_pong; - - if (last_stage) { - // __memcpy_async( - // output + (Fout_stride + i * large_radix) * 2, - // nram_para_store.r, - // sizeof(DT) * 2 * para_store_num * large_radix, NRAM2GDRAM); - - __memcpy_async(output + (Fout_stride + i) * 2 * nb, nram_para_store.r, - sizeof(DT) * 2 * para_batch, NRAM2GDRAM, - nb * large_out_stride * 2 * sizeof(DT), - sizeof(DT) * 2 * para_batch, large_radix - 1); - } else { - // // real - // __memcpy_async(output + Fout_stride + i * large_radix, - // nram_para_store.r, - // para_store_num * large_radix * sizeof(DT), - // NRAM2GDRAM); - // // imag - // __memcpy_async(output + Fout_stride + i * large_radix + nfft, - // nram_para_store.i, - // para_store_num * large_radix * sizeof(DT), - // NRAM2GDRAM); - // real - __memcpy_async(output + (Fout_stride + i) * para_batch, - nram_para_store.r, para_batch * sizeof(DT), NRAM2GDRAM, - para_batch * large_out_stride * sizeof(DT), - sizeof(DT) * para_batch, large_radix - 1); - // imag - __memcpy_async( - output + (Fout_stride + i) * para_batch + nfft * para_batch, - nram_para_store.i, para_batch * sizeof(DT), NRAM2GDRAM, - para_batch * large_out_stride * sizeof(DT), - sizeof(DT) * para_batch, large_radix - 1); - } - } - // __sync(); - // pipeline: compute-stage - - if (repeat_id >= 1 && repeat_id < repeat_num + 1) { - // MLULOG("pipeline: compute-stage.\n"); - // int i = (repeat_id - 1); - - FFT_CPX_T
nram_para_load_in = (repeat_id % 2 != 0) - ? nram_para_load_in_ping - : nram_para_load_in_pong; - - FFT_CPX_T
nram_para_load_tw = (repeat_id % 2 != 0) - ? nram_para_load_tw_ping - : nram_para_load_tw_pong; - - FFT_CPX_T
nram_para_store = - (repeat_id % 2 != 0) ? nram_para_store_ping : nram_para_store_pong; - - // __bang_transpose(nram_transpose_load, nram_para_load, large_radix, - // 2 * para_ldst_num); - - // rotation-large - // __bang_mul(CPX_MUL_RR, nram_para_load_in.r + para_batch, - // nram_para_load_tw.r, para_batch * (large_radix - 1)); - // __bang_mul(CPX_MUL_II, nram_para_load_in.i + para_batch, - // nram_para_load_tw.i, para_batch * (large_radix - 1)); - // __bang_mul(CPX_MUL_RI, nram_para_load_in.r + para_batch, - // nram_para_load_tw.i, para_batch * (large_radix - 1)); - // __bang_mul(CPX_MUL_IR, nram_para_load_in.i + para_batch, - // nram_para_load_tw.r, para_batch * (large_radix - 1)); - - // __bang_sub(nram_para_load_in.r + para_batch, CPX_MUL_RR, CPX_MUL_II, - // para_batch * (large_radix - 1)); - // __bang_add(nram_para_load_in.i + para_batch, CPX_MUL_RI, CPX_MUL_IR, - // para_batch * (large_radix - 1)); - if (1) { - for (int i = 1; i < large_radix; i++) { - // __memcpy(&Fin.r[nram_in_offset], - // &nram_in_r[butterfly_num * section_num * i], - // butterfly_num * section_num * sizeof(DT), NRAM2NRAM, - // butterfly_num * section_num * sizeof(DT), - // butterfly_num * section_num * radix * sizeof(DT), - // para_large_butterfly - 1); - - // __memcpy(&Fin.i[nram_in_offset], - // &nram_in_i[butterfly_num * section_num * i], - // butterfly_num * section_num * sizeof(DT), NRAM2NRAM, - // butterfly_num * section_num * sizeof(DT), - // butterfly_num * section_num * radix * sizeof(DT), - // para_large_butterfly - 1); - - __bang_mul_scalar(&CPX_MUL_RR[(i - 1) * para_batch], - nram_para_load_in.r + para_batch * i, - nram_para_load_tw.r[(i - 1)], para_batch); - __bang_mul_scalar(&CPX_MUL_RI[(i - 1) * para_batch], - nram_para_load_in.r + para_batch * i, - nram_para_load_tw.i[(i - 1)], para_batch); - __bang_mul_scalar(&CPX_MUL_IR[(i - 1) * para_batch], - nram_para_load_in.i + para_batch * i, - nram_para_load_tw.r[(i - 1)], para_batch); - __bang_mul_scalar(&CPX_MUL_II[(i - 1) * para_batch], - nram_para_load_in.i + para_batch * i, - nram_para_load_tw.i[(i - 1)], para_batch); - - // __bang_cycle_mul(&CPX_MUL_RR[(i - 1) * para_batch], - // nram_para_load_in.r + para_batch * i, - // &nram_para_load_tw.r[(i - 1) * - // large_butterfly_num], para_batch, 1); - // __bang_cycle_mul(&CPX_MUL_RI[(i - 1) * para_batch], - // nram_para_load_in.r + para_batch * i, - // &nram_para_load_tw.i[(i - 1) * - // large_butterfly_num], para_batch, 1); - // __bang_cycle_mul(&CPX_MUL_IR[(i - 1) * para_batch], - // nram_para_load_in.i + para_batch * i, - // &nram_para_load_tw.r[(i - 1) * - // large_butterfly_num], para_batch, 1); - // __bang_cycle_mul(&CPX_MUL_II[(i - 1) * para_batch], - // nram_para_load_in.i + para_batch * i, - // &nram_para_load_tw.i[(i - 1) * - // large_butterfly_num], para_batch, 1); - } - __bang_sub(nram_para_load_in.r + para_batch, CPX_MUL_RR, CPX_MUL_II, - para_batch * (large_radix - 1)); - __bang_add(nram_para_load_in.i + para_batch, CPX_MUL_RI, CPX_MUL_IR, - para_batch * (large_radix - 1)); - } - // __bang_transpose(nram_transpose_load.r, nram_para_load_in.r, - // large_radix, para_batch); - // __bang_transpose(nram_transpose_load.i, nram_para_load_in.i, - // large_radix, para_batch); - - // for (int compute_id = 0; compute_id < para_batch; compute_id++) { - for (int compute_id = 0; compute_id < para_batch; - compute_id += para_batch) { - // load real & imag - - radix = small_factors[4]; - small_section_num = small_factors[5]; - small_in_stride = small_factors[7]; - small_stage_count = _small_stage_count; - - // __memcpy(nram_in_r, - // nram_transpose_load + compute_id * large_radix * 2, - // large_radix * sizeof(DT) * 2, NRAM2NRAM); - - // first stage - // if(0) - if (ld_dft_radix != radix) { - ld_dft_radix = radix; - for (int entry = 0;; entry++) { - if (dft_table[entry].radix == ld_dft_radix) { - align_K = K_num * ((radix + K_num - 1) / K_num); - __memcpy(nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2], - sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM); - break; - } - - if (dft_table[entry].radix == -1) { - break; - } - } - } + // DT* nram_transpose_store = nram_in_r; - switch (radix) { - default: - // computeGenericButterflyFirststage(Fout, buffer, twiddles, - // radix, section_num, butterfly_num, in_stride, 0, dir); - MLULOG("computeGenericButterflyFirststageMat: %d.\n", radix); + // after first stage: [butterfly_num, para_ldst_num, radix] + // other in: [para_ldst_num, butterfly_num, radix] == + // [para_ldst_num, large_radix] + TRANSPOSE_XYZ2YXZ_PAIR(nram_out_r, nram_out_i, nram_in_r, nram_in_i, + small_section_num, para_num, radix, DT) - computeGenericButterflyFirststageMat( - nram_out_r, nram_out_i, nram_para_load_in.r, - nram_para_load_in.i, nram_scratch, nram_dftmtx, - small_section_num * para_batch, - small_section_num * para_batch, 1, dir, radix); - break; - } - // for (int j = 0; j < large_radix; j++) { - // MLULOG("output i: (%f, %f).\n", nram_out_r[j], nram_out_i[j]); - // } + // DT *sram_tw = (DT *)sram_buffer; + DT *nram_tw = _nram_tw; + value_mul = 8; - // [radix, small_section_num, para_ldst_num] -> - // [small_section_num, para_ldst_num, radix] -> - // [para_ldst_num, small_section_num, radix] -> - // [small_section_num, radix, para_ldst_num] == - // [large_radix, para_ldst_num] + for (; small_stage_count > 1; small_stage_count--) { + FFT_SWAP_PTR(nram_out_r, nram_in_r); + FFT_SWAP_PTR(nram_out_i, nram_in_i); - small_stage_count--; - if (small_stage_count == 0) { - // FFT_SWAP_PTR(nram_out_r, nram_in_r); - // FFT_SWAP_PTR(nram_out_i, nram_in_i); - // __bang_transpose(nram_out_r, nram_in_r, para_batch, large_radix); - // __bang_transpose(nram_out_i, nram_in_i, para_batch, large_radix); + // value_mul = (_small_stage_count - small_stage_count + 1) * 4; + // // update parameter + radix = small_factors[value_mul++]; + small_section_num = small_factors[value_mul++]; + small_butterfly_num = small_factors[value_mul++]; + small_in_stride = small_factors[value_mul++]; + // copy GDRAM2SRAM - // nram to gdram + if (ld_dft_radix != radix) { + ld_dft_radix = radix; + for (int entry = 0;; entry++) { + if (dft_table[entry].radix == ld_dft_radix) { + align_K = K_num * ((radix + K_num - 1) / K_num); + __memcpy( + nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2], + sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM); + break; + } - // if (last_stage) { - // // [2, para_ldst_num, large_radix] -> [para_ldst_num, - // large_radix, - // // 2] - // // DT* nram_transpose_store = nram_in_r; + if (dft_table[entry].radix == -1) { + break; + } + } + } - // __bang_transpose(nram_in_r, nram_out_r, 2, - // max_para_ldst_num * large_radix); - - // } else { - // // [2, para_ldst_num, large_radix] -> [2, para_ldst_num, - // // large_radix] - // // TODO(zrg): redundant move - // __memcpy(nram_in_r, nram_out_r, - // para_ldst_num * large_radix * sizeof(DT), NRAM2NRAM); - // __memcpy(nram_in_i, - // nram_out_i, para_ldst_num * large_radix * sizeof(DT), - // NRAM2NRAM); - // } + if (sec_count == 0 && compute_id == 0 && repeat_id == 1) { + __memcpy(nram_tw, small_twiddles, + small_butterfly_num * (radix - 1) * sizeof(DT) * 2, + SRAM2NRAM); + small_twiddles += small_butterfly_num * (radix - 1) * 2; + } - if (last_stage) { - __memcpy(nram_transpose_temp.r + compute_id * large_radix * 2, - nram_out_r, sizeof(DT) * large_radix, NRAM2NRAM, - sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix, - para_batch - 1); - - __memcpy(nram_transpose_temp.i + compute_id * large_radix * 2, - nram_out_i, sizeof(DT) * large_radix, NRAM2NRAM, - sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix, - para_batch - 1); - - __bang_transpose(nram_para_store.r, nram_transpose_temp.r, - para_batch * 2, large_radix); - } else { - __bang_transpose(nram_para_store.r, nram_out_r, para_batch, - large_radix); - __bang_transpose(nram_para_store.i, nram_out_i, para_batch, - large_radix); - } + switch (radix) { + default: + // computeGenericButterflyOtherstages(Fout, buffer, twiddles, + // radix, section_num, butterfly_num, in_stride, 0, dir); + computeGenericButterflyOtherstagesMat( + nram_out_r, nram_out_i, nram_in_r, nram_in_i, + nram_scratch, nram_dftmtx, nram_tw, small_section_num, + small_butterfly_num, para_num, small_in_stride, dir, + radix); + break; + } - continue; - } + nram_tw += small_butterfly_num * (radix - 1) * 2; + } // for (stage_count) - FFT_SWAP_PTR(nram_out_r, nram_in_r); - FFT_SWAP_PTR(nram_out_i, nram_in_i); - // DT* nram_transpose_store = nram_in_r; - - // for (int para_ldst_id = 0; para_ldst_id < para_ldst_num; - // para_ldst_id++) { - // __memcpy(nram_out_r + para_ldst_id * small_section_num * radix, - // nram_in_r + para_ldst_id * radix, sizeof(DT) * radix, - // NRAM2NRAM, sizeof(DT) * radix, - // para_ldst_num * radix * sizeof(DT), small_section_num - - // 1); - - // __memcpy(nram_out_i + para_ldst_id * small_section_num * radix, - // nram_in_i + para_ldst_id * radix, sizeof(DT) * radix, - // NRAM2NRAM, sizeof(DT) * radix, - // para_ldst_num * radix * sizeof(DT), small_section_num - - // 1); - // } + // for (int j = 0; j < large_radix; j++) { + // MLULOG("output i: (%f, %f).\n", nram_out_r[j], nram_out_i[j]); + // } - TRANSPOSE_XYZ2YXZ_PAIR(nram_out_r, nram_out_i, nram_in_r, nram_in_i, - small_section_num, para_batch, radix, DT) - - // TODO(zrg) : add not last-stage - // if (small_stage_count == 0) { - // // if last-stage: stride = large_radix * 2 - // // compute_id 0 r - // // compute_id 0 i - // // compute_id 1 r - // // compute_id 1 i - // // else: stride = large_radix - // // compute_id 0 r - // // compute_id 1 i - // // compute_id 0 r - // // compute_id 1 i - - // // [radix, small_section_num, para_ldst_num] -> - // // [small_section_num, para_ldst_num, radix] -> - // // [para_ldst_num, small_section_num, radix] -> - // // [small_section_num, radix, para_ldst_num] == - // // [large_radix, para_ldst_num] - - // __memcpy(nram_transpose_temp.r + - // compute_id * large_radix * (1 + (int)last_stage), - // nram_out_r, sizeof(DT) * large_radix, NRAM2NRAM, - // sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix, - // para_ldst_num - 1); - - // __memcpy(nram_transpose_temp.i + - // compute_id * large_radix * (1 + (int)last_stage), - // nram_out_i, sizeof(DT) * large_radix, NRAM2NRAM, - // sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix, - // para_ldst_num - 1); - - // // __memcpy(nram_transpose_temp.r + - // // compute_id * large_radix * (1 + (int)last_stage), - // // nram_out_r, large_radix * sizeof(DT), NRAM2NRAM); - // // __memcpy(nram_transpose_temp.i + - // // compute_id * large_radix * (1 + (int)last_stage), - // // nram_out_i, large_radix * sizeof(DT), NRAM2NRAM); - - // // __bang_transpose(nram_transpose_temp.r, nram_transpose_temp.r, - // // max_para_ldst_num * 2, large_radix); - // continue; - // } + // MLULOG("butterfly id: %d\n", i); + // last stage + { + FFT_SWAP_PTR(nram_out_r, nram_in_r); + FFT_SWAP_PTR(nram_out_i, nram_in_i); + // copy GDRAM2SRAM - // DT *sram_tw = (DT *)sram_buffer; - DT *nram_tw = _nram_tw; - value_mul = 8; + // update parameter + radix = small_factors[value_mul++]; + small_section_num = small_factors[value_mul++]; + small_butterfly_num = small_factors[value_mul++]; + small_in_stride = small_factors[value_mul]; - for (; small_stage_count > 1; small_stage_count--) { - FFT_SWAP_PTR(nram_out_r, nram_in_r); - FFT_SWAP_PTR(nram_out_i, nram_in_i); + if (sec_count == 0 && compute_id == 0 && repeat_id == 1) { + __memcpy(nram_tw, small_twiddles, + small_butterfly_num * (radix - 1) * sizeof(DT) * 2, + SRAM2NRAM); + } - // value_mul = (_small_stage_count - small_stage_count + 1) * 4; - // // update parameter - radix = small_factors[value_mul++]; - small_section_num = small_factors[value_mul++]; - small_butterfly_num = small_factors[value_mul++]; - small_in_stride = small_factors[value_mul++]; - // copy GDRAM2SRAM + if (ld_dft_radix != radix) { + ld_dft_radix = radix; + for (int entry = 0;; entry++) { + if (dft_table[entry].radix == ld_dft_radix) { + align_K = K_num * ((radix + K_num - 1) / K_num); + __memcpy( + nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2], + sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM); + break; + } - if (ld_dft_radix != radix) { - ld_dft_radix = radix; - for (int entry = 0;; entry++) { - if (dft_table[entry].radix == ld_dft_radix) { - align_K = K_num * ((radix + K_num - 1) / K_num); - __memcpy(nram_dftmtx, - &dft_matrix[dft_table[entry].offset * 2], - sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM); - break; + if (dft_table[entry].radix == -1) { + break; + } } - - if (dft_table[entry].radix == -1) { + } + switch (radix) { + // case 2: + // break; + // case 3: + // computeRadix3ButterflyLaststage( + // nram_out_r, nram_out_i, nram_in_r, nram_in_i, + // nram_scratch, nram_tw, small_section_num, + // small_butterfly_num, small_in_stride, dir); + // break; + // case 9: + // computeRadix9ButterflyLaststage( + // nram_out_r, nram_out_i, nram_in_r, nram_in_i, + // nram_scratch, nram_tw, small_section_num, + // small_butterfly_num, small_in_stride, dir); + // break; + default: + // computeGenericButterflyLaststage(Fout, buffer, twiddles, + // radix, section_num, butterfly_num, in_stride, 0, dir); + computeGenericButterflyLaststageMat( + nram_out_r, nram_out_i, nram_in_r, nram_in_i, + nram_scratch, nram_dftmtx, nram_tw, small_section_num, + small_butterfly_num, para_num, small_in_stride, dir, + radix); break; - } } - } - - if (sec_count == 0 && compute_id == 0 && repeat_id == 1) { - __memcpy(nram_tw, small_twiddles, - small_butterfly_num * (radix - 1) * sizeof(DT) * 2, - SRAM2NRAM); - small_twiddles += small_butterfly_num * (radix - 1) * 2; - } - - switch (radix) { - // case 2: - // // computeRadix2ButterflyOtherstages(Fout, Fin, section_num, - // // section_num, 1, dir); - // break; - // case 3: - // computeRadix3ButterflyOtherstages( - // nram_out_r, nram_out_i, nram_in_r, nram_in_i, - // nram_scratch, nram_tw, small_section_num, - // small_butterfly_num, small_in_stride, dir); - // break; - // case 9: - // computeRadix9ButterflyOtherstages( - // nram_out_r, nram_out_i, nram_in_r, nram_in_i, - // nram_scratch, nram_tw, small_section_num, - // small_butterfly_num, small_in_stride, dir); - // break; - default: - // computeGenericButterflyOtherstages(Fout, buffer, twiddles, - // radix, section_num, butterfly_num, in_stride, 0, dir); - computeGenericButterflyOtherstagesMat( - nram_out_r, nram_out_i, nram_in_r, nram_in_i, nram_scratch, - nram_dftmtx, nram_tw, small_section_num, - small_butterfly_num, para_batch, small_in_stride, dir, - radix); - break; - } - nram_tw += small_butterfly_num * (radix - 1) * 2; - } // for (stage_count) - - // for (int j = 0; j < large_radix; j++) { - // MLULOG("output i: (%f, %f).\n", nram_out_r[j], nram_out_i[j]); - // } + // if last-stage: stride = large_radix * 2 + // compute_id 0 r + // compute_id 0 i + // compute_id 1 r + // compute_id 1 i + // else: stride = large_radix + // compute_id 0 r + // compute_id 1 i + // compute_id 0 r + // compute_id 1 i + // __memcpy(nram_transpose_temp.r + + // compute_id * large_radix * (1 + (int)last_stage), + // nram_out_r, large_radix * sizeof(DT), NRAM2NRAM); + // __memcpy(nram_transpose_temp.i + + // compute_id * large_radix * (1 + (int)last_stage), + // nram_out_i, large_radix * sizeof(DT), NRAM2NRAM); - // MLULOG("butterfly id: %d\n", i); - // last stage - { - FFT_SWAP_PTR(nram_out_r, nram_in_r); - FFT_SWAP_PTR(nram_out_i, nram_in_i); - // copy GDRAM2SRAM + // rotation-large + __bang_mul(CPX_MUL_RR, nram_out_r + para_num, nram_para_load_tw.r, + para_num * (large_radix - 1)); + __bang_mul(CPX_MUL_II, nram_out_i + para_num, nram_para_load_tw.i, + para_num * (large_radix - 1)); + __bang_mul(CPX_MUL_RI, nram_out_r + para_num, nram_para_load_tw.i, + para_num * (large_radix - 1)); + __bang_mul(CPX_MUL_IR, nram_out_i + para_num, nram_para_load_tw.r, + para_num * (large_radix - 1)); + + __bang_sub(nram_out_r + para_num, CPX_MUL_RR, CPX_MUL_II, + para_num * (large_radix - 1)); + __bang_add(nram_out_i + para_num, CPX_MUL_RI, CPX_MUL_IR, + para_num * (large_radix - 1)); - // update parameter - radix = small_factors[value_mul++]; - small_section_num = small_factors[value_mul++]; - small_butterfly_num = small_factors[value_mul++]; - small_in_stride = small_factors[value_mul]; + if (last_stage) { + __memcpy(nram_transpose_temp.r + + compute_id * large_radix * (1 + (int)last_stage), + nram_out_r, sizeof(DT) * large_radix, NRAM2NRAM, + sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix, + para_num - 1); - if (sec_count == 0 && compute_id == 0 && repeat_id == 1) { - __memcpy(nram_tw, small_twiddles, - small_butterfly_num * (radix - 1) * sizeof(DT) * 2, - SRAM2NRAM); - } + __memcpy(nram_transpose_temp.i + + compute_id * large_radix * (1 + (int)last_stage), + nram_out_i, sizeof(DT) * large_radix, NRAM2NRAM, + sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix, + para_num - 1); - if (ld_dft_radix != radix) { - ld_dft_radix = radix; - for (int entry = 0;; entry++) { - if (dft_table[entry].radix == ld_dft_radix) { - align_K = K_num * ((radix + K_num - 1) / K_num); - __memcpy(nram_dftmtx, - &dft_matrix[dft_table[entry].offset * 2], - sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM); - break; + __bang_transpose(nram_para_store_pong.r, nram_transpose_temp.r, + para_num * 2, large_radix); + } else { + if (nram_out_r == nram_para_store_pong.r) { + FFT_SWAP_PTR(nram_para_load_in_pong.r, nram_para_store_pong.r) + FFT_SWAP_PTR(nram_para_load_in_pong.i, nram_para_store_pong.i) } - if (dft_table[entry].radix == -1) { - break; - } + __bang_transpose(nram_para_store_pong.r, nram_out_r, para_num, + large_radix); + __bang_transpose(nram_para_store_pong.i, nram_out_i, para_num, + large_radix); } - } - switch (radix) { - // case 2: - // break; - // case 3: - // computeRadix3ButterflyLaststage( - // nram_out_r, nram_out_i, nram_in_r, nram_in_i, - // nram_scratch, nram_tw, small_section_num, - // small_butterfly_num, small_in_stride, dir); - // break; - // case 9: - // computeRadix9ButterflyLaststage( - // nram_out_r, nram_out_i, nram_in_r, nram_in_i, - // nram_scratch, nram_tw, small_section_num, - // small_butterfly_num, small_in_stride, dir); - // break; - default: - // computeGenericButterflyLaststage(Fout, buffer, twiddles, - // radix, section_num, butterfly_num, in_stride, 0, dir); - computeGenericButterflyLaststageMat( - nram_out_r, nram_out_i, nram_in_r, nram_in_i, nram_scratch, - nram_dftmtx, nram_tw, small_section_num, - small_butterfly_num, para_batch, small_in_stride, dir, - radix); - break; - } - if (last_stage) { - // [2, para_batch, large_radix] -> [large_radix, para_batch, 2] - __memcpy(nram_transpose_temp.r + compute_id * large_radix * 2, - nram_out_r, sizeof(DT) * large_radix, NRAM2NRAM, - sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix, - para_batch - 1); - - __memcpy(nram_transpose_temp.i + compute_id * large_radix * 2, - nram_out_i, sizeof(DT) * large_radix, NRAM2NRAM, - sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix, - para_batch - 1); - - __bang_transpose(nram_para_store.r, nram_transpose_temp.r, - para_batch * 2, large_radix); - } else { - __bang_transpose(nram_para_store.r, nram_out_r, para_batch, - large_radix); - __bang_transpose(nram_para_store.i, nram_out_i, para_batch, - large_radix); + // __bang_transpose(nram_para_store, nram_transpose_temp.r, + // max_para_ldst_num * 2, large_radix); } - - // __bang_transpose(nram_para_store, nram_transpose_temp.r, - // max_para_ldst_num * 2, large_radix); } } - } + // __sync(); - __sync(); + __sync(); + FFT_SWAP_PTR(nram_para_load_in_ping.r, nram_para_load_in_pong.r) + FFT_SWAP_PTR(nram_para_load_in_ping.i, nram_para_load_in_pong.i) + FFT_SWAP_PTR(nram_para_store_ping.r, nram_para_store_pong.r) + FFT_SWAP_PTR(nram_para_store_ping.i, nram_para_store_pong.i) + } } - Fin_stride += large_butterfly_num; - Fout_stride += large_radix * large_butterfly_num; + // Fin_stride += large_butterfly_num; + // Fout_stride += large_radix * large_butterfly_num; } } template -__mlu_func__ void computeLargeButterflyLaststageColumn( +__mlu_func__ void computeLargeButterflyLaststageBatchPingpongC2R( DT *output, DT *input, const DT *cur_large_twiddles, const DT *_twiddles, const DT *dft_matrix, int large_section_num, int large_butterfly_num, int large_in_stride, void *nram_buf, const int *small_factors, int nfft, - int dir, int para_batch, int nb) { - computeLargeButterflyOtherstagesColumn( + const int t_start, const int t_end, int dir) { + computeLargeButterflyOtherstagesBatchPingpongC2R( output, input, cur_large_twiddles, _twiddles, dft_matrix, large_section_num, large_butterfly_num, large_in_stride, nram_buf, - small_factors, nfft, dir, 1, para_batch, nb); + small_factors, nfft, t_start, t_end, dir, 1); } diff --git a/kernels/fft/irfft/irfft.h b/kernels/fft/irfft/irfft.h index ccfd02896..b4f305793 100644 --- a/kernels/fft/irfft/irfft.h +++ b/kernels/fft/irfft/irfft.h @@ -33,8 +33,8 @@ mluOpStatus_t setIRFFT1dReserveArea(mluOpHandle_t handle, const std::string api); mluOpStatus_t setIRFFT1dReserveArea_v2(mluOpHandle_t handle, - mluOpFFTPlan_t fft_plan, - const std::string api); + mluOpFFTPlan_t fft_plan, + const std::string api); mluOpStatus_t execIRFFT1d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan, const void *input, const float scale_factor, diff --git a/kernels/fft/irfft/irfft_host.cpp b/kernels/fft/irfft/irfft_host.cpp index 95f510a9b..ecee58a08 100644 --- a/kernels/fft/irfft/irfft_host.cpp +++ b/kernels/fft/irfft/irfft_host.cpp @@ -546,8 +546,8 @@ mluOpStatus_t setIRFFT1dReserveArea(mluOpHandle_t handle, } mluOpStatus_t setIRFFT1dReserveArea_v2(mluOpHandle_t handle, - mluOpFFTPlan_t fft_plan, - const std::string api) { + mluOpFFTPlan_t fft_plan, + const std::string api) { mluOpStatus_t status = MLUOP_STATUS_SUCCESS; VLOG(5) << "Into configure IRFFT1d ReserveArea Addrs (zrg)";