diff --git a/kernels/fft/c2c_fft/c2c_fft_host.cpp b/kernels/fft/c2c_fft/c2c_fft_host.cpp
index 29e213163..116763023 100644
--- a/kernels/fft/c2c_fft/c2c_fft_host.cpp
+++ b/kernels/fft/c2c_fft/c2c_fft_host.cpp
@@ -2112,7 +2112,7 @@ mluOpStatus_t execIRFFT1d_v2(mluOpHandle_t handle,
// fft_plan->mlu_addrs.input);
// INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
// }
-
+
status = execFFTc2r1d(handle, fft_plan, scale_factor);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
diff --git a/kernels/fft/fft.cpp b/kernels/fft/fft.cpp
index e16ad65b9..16247326f 100644
--- a/kernels/fft/fft.cpp
+++ b/kernels/fft/fft.cpp
@@ -603,13 +603,13 @@ mluOpStatus_t MLUOP_WIN_API fftFactor(const int _n, int *facbuf,
}
break;
- case (32 * 17):
- if (n % 32 == 0) {
- r = 32;
- } else if ((n % 17) == 0) {
- r = 17;
- }
- break;
+ case (32 * 17):
+ if (n % 32 == 0) {
+ r = 32;
+ } else if ((n % 17) == 0) {
+ r = 17;
+ }
+ break;
case 600:
if (n % 30 == 0) {
@@ -726,19 +726,19 @@ mluOpStatus_t MLUOP_WIN_API fftTwoStepFactor(mluOpFFTPlan_t fft_plan,
if (is_row_major) {
switch (_n) {
case (32 * 17):
- r = 32 * 17;
+ r = 32 * 17;
break;
case (200):
- r = 200;
+ r = 200;
break;
case (600):
- r = 600;
+ r = 600;
break;
case (256):
- r = 256;
+ r = 256;
break;
case 1024:
@@ -868,7 +868,29 @@ mluOpStatus_t MLUOP_WIN_API fftTwoStepFactor(mluOpFFTPlan_t fft_plan,
}
}
n /= r;
- in_stride = _n / r;
+
+ switch (fft_plan->fft_type) {
+ // r2c
+ case CNFFT_HALF2COMPLEX_HALF:
+ case CNFFT_FLOAT2COMPLEX_FLOAT:
+ case CNFFT_COMPLEX_HALF2HALF:
+ case CNFFT_COMPLEX_FLOAT2FLOAT: {
+ if ((n * r) != _n) {
+ in_stride = (((out_stride / 2) + 1) * section_num) / r;
+
+ } else {
+ in_stride = _n / r;
+ }
+ }
+
+ case CNFFT_COMPLEX_HALF2COMPLEX_HALF:
+ case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT: {
+ in_stride = _n / r;
+ }
+ default:
+ break;
+ }
+
section_num = n;
stage_num++;
@@ -991,7 +1013,7 @@ mluOpStatus_t MLUOP_WIN_API calParallelNumLowBound(mluOpFFTPlan_t fft_plan,
switch (fft_plan->fft_type) {
// r2c
case CNFFT_HALF2COMPLEX_HALF:
- case CNFFT_FLOAT2COMPLEX_FLOAT:
+ case CNFFT_FLOAT2COMPLEX_FLOAT:
case CNFFT_COMPLEX_HALF2HALF:
case CNFFT_COMPLEX_FLOAT2FLOAT:
case CNFFT_COMPLEX_HALF2COMPLEX_HALF:
@@ -2226,13 +2248,12 @@ mluOpStatus_t MLUOP_WIN_API mluOpSetFFTReserveArea(mluOpHandle_t handle,
case CNFFT_COMPLEX_HALF2COMPLEX_HALF:
case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT: {
if (fft_plan->rank == 1) {
- if(fft_plan->prime) {
- status = setFFT1dReserveArea(handle, fft_plan, api);
- }else{
-status = setFFT1dReserveArea_v2(handle, fft_plan, api);
+ if (fft_plan->prime) {
+ status = setFFT1dReserveArea(handle, fft_plan, api);
+ } else {
+ status = setFFT1dReserveArea_v2(handle, fft_plan, api);
}
-
} else if (fft_plan->rank == 2) {
// status = setFFT1dReserveArea(handle, fft_plan, api);
status = setFFT2dReserveArea(handle, fft_plan, api);
@@ -2244,10 +2265,10 @@ status = setFFT1dReserveArea_v2(handle, fft_plan, api);
case CNFFT_COMPLEX_HALF2HALF:
case CNFFT_COMPLEX_FLOAT2FLOAT: {
if (fft_plan->rank == 1) {
- if(fft_plan->prime) {
- status = setIRFFT1dReserveArea(handle, fft_plan, api);
- }else{
-status = setIRFFT1dReserveArea_v2(handle, fft_plan, api);
+ if (fft_plan->prime) {
+ status = setIRFFT1dReserveArea(handle, fft_plan, api);
+ } else {
+ status = setIRFFT1dReserveArea_v2(handle, fft_plan, api);
}
} else if (fft_plan->rank == 2) {
// status = setFFT1dReserveArea(handle, fft_plan, api);
diff --git a/kernels/fft/fft_optm_device/fft_c2c_stockham_nram.h b/kernels/fft/fft_optm_device/fft_c2c_stockham_nram.h
index 1a1b8a33c..b72ab76a6 100644
--- a/kernels/fft/fft_optm_device/fft_c2c_stockham_nram.h
+++ b/kernels/fft/fft_optm_device/fft_c2c_stockham_nram.h
@@ -936,7 +936,7 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpong(
small_butterfly_num, para_num, small_in_stride, dir,
radix);
} else {
- // TODO(zrg):check
+ // TODO(zrg): check
computeGenericButterflyLaststageMat(
nram_para_store_pong,
nram_para_store_pong + max_para_ldst_num * large_radix,
@@ -1117,25 +1117,24 @@ __mlu_func__ void computeLargeButterflyOtherstages(
int para_load_num = (max_para_ldst_num > (large_butterfly_num - i))
? (large_butterfly_num - i)
: max_para_ldst_num;
- if (para_load_num != 1) {
- __memcpy_async(nram_para_load_in.r, input + Fin_stride + i,
- sizeof(DT) * para_load_num, GDRAM2NRAM,
- sizeof(DT) * para_load_num,
- large_in_stride * sizeof(DT), large_radix - 1);
- __memcpy_async(nram_para_load_in.i, input + nfft + Fin_stride + i,
- sizeof(DT) * para_load_num, GDRAM2NRAM,
- sizeof(DT) * para_load_num,
- large_in_stride * sizeof(DT), large_radix - 1);
- __memcpy_async(nram_para_load_tw.r, cur_large_twiddles + i,
- sizeof(DT) * para_load_num, SRAM2NRAM,
- sizeof(DT) * para_load_num,
- large_out_stride * sizeof(DT), large_radix - 2);
- __memcpy_async(
- nram_para_load_tw.i,
- cur_large_twiddles + large_butterfly_num * (large_radix - 1) + i,
- sizeof(DT) * para_load_num, SRAM2NRAM, sizeof(DT) * para_load_num,
- large_out_stride * sizeof(DT), large_radix - 2);
- }
+
+ __memcpy_async(nram_para_load_in.r, input + Fin_stride + i,
+ sizeof(DT) * para_load_num, GDRAM2NRAM,
+ sizeof(DT) * para_load_num, large_in_stride * sizeof(DT),
+ large_radix - 1);
+ __memcpy_async(nram_para_load_in.i, input + nfft + Fin_stride + i,
+ sizeof(DT) * para_load_num, GDRAM2NRAM,
+ sizeof(DT) * para_load_num, large_in_stride * sizeof(DT),
+ large_radix - 1);
+ __memcpy_async(nram_para_load_tw.r, cur_large_twiddles + i,
+ sizeof(DT) * para_load_num, SRAM2NRAM,
+ sizeof(DT) * para_load_num,
+ large_out_stride * sizeof(DT), large_radix - 2);
+ __memcpy_async(
+ nram_para_load_tw.i,
+ cur_large_twiddles + large_butterfly_num * (large_radix - 1) + i,
+ sizeof(DT) * para_load_num, SRAM2NRAM, sizeof(DT) * para_load_num,
+ large_out_stride * sizeof(DT), large_radix - 2);
}
// pipeline: store-stage
@@ -1763,7 +1762,7 @@ __mlu_func__ void computeLargeButterflyOtherstagesBatchPingpong(
// pipeline: load-stage
if (repeat_id < repeat_num) {
- if (para_num != 1) {
+ if (1) {
__memcpy_async(
nram_para_load_in_ping.r,
input_batch + sec_count * large_butterfly_num + butterfly_id,
diff --git a/kernels/fft/fft_optm_device/fft_c2r_stockham_gdram.h b/kernels/fft/fft_optm_device/fft_c2r_stockham_gdram.h
index c2bf4d154..d32c6b5b7 100644
--- a/kernels/fft/fft_optm_device/fft_c2r_stockham_gdram.h
+++ b/kernels/fft/fft_optm_device/fft_c2r_stockham_gdram.h
@@ -56,7 +56,7 @@ __mlu_func__ void computeMutiStageOnchipC2R(DT *input, DT *output, int *factors,
taskId, repeat_num, remain_num, t_len, t_start, t_end);
int radix, section_num, butterfly_num, in_stride, stage_count, value_mul,
- small_factors_offset;
+ out_stride, small_factors_offset;
int *small_factors;
int last_stage;
@@ -78,17 +78,23 @@ __mlu_func__ void computeMutiStageOnchipC2R(DT *input, DT *output, int *factors,
const int nfft = factors[1];
// first stage
- radix = factors[5 + 0];
- section_num = factors[5 + 1];
- in_stride = factors[5 + 3];
- small_factors_offset = factors[5 + 4];
+ radix = factors[5 * _stage_count + 0];
+ section_num = factors[5 * _stage_count + 1];
+ butterfly_num = factors[5 * _stage_count + 3];
+ out_stride = factors[5 * _stage_count + 3];
+ in_stride = butterfly_num;
+ small_factors_offset = factors[5 * _stage_count + 4];
+
+ for (int loop_stage = 2; loop_stage < _stage_count; loop_stage++) {
+ cur_radix = factors[5 * loop_stage];
+ butterfly_num_stage = factors[5 * loop_stage + 2];
+ twiddles += (cur_radix - 1) * (butterfly_num_stage / 2);
+ }
// small_factors = factors + small_factors_offset;
- stage_count = _stage_count;
- last_stage = (stage_count == 1);
-
-
+ stage_count = 1;
+ last_stage = (_stage_count == 1);
if (__is_mpu()) {
__memcpy_async(sram_factors, factors, FFT_MAXFACTORS * sizeof(int),
@@ -96,7 +102,6 @@ __mlu_func__ void computeMutiStageOnchipC2R(DT *input, DT *output, int *factors,
if (twiddles_size) {
__memcpy_async(sram_twiddles, twiddles, twiddles_size * sizeof(DT),
GDRAM2SRAM);
-
}
const dft_table_entry *dft_table_gdram =
@@ -166,11 +171,11 @@ __mlu_func__ void computeMutiStageOnchipC2R(DT *input, DT *output, int *factors,
}
// sram_large_tw
- stage_count--;
- if (stage_count == 0) {
- // continue;
+
+ if (stage_count == _stage_count) {
return;
}
+ stage_count++;
// if (__is_mpu()) {
// return;
@@ -178,7 +183,7 @@ __mlu_func__ void computeMutiStageOnchipC2R(DT *input, DT *output, int *factors,
// sram_large_tw
value_mul = 10;
- for (; stage_count > 1; stage_count--) {
+ for (; stage_count < _stage_count; stage_count++) {
// fft_swap_ptr
(&buffer, &output);
// FFT_SWAP_PTR(buffer, output);
FFT_SWAP_PTR(buffer, output);
@@ -198,14 +203,17 @@ __mlu_func__ void computeMutiStageOnchipC2R(DT *input, DT *output, int *factors,
FFT_SWAP_PTR(odd_extra_buffer, output);
}
- // value_mul = (_stage_count - stage_count + 1) * 5;
+ value_mul = (_stage_count - stage_count + 1) * 5;
// update parameter
- radix = factors[value_mul++];
- section_num = factors[value_mul++];
- butterfly_num = factors[value_mul++];
- in_stride = factors[value_mul++];
- small_factors_offset = factors[value_mul++];
+
+ radix = factors[value_mul];
+ section_num = factors[value_mul + 1];
+ butterfly_num = factors[value_mul + 2];
+ in_stride = butterfly_num;
+ out_stride = factors[value_mul + 3];
+ small_factors_offset = factors[value_mul + 4];
+ twiddles -= (radix - 1) * (butterfly_num / 2 + 1);
small_factors = factors + small_factors_offset;
@@ -213,28 +221,14 @@ __mlu_func__ void computeMutiStageOnchipC2R(DT *input, DT *output, int *factors,
// MLULOG("other stage radix: %d \n", radix);
if (repeat_num > 0 || taskId < remain_num) {
- if (6000 / radix > repeat_num && 0) {
- for (int t = t_start; t < t_end; t++) {
- DT *output_batch = output + t * (nfft << 1);
- DT *buffer_batch = buffer + t * (nfft << 1);
-
- computeLargeButterflyOtherstages(
- output_batch, buffer_batch, (DT *)twiddles, _twiddles,
- sram_dftmtx, section_num, butterfly_num, in_stride,
- (void *)nram_buf, small_factors, nfft, direction, 0);
-
- // __sync();
- }
- } else {
- computeLargeButterflyOtherstagesBatchPingpong(
- output, buffer, (DT *)twiddles, _twiddles, sram_dftmtx,
- section_num, butterfly_num, in_stride, (void *)nram_buf,
- small_factors, nfft, t_start, t_end, direction, 0);
- }
+ computeLargeButterflyOtherstagesBatchPingpongC2R(
+ output, buffer, (DT *)twiddles, _twiddles, sram_dftmtx, section_num,
+ butterfly_num, in_stride, (void *)nram_buf, small_factors, nfft,
+ t_start, t_end, direction, 0);
}
}
- twiddles += butterfly_num * (radix - 1) * 2; // 2 for complex
- } // for (stage_count)
+ twiddles += (butterfly_num / 2 - 1) * (radix - 1) * 2; // 2 for complex
+ } // for (stage_count)
// __mlu_shared__ DT *sram_tw[2048]; // radix-1024
// __mlu_shared__ DT *sram_tw[64]; // radix-1024
@@ -248,32 +242,21 @@ __mlu_func__ void computeMutiStageOnchipC2R(DT *input, DT *output, int *factors,
FFT_SWAP_PTR(buffer, output);
// update parameter
- radix = factors[value_mul++];
- section_num = factors[value_mul++];
- butterfly_num = factors[value_mul++];
- in_stride = factors[value_mul++];
- small_factors_offset = factors[value_mul];
+ radix = factors[5];
+ section_num = factors[6];
+ butterfly_num = factors[7];
+ in_stride = butterfly_num;
+ out_stride = factors[8];
+ small_factors_offset = factors[9];
small_factors = factors + small_factors_offset;
if (__is_ipu()) {
if (repeat_num > 0 || taskId < remain_num) {
- if (0) {
- for (int t = t_start; t < t_end; t++) {
- DT *output_batch = output + t * (nfft << 1);
- DT *buffer_batch = buffer + t * (nfft << 1);
-
- computeLargeButterflyLaststage(
- output_batch, buffer_batch, (DT *)twiddles, _twiddles,
- sram_dftmtx, section_num, butterfly_num, in_stride,
- (void *)nram_buf, small_factors, nfft, direction);
- }
- } else {
- computeLargeButterflyLaststageBatchPingpong(
- output, buffer, (DT *)twiddles, _twiddles, sram_dftmtx,
- section_num, butterfly_num, in_stride, (void *)nram_buf,
- small_factors, nfft, t_start, t_end, direction);
- }
+ computeLargeButterflyLaststageBatchPingpongC2R(
+ output, buffer, (DT *)twiddles, _twiddles, sram_dftmtx, section_num,
+ butterfly_num, in_stride, (void *)nram_buf, small_factors, nfft,
+ t_start, t_end, direction);
}
}
}
diff --git a/kernels/fft/fft_optm_device/fft_c2r_stockham_nram.h b/kernels/fft/fft_optm_device/fft_c2r_stockham_nram.h
index 42568f848..4939a65fe 100644
--- a/kernels/fft/fft_optm_device/fft_c2r_stockham_nram.h
+++ b/kernels/fft/fft_optm_device/fft_c2r_stockham_nram.h
@@ -24,465 +24,6 @@
#include "kernels/fft/fft_optm_device/fft_generic_butterfly.h"
#include "kernels/fft/fft_optm_device/fft_vector_butterfly.h"
-template
-__mlu_func__ void computeLargeButterflyFirststageC2R(
- DT *output, DT *input, int large_in_stride, int section_num,
- const DT *twiddles, const DT *dft_matrix, void *nram_buf,
- const int *small_factors, int dir, int nfft, int last_stage) {
- const dft_table_entry *dft_table = (const dft_table_entry *)dft_matrix;
-
- // network info
- int radix, small_in_stride, small_stage_count, large_radix,
- _small_stage_count;
- int small_section_num, small_butterfly_num, value_mul;
- int tw_offset;
-
- const int K_num = 64 / sizeof(DT);
- int align_K = 0;
-
- _small_stage_count = small_factors[0];
- large_radix = small_factors[1];
- tw_offset = small_factors[2];
-
- // load compute store
- // (0) load 0 ping sync()
- // (1) compute 0 ping load 1 pong sync()
- // (2) store 0 compute 1 pong load 2 ping sync()
- // (3) store 1 compute 2 load 3 sync()
-
- // compute last-large-stage (nram_out_r,nram_out_i) [2, large_radix]->
- // transpose -> [large_radix, 2]
-
- // complex array -> real array, imag array -> complex array
- // first-large-stage complex -> real array, imag array
- // other-large-stage none
- // last-large-stage real array, imag array -> complex
-
- // const int max_para_ldst_num = 1;
- int max_para_ldst_num = (4096 + large_radix - 1) / large_radix;
- max_para_ldst_num =
- (section_num < max_para_ldst_num) ? section_num : max_para_ldst_num;
-
- const DT *small_twiddles = twiddles + tw_offset * 2; // complex
-
- // assign nram space
- int nram_buf_offset = 0;
-
- DT *nram_in_r = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += large_radix * max_para_ldst_num;
-
- DT *nram_in_i = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += large_radix * max_para_ldst_num;
-
- DT *nram_out_r = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += large_radix * max_para_ldst_num;
-
- DT *nram_out_i = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += large_radix * max_para_ldst_num;
-
- DT *nram_para_load_ping = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex
-
- DT *nram_para_load_pong = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex
-
- DT *nram_para_store_ping = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex
-
- DT *nram_para_store_pong = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex
-
- // transpose space: [radix, 2 * parrallel] -> [parrallel * 2, radix]
- // DT *nram_transpose_load = (DT *)nram_buf + nram_buf_offset;
- // nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex
-
- // FFT_CPX_T nram_transpose_temp;
- // temp out-space before transpose
- // if last-stage:
- // compute_id 0 r
- // compute_id 0 i
- // compute_id 1 r
- // compute_id 1 i
- // else:
- // compute_id 0 r
- // compute_id 1 i
- // compute_id 0 r
- // compute_id 1 i
-
- DT *_nram_tw = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += large_radix * 2; // complex
-
- int ld_dft_radix = -1;
- const int max_radix = 64;
- DT *nram_dftmtx = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += max_radix * max_radix * 2; // complex
-
- // nram space used:
- // sizeof(DT) * 2 * large_radix * (max_para_ldst_num * 6 + 1) + sizeof(DT) * 2
- // * (max_radix * max_radix)
- // + sizeof(DT) * 2 * large_radix * max_para_ldst_num * 4
- DT *nram_scratch = (DT *)nram_buf + nram_buf_offset;
-
- __memcpy_async(_nram_tw, small_twiddles, large_radix * sizeof(DT) * 2,
- SRAM2NRAM);
-
- // ceil
- int repeat_num = (section_num + max_para_ldst_num - 1) / max_para_ldst_num;
-
- for (int repeat_id = 0; repeat_id < repeat_num + 2; ++repeat_id) {
- // pipeline: load-stage
- if (repeat_id < repeat_num) {
- // MLULOG("pipeline: load-stage.\n");
- int i = max_para_ldst_num * repeat_id;
- DT *nram_para_load =
- (repeat_id % 2 == 0) ? nram_para_load_ping : nram_para_load_pong;
-
- // DT *nram_dftmtx =
- // (repeat_id % 2 == 0) ? nram_dftmtx_ping : nram_dftmtx_pong;
- int para_load_num = (max_para_ldst_num > (section_num - i))
- ? (section_num - i)
- : max_para_ldst_num;
- if (section_num == 1) {
- __memcpy_async(nram_para_load, input, sizeof(DT) * 2 * large_radix,
- GDRAM2NRAM);
- } else {
- // gather load
- // 2d memcpy
- // 0 1 2 3 4 ... 1023
- // GDRAM -> NRAM
- // 8bytes radix-1024
- // 64bytes
-
- __memcpy_async(nram_para_load, input + i * 2,
- sizeof(DT) * 2 * para_load_num, GDRAM2NRAM,
- sizeof(DT) * 2 * para_load_num,
- large_in_stride * sizeof(DT) * 2, large_radix - 1);
- }
- }
-
- // pipeline: store-stage
- if (repeat_id >= 2) {
- // MLULOG("pipeline: store-stage.\n");
- int i = max_para_ldst_num * (repeat_id - 2);
-
- int para_store_num = (max_para_ldst_num > (section_num - i))
- ? (section_num - i)
- : max_para_ldst_num;
-
- DT *nram_para_store =
- (repeat_id % 2 == 0) ? nram_para_store_ping : nram_para_store_pong;
-
- if (last_stage) {
- if (section_num == 1) {
- __memcpy_async(output, nram_para_store, sizeof(DT) * 2 * large_radix,
- NRAM2GDRAM);
- } else {
- // scatter-store
- __memcpy_async(output + i * large_radix * 2, nram_para_store,
- sizeof(DT) * 2 * para_store_num * large_radix,
- NRAM2GDRAM);
- }
- } else {
- // real
- __memcpy_async(output + i * large_radix, nram_para_store,
- para_store_num * large_radix * sizeof(DT), NRAM2GDRAM);
- // imag
- __memcpy_async(output + i * large_radix + nfft,
- nram_para_store + max_para_ldst_num * large_radix,
- para_store_num * large_radix * sizeof(DT), NRAM2GDRAM);
- }
- }
-
- // pipeline: compute-stage
-
- if (repeat_id >= 1 && repeat_id < repeat_num + 1) {
- int i = max_para_ldst_num * (repeat_id - 1);
-
- DT *nram_para_load =
- (repeat_id % 2 != 0) ? nram_para_load_ping : nram_para_load_pong;
- DT *nram_para_store =
- (repeat_id % 2 != 0) ? nram_para_store_ping : nram_para_store_pong;
-
- int para_ldst_num = (max_para_ldst_num > (section_num - i))
- ? (section_num - i)
- : max_para_ldst_num;
- // // [large_radix, para_ldst_num, 2] -> [para_ldst_num, 2, large_radix]
- // __bang_transpose(nram_transpose_load, nram_para_load, large_radix,
- // 2 * para_ldst_num);
-
- // [large_radix, para_ldst_num, 2] -> [2, para_ldst_num, large_radix]
- // overlap nram_out_r
- // DT *nram_transpose_load = nram_out_r;
- // __bang_transpose(nram_transpose_load, nram_para_load,
- // large_radix * para_ldst_num, 2);
- // // [large_radix, para_ldst_num] -> [para_ldst_num, large_radix]
- // __bang_transpose(nram_in_r, nram_transpose_load, large_radix,
- // para_ldst_num);
- // __bang_transpose(nram_in_i,
- // nram_transpose_load + large_radix * para_ldst_num,
- // large_radix, para_ldst_num);
-
- // DT *nram_transpose_load = nram_in_r;
- __bang_transpose(nram_in_r, nram_para_load, large_radix * para_ldst_num,
- 2);
- // [large_radix, para_ldst_num] -> [para_ldst_num, large_radix]
- // __bang_transpose(nram_in_r, nram_transpose_load, large_radix,
- // para_ldst_num);
- // __bang_transpose(nram_in_i,
- // nram_transpose_load + large_radix * para_ldst_num,
- // large_radix, para_ldst_num);
-
- for (int compute_id = 0; compute_id < para_ldst_num;
- compute_id += para_ldst_num) {
- // load real & imag
-
- radix = small_factors[4];
- small_section_num = small_factors[5];
- small_in_stride = small_factors[7];
- small_stage_count = _small_stage_count;
-
- // __memcpy(nram_in_r,
- // nram_transpose_load + compute_id * large_radix * 2,
- // large_radix * sizeof(DT) * 2, NRAM2NRAM);
-
- // first stage
-
- if (ld_dft_radix != radix) {
- ld_dft_radix = radix;
- for (int entry = 0;; entry++) {
- if (dft_table[entry].radix == ld_dft_radix) {
- align_K = K_num * ((radix + K_num - 1) / K_num);
- __memcpy(nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2],
- sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM);
- }
-
- if (dft_table[entry].radix == -1) {
- break;
- }
- }
- }
-
- switch (radix) {
- default:
- MLULOG("computeGenericButterflyFirststageMat: %d.\n", radix);
- computeGenericButterflyFirststageMat(
- nram_out_r, nram_out_i, nram_in_r,
- nram_in_r + large_radix * para_ldst_num, nram_scratch,
- nram_dftmtx, small_section_num * para_ldst_num,
- small_section_num * para_ldst_num, 1, dir, radix);
- break;
- }
-
- // [radix, small_section_num, para_ldst_num] ->
- // [small_section_num, para_ldst_num, radix] -> [para_ldst_num,
- // small_section_num, radix]
-
- small_stage_count--;
- if (small_stage_count == 0) {
- // nram to gdram
-
- if (last_stage) {
- // [2, para_ldst_num, large_radix] -> [para_ldst_num, large_radix,
- // 2]
- // DT* nram_transpose_store = nram_in_r;
-
- __bang_transpose(nram_para_store, nram_out_r, 2,
- max_para_ldst_num * large_radix);
- } else {
- // [2, para_ldst_num, large_radix] -> [2, para_ldst_num,
- // large_radix]
- // TODO(zrg): redundant move
- __memcpy(nram_para_store, nram_out_r,
- para_ldst_num * large_radix * sizeof(DT), NRAM2NRAM);
- __memcpy(nram_para_store + max_para_ldst_num * large_radix,
- nram_out_i, para_ldst_num * large_radix * sizeof(DT),
- NRAM2NRAM);
- }
-
- continue;
- }
-
- // [small_section_num, para_ldst_num, radix] -> [para_ldst_num,
- // small_section_num, radix]
-
- FFT_SWAP_PTR(nram_out_r, nram_in_r);
- FFT_SWAP_PTR(nram_out_i, nram_in_i);
-
- TRANSPOSE_XYZ2YXZ_PAIR(nram_out_r, nram_out_i, nram_in_r, nram_in_i,
- small_section_num, para_ldst_num, radix, DT)
-
- value_mul = 8;
- // DT *sram_tw = (DT *)sram_buffer;
- DT *nram_tw = _nram_tw;
-
- for (; small_stage_count > 1; small_stage_count--) {
- FFT_SWAP_PTR(nram_out_r, nram_in_r);
- FFT_SWAP_PTR(nram_out_i, nram_in_i);
-
- // value_mul = (_small_stage_count - small_stage_count + 1) << 2;
-
- // // update parameter
- radix = small_factors[value_mul++];
- small_section_num = small_factors[value_mul++];
- small_butterfly_num = small_factors[value_mul++];
- small_in_stride = small_factors[value_mul++];
- // value_mul += 4;
- // copy GDRAM2SRAM
-
- // if (compute_id == 0 && repeat_id == 1 && 0) {
- // __memcpy(nram_tw, small_twiddles,
- // small_butterfly_num * (radix - 1) * sizeof(DT) * 2,
- // GDRAM2NRAM);
- // small_twiddles += small_butterfly_num * (radix - 1) * 2;
- // }
-
- if (ld_dft_radix != radix) {
- ld_dft_radix = radix;
- for (int entry = 0;; entry++) {
- if (dft_table[entry].radix == ld_dft_radix) {
- align_K = K_num * ((radix + K_num - 1) / K_num);
- __memcpy(nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2],
- sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM);
- break;
- }
-
- if (dft_table[entry].radix == -1) {
- break;
- }
- }
- }
-
- switch (radix) {
- case 2:
- // computeRadix2ButterflyOtherstages(Fout, Fin, section_num,
- // section_num, 1, dir);
- break;
- case 3:
- computeRadix3ButterflyOtherstages(
- nram_out_r, nram_out_i, nram_in_r, nram_in_i, nram_scratch,
- nram_tw, small_section_num, small_butterfly_num,
- small_in_stride, dir);
- break;
-
- default:
- // computeGenericButterflyOtherstages(Fout, buffer, twiddles,
- // radix, section_num, butterfly_num, in_stride, 0, dir);
- MLULOG("computeGenericButterflyOtherstagesMat: %d.\n", radix);
- computeGenericButterflyOtherstagesMat(
- nram_out_r, nram_out_i, nram_in_r, nram_in_i, nram_scratch,
- nram_dftmtx, nram_tw, small_section_num, small_butterfly_num,
- para_ldst_num, small_in_stride, dir, radix);
- break;
- }
-
- nram_tw += small_butterfly_num * (radix - 1) * 2;
- } // for (stage_count)
-
- // for (int j = 0; j < large_radix; j++) {
- // MLULOG("output i: (%f, %f).\n", nram_out_r[j], nram_out_i[j]);
- // }
-
- // MLULOG("butterfly id: %d\n", i);
- // last stage
- {
- FFT_SWAP_PTR(nram_out_r, nram_in_r);
- FFT_SWAP_PTR(nram_out_i, nram_in_i);
-
- // copy GDRAM2SRAM
-
- // update parameter
- // value_mul = _small_stage_count << 2;
- radix = small_factors[value_mul++];
- small_section_num = small_factors[value_mul++];
- small_butterfly_num = small_factors[value_mul++];
- small_in_stride = small_factors[value_mul];
-
- // if (compute_id == 0 && repeat_id == 1 && 0) {
- // __memcpy(nram_tw, small_twiddles,
- // small_butterfly_num * (radix - 1) * sizeof(DT) * 2,
- // GDRAM2NRAM);
- // }
- if (ld_dft_radix != radix) {
- ld_dft_radix = radix;
- for (int entry = 0;; entry++) {
- if (dft_table[entry].radix == ld_dft_radix) {
- align_K = K_num * ((radix + K_num - 1) / K_num);
- __memcpy(nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2],
- sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM);
- break;
- }
-
- if (dft_table[entry].radix == -1) {
- break;
- }
- }
- }
-
- switch (radix) {
- case 2:
- break;
-
- default:
- MLULOG("computeGenericButterflyLaststageMat: %d.\n", radix);
- computeGenericButterflyLaststageMat(
- nram_out_r, nram_out_i, nram_in_r, nram_in_i, nram_scratch,
- nram_dftmtx, nram_tw, small_section_num, small_butterfly_num,
- para_ldst_num, small_in_stride, dir, radix);
- MLULOG("computeGenericButterflyLaststageMat: %d End.\n", radix);
- break;
- }
-
- if (last_stage) {
- // [2, para_ldst_num, large_radix] -> [para_ldst_num, large_radix,
- // 2]
- // DT* nram_transpose_store = nram_in_r;
-
- __bang_transpose(nram_para_store, nram_out_r, 2,
- max_para_ldst_num * large_radix);
- } else {
- // [2, para_ldst_num, large_radix] -> [2, para_ldst_num,
- // large_radix]
- // TODO(zrg): redundant move
- __memcpy(nram_para_store, nram_out_r,
- para_ldst_num * large_radix * sizeof(DT), NRAM2NRAM);
- __memcpy(nram_para_store + max_para_ldst_num * large_radix,
- nram_out_i, para_ldst_num * large_radix * sizeof(DT),
- NRAM2NRAM);
- }
-
- // if (last_stage) {
- // // MLULOG("last_stage. \n");
-
- // // __memcpy(nram_transpose_temp + (compute_id * 2) * large_radix,
- // // nram_out_r, large_radix * sizeof(DT), NRAM2NRAM);
- // // __memcpy(nram_transpose_temp
- // // + (compute_id * 2 + 1) * large_radix,
- // // nram_out_i, large_radix * sizeof(DT), NRAM2NRAM);
-
- // __memcpy(nram_transpose_temp + (compute_id * 2) * large_radix,
- // nram_out_r, large_radix * sizeof(DT) * 2, NRAM2NRAM);
- // __bang_transpose(
- // nram_para_store + (compute_id * 2) * large_radix,
- // nram_transpose_temp + (compute_id * 2) * large_radix, 2,
- // large_radix);
-
- // } else {
- // // MLULOG("not last_stage. \n");
- // __memcpy(nram_para_store + compute_id * large_radix, nram_out_r,
- // large_radix * sizeof(DT), NRAM2NRAM);
- // __memcpy(nram_para_store +
- // (compute_id + max_para_ldst_num) * large_radix,
- // nram_out_i, large_radix * sizeof(DT), NRAM2NRAM);
- // }
- // MLULOG("last_stage. \n");
- }
- }
- }
-
- __sync();
- }
-}
-
template
__mlu_func__ void computeLargeButterflyFirststageBatchPingpongC2R(
DT *output, DT *input, int large_in_stride, int section_num,
@@ -506,8 +47,7 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpongC2R(
large_radix = small_factors[1];
tw_offset = small_factors[2];
printf("tw_offset: %d\n\n", tw_offset);
- const int half_butterfly_num = section_num/2 + 1;
-
+ const int half_butterfly_num = section_num / 2 + 1;
max_para_ldst_num = (half_butterfly_num < small_factors[3])
? half_butterfly_num
@@ -535,7 +75,7 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpongC2R(
// max_para_ldst_num = ((7232) / large_radix > 0) ? (7232) / large_radix : 1;
const DT *small_twiddles = twiddles + tw_offset * 2; // complex
-
+
// assign nram space
int nram_buf_offset = 0;
@@ -580,18 +120,8 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpongC2R(
(DT *)nram_buf + nram_buf_offset + max_radix * max_radix * 2};
nram_buf_offset += max_radix * max_radix * 4; // complex
- // nram space used:
- // sizeof(DT) * 2 * large_radix * (max_para_ldst_num * 6 + 1) + sizeof(DT) * 2
- // * (max_radix * max_radix)
- // + sizeof(DT) * 2 * large_radix * max_para_ldst_num * 4
DT *nram_scratch = (DT *)nram_buf + nram_buf_offset;
- // DT *nram_temp_r = (DT *)nram_buf + nram_buf_offset;
- // nram_buf_offset += large_radix * max_para_ldst_num;
-
- // DT *nram_temp_i = (DT *)nram_buf + nram_buf_offset;
- // nram_buf_offset += large_radix * max_para_ldst_num;
-
__memcpy_async(_nram_tw, small_twiddles, large_radix * sizeof(DT) * 2,
SRAM2NRAM);
@@ -646,15 +176,12 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpongC2R(
}
}
-
// pipeline: store-stage
if (repeat_id >= 2) {
if (last_stage) {
if (half_butterfly_num == 1) {
// store only real part
- printf(" taskId: %d, store only real part, odis: %ld , size: %ld\n",
- taskId, (size_t)output_batch,
- size_t(sizeof(DT) * large_radix));
+
__memcpy_async(output_batch - 2 * odist, nram_para_store_ping,
sizeof(DT) * large_radix, NRAM2GDRAM);
@@ -679,10 +206,9 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpongC2R(
// pipeline: compute-stage
if (repeat_id >= 1 && repeat_id < repeat_num + 1) {
-
- for (int j = 0; j < 32; j++) {
- printf("694 _nram_tw [%d]: (%f).\n", j, _nram_tw[j]);
- }
+ for (int j = 0; j < 32; j++) {
+ printf("694 _nram_tw [%d]: (%f).\n", j, _nram_tw[j]);
+ }
DT *nram_in_r = nram_para_store_pong;
DT *nram_in_i = nram_para_store_pong + large_radix * max_para_ldst_num;
@@ -704,14 +230,10 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpongC2R(
nram_in_r + upper_radix * para_num + para_num * large_radix, 1,
lower_radix * para_num);
- __memcpy(nram_in_r + upper_radix * para_num + para_num * large_radix,
- nram_out_i, para_num * lower_radix * sizeof(DT), NRAM2NRAM);
-
-
-
+ __bang_mul_scalar(
+ nram_in_r + upper_radix * para_num + para_num * large_radix,
+ nram_out_i, -1, para_num * lower_radix);
} else {
-
-
__bang_transpose(nram_in_r, nram_para_load_pong, large_radix, 2);
__bang_rotate180(nram_out_r, nram_in_r + 1, 1,
@@ -723,10 +245,9 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpongC2R(
__bang_rotate180(nram_out_i, nram_in_r + 1 + large_radix, 1,
(large_radix - large_radix / 2 - 1));
- __bang_mul_scalar(
- nram_in_r + (large_radix / 2 + 1) + large_radix, nram_out_i, -1,
- (large_radix - large_radix / 2 - 1));
-
+ __bang_mul_scalar(nram_in_r + (large_radix / 2 + 1) + large_radix,
+ nram_out_i, -1,
+ (large_radix - large_radix / 2 - 1));
}
for (int compute_id = 0; compute_id < para_num;
@@ -738,7 +259,6 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpongC2R(
small_in_stride = small_factors[7];
small_stage_count = _small_stage_count;
-
// first stage
if (ld_dft_radix[0] != radix && ld_dft_radix[1] != radix) {
ld_dft_radix[1] = ld_dft_radix[0];
@@ -933,17 +453,13 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpongC2R(
FFT_SWAP_PTR(nram_dftmtx[0], nram_dftmtx[1]);
}
-
switch (radix) {
default:
if (last_stage) {
-
if (nram_in_r == nram_para_store_pong) {
FFT_SWAP_PTR(nram_para_store_pong, nram_para_load_pong);
}
-
-
computeGenericButterflyLaststageMat(
nram_para_store_pong, nram_out_i, nram_in_r, nram_in_i,
nram_scratch, nram_dftmtx[0], nram_tw, small_section_num,
@@ -951,7 +467,6 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpongC2R(
radix);
} else {
-
computeGenericButterflyLaststageMat(
nram_para_store_pong,
nram_para_store_pong + max_para_ldst_num * large_radix,
@@ -983,29 +498,36 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpongC2R(
}
template
-__mlu_func__ void computeLargeButterflyOtherstages(
+__mlu_func__ void computeLargeButterflyOtherstagesBatchPingpongC2R(
DT *output, DT *input, const DT *cur_large_twiddles, const DT *_twiddles,
const DT *dft_matrix, int large_section_num, int large_butterfly_num,
int large_in_stride, void *nram_buf, const int *small_factors, int nfft,
- int dir, int last_stage) {
+ const int t_start, const int t_end, int dir, int last_stage) {
// return;
const dft_table_entry *dft_table = (const dft_table_entry *)dft_matrix;
- const int K_num = 64 / sizeof(DT);
- int align_K = 0;
+
int radix, small_in_stride, small_stage_count, large_radix,
_small_stage_count;
int small_section_num, small_butterfly_num, value_mul;
const int large_out_stride = large_butterfly_num;
- int tw_offset;
+ const int half_butterfly_num = large_butterfly_num / 2 + 1;
+ int tw_offset;
+ const int K_num = 64 / sizeof(DT);
+ int align_K = 0;
_small_stage_count = small_factors[0];
large_radix = small_factors[1];
tw_offset = small_factors[2];
-
+ const int upper_radix = (large_radix + 1) / 2;
+ const int lower_radix = large_radix - upper_radix;
const DT *small_twiddles = _twiddles + tw_offset * 2; // complex
- const int max_para_ldst_num = (4096 + large_radix - 1) / large_radix;
+ // const int max_para_ldst_num = (6144 + large_radix - 1) / large_radix;
+ // int max_para_ldst_num = (6400 + large_radix - 1) / large_radix;
+ int max_para_ldst_num = (half_butterfly_num < small_factors[3])
+ ? half_butterfly_num
+ : small_factors[3];
// int para_ldst_num;
// TODO(zrg): save nram space.
@@ -1017,19 +539,7 @@ __mlu_func__ void computeLargeButterflyOtherstages(
// FFT_CPX_T *nram_out = &nram_in[large_radix];
// FFT_CPX_T *nram_buf = &nram_in[large_radix * 2];
int nram_buf_offset = 0;
- DT *nram_in_r = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += large_radix * max_para_ldst_num;
-
- DT *nram_in_i = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += large_radix * max_para_ldst_num;
-
- DT *nram_out_r = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += large_radix * max_para_ldst_num;
-
- DT *nram_out_i = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += large_radix * max_para_ldst_num;
- // parallel load/store space
FFT_CPX_T nram_para_load_in_ping = {
(DT *)nram_buf + nram_buf_offset,
(DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num};
@@ -1040,12 +550,7 @@ __mlu_func__ void computeLargeButterflyOtherstages(
(DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num};
nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex
- FFT_CPX_T nram_para_load_tw_ping = {
- (DT *)nram_buf + nram_buf_offset,
- (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num};
- nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex
-
- FFT_CPX_T nram_para_load_tw_pong = {
+ FFT_CPX_T nram_para_load_tw = {
(DT *)nram_buf + nram_buf_offset,
(DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num};
nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex
@@ -1060,34 +565,11 @@ __mlu_func__ void computeLargeButterflyOtherstages(
(DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num};
nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex
- // overlap nram_in
- FFT_CPX_T nram_transpose_temp;
- // temp out-space before transpose
- // if last-stage:
- // compute_id 0 r
- // compute_id 0 i
- // compute_id 1 r
- // compute_id 1 i
- // else:
- // compute_id 0 r
- // compute_id 1 i
- // compute_id 0 r
- // compute_id 1 i
- nram_transpose_temp = {
- (DT *)nram_in_r,
- (DT *)nram_in_r + large_radix * ((int)last_stage) +
- large_radix * (1 - (int)last_stage) * max_para_ldst_num};
// nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex
DT *_nram_tw = (DT *)nram_buf + nram_buf_offset;
nram_buf_offset += large_radix * 2; // complex
- // transpose space: [radix, 2 * parrallel] -> [parrallel * 2, radix]
- // FFT_CPX_T nram_transpose_load = {
- // (DT *)nram_buf + nram_buf_offset,
- // (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num};
- // nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex
-
int ld_dft_radix = -1;
const int max_radix = 64;
DT *nram_dftmtx = (DT *)nram_buf + nram_buf_offset;
@@ -1103,348 +585,151 @@ __mlu_func__ void computeLargeButterflyOtherstages(
nram_buf_offset += large_radix * max_para_ldst_num * 4; // complex
- // size: (large_radix - 1) * max_para_ldst_num
- // DT *scratch_tw_r = &CPX_MUL_II[large_radix * max_para_ldst_num];
- // DT *scratch_tw_i = &scratch_tw_r[(large_radix - 1) * max_para_ldst_num];
-
- int Fin_stride = 0, Fout_stride = 0;
- int sec_count;
- int repeat_num =
- (large_butterfly_num + max_para_ldst_num - 1) / max_para_ldst_num;
- for (sec_count = 0; sec_count < large_section_num; ++sec_count) {
- for (int repeat_id = 0; repeat_id < repeat_num + 2; ++repeat_id) {
- // small_twiddles = _small_twiddles;
-
- // pipeline: load-stage
-
- if (repeat_id < repeat_num) {
- // MLULOG("pipeline: load-stage.\n");
- int i = max_para_ldst_num * repeat_id;
- FFT_CPX_T nram_para_load_in = (repeat_id % 2 == 0)
- ? nram_para_load_in_ping
- : nram_para_load_in_pong;
-
- FFT_CPX_T nram_para_load_tw = (repeat_id % 2 == 0)
- ? nram_para_load_tw_ping
- : nram_para_load_tw_pong;
-
- int para_load_num = (max_para_ldst_num > (large_butterfly_num - i))
- ? (large_butterfly_num - i)
- : max_para_ldst_num;
- if (para_load_num != 1) {
- __memcpy_async(nram_para_load_in.r, input + Fin_stride + i,
- sizeof(DT) * para_load_num, GDRAM2NRAM,
- sizeof(DT) * para_load_num,
- large_in_stride * sizeof(DT), large_radix - 1);
- __memcpy_async(nram_para_load_in.i, input + nfft + Fin_stride + i,
- sizeof(DT) * para_load_num, GDRAM2NRAM,
- sizeof(DT) * para_load_num,
- large_in_stride * sizeof(DT), large_radix - 1);
- __memcpy_async(nram_para_load_tw.r, cur_large_twiddles + i,
- sizeof(DT) * para_load_num, SRAM2NRAM,
- sizeof(DT) * para_load_num,
- large_out_stride * sizeof(DT), large_radix - 2);
- __memcpy_async(
- nram_para_load_tw.i,
- cur_large_twiddles + large_butterfly_num * (large_radix - 1) + i,
- sizeof(DT) * para_load_num, SRAM2NRAM, sizeof(DT) * para_load_num,
- large_out_stride * sizeof(DT), large_radix - 2);
- }
- }
+ // overlap nram_in
+ FFT_CPX_T nram_transpose_temp;
+ // temp out-space before transpose
+ // if last-stage:
+ // compute_id 0 r
+ // compute_id 0 i
+ // compute_id 1 r
+ // compute_id 1 i
+ // else:
+ // compute_id 0 r
+ // compute_id 1 i
+ // compute_id 0 r
+ // compute_id 1 i
+ nram_transpose_temp = {
+ (DT *)nram_buf + nram_buf_offset,
+ (DT *)nram_buf + nram_buf_offset + large_radix * ((int)last_stage) +
+ large_radix * (1 - (int)last_stage) * max_para_ldst_num};
- // pipeline: store-stage
- if (repeat_id >= 2) {
- // MLULOG("pipeline: store-stage.\n");
- int i = max_para_ldst_num * (repeat_id - 2);
+ // size: (large_radix - 1) * max_para_ldst_num
+ // DT *scratch_tw_r = &CPX_MUL_II[large_radix * max_para_ldst_num];
+ // DT *scratch_tw_i = &scratch_tw_r[(large_radix - 1) * max_para_ldst_num];
- int para_store_num = (max_para_ldst_num > (large_butterfly_num - i))
- ? (large_butterfly_num - i)
- : max_para_ldst_num;
+ // int Fin_stride = 0, Fout_stride = 0;
+ int sec_count;
+ // int repeat_num =
- FFT_CPX_T nram_para_store =
- (repeat_id % 2 == 0) ? nram_para_store_ping : nram_para_store_pong;
+ int repeat_num = (t_end - t_start);
+ const int odist = last_stage ? nfft : nfft << 1;
+ const int idist = (nfft / 2 + 1) << 1;
+ input += t_start * idist;
+ output += t_start * odist;
- if (last_stage) {
- // __memcpy_async(
- // output + (Fout_stride + i * large_radix) * 2,
- // nram_para_store.r,
- // sizeof(DT) * 2 * para_store_num * large_radix, NRAM2GDRAM);
-
- __memcpy_async(output + (Fout_stride + i) * 2, nram_para_store.r,
- sizeof(DT) * 2 * para_store_num, NRAM2GDRAM,
- large_out_stride * 2 * sizeof(DT),
- sizeof(DT) * 2 * para_store_num, large_radix - 1);
- } else {
- // // real
- // __memcpy_async(output + Fout_stride + i * large_radix,
- // nram_para_store.r,
- // para_store_num * large_radix * sizeof(DT),
- // NRAM2GDRAM);
- // // imag
- // __memcpy_async(output + Fout_stride + i * large_radix + nfft,
- // nram_para_store.i,
- // para_store_num * large_radix * sizeof(DT),
- // NRAM2GDRAM);
- // real
- __memcpy_async(output + Fout_stride + i, nram_para_store.r,
- para_store_num * sizeof(DT), NRAM2GDRAM,
- large_out_stride * sizeof(DT),
- sizeof(DT) * para_store_num, large_radix - 1);
- // imag
- __memcpy_async(output + Fout_stride + i + nfft, nram_para_store.i,
- para_store_num * sizeof(DT), NRAM2GDRAM,
- large_out_stride * sizeof(DT),
- sizeof(DT) * para_store_num, large_radix - 1);
- }
- }
- // __sync();
- // pipeline: compute-stage
+ for (int butterfly_id = 0; butterfly_id < half_butterfly_num;
+ butterfly_id += max_para_ldst_num) {
+ for (sec_count = 0; sec_count < large_section_num; ++sec_count) {
+ DT *output_batch = output;
+ DT *input_batch = input;
+ int para_num = (max_para_ldst_num > (half_butterfly_num - butterfly_id))
+ ? (half_butterfly_num - butterfly_id)
+ : max_para_ldst_num;
- if (repeat_id >= 1 && repeat_id < repeat_num + 1) {
- // MLULOG("pipeline: compute-stage.\n");
- int i = max_para_ldst_num * (repeat_id - 1);
-
- FFT_CPX_T nram_para_load_in = (repeat_id % 2 != 0)
- ? nram_para_load_in_ping
- : nram_para_load_in_pong;
-
- FFT_CPX_T nram_para_load_tw = (repeat_id % 2 != 0)
- ? nram_para_load_tw_ping
- : nram_para_load_tw_pong;
-
- FFT_CPX_T nram_para_store =
- (repeat_id % 2 != 0) ? nram_para_store_ping : nram_para_store_pong;
-
- int para_ldst_num = (max_para_ldst_num > (large_butterfly_num - i))
- ? (large_butterfly_num - i)
- : max_para_ldst_num;
-
- // __bang_transpose(nram_transpose_load, nram_para_load, large_radix,
- // 2 * para_ldst_num);
-
- // rotation-large
- __bang_mul(CPX_MUL_RR, nram_para_load_in.r + para_ldst_num,
- nram_para_load_tw.r, para_ldst_num * (large_radix - 1));
- __bang_mul(CPX_MUL_II, nram_para_load_in.i + para_ldst_num,
- nram_para_load_tw.i, para_ldst_num * (large_radix - 1));
- __bang_mul(CPX_MUL_RI, nram_para_load_in.r + para_ldst_num,
- nram_para_load_tw.i, para_ldst_num * (large_radix - 1));
- __bang_mul(CPX_MUL_IR, nram_para_load_in.i + para_ldst_num,
- nram_para_load_tw.r, para_ldst_num * (large_radix - 1));
-
- __bang_sub(nram_para_load_in.r + para_ldst_num, CPX_MUL_RR, CPX_MUL_II,
- para_ldst_num * (large_radix - 1));
- __bang_add(nram_para_load_in.i + para_ldst_num, CPX_MUL_RI, CPX_MUL_IR,
- para_ldst_num * (large_radix - 1));
-
- // __bang_transpose(nram_transpose_load.r, nram_para_load_in.r,
- // large_radix, para_ldst_num);
- // __bang_transpose(nram_transpose_load.i, nram_para_load_in.i,
- // large_radix, para_ldst_num);
-
- // for (int compute_id = 0; compute_id < para_ldst_num; compute_id++) {
- for (int compute_id = 0; compute_id < para_ldst_num;
- compute_id += para_ldst_num) {
- // load real & imag
+ for (int repeat_id = 0; repeat_id < repeat_num + 2;
+ ++repeat_id, input_batch += idist, output_batch += odist) {
+ // small_twiddles = _small_twiddles;
- radix = small_factors[4];
- small_section_num = small_factors[5];
- small_in_stride = small_factors[7];
- small_stage_count = _small_stage_count;
+ // pipeline: load-stage
- // __memcpy(nram_in_r,
- // nram_transpose_load + compute_id * large_radix * 2,
- // large_radix * sizeof(DT) * 2, NRAM2NRAM);
+ if (repeat_id < repeat_num) {
+ if (1) {
+ __memcpy_async(
+ nram_para_load_in_ping.r,
+ input_batch + sec_count * half_butterfly_num + butterfly_id,
+ sizeof(DT) * para_num, GDRAM2NRAM, sizeof(DT) * para_num,
+ large_in_stride * sizeof(DT), upper_radix - 1);
+ __memcpy_async(nram_para_load_in_ping.i,
+ input_batch + nfft + sec_count * half_butterfly_num +
+ butterfly_id,
+ sizeof(DT) * para_num, GDRAM2NRAM,
+ sizeof(DT) * para_num, large_in_stride * sizeof(DT),
+ upper_radix - 1);
- // first stage
- // if(0)
- if (ld_dft_radix != radix) {
- ld_dft_radix = radix;
- for (int entry = 0;; entry++) {
- if (dft_table[entry].radix == ld_dft_radix) {
- align_K = K_num * ((radix + K_num - 1) / K_num);
- __memcpy(nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2],
- sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM);
- break;
- }
+ __memcpy_async(
+ nram_para_load_in_ping.r + upper_radix * para_num,
+ input_batch + sec_count * half_butterfly_num + butterfly_id +
+ (half_butterfly_num - butterfly_id - para_num + 1),
+ sizeof(DT) * para_num, GDRAM2NRAM, sizeof(DT) * para_num,
+ large_in_stride * sizeof(DT), lower_radix - 1);
+ __memcpy_async(
+ nram_para_load_in_ping.i + upper_radix * para_num,
+ input_batch + nfft + sec_count * half_butterfly_num +
+ butterfly_id +
+ (half_butterfly_num - butterfly_id - para_num + 1),
+ sizeof(DT) * para_num, GDRAM2NRAM, sizeof(DT) * para_num,
+ large_in_stride * sizeof(DT), lower_radix - 1);
- if (dft_table[entry].radix == -1) {
- break;
- }
+ if (repeat_id == 0 && sec_count == 0) {
+ __memcpy_async(
+ nram_para_load_tw.r, cur_large_twiddles + butterfly_id,
+ sizeof(DT) * para_num, SRAM2NRAM, sizeof(DT) * para_num,
+ large_out_stride * sizeof(DT), large_radix - 2);
+ __memcpy_async(
+ nram_para_load_tw.i,
+ cur_large_twiddles + half_butterfly_num * (large_radix - 1) +
+ butterfly_id,
+ sizeof(DT) * para_num, SRAM2NRAM, sizeof(DT) * para_num,
+ large_out_stride * sizeof(DT), large_radix - 2);
}
}
+ }
- switch (radix) {
- default:
- // computeGenericButterflyFirststage(Fout, buffer, twiddles,
- // radix, section_num, butterfly_num, in_stride, 0, dir);
- MLULOG("computeGenericButterflyFirststageMat: %d.\n", radix);
-
- // para_ldst_num = 1
- // in: [radix, butterfly_num]
- // butterfly: [radix, radix] * [radix, butterfly_num]
- // out_butterfly: [radix, butterfly_num]
- // out: [butterfly_num, radix]
-
- // para_ldst_num != 1
- // in: [radix, butterfly_num, para_ldst_num] == [large_radix,
- // para_ldst_num] butterfly: [radix, radix] * [radix,
- // butterfly_num, para_ldst_num] out_butterfly: [radix,
- // butterfly_num, para_ldst_num] == [radix, butterfly_num *
- // para_ldst_num] out: [butterfly_num, para_ldst_num, radix]
-
- computeGenericButterflyFirststageMat(
- nram_out_r, nram_out_i, nram_para_load_in.r,
- nram_para_load_in.i, nram_scratch, nram_dftmtx,
- small_section_num * para_ldst_num,
- small_section_num * para_ldst_num, 1, dir, radix);
- break;
+ // pipeline: store-stage
+ if (repeat_id >= 2) {
+ if (last_stage) {
+ __memcpy_async(output_batch - odist * 2 +
+ (sec_count * large_radix * half_butterfly_num +
+ butterfly_id) *
+ 2,
+ nram_para_store_ping.r, sizeof(DT) * 2 * para_num,
+ NRAM2GDRAM, large_out_stride * 2 * sizeof(DT),
+ sizeof(DT) * 2 * para_num, large_radix - 1);
+ } else {
+ // real
+ __memcpy_async(output_batch - odist * 2 +
+ sec_count * large_radix * half_butterfly_num +
+ butterfly_id,
+ nram_para_store_ping.r, para_num * sizeof(DT),
+ NRAM2GDRAM, large_out_stride * sizeof(DT),
+ sizeof(DT) * para_num, large_radix - 1);
+ // imag
+ __memcpy_async(output_batch - odist * 2 +
+ sec_count * large_radix * half_butterfly_num +
+ butterfly_id + nfft,
+ nram_para_store_ping.i, para_num * sizeof(DT),
+ NRAM2GDRAM, large_out_stride * sizeof(DT),
+ sizeof(DT) * para_num, large_radix - 1);
}
+ }
+ // __sync();
+ // pipeline: compute-stage
- // for (int j = 0; j < large_radix; j++) {
- // MLULOG("output i: (%f, %f).\n", nram_out_r[j], nram_out_i[j]);
- // }
-
- // [radix, small_section_num, para_ldst_num] ->
- // [small_section_num, para_ldst_num, radix] ->
- // [para_ldst_num, small_section_num, radix] ->
- // [small_section_num, radix, para_ldst_num] ==
- // [large_radix, para_ldst_num]
-
- small_stage_count--;
- if (small_stage_count == 0) {
- // nram to gdram
-
- // if (last_stage) {
- // // [2, para_ldst_num, large_radix] -> [para_ldst_num,
- // large_radix,
- // // 2]
- // // DT* nram_transpose_store = nram_in_r;
-
- // __bang_transpose(nram_in_r, nram_out_r, 2,
- // max_para_ldst_num * large_radix);
-
- // } else {
- // // [2, para_ldst_num, large_radix] -> [2, para_ldst_num,
- // // large_radix]
- // // TODO(zrg): redundant move
- // __memcpy(nram_in_r, nram_out_r,
- // para_ldst_num * large_radix * sizeof(DT), NRAM2NRAM);
- // __memcpy(nram_in_i,
- // nram_out_i, para_ldst_num * large_radix * sizeof(DT),
- // NRAM2NRAM);
- // }
-
- // [nfft, 2] -> [2, nfft] -> [2, nfft] -> [nfft, 2]
- if (last_stage) {
- __memcpy(nram_transpose_temp.r +
- compute_id * large_radix * (1 + (int)last_stage),
- nram_out_r, sizeof(DT) * large_radix, NRAM2NRAM,
- sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix,
- para_ldst_num - 1);
-
- __memcpy(nram_transpose_temp.i +
- compute_id * large_radix * (1 + (int)last_stage),
- nram_out_i, sizeof(DT) * large_radix, NRAM2NRAM,
- sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix,
- para_ldst_num - 1);
-
- __bang_transpose(nram_para_store.r, nram_transpose_temp.r,
- para_ldst_num * 2, large_radix);
- } else {
- __bang_transpose(nram_para_store.r, nram_out_r, para_ldst_num,
- large_radix);
- __bang_transpose(nram_para_store.i, nram_out_i, para_ldst_num,
- large_radix);
- }
-
- continue;
- }
+ if (repeat_id >= 1 && repeat_id < repeat_num + 1) {
+ // MLULOG("pipeline: compute-stage.\n");
+ // int i = max_para_ldst_num * (repeat_id - 1);
- FFT_SWAP_PTR(nram_out_r, nram_in_r);
- FFT_SWAP_PTR(nram_out_i, nram_in_i);
- // DT* nram_transpose_store = nram_in_r;
-
- // for (int para_ldst_id = 0; para_ldst_id < para_ldst_num;
- // para_ldst_id++) {
- // __memcpy(nram_out_r + para_ldst_id * small_section_num * radix,
- // nram_in_r + para_ldst_id * radix, sizeof(DT) * radix,
- // NRAM2NRAM, sizeof(DT) * radix,
- // para_ldst_num * radix * sizeof(DT), small_section_num -
- // 1);
-
- // __memcpy(nram_out_i + para_ldst_id * small_section_num * radix,
- // nram_in_i + para_ldst_id * radix, sizeof(DT) * radix,
- // NRAM2NRAM, sizeof(DT) * radix,
- // para_ldst_num * radix * sizeof(DT), small_section_num -
- // 1);
- // }
+ DT *nram_in_r = nram_para_load_in_pong.r;
+ DT *nram_in_i = nram_para_load_in_pong.i;
- // after first stage: [butterfly_num, para_ldst_num, radix]
- // other in: [para_ldst_num, butterfly_num, radix] == [para_ldst_num,
- // large_radix]
- TRANSPOSE_XYZ2YXZ_PAIR(nram_out_r, nram_out_i, nram_in_r, nram_in_i,
- small_section_num, para_ldst_num, radix, DT)
-
- // TODO(zrg) : add not last-stage
- // if (small_stage_count == 0) {
- // // if last-stage: stride = large_radix * 2
- // // compute_id 0 r
- // // compute_id 0 i
- // // compute_id 1 r
- // // compute_id 1 i
- // // else: stride = large_radix
- // // compute_id 0 r
- // // compute_id 1 i
- // // compute_id 0 r
- // // compute_id 1 i
-
- // // [radix, small_section_num, para_ldst_num] ->
- // // [small_section_num, para_ldst_num, radix] ->
- // // [para_ldst_num, small_section_num, radix] ->
- // // [small_section_num, radix, para_ldst_num] ==
- // // [large_radix, para_ldst_num]
-
- // __memcpy(nram_transpose_temp.r +
- // compute_id * large_radix * (1 + (int)last_stage),
- // nram_out_r, sizeof(DT) * large_radix, NRAM2NRAM,
- // sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix,
- // para_ldst_num - 1);
-
- // __memcpy(nram_transpose_temp.i +
- // compute_id * large_radix * (1 + (int)last_stage),
- // nram_out_i, sizeof(DT) * large_radix, NRAM2NRAM,
- // sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix,
- // para_ldst_num - 1);
-
- // // __memcpy(nram_transpose_temp.r +
- // // compute_id * large_radix * (1 + (int)last_stage),
- // // nram_out_r, large_radix * sizeof(DT), NRAM2NRAM);
- // // __memcpy(nram_transpose_temp.i +
- // // compute_id * large_radix * (1 + (int)last_stage),
- // // nram_out_i, large_radix * sizeof(DT), NRAM2NRAM);
-
- // // __bang_transpose(nram_transpose_temp.r, nram_transpose_temp.r,
- // // max_para_ldst_num * 2, large_radix);
- // continue;
- // }
+ DT *nram_out_r = nram_para_store_pong.r;
+ DT *nram_out_i = nram_para_store_pong.i;
- // DT *sram_tw = (DT *)sram_buffer;
- DT *nram_tw = _nram_tw;
- value_mul = 8;
+ for (int compute_id = 0; compute_id < para_num;
+ compute_id += para_num) {
+ // load real & imag
- for (; small_stage_count > 1; small_stage_count--) {
- FFT_SWAP_PTR(nram_out_r, nram_in_r);
- FFT_SWAP_PTR(nram_out_i, nram_in_i);
+ radix = small_factors[4];
+ small_section_num = small_factors[5];
+ small_in_stride = small_factors[7];
+ small_stage_count = _small_stage_count;
- // value_mul = (_small_stage_count - small_stage_count + 1) * 4;
- // // update parameter
- radix = small_factors[value_mul++];
- small_section_num = small_factors[value_mul++];
- small_butterfly_num = small_factors[value_mul++];
- small_in_stride = small_factors[value_mul++];
- // copy GDRAM2SRAM
+ // __memcpy(nram_in_r,
+ // nram_transpose_load + compute_id * large_radix * 2,
+ // large_radix * sizeof(DT) * 2, NRAM2NRAM);
+ // first stage
+ // if(0)
if (ld_dft_radix != radix) {
ld_dft_radix = radix;
for (int entry = 0;; entry++) {
@@ -1462,2040 +747,329 @@ __mlu_func__ void computeLargeButterflyOtherstages(
}
}
- if (sec_count == 0 && compute_id == 0 && repeat_id == 1) {
- __memcpy(nram_tw, small_twiddles,
- small_butterfly_num * (radix - 1) * sizeof(DT) * 2,
- SRAM2NRAM);
- small_twiddles += small_butterfly_num * (radix - 1) * 2;
- }
-
switch (radix) {
- // case 2:
- // // computeRadix2ButterflyOtherstages(Fout, Fin, section_num,
- // // section_num, 1, dir);
- // break;
- // case 3:
- // computeRadix3ButterflyOtherstages(
- // nram_out_r, nram_out_i, nram_in_r, nram_in_i,
- // nram_scratch, nram_tw, small_section_num,
- // small_butterfly_num, small_in_stride, dir);
- // break;
- // case 9:
- // computeRadix9ButterflyOtherstages(
- // nram_out_r, nram_out_i, nram_in_r, nram_in_i,
- // nram_scratch, nram_tw, small_section_num,
- // small_butterfly_num, small_in_stride, dir);
- // break;
default:
- // computeGenericButterflyOtherstages(Fout, buffer, twiddles,
+ // computeGenericButterflyFirststage(Fout, buffer, twiddles,
// radix, section_num, butterfly_num, in_stride, 0, dir);
- computeGenericButterflyOtherstagesMat(
- nram_out_r, nram_out_i, nram_in_r, nram_in_i, nram_scratch,
- nram_dftmtx, nram_tw, small_section_num,
- small_butterfly_num, para_ldst_num, small_in_stride, dir,
- radix);
- break;
- }
-
- nram_tw += small_butterfly_num * (radix - 1) * 2;
- } // for (stage_count)
-
- // for (int j = 0; j < large_radix; j++) {
- // MLULOG("output i: (%f, %f).\n", nram_out_r[j], nram_out_i[j]);
- // }
-
- // MLULOG("butterfly id: %d\n", i);
- // last stage
- {
- FFT_SWAP_PTR(nram_out_r, nram_in_r);
- FFT_SWAP_PTR(nram_out_i, nram_in_i);
- // copy GDRAM2SRAM
-
- // update parameter
- radix = small_factors[value_mul++];
- small_section_num = small_factors[value_mul++];
- small_butterfly_num = small_factors[value_mul++];
- small_in_stride = small_factors[value_mul];
+ MLULOG("computeGenericButterflyFirststageMat: %d.\n", radix);
- if (sec_count == 0 && compute_id == 0 && repeat_id == 1) {
- __memcpy(nram_tw, small_twiddles,
- small_butterfly_num * (radix - 1) * sizeof(DT) * 2,
- SRAM2NRAM);
- }
+ // para_num = 1
+ // in: [radix, butterfly_num]
+ // butterfly: [radix, radix] * [radix, butterfly_num]
+ // out_butterfly: [radix, butterfly_num]
+ // out: [butterfly_num, radix]
- if (ld_dft_radix != radix) {
- ld_dft_radix = radix;
- for (int entry = 0;; entry++) {
- if (dft_table[entry].radix == ld_dft_radix) {
- align_K = K_num * ((radix + K_num - 1) / K_num);
- __memcpy(nram_dftmtx,
- &dft_matrix[dft_table[entry].offset * 2],
- sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM);
- break;
- }
+ // para_num != 1
+ // in: [radix, butterfly_num, para_num] == [large_radix,
+ // para_num] butterfly: [radix, radix] * [radix,
+ // butterfly_num, para_num] out_butterfly: [radix,
+ // butterfly_num, para_num] == [radix, butterfly_num *
+ // para_num] out: [butterfly_num, para_num, radix]
- if (dft_table[entry].radix == -1) {
- break;
- }
- }
- }
- switch (radix) {
- // case 2:
- // break;
- // case 3:
- // computeRadix3ButterflyLaststage(
- // nram_out_r, nram_out_i, nram_in_r, nram_in_i,
- // nram_scratch, nram_tw, small_section_num,
- // small_butterfly_num, small_in_stride, dir);
- // break;
- // case 9:
- // computeRadix9ButterflyLaststage(
- // nram_out_r, nram_out_i, nram_in_r, nram_in_i,
- // nram_scratch, nram_tw, small_section_num,
- // small_butterfly_num, small_in_stride, dir);
- // break;
- default:
- // computeGenericButterflyLaststage(Fout, buffer, twiddles,
- // radix, section_num, butterfly_num, in_stride, 0, dir);
- computeGenericButterflyLaststageMat(
- nram_out_r, nram_out_i, nram_in_r, nram_in_i, nram_scratch,
- nram_dftmtx, nram_tw, small_section_num,
- small_butterfly_num, para_ldst_num, small_in_stride, dir,
- radix);
+ computeGenericButterflyFirststageMat(
+ nram_out_r, nram_out_i, nram_para_load_in_pong.r,
+ nram_para_load_in_pong.i, nram_scratch, nram_dftmtx,
+ small_section_num * para_num, small_section_num * para_num,
+ 1, dir, radix);
break;
}
- // if last-stage: stride = large_radix * 2
- // compute_id 0 r
- // compute_id 0 i
- // compute_id 1 r
- // compute_id 1 i
- // else: stride = large_radix
- // compute_id 0 r
- // compute_id 1 i
- // compute_id 0 r
- // compute_id 1 i
- // __memcpy(nram_transpose_temp.r +
- // compute_id * large_radix * (1 + (int)last_stage),
- // nram_out_r, large_radix * sizeof(DT), NRAM2NRAM);
- // __memcpy(nram_transpose_temp.i +
- // compute_id * large_radix * (1 + (int)last_stage),
- // nram_out_i, large_radix * sizeof(DT), NRAM2NRAM);
+ // for (int j = 0; j < large_radix; j++) {
+ // MLULOG("output i: (%f, %f).\n", nram_out_r[j], nram_out_i[j]);
+ // }
- if (last_stage) {
- __memcpy(nram_transpose_temp.r +
- compute_id * large_radix * (1 + (int)last_stage),
- nram_out_r, sizeof(DT) * large_radix, NRAM2NRAM,
- sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix,
- para_ldst_num - 1);
-
- __memcpy(nram_transpose_temp.i +
- compute_id * large_radix * (1 + (int)last_stage),
- nram_out_i, sizeof(DT) * large_radix, NRAM2NRAM,
- sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix,
- para_ldst_num - 1);
-
- __bang_transpose(nram_para_store.r, nram_transpose_temp.r,
- para_ldst_num * 2, large_radix);
- } else {
- __bang_transpose(nram_para_store.r, nram_out_r, para_ldst_num,
- large_radix);
- __bang_transpose(nram_para_store.i, nram_out_i, para_ldst_num,
- large_radix);
- }
+ // [radix, small_section_num, para_num] ->
+ // [small_section_num, para_num, radix] ->
+ // [para_num, small_section_num, radix] ->
+ // [small_section_num, radix, para_num] ==
+ // [large_radix, para_num]
- // __bang_transpose(nram_para_store, nram_transpose_temp.r,
- // max_para_ldst_num * 2, large_radix);
- }
- }
- }
- // __sync();
+ small_stage_count--;
+ if (small_stage_count == 0) {
+ // nram to gdram
- __sync();
- }
- Fin_stride += large_butterfly_num;
- Fout_stride += large_radix * large_butterfly_num;
- }
-}
-
-template
-__mlu_func__ void computeLargeButterflyLaststage(
- DT *output, DT *input, const DT *cur_large_twiddles, const DT *_twiddles,
- const DT *dft_matrix, int large_section_num, int large_butterfly_num,
- int large_in_stride, void *nram_buf, const int *small_factors, int nfft,
- int dir) {
- computeLargeButterflyOtherstages(output, input, cur_large_twiddles, _twiddles,
- dft_matrix, large_section_num,
- large_butterfly_num, large_in_stride,
- nram_buf, small_factors, nfft, dir, 1);
-}
-
-template
-__mlu_func__ void computeLargeButterflyOtherstagesBatchPingpong(
- DT *output, DT *input, const DT *cur_large_twiddles, const DT *_twiddles,
- const DT *dft_matrix, int large_section_num, int large_butterfly_num,
- int large_in_stride, void *nram_buf, const int *small_factors, int nfft,
- const int t_start, const int t_end, int dir, int last_stage) {
- // return;
- const dft_table_entry *dft_table = (const dft_table_entry *)dft_matrix;
-
- int radix, small_in_stride, small_stage_count, large_radix,
- _small_stage_count;
- int small_section_num, small_butterfly_num, value_mul;
-
- const int large_out_stride = large_butterfly_num;
- int tw_offset;
- const int K_num = 64 / sizeof(DT);
- int align_K = 0;
- _small_stage_count = small_factors[0];
- large_radix = small_factors[1];
- tw_offset = small_factors[2];
-
- const DT *small_twiddles = _twiddles + tw_offset * 2; // complex
-
- // const int max_para_ldst_num = (6144 + large_radix - 1) / large_radix;
- // int max_para_ldst_num = (6400 + large_radix - 1) / large_radix;
- int max_para_ldst_num = (large_butterfly_num < small_factors[3])
- ? large_butterfly_num
- : small_factors[3];
-
- // int para_ldst_num;
- // TODO(zrg): save nram space.
- // __nram__ DT nram_space[MAX_BUTTERFLY_ON_CHIP * 2];
- // 0 1 2 3 4 5
- // 0 1 3
- // DT *nram_buf_end = (DT*)&((uint8_t*)nram_buf)[NRAM_BUFFER_SIZE];
- // FFT_CPX_T *nram_in = (FFT_CPX_T *)nram_buffer;
- // FFT_CPX_T *nram_out = &nram_in[large_radix];
- // FFT_CPX_T *nram_buf = &nram_in[large_radix * 2];
- int nram_buf_offset = 0;
-
- FFT_CPX_T nram_para_load_in_ping = {
- (DT *)nram_buf + nram_buf_offset,
- (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num};
- nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex
-
- FFT_CPX_T nram_para_load_in_pong = {
- (DT *)nram_buf + nram_buf_offset,
- (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num};
- nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex
-
- FFT_CPX_T nram_para_load_tw = {
- (DT *)nram_buf + nram_buf_offset,
- (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num};
- nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex
-
- // FFT_CPX_T nram_para_load_tw_ping = {
- // (DT *)nram_buf + nram_buf_offset,
- // (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num};
- // nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex
-
- // FFT_CPX_T nram_para_load_tw_pong = {
- // (DT *)nram_buf + nram_buf_offset,
- // (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num};
- // nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex
-
- FFT_CPX_T nram_para_store_ping = {
- (DT *)nram_buf + nram_buf_offset,
- (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num};
- nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex
-
- FFT_CPX_T nram_para_store_pong = {
- (DT *)nram_buf + nram_buf_offset,
- (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num};
- nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex
-
- // nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex
-
- DT *_nram_tw = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += large_radix * 2; // complex
-
- // transpose space: [radix, 2 * parrallel] -> [parrallel * 2, radix]
- // FFT_CPX_T nram_transpose_load = {
- // (DT *)nram_buf + nram_buf_offset,
- // (DT *)nram_buf + nram_buf_offset + large_radix * max_para_ldst_num};
- // nram_buf_offset += large_radix * max_para_ldst_num * 2; // complex
-
- int ld_dft_radix = -1;
- const int max_radix = 64;
- DT *nram_dftmtx = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += max_radix * max_radix * 2; // complex
-
- DT *nram_scratch = (DT *)nram_buf + nram_buf_offset;
-
- // temp overlap with "nram_scratch"
- DT *CPX_MUL_RR = nram_scratch;
- DT *CPX_MUL_RI = &CPX_MUL_RR[large_radix * max_para_ldst_num];
- DT *CPX_MUL_IR = &CPX_MUL_RI[large_radix * max_para_ldst_num];
- DT *CPX_MUL_II = &CPX_MUL_IR[large_radix * max_para_ldst_num];
-
- nram_buf_offset += large_radix * max_para_ldst_num * 4; // complex
-
- // overlap nram_in
- FFT_CPX_T nram_transpose_temp;
- // temp out-space before transpose
- // if last-stage:
- // compute_id 0 r
- // compute_id 0 i
- // compute_id 1 r
- // compute_id 1 i
- // else:
- // compute_id 0 r
- // compute_id 1 i
- // compute_id 0 r
- // compute_id 1 i
- nram_transpose_temp = {
- (DT *)nram_buf + nram_buf_offset,
- (DT *)nram_buf + nram_buf_offset + large_radix * ((int)last_stage) +
- large_radix * (1 - (int)last_stage) * max_para_ldst_num};
-
- // size: (large_radix - 1) * max_para_ldst_num
- // DT *scratch_tw_r = &CPX_MUL_II[large_radix * max_para_ldst_num];
- // DT *scratch_tw_i = &scratch_tw_r[(large_radix - 1) * max_para_ldst_num];
-
- // int Fin_stride = 0, Fout_stride = 0;
- int sec_count;
- // int repeat_num =
- // (large_butterfly_num + max_para_ldst_num - 1) / max_para_ldst_num;
- int repeat_num = (t_end - t_start);
- input += t_start * (nfft << 1);
- output += t_start * (nfft << 1);
-
- for (int butterfly_id = 0; butterfly_id < large_butterfly_num;
- butterfly_id += max_para_ldst_num) {
- for (sec_count = 0; sec_count < large_section_num; ++sec_count) {
- DT *output_batch = output;
- DT *input_batch = input;
- int para_num = (max_para_ldst_num > (large_butterfly_num - butterfly_id))
- ? (large_butterfly_num - butterfly_id)
- : max_para_ldst_num;
-
- for (int repeat_id = 0; repeat_id < repeat_num + 2; ++repeat_id,
- input_batch += (nfft << 1), output_batch += (nfft << 1)) {
- // small_twiddles = _small_twiddles;
-
- // pipeline: load-stage
-
- if (repeat_id < repeat_num) {
- if (para_num != 1) {
- __memcpy_async(
- nram_para_load_in_ping.r,
- input_batch + sec_count * large_butterfly_num + butterfly_id,
- sizeof(DT) * para_num, GDRAM2NRAM, sizeof(DT) * para_num,
- large_in_stride * sizeof(DT), large_radix - 1);
- __memcpy_async(nram_para_load_in_ping.i,
- input_batch + nfft +
- sec_count * large_butterfly_num + butterfly_id,
- sizeof(DT) * para_num, GDRAM2NRAM,
- sizeof(DT) * para_num, large_in_stride * sizeof(DT),
- large_radix - 1);
- if (repeat_id == 0 && sec_count == 0) {
- __memcpy_async(
- nram_para_load_tw.r, cur_large_twiddles + butterfly_id,
- sizeof(DT) * para_num, SRAM2NRAM, sizeof(DT) * para_num,
- large_out_stride * sizeof(DT), large_radix - 2);
- __memcpy_async(
- nram_para_load_tw.i,
- cur_large_twiddles + large_butterfly_num * (large_radix - 1) +
- butterfly_id,
- sizeof(DT) * para_num, SRAM2NRAM, sizeof(DT) * para_num,
- large_out_stride * sizeof(DT), large_radix - 2);
- }
- }
- }
-
- // pipeline: store-stage
- if (repeat_id >= 2) {
- if (last_stage) {
- __memcpy_async(output_batch - (nfft << 2) +
- (sec_count * large_radix * large_butterfly_num +
- butterfly_id) *
- 2,
- nram_para_store_ping.r, sizeof(DT) * 2 * para_num,
- NRAM2GDRAM, large_out_stride * 2 * sizeof(DT),
- sizeof(DT) * 2 * para_num, large_radix - 1);
- } else {
- // real
- __memcpy_async(output_batch - (nfft << 2) +
- sec_count * large_radix * large_butterfly_num +
- butterfly_id,
- nram_para_store_ping.r, para_num * sizeof(DT),
- NRAM2GDRAM, large_out_stride * sizeof(DT),
- sizeof(DT) * para_num, large_radix - 1);
- // imag
- __memcpy_async(output_batch - (nfft << 2) +
- sec_count * large_radix * large_butterfly_num +
- butterfly_id + nfft,
- nram_para_store_ping.i, para_num * sizeof(DT),
- NRAM2GDRAM, large_out_stride * sizeof(DT),
- sizeof(DT) * para_num, large_radix - 1);
- }
- }
- // __sync();
- // pipeline: compute-stage
-
- if (repeat_id >= 1 && repeat_id < repeat_num + 1) {
- // MLULOG("pipeline: compute-stage.\n");
- // int i = max_para_ldst_num * (repeat_id - 1);
-
- DT *nram_in_r = nram_para_load_in_pong.r;
- DT *nram_in_i = nram_para_load_in_pong.i;
-
- DT *nram_out_r = nram_para_store_pong.r;
- DT *nram_out_i = nram_para_store_pong.i;
-
- // rotation-large
- __bang_mul(CPX_MUL_RR, nram_para_load_in_pong.r + para_num,
- nram_para_load_tw.r, para_num * (large_radix - 1));
- __bang_mul(CPX_MUL_II, nram_para_load_in_pong.i + para_num,
- nram_para_load_tw.i, para_num * (large_radix - 1));
- __bang_mul(CPX_MUL_RI, nram_para_load_in_pong.r + para_num,
- nram_para_load_tw.i, para_num * (large_radix - 1));
- __bang_mul(CPX_MUL_IR, nram_para_load_in_pong.i + para_num,
- nram_para_load_tw.r, para_num * (large_radix - 1));
-
- __bang_sub(nram_para_load_in_pong.r + para_num, CPX_MUL_RR,
- CPX_MUL_II, para_num * (large_radix - 1));
- __bang_add(nram_para_load_in_pong.i + para_num, CPX_MUL_RI,
- CPX_MUL_IR, para_num * (large_radix - 1));
-
- // __bang_transpose(nram_transpose_load.r, nram_para_load_in.r,
- // large_radix, para_num);
- // __bang_transpose(nram_transpose_load.i, nram_para_load_in.i,
- // large_radix, para_num);
-
- // for (int compute_id = 0; compute_id < para_num; compute_id++)
- // {
- for (int compute_id = 0; compute_id < para_num;
- compute_id += para_num) {
- // load real & imag
-
- radix = small_factors[4];
- small_section_num = small_factors[5];
- small_in_stride = small_factors[7];
- small_stage_count = _small_stage_count;
-
- // __memcpy(nram_in_r,
- // nram_transpose_load + compute_id * large_radix * 2,
- // large_radix * sizeof(DT) * 2, NRAM2NRAM);
-
- // first stage
- // if(0)
- if (ld_dft_radix != radix) {
- ld_dft_radix = radix;
- for (int entry = 0;; entry++) {
- if (dft_table[entry].radix == ld_dft_radix) {
- align_K = K_num * ((radix + K_num - 1) / K_num);
- __memcpy(nram_dftmtx,
- &dft_matrix[dft_table[entry].offset * 2],
- sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM);
- break;
- }
-
- if (dft_table[entry].radix == -1) {
- break;
- }
- }
- }
-
- switch (radix) {
- default:
- // computeGenericButterflyFirststage(Fout, buffer, twiddles,
- // radix, section_num, butterfly_num, in_stride, 0, dir);
- MLULOG("computeGenericButterflyFirststageMat: %d.\n", radix);
-
- // para_num = 1
- // in: [radix, butterfly_num]
- // butterfly: [radix, radix] * [radix, butterfly_num]
- // out_butterfly: [radix, butterfly_num]
- // out: [butterfly_num, radix]
-
- // para_num != 1
- // in: [radix, butterfly_num, para_num] == [large_radix,
- // para_num] butterfly: [radix, radix] * [radix,
- // butterfly_num, para_num] out_butterfly: [radix,
- // butterfly_num, para_num] == [radix, butterfly_num *
- // para_num] out: [butterfly_num, para_num, radix]
-
- computeGenericButterflyFirststageMat(
- nram_out_r, nram_out_i, nram_para_load_in_pong.r,
- nram_para_load_in_pong.i, nram_scratch, nram_dftmtx,
- small_section_num * para_num, small_section_num * para_num,
- 1, dir, radix);
- break;
- }
-
- // for (int j = 0; j < large_radix; j++) {
- // MLULOG("output i: (%f, %f).\n", nram_out_r[j], nram_out_i[j]);
- // }
-
- // [radix, small_section_num, para_num] ->
- // [small_section_num, para_num, radix] ->
- // [para_num, small_section_num, radix] ->
- // [small_section_num, radix, para_num] ==
- // [large_radix, para_num]
-
- small_stage_count--;
- if (small_stage_count == 0) {
- // nram to gdram
-
- // if (last_stage) {
- // // [2, para_num, large_radix] -> [para_num,
- // large_radix,
- // // 2]
- // // DT* nram_transpose_store = nram_in_r;
-
- // __bang_transpose(nram_in_r, nram_out_r, 2,
- // max_para_num * large_radix);
-
- // } else {
- // // [2, para_num, large_radix] -> [2, para_num,
- // // large_radix]
- // // TODO(zrg): redundant move
- // __memcpy(nram_in_r, nram_out_r,
- // para_num * large_radix * sizeof(DT),
- // NRAM2NRAM);
- // __memcpy(nram_in_i,
- // nram_out_i, para_num * large_radix *
- // sizeof(DT), NRAM2NRAM);
- // }
-
- // [nfft, 2] -> [2, nfft] -> [2, nfft] -> [nfft, 2]
- if (last_stage) {
- __memcpy(nram_transpose_temp.r +
- compute_id * large_radix * (1 + (int)last_stage),
- nram_out_r, sizeof(DT) * large_radix, NRAM2NRAM,
- sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix,
- para_num - 1);
-
- __memcpy(nram_transpose_temp.i +
- compute_id * large_radix * (1 + (int)last_stage),
- nram_out_i, sizeof(DT) * large_radix, NRAM2NRAM,
- sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix,
- para_num - 1);
-
- __bang_transpose(nram_para_store_pong.r, nram_transpose_temp.r,
- para_num * 2, large_radix);
- } else {
- if (nram_out_r == nram_para_store_pong.r) {
- FFT_SWAP_PTR(nram_para_load_in_pong.r, nram_para_store_pong.r)
- FFT_SWAP_PTR(nram_para_load_in_pong.i, nram_para_store_pong.i)
- }
- __bang_transpose(nram_para_store_pong.r, nram_out_r, para_num,
- large_radix);
- __bang_transpose(nram_para_store_pong.i, nram_out_i, para_num,
- large_radix);
- }
-
- continue;
- }
-
- FFT_SWAP_PTR(nram_out_r, nram_in_r);
- FFT_SWAP_PTR(nram_out_i, nram_in_i);
- // DT* nram_transpose_store = nram_in_r;
-
- // for (int para_ldst_id = 0; para_ldst_id < para_ldst_num;
- // para_ldst_id++) {
- // __memcpy(nram_out_r + para_ldst_id * small_section_num * radix,
- // nram_in_r + para_ldst_id * radix, sizeof(DT) * radix,
- // NRAM2NRAM, sizeof(DT) * radix,
- // para_ldst_num * radix * sizeof(DT), small_section_num
- // - 1);
-
- // __memcpy(nram_out_i + para_ldst_id * small_section_num * radix,
- // nram_in_i + para_ldst_id * radix, sizeof(DT) * radix,
- // NRAM2NRAM, sizeof(DT) * radix,
- // para_ldst_num * radix * sizeof(DT), small_section_num
- // - 1);
- // }
-
- // after first stage: [butterfly_num, para_ldst_num, radix]
- // other in: [para_ldst_num, butterfly_num, radix] ==
- // [para_ldst_num, large_radix]
- TRANSPOSE_XYZ2YXZ_PAIR(nram_out_r, nram_out_i, nram_in_r, nram_in_i,
- small_section_num, para_num, radix, DT)
-
- // TODO(zrg) : add not last-stage
- // if (small_stage_count == 0) {
- // // if last-stage: stride = large_radix * 2
- // // compute_id 0 r
- // // compute_id 0 i
- // // compute_id 1 r
- // // compute_id 1 i
- // // else: stride = large_radix
- // // compute_id 0 r
- // // compute_id 1 i
- // // compute_id 0 r
- // // compute_id 1 i
-
- // // [radix, small_section_num, para_ldst_num] ->
- // // [small_section_num, para_ldst_num, radix] ->
- // // [para_ldst_num, small_section_num, radix] ->
- // // [small_section_num, radix, para_ldst_num] ==
- // // [large_radix, para_ldst_num]
-
- // __memcpy(nram_transpose_temp.r +
- // compute_id * large_radix * (1 + (int)last_stage),
- // nram_out_r, sizeof(DT) * large_radix, NRAM2NRAM,
- // sizeof(DT) * large_radix * 2, sizeof(DT) *
- // large_radix, para_ldst_num - 1);
-
- // __memcpy(nram_transpose_temp.i +
- // compute_id * large_radix * (1 + (int)last_stage),
- // nram_out_i, sizeof(DT) * large_radix, NRAM2NRAM,
- // sizeof(DT) * large_radix * 2, sizeof(DT) *
- // large_radix, para_ldst_num - 1);
-
- // // __memcpy(nram_transpose_temp.r +
- // // compute_id * large_radix * (1 +
- // (int)last_stage),
- // // nram_out_r, large_radix * sizeof(DT), NRAM2NRAM);
- // // __memcpy(nram_transpose_temp.i +
- // // compute_id * large_radix * (1 +
- // (int)last_stage),
- // // nram_out_i, large_radix * sizeof(DT), NRAM2NRAM);
-
- // // __bang_transpose(nram_transpose_temp.r,
- // nram_transpose_temp.r,
- // // max_para_ldst_num * 2, large_radix);
- // continue;
- // }
-
- // DT *sram_tw = (DT *)sram_buffer;
- DT *nram_tw = _nram_tw;
- value_mul = 8;
-
- for (; small_stage_count > 1; small_stage_count--) {
- FFT_SWAP_PTR(nram_out_r, nram_in_r);
- FFT_SWAP_PTR(nram_out_i, nram_in_i);
-
- // value_mul = (_small_stage_count - small_stage_count + 1) * 4;
- // // update parameter
- radix = small_factors[value_mul++];
- small_section_num = small_factors[value_mul++];
- small_butterfly_num = small_factors[value_mul++];
- small_in_stride = small_factors[value_mul++];
- // copy GDRAM2SRAM
-
- if (ld_dft_radix != radix) {
- ld_dft_radix = radix;
- for (int entry = 0;; entry++) {
- if (dft_table[entry].radix == ld_dft_radix) {
- align_K = K_num * ((radix + K_num - 1) / K_num);
- __memcpy(
- nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2],
- sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM);
- break;
- }
-
- if (dft_table[entry].radix == -1) {
- break;
- }
- }
- }
-
- if (sec_count == 0 && compute_id == 0 && repeat_id == 1) {
- __memcpy(nram_tw, small_twiddles,
- small_butterfly_num * (radix - 1) * sizeof(DT) * 2,
- SRAM2NRAM);
- small_twiddles += small_butterfly_num * (radix - 1) * 2;
- }
-
- switch (radix) {
- // case 2:
- // // computeRadix2ButterflyOtherstages(Fout, Fin,
- // section_num,
- // // section_num, 1, dir);
- // break;
- // case 3:
- // computeRadix3ButterflyOtherstages(
- // nram_out_r, nram_out_i, nram_in_r, nram_in_i,
- // nram_scratch, nram_tw, small_section_num,
- // small_butterfly_num, small_in_stride, dir);
- // break;
- // case 9:
- // computeRadix9ButterflyOtherstages(
- // nram_out_r, nram_out_i, nram_in_r, nram_in_i,
- // nram_scratch, nram_tw, small_section_num,
- // small_butterfly_num, small_in_stride, dir);
- // break;
- default:
- // computeGenericButterflyOtherstages(Fout, buffer, twiddles,
- // radix, section_num, butterfly_num, in_stride, 0, dir);
- computeGenericButterflyOtherstagesMat(
- nram_out_r, nram_out_i, nram_in_r, nram_in_i,
- nram_scratch, nram_dftmtx, nram_tw, small_section_num,
- small_butterfly_num, para_num, small_in_stride, dir,
- radix);
- break;
- }
-
- nram_tw += small_butterfly_num * (radix - 1) * 2;
- } // for (stage_count)
-
- // for (int j = 0; j < large_radix; j++) {
- // MLULOG("output i: (%f, %f).\n", nram_out_r[j], nram_out_i[j]);
- // }
-
- // MLULOG("butterfly id: %d\n", i);
- // last stage
- {
- FFT_SWAP_PTR(nram_out_r, nram_in_r);
- FFT_SWAP_PTR(nram_out_i, nram_in_i);
- // copy GDRAM2SRAM
-
- // update parameter
- radix = small_factors[value_mul++];
- small_section_num = small_factors[value_mul++];
- small_butterfly_num = small_factors[value_mul++];
- small_in_stride = small_factors[value_mul];
-
- if (sec_count == 0 && compute_id == 0 && repeat_id == 1) {
- __memcpy(nram_tw, small_twiddles,
- small_butterfly_num * (radix - 1) * sizeof(DT) * 2,
- SRAM2NRAM);
- }
-
- if (ld_dft_radix != radix) {
- ld_dft_radix = radix;
- for (int entry = 0;; entry++) {
- if (dft_table[entry].radix == ld_dft_radix) {
- align_K = K_num * ((radix + K_num - 1) / K_num);
- __memcpy(
- nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2],
- sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM);
- break;
- }
-
- if (dft_table[entry].radix == -1) {
- break;
- }
- }
- }
- switch (radix) {
- // case 2:
- // break;
- // case 3:
- // computeRadix3ButterflyLaststage(
- // nram_out_r, nram_out_i, nram_in_r, nram_in_i,
- // nram_scratch, nram_tw, small_section_num,
- // small_butterfly_num, small_in_stride, dir);
- // break;
- // case 9:
- // computeRadix9ButterflyLaststage(
- // nram_out_r, nram_out_i, nram_in_r, nram_in_i,
- // nram_scratch, nram_tw, small_section_num,
- // small_butterfly_num, small_in_stride, dir);
- // break;
- default:
- // computeGenericButterflyLaststage(Fout, buffer, twiddles,
- // radix, section_num, butterfly_num, in_stride, 0, dir);
- computeGenericButterflyLaststageMat(
- nram_out_r, nram_out_i, nram_in_r, nram_in_i,
- nram_scratch, nram_dftmtx, nram_tw, small_section_num,
- small_butterfly_num, para_num, small_in_stride, dir,
- radix);
- break;
- }
-
- // if last-stage: stride = large_radix * 2
- // compute_id 0 r
- // compute_id 0 i
- // compute_id 1 r
- // compute_id 1 i
- // else: stride = large_radix
- // compute_id 0 r
- // compute_id 1 i
- // compute_id 0 r
- // compute_id 1 i
- // __memcpy(nram_transpose_temp.r +
- // compute_id * large_radix * (1 + (int)last_stage),
- // nram_out_r, large_radix * sizeof(DT), NRAM2NRAM);
- // __memcpy(nram_transpose_temp.i +
- // compute_id * large_radix * (1 + (int)last_stage),
- // nram_out_i, large_radix * sizeof(DT), NRAM2NRAM);
-
- if (last_stage) {
- __memcpy(nram_transpose_temp.r +
- compute_id * large_radix * (1 + (int)last_stage),
- nram_out_r, sizeof(DT) * large_radix, NRAM2NRAM,
- sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix,
- para_num - 1);
-
- __memcpy(nram_transpose_temp.i +
- compute_id * large_radix * (1 + (int)last_stage),
- nram_out_i, sizeof(DT) * large_radix, NRAM2NRAM,
- sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix,
- para_num - 1);
-
- __bang_transpose(nram_para_store_pong.r, nram_transpose_temp.r,
- para_num * 2, large_radix);
- } else {
- if (nram_out_r == nram_para_store_pong.r) {
- FFT_SWAP_PTR(nram_para_load_in_pong.r, nram_para_store_pong.r)
- FFT_SWAP_PTR(nram_para_load_in_pong.i, nram_para_store_pong.i)
- }
-
- __bang_transpose(nram_para_store_pong.r, nram_out_r, para_num,
- large_radix);
- __bang_transpose(nram_para_store_pong.i, nram_out_i, para_num,
- large_radix);
- }
-
- // __bang_transpose(nram_para_store, nram_transpose_temp.r,
- // max_para_ldst_num * 2, large_radix);
- }
- }
- }
- // __sync();
-
- __sync();
- FFT_SWAP_PTR(nram_para_load_in_ping.r, nram_para_load_in_pong.r)
- FFT_SWAP_PTR(nram_para_load_in_ping.i, nram_para_load_in_pong.i)
- FFT_SWAP_PTR(nram_para_store_ping.r, nram_para_store_pong.r)
- FFT_SWAP_PTR(nram_para_store_ping.i, nram_para_store_pong.i)
- }
- }
- // Fin_stride += large_butterfly_num;
- // Fout_stride += large_radix * large_butterfly_num;
- }
-}
-
-template
-__mlu_func__ void computeLargeButterflyLaststageBatchPingpong(
- DT *output, DT *input, const DT *cur_large_twiddles, const DT *_twiddles,
- const DT *dft_matrix, int large_section_num, int large_butterfly_num,
- int large_in_stride, void *nram_buf, const int *small_factors, int nfft,
- const int t_start, const int t_end, int dir) {
- computeLargeButterflyOtherstagesBatchPingpong(
- output, input, cur_large_twiddles, _twiddles, dft_matrix,
- large_section_num, large_butterfly_num, large_in_stride, nram_buf,
- small_factors, nfft, t_start, t_end, dir, 1);
-}
-
-template
-__mlu_func__ void computeLargeButterflyFirststageColumn(
- DT *output, DT *input, int large_in_stride, int section_num,
- const DT *twiddles, const DT *dft_matrix, void *nram_buf,
- const int *small_factors, int dir, int nfft, int last_stage,
- const int para_batch, const int nb) {
- // constant
- // const int para_batch = 3;
- const int K_num = 64 / sizeof(DT);
- int align_K = 0;
- const dft_table_entry *dft_table = (const dft_table_entry *)dft_matrix;
- // test
- // for(int i =0; i<3; i++){
- // MLULOG("entry: %d, dft_table.radix: %d, dft_table.offset: %d.\n",
- // i, dft_table[i].radix, dft_table[i].offset);
- // }
- // network info
- int radix, small_in_stride, small_stage_count, large_radix,
- _small_stage_count;
- int small_section_num, small_butterfly_num, value_mul;
- int tw_offset;
- // int max_radix = small_factors[4];
- _small_stage_count = small_factors[0];
- large_radix = small_factors[1];
- tw_offset = small_factors[2];
-
- // for (int i=2; i<= _small_stage_count; i++) {
-
- // max_radix = max(small_factors[i*4], max_radix);
-
- // }
-
- // load compute store
- // (0) load 0 ping sync()
- // (1) compute 0 ping load 1 pong sync()
- // (2) store 0 compute 1 pong load 2 ping sync()
- // (3) store 1 compute 2 load 3 sync()
-
- // compute last-large-stage (nram_out_r,nram_out_i) [2, large_radix]->
- // transpose -> [large_radix, 2]
-
- // complex array -> real array, imag array -> complex array
- // first-large-stage complex -> real array, imag array
- // other-large-stage none
- // last-large-stage real array, imag array -> complex
- const DT *small_twiddles = twiddles + tw_offset * 2; // complex
-
- // assign nram space
- int nram_buf_offset = 0;
-
- // parallel load/store space
- // sizeof(DT) * 2 * large_radix * para_batch * 4
- DT *nram_para_load_ping = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += large_radix * para_batch * 2; // complex
-
- DT *nram_para_load_pong = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += large_radix * para_batch * 2; // complex
-
- DT *nram_para_store_ping = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += large_radix * para_batch * 2; // complex
-
- DT *nram_para_store_pong = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += large_radix * para_batch * 2; // complex
-
- // transpose space: [radix, 2 * parrallel] -> [parrallel * 2, radix]
- // DT *nram_transpose_load = (DT *)nram_buf + nram_buf_offset;
- // nram_buf_offset += large_radix * para_batch * 2; // complex
-
- // FFT_CPX_T nram_transpose_temp;
- // temp out-space before transpose
- // if last-stage:
- // compute_id 0 r
- // compute_id 0 i
- // compute_id 1 r
- // compute_id 1 i
- // else:
- // compute_id 0 r
- // compute_id 1 i
- // compute_id 0 r
- // compute_id 1 i
-
- DT *_nram_tw = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += large_radix * 2; // complex
-
- // load dftmtx sample
- int ld_dft_radix = -1;
- const int max_radix = 64;
- DT *nram_dftmtx = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += max_radix * max_radix * 2; // complex
-
- // const int ld_dft_radix = 16;
- // DT *nram_dftmtx8 = (DT *)nram_buf + nram_buf_offset;
- // nram_buf_offset += 8 * 8 * 2; // complex
-
- // for (int entry = 0;; entry++) {
- // if (dft_table[entry].radix == 8) {
- // __memcpy_async(nram_dftmtx8, &dft_matrix[dft_table[entry].offset * 2],
- // sizeof(DT) * 2 * 8 * 8, GDRAM2NRAM);
- // break;
- // }
- // if (dft_table[entry].radix == -1) {
- // break;
- // }
- // }
-
- // nram space used:
- // sizeof(DT) * 2 * large_radix * (para_batch * 6 + 1) + sizeof(DT) * 2
- // * (max_radix * max_radix)
- // + sizeof(DT) * 2 * large_radix * para_batch * 4
- DT *nram_scratch = (DT *)nram_buf + nram_buf_offset;
-
- // DT *nram_transpose_temp = nram_scratch;
- // overlap nram_scratch
- // nram_transpose_temp = {
- // (DT *)nram_scratch,
- // (DT *)nram_scratch + large_radix * ((int)last_stage) +
- // large_radix * (1 - (int)last_stage) * para_batch};
- // nram_buf_offset += large_radix * para_batch * 2; // complex
-
- __memcpy_async(_nram_tw, small_twiddles, large_radix * sizeof(DT) * 2,
- SRAM2NRAM);
-
- // return;
- // ceil
- int repeat_num = section_num;
- // MLULOG("repeat_num column: %d\n", repeat_num);
- for (int repeat_id = 0; repeat_id < repeat_num + 2; ++repeat_id) {
- // pipeline: load-stage
-
- if (repeat_id < repeat_num) {
- // MLULOG("pipeline: load-stage.\n");
- int i = repeat_id;
-
- // DT *nram_dftmtx =
- // (repeat_id % 2 == 0) ? nram_dftmtx_ping : nram_dftmtx_pong;
-
- // if (section_num == 1) {
- // __memcpy_async(nram_para_load, input, sizeof(DT) * 2 * large_radix,
- // GDRAM2NRAM);
- // } else {
- // // gather load
- // // 2d memcpy
- // // 0 1 2 3 4 ... 1023
- // // GDRAM -> NRAM
- // // 8bytes radix-1024
- // // 64bytes
-
- // __memcpy_async(nram_para_load, input + i * 2,
- // sizeof(DT) * 2 * para_batch, GDRAM2NRAM,
- // sizeof(DT) * 2 * para_batch,
- // large_in_stride * sizeof(DT) * 2, large_radix - 1);
- // }
-
- // if(0)
- // __memcpy_async(nram_para_load, input, sizeof(DT) * 2 * large_radix *
- // para_batch,
- // GDRAM2NRAM);
- // if(0)
- __memcpy_async(nram_para_load_ping, input + i * 2 * nb,
- sizeof(DT) * 2 * para_batch, GDRAM2NRAM,
- sizeof(DT) * 2 * para_batch,
- nb * large_in_stride * sizeof(DT) * 2, large_radix - 1);
- }
-
- // pipeline: store-stage
-
- if (repeat_id >= 2) {
- // MLULOG("pipeline: store-stage.\n");
- int i = (repeat_id - 2);
-
- if (last_stage) {
- // if(0)
- // __memcpy_async(output + i * large_radix * 2 * nb1,
- // nram_para_store_ping,
- // para_batch * sizeof(DT) * 2, NRAM2GDRAM,
- // nb1 * 2 * sizeof(DT), para_batch * sizeof(DT) * 2,
- // large_radix - 1);
- __memcpy_async(output + i * large_radix * 2 * nb, nram_para_store_ping,
- para_batch * sizeof(DT) * 2, NRAM2GDRAM,
- nb * 2 * sizeof(DT), para_batch * sizeof(DT) * 2,
- large_radix - 1);
- // __memcpy_async(output + i * large_radix * 2 * para_batch,
- // nram_para_store_ping,
- // 2* para_batch * large_radix * sizeof(DT), NRAM2GDRAM);
- } else {
- // // real
- // __memcpy_async(output + i * large_radix * nb1, nram_para_store,
- // para_batch * sizeof(DT), NRAM2GDRAM, nb1 * sizeof(DT),
- // para_batch * sizeof(DT), large_radix - 1);
- // // imag
- // __memcpy_async(output + i * large_radix * nb1 + nb0 * nb1,
- // nram_para_store + para_batch * large_radix,
- // para_batch * sizeof(DT), NRAM2GDRAM, nb1 * sizeof(DT),
- // para_batch * sizeof(DT), large_radix - 1);
-
- // real
- __memcpy_async(output + i * large_radix * para_batch,
- nram_para_store_ping,
- para_batch * large_radix * sizeof(DT), NRAM2GDRAM);
- // imag
- __memcpy_async(
- output + i * large_radix * para_batch + nfft * para_batch,
- nram_para_store_ping + para_batch * large_radix,
- para_batch * large_radix * sizeof(DT), NRAM2GDRAM);
- }
- }
-
- // pipeline: compute-stage
-
- if (repeat_id >= 1 && repeat_id < repeat_num + 1) {
- // int i = (repeat_id - 1);
-
- // // [large_radix, para_batch, 2] -> [para_batch, 2, large_radix]
- // __bang_transpose(nram_transpose_load, nram_para_load, large_radix,
- // 2 * para_batch);
-
- // [large_radix, para_batch, 2] -> [2, para_batch, large_radix]
- // overlap nram_out_r
- // DT *nram_transpose_load = nram_out_r;
- // __bang_transpose(nram_transpose_load, nram_para_load,
- // large_radix * para_batch, 2);
- // // [large_radix, para_batch] -> [para_batch, large_radix]
- // __bang_transpose(nram_in_r, nram_transpose_load, large_radix,
- // para_batch);
- // __bang_transpose(nram_in_i,
- // nram_transpose_load + large_radix * para_batch,
- // large_radix, para_batch);
-
- DT *nram_in_r = nram_para_store_pong;
- DT *nram_in_i = nram_para_store_pong + large_radix * para_batch;
-
- DT *nram_out_r = nram_para_load_pong;
- DT *nram_out_i = nram_para_load_pong + large_radix * para_batch;
-
- // DT *nram_transpose_load = nram_in_r;
- __bang_transpose(nram_in_r, nram_para_load_pong, large_radix * para_batch,
- 2);
- // [large_radix, para_batch] -> [para_batch, large_radix]
- // __bang_transpose(nram_in_r, nram_transpose_load, large_radix,
- // para_batch);
- // __bang_transpose(nram_in_i,
- // nram_transpose_load + large_radix * para_batch,
- // large_radix, para_batch);
-
- for (int compute_id = 0; compute_id < para_batch;
- compute_id += para_batch) {
- // load real & imag
-
- radix = small_factors[4];
- small_section_num = small_factors[5];
- small_in_stride = small_factors[7];
- small_stage_count = _small_stage_count;
-
- // __memcpy(nram_in_r,
- // nram_transpose_load + compute_id * large_radix * 2,
- // large_radix * sizeof(DT) * 2, NRAM2NRAM);
-
- // first stage
-
- if (ld_dft_radix != radix) {
- ld_dft_radix = radix;
- for (int entry = 0;; entry++) {
- if (dft_table[entry].radix == ld_dft_radix) {
- align_K = K_num * ((radix + K_num - 1) / K_num);
- __memcpy(nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2],
- sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM);
- break;
- }
-
- if (dft_table[entry].radix == -1) {
- break;
- }
- }
- }
-
- switch (radix) {
- default:
- MLULOG("computeGenericButterflyFirststageMat: %d.\n", radix);
- computeGenericButterflyFirststageMat(
- nram_out_r, nram_out_i, nram_in_r, nram_in_i, nram_scratch,
- nram_dftmtx, small_section_num * para_batch,
- small_section_num * para_batch, 1, dir, radix);
- break;
- }
- // for (int j = 0; j < ld_dft_radix * align_K; j++) {
- // printf("nram_dftmtx [%d][%d]: (%f, %f).\n", (ld_dft_radix *
- // align_K)nram_dftmtx[j], nram_dftmtx[j + ld_dft_radix * align_K]);
- // }
-
- // [radix, small_section_num, para_batch] ->
- // [small_section_num, para_batch, radix] -> [para_batch,
- // small_section_num, radix]
-
- // __memcpy(nram_out_r + para_ldst_id * small_section_num * radix,
- // nram_in_r + para_ldst_id * radix, sizeof(DT) * radix,
- // NRAM2NRAM, sizeof(DT) * radix,
- // para_batch * radix * sizeof(DT), small_section_num - 1);
-
- // __memcpy(nram_out_i + para_ldst_id * small_section_num * radix,
- // nram_in_i + para_ldst_id * radix, sizeof(DT) * radix,
- // NRAM2NRAM, sizeof(DT) * radix,
- // para_batch * radix * sizeof(DT), small_section_num - 1);
- // __sync_move();
-
- small_stage_count--;
- if (small_stage_count == 0) {
- // nram to gdram
- FFT_SWAP_PTR(nram_out_r, nram_in_r);
- FFT_SWAP_PTR(nram_out_i, nram_in_i);
- __bang_transpose(nram_out_r, nram_in_r, para_batch, large_radix);
- __bang_transpose(nram_out_i, nram_in_i, para_batch, large_radix);
-
- if (nram_out_r == nram_para_store_pong) {
- FFT_SWAP_PTR(nram_para_load_pong, nram_para_store_pong)
- }
-
- if (last_stage) {
- // [2, para_batch, large_radix] -> [para_batch, large_radix,
- // 2]
- // DT* nram_transpose_store = nram_in_r;
-
- __bang_transpose(nram_para_store_pong, nram_out_r, 2,
- para_batch * large_radix);
- } else {
- // [2, para_batch, large_radix] -> [2, para_batch,
- // large_radix]
- // TODO(zrg): redundant move
- __memcpy(nram_para_store_pong, nram_out_r,
- para_batch * large_radix * sizeof(DT), NRAM2NRAM);
- __memcpy(nram_para_store_pong + para_batch * large_radix,
- nram_out_i, para_batch * large_radix * sizeof(DT),
- NRAM2NRAM);
- }
-
- continue;
- }
-
- FFT_SWAP_PTR(nram_out_r, nram_in_r);
- FFT_SWAP_PTR(nram_out_i, nram_in_i);
-
- TRANSPOSE_XYZ2YXZ_PAIR(nram_out_r, nram_out_i, nram_in_r, nram_in_i,
- small_section_num, para_batch, radix, DT)
-
- value_mul = 8;
- // DT *sram_tw = (DT *)sram_buffer;
- DT *nram_tw = _nram_tw;
-
- for (; small_stage_count > 1; small_stage_count--) {
- FFT_SWAP_PTR(nram_out_r, nram_in_r);
- FFT_SWAP_PTR(nram_out_i, nram_in_i);
-
- // value_mul = (_small_stage_count - small_stage_count + 1) << 2;
-
- // // update parameter
- radix = small_factors[value_mul++];
- small_section_num = small_factors[value_mul++];
- small_butterfly_num = small_factors[value_mul++];
- small_in_stride = small_factors[value_mul++];
- // value_mul += 4;
- // copy GDRAM2SRAM
-
- // if (compute_id == 0 && repeat_id == 1 && 0) {
- // __memcpy(nram_tw, small_twiddles,
- // small_butterfly_num * (radix - 1) * sizeof(DT) * 2,
- // GDRAM2NRAM);
- // small_twiddles += small_butterfly_num * (radix - 1) * 2;
- // }
-
- if (ld_dft_radix != radix) {
- ld_dft_radix = radix;
- for (int entry = 0;; entry++) {
- if (dft_table[entry].radix == ld_dft_radix) {
- align_K = K_num * ((radix + K_num - 1) / K_num);
- __memcpy(nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2],
- sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM);
- break;
- }
-
- if (dft_table[entry].radix == -1) {
- break;
- }
- }
- }
-
- switch (radix) {
- default:
- // computeGenericButterflyOtherstages(Fout, buffer, twiddles,
- // radix, section_num, butterfly_num, in_stride, 0, dir);
- MLULOG("computeGenericButterflyOtherstagesMat: %d.\n", radix);
- computeGenericButterflyOtherstagesMat(
- nram_out_r, nram_out_i, nram_in_r, nram_in_i, nram_scratch,
- nram_dftmtx, nram_tw, small_section_num, small_butterfly_num,
- para_batch, small_in_stride, dir, radix);
- break;
- }
+ // if (last_stage) {
+ // // [2, para_num, large_radix] -> [para_num,
+ // large_radix,
+ // // 2]
+ // // DT* nram_transpose_store = nram_in_r;
- nram_tw += small_butterfly_num * (radix - 1) * 2;
- } // for (stage_count)
+ // __bang_transpose(nram_in_r, nram_out_r, 2,
+ // max_para_num * large_radix);
- // for (int j = 0; j < large_radix; j++) {
- // MLULOG("output i: (%f, %f).\n", nram_out_r[j], nram_out_i[j]);
- // }
+ // } else {
+ // // [2, para_num, large_radix] -> [2, para_num,
+ // // large_radix]
+ // // TODO(zrg): redundant move
+ // __memcpy(nram_in_r, nram_out_r,
+ // para_num * large_radix * sizeof(DT),
+ // NRAM2NRAM);
+ // __memcpy(nram_in_i,
+ // nram_out_i, para_num * large_radix *
+ // sizeof(DT), NRAM2NRAM);
+ // }
- // MLULOG("butterfly id: %d\n", i);
- // last stage
- {
- FFT_SWAP_PTR(nram_out_r, nram_in_r);
- FFT_SWAP_PTR(nram_out_i, nram_in_i);
+ // [nfft, 2] -> [2, nfft] -> [2, nfft] -> [nfft, 2]
+ if (last_stage) {
+ __memcpy(nram_transpose_temp.r +
+ compute_id * large_radix * (1 + (int)last_stage),
+ nram_out_r, sizeof(DT) * large_radix, NRAM2NRAM,
+ sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix,
+ para_num - 1);
- // copy GDRAM2SRAM
+ __memcpy(nram_transpose_temp.i +
+ compute_id * large_radix * (1 + (int)last_stage),
+ nram_out_i, sizeof(DT) * large_radix, NRAM2NRAM,
+ sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix,
+ para_num - 1);
- // update parameter
- // value_mul = _small_stage_count << 2;
- radix = small_factors[value_mul++];
- small_section_num = small_factors[value_mul++];
- small_butterfly_num = small_factors[value_mul++];
- small_in_stride = small_factors[value_mul];
+ __bang_transpose(nram_para_store_pong.r, nram_transpose_temp.r,
+ para_num * 2, large_radix);
+ } else {
+ // rotation-large
+ __bang_mul(CPX_MUL_RR, nram_out_r + para_num,
+ nram_para_load_tw.r, para_num * (large_radix - 1));
+ __bang_mul(CPX_MUL_II, nram_out_i + para_num,
+ nram_para_load_tw.i, para_num * (large_radix - 1));
+ __bang_mul(CPX_MUL_RI, nram_out_r + para_num,
+ nram_para_load_tw.i, para_num * (large_radix - 1));
+ __bang_mul(CPX_MUL_IR, nram_out_i + para_num,
+ nram_para_load_tw.r, para_num * (large_radix - 1));
+
+ __bang_sub(nram_out_r + para_num, CPX_MUL_RR, CPX_MUL_II,
+ para_num * (large_radix - 1));
+ __bang_add(nram_out_i + para_num, CPX_MUL_RI, CPX_MUL_IR,
+ para_num * (large_radix - 1));
- // if (compute_id == 0 && repeat_id == 1 && 0) {
- // __memcpy(nram_tw, small_twiddles,
- // small_butterfly_num * (radix - 1) * sizeof(DT) * 2,
- // GDRAM2NRAM);
- // }
- if (ld_dft_radix != radix) {
- ld_dft_radix = radix;
- for (int entry = 0;; entry++) {
- if (dft_table[entry].radix == ld_dft_radix) {
- align_K = K_num * ((radix + K_num - 1) / K_num);
- __memcpy(nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2],
- sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM);
- break;
+ if (nram_out_r == nram_para_store_pong.r) {
+ FFT_SWAP_PTR(nram_para_load_in_pong.r, nram_para_store_pong.r)
+ FFT_SWAP_PTR(nram_para_load_in_pong.i, nram_para_store_pong.i)
+ }
+ __bang_transpose(nram_para_store_pong.r, nram_out_r, para_num,
+ large_radix);
+ __bang_transpose(nram_para_store_pong.i, nram_out_i, para_num,
+ large_radix);
}
- if (dft_table[entry].radix == -1) {
- break;
- }
+ continue;
}
- }
-
- switch (radix) {
- case 2:
- break;
-
- default:
- MLULOG("computeGenericButterflyLaststageMat: %d.\n", radix);
- computeGenericButterflyLaststageMat(
- nram_out_r, nram_out_i, nram_in_r, nram_in_i, nram_scratch,
- nram_dftmtx, nram_tw, small_section_num, small_butterfly_num,
- para_batch, small_in_stride, dir, radix);
- MLULOG("computeGenericButterflyLaststageMat: %d End.\n", radix);
- break;
- }
-
- // [2, para_batch, large_radix] -> [2, large_radix, para_batch]
-
- if (last_stage) {
- // TRANSPOSE_XYZ2YXZ_PAIR(nram_out_r, nram_out_i, nram_in_r,
- // nram_in_i, 2, para_batch, large_radix, DT)
-
- // [2, para_batch, large_radix] -> [para_batch, large_radix,
- // 2]
- // DT* nram_transpose_store = nram_in_r;
FFT_SWAP_PTR(nram_out_r, nram_in_r);
FFT_SWAP_PTR(nram_out_i, nram_in_i);
- __bang_transpose(nram_out_r, nram_in_r, para_batch, large_radix);
- __bang_transpose(nram_out_i, nram_in_i, para_batch, large_radix);
-
- if (nram_out_r == nram_para_store_pong) {
- FFT_SWAP_PTR(nram_para_load_pong, nram_para_store_pong)
- }
-
- __bang_transpose(nram_para_store_pong, nram_out_r, 2,
- para_batch * large_radix);
- // [2, para_batch, large_radix] -> [para_batch, 2, large_radix] ->
- // [large_radix, para_batch, 2]
- // __bang_transpose(nram_para_store, nram_out_r, para_batch * 2,
- // large_radix);
- } else {
- // [2, para_batch, large_radix] -> [2, para_batch,
- // large_radix]
- // TODO(zrg): test
- if (nram_out_r == nram_para_store_pong) {
- FFT_SWAP_PTR(nram_para_load_pong, nram_para_store_pong)
- }
-
- __bang_transpose(nram_para_store_pong, nram_out_r, para_batch,
- large_radix);
- __bang_transpose(nram_para_store_pong + para_batch * large_radix,
- nram_out_i, para_batch, large_radix);
-
- // __memcpy(nram_para_store, nram_out_r,
- // para_batch * large_radix * sizeof(DT), NRAM2NRAM);
- // __memcpy(nram_para_store + para_batch * large_radix, nram_out_i,
- // para_batch * large_radix * sizeof(DT), NRAM2NRAM);
- }
- }
- }
- }
-
- __sync();
- FFT_SWAP_PTR(nram_para_load_ping, nram_para_load_pong)
- FFT_SWAP_PTR(nram_para_store_ping, nram_para_store_pong)
- }
-}
-
-template
-__mlu_func__ void computeLargeButterflyOtherstagesColumn(
- DT *output, DT *input, const DT *cur_large_twiddles, const DT *_twiddles,
- const DT *dft_matrix, int large_section_num, int large_butterfly_num,
- int large_in_stride, void *nram_buf, const int *small_factors, int nfft,
- int dir, int last_stage, int para_batch, int nb) {
- // return;
- const dft_table_entry *dft_table = (const dft_table_entry *)dft_matrix;
-
- int radix, small_in_stride, small_stage_count, large_radix,
- _small_stage_count;
- int small_section_num, small_butterfly_num, value_mul;
-
- const int large_out_stride = large_butterfly_num;
- int tw_offset;
-
- _small_stage_count = small_factors[0];
- large_radix = small_factors[1];
- tw_offset = small_factors[2];
-
- const int K_num = 64 / sizeof(DT);
- int align_K = 0;
-
- const DT *small_twiddles = _twiddles + tw_offset * 2; // complex
- // const DT *small_twiddles; // complex
-
- // MLULOG("small_section_num: %d.\n\n\n", small_section_num);
- // max num for parallel load/store
- // const int max_para_ldst_num = 2187 / large_radix;
- // const int max_para_ldst_num = (2187 + large_radix - 1) / large_radix;
- // const int max_para_ldst_num = (512 + large_radix - 1) / large_radix;
- // const int max_para_ldst_num = 20;
- // const int max_para_ldst_num = (4096 + large_radix - 1) / large_radix;
-
- // int para_ldst_num;
- // TODO(zrg): save nram space.
- // __nram__ DT nram_space[MAX_BUTTERFLY_ON_CHIP * 2];
- // 0 1 2 3 4 5
- // 0 1 3
- // DT *nram_buf_end = (DT*)&((uint8_t*)nram_buf)[NRAM_BUFFER_SIZE];
- // FFT_CPX_T *nram_in = (FFT_CPX_T *)nram_buffer;
- // FFT_CPX_T *nram_out = &nram_in[large_radix];
- // FFT_CPX_T *nram_buf = &nram_in[large_radix * 2];
- int nram_buf_offset = 0;
- DT *nram_in_r = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += large_radix * para_batch;
-
- DT *nram_in_i = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += large_radix * para_batch;
-
- DT *nram_out_r = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += large_radix * para_batch;
-
- DT *nram_out_i = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += large_radix * para_batch;
-
- // parallel load/store space
- FFT_CPX_T nram_para_load_in_ping = {
- (DT *)nram_buf + nram_buf_offset,
- (DT *)nram_buf + nram_buf_offset + large_radix * para_batch};
- nram_buf_offset += large_radix * para_batch * 2; // complex
-
- FFT_CPX_T nram_para_load_in_pong = {
- (DT *)nram_buf + nram_buf_offset,
- (DT *)nram_buf + nram_buf_offset + large_radix * para_batch};
- nram_buf_offset += large_radix * para_batch * 2; // complex
-
- // FFT_CPX_T nram_para_load_tw_ping = {
- // (DT *)nram_buf + nram_buf_offset,
- // (DT *)nram_buf + nram_buf_offset + large_radix * para_batch};
- // nram_buf_offset += large_radix * para_batch * 2; // complex
-
- // FFT_CPX_T nram_para_load_tw_pong = {
- // (DT *)nram_buf + nram_buf_offset,
- // (DT *)nram_buf + nram_buf_offset + large_radix * para_batch};
- // nram_buf_offset += large_radix * para_batch * 2; // complex
-
- FFT_CPX_T nram_para_load_tw_ping = {
- (DT *)nram_buf + nram_buf_offset,
- (DT *)nram_buf + nram_buf_offset + (large_radix - 1)};
- nram_buf_offset += (large_radix - 1) * 2; // complex
-
- FFT_CPX_T nram_para_load_tw_pong = {
- (DT *)nram_buf + nram_buf_offset,
- (DT *)nram_buf + nram_buf_offset + (large_radix - 1)};
- nram_buf_offset += (large_radix - 1) * 2; // complex
-
- FFT_CPX_T nram_para_store_ping = {
- (DT *)nram_buf + nram_buf_offset,
- (DT *)nram_buf + nram_buf_offset + large_radix * para_batch};
- nram_buf_offset += large_radix * para_batch * 2; // complex
-
- FFT_CPX_T nram_para_store_pong = {
- (DT *)nram_buf + nram_buf_offset,
- (DT *)nram_buf + nram_buf_offset + large_radix * para_batch};
- nram_buf_offset += large_radix * para_batch * 2; // complex
-
- // overlap nram_in
- FFT_CPX_T nram_transpose_temp;
- // temp out-space before transpose
- // if last-stage:
- // compute_id 0 r
- // compute_id 0 i
- // compute_id 1 r
- // compute_id 1 i
- // else:
- // compute_id 0 r
- // compute_id 1 i
- // compute_id 0 r
- // compute_id 1 i
- nram_transpose_temp = {(DT *)nram_in_r,
- (DT *)nram_in_r + large_radix * ((int)last_stage) +
- large_radix * (1 - (int)last_stage) * para_batch};
- // nram_buf_offset += large_radix * para_batch * 2; // complex
-
- DT *_nram_tw = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += large_radix * 2; // complex
-
- // transpose space: [radix, 2 * parrallel] -> [parrallel * 2, radix]
- // FFT_CPX_T nram_transpose_load = {
- // (DT *)nram_buf + nram_buf_offset,
- // (DT *)nram_buf + nram_buf_offset + large_radix * para_batch};
- // nram_buf_offset += large_radix * para_batch * 2; // complex
-
- // load dftmtx sample
- // const int ld_dft_radix = 16;
- // DT *nram_dftmtx = (DT *)nram_buf + nram_buf_offset;
- // nram_buf_offset += ld_dft_radix * ld_dft_radix * 2; // complex
-
- // for (int entry = 0;; entry++) {
- // if (dft_table[entry].radix == ld_dft_radix) {
- // __memcpy(nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2],
- // sizeof(DT) * 2 * ld_dft_radix * ld_dft_radix, GDRAM2NRAM);
- // break;
- // }
-
- // if (dft_table[entry].radix == -1) {
- // break;
- // }
- // }
- int ld_dft_radix = -1;
- const int max_radix = 64;
- DT *nram_dftmtx = (DT *)nram_buf + nram_buf_offset;
- nram_buf_offset += max_radix * max_radix * 2; // complex
-
- DT *nram_scratch = (DT *)nram_buf + nram_buf_offset;
-
- // temp overlap with "nram_scratch"
- DT *CPX_MUL_RR = nram_scratch;
- DT *CPX_MUL_RI = &CPX_MUL_RR[(large_radix - 1) * para_batch];
- DT *CPX_MUL_IR = &CPX_MUL_RI[(large_radix - 1) * para_batch];
- DT *CPX_MUL_II = &CPX_MUL_IR[(large_radix - 1) * para_batch];
-
- nram_buf_offset += (large_radix - 1) * para_batch * 4; // complex
-
- // size: (large_radix - 1) * para_batch
- // DT *scratch_tw_r = &CPX_MUL_II[large_radix * para_batch];
- // DT *scratch_tw_i = &scratch_tw_r[(large_radix - 1) * para_batch];
-
- // if (nram_buf == NULL) {
- // MLULOG("nram_buf: NULL.\n");
- // }
- // if (input == NULL) {
- // MLULOG("input: NULL.\n");
- // }
- // if (output == NULL) {
- // MLULOG("output: NULL.\n");
- // }
-
- // if (cur_large_twiddles == NULL) {
- // MLULOG("twiddles: NULL.\n");
- // }
-
- // __nram__ DT *nram_buf = &nram_space[MAX_BUTTERFLY_ON_CHIP*2];
-
- // DT *odd_extra_buffer = buffer + nfft*2; // for in_place temp buffer
-
- // const int para_num = 1;
-
- int Fin_stride = 0, Fout_stride = 0;
- int sec_count;
- int repeat_num = large_butterfly_num;
-
- for (sec_count = 0; sec_count < large_section_num; ++sec_count) {
- for (int repeat_id = 0; repeat_id < repeat_num + 2; ++repeat_id) {
- // small_twiddles = _small_twiddles;
-
- // pipeline: load-stage
-
- if (repeat_id < repeat_num) {
- // MLULOG("pipeline: load-stage.\n");
- int i = repeat_id;
- FFT_CPX_T nram_para_load_in = (repeat_id % 2 == 0)
- ? nram_para_load_in_ping
- : nram_para_load_in_pong;
-
- FFT_CPX_T nram_para_load_tw = (repeat_id % 2 == 0)
- ? nram_para_load_tw_ping
- : nram_para_load_tw_pong;
-
- if (para_batch != 1 || 1) {
- __memcpy_async(
- nram_para_load_in.r, input + (Fin_stride + i) * para_batch,
- sizeof(DT) * para_batch, GDRAM2NRAM, sizeof(DT) * para_batch,
- para_batch * large_in_stride * sizeof(DT), large_radix - 1);
- __memcpy_async(
- nram_para_load_in.i,
- input + para_batch * nfft + (Fin_stride + i) * para_batch,
- sizeof(DT) * para_batch, GDRAM2NRAM, sizeof(DT) * para_batch,
- para_batch * large_in_stride * sizeof(DT), large_radix - 1);
-
- // __memcpy_async(nram_para_load_tw.r, cur_large_twiddles + i,
- // sizeof(DT), SRAM2NRAM, sizeof(DT),
- // large_out_stride * sizeof(DT), large_radix - 2);
- // __memcpy_async(
- // nram_para_load_tw.i,
- // cur_large_twiddles + large_butterfly_num * (large_radix - 1) +
- // i, sizeof(DT), SRAM2NRAM, sizeof(DT), large_out_stride *
- // sizeof(DT), large_radix - 2);
- __memcpy_async(nram_para_load_tw.r,
- cur_large_twiddles + i * (large_radix - 1),
- sizeof(DT) * (large_radix - 1) * 2, SRAM2NRAM);
- __memcpy_async(nram_para_load_tw.i,
- cur_large_twiddles +
- large_butterfly_num * (large_radix - 1) +
- i * (large_radix - 1),
- sizeof(DT) * (large_radix - 1), SRAM2NRAM);
- }
- }
-
- // pipeline: store-stage
- if (repeat_id >= 2) {
- // MLULOG("pipeline: store-stage.\n");
- int i = (repeat_id - 2);
-
- // int para_store_num = (max_para_ldst_num > (large_butterfly_num - i))
- // ? (large_butterfly_num - i)
- // : max_para_ldst_num;
-
- FFT_CPX_T nram_para_store =
- (repeat_id % 2 == 0) ? nram_para_store_ping : nram_para_store_pong;
-
- if (last_stage) {
- // __memcpy_async(
- // output + (Fout_stride + i * large_radix) * 2,
- // nram_para_store.r,
- // sizeof(DT) * 2 * para_store_num * large_radix, NRAM2GDRAM);
-
- __memcpy_async(output + (Fout_stride + i) * 2 * nb, nram_para_store.r,
- sizeof(DT) * 2 * para_batch, NRAM2GDRAM,
- nb * large_out_stride * 2 * sizeof(DT),
- sizeof(DT) * 2 * para_batch, large_radix - 1);
- } else {
- // // real
- // __memcpy_async(output + Fout_stride + i * large_radix,
- // nram_para_store.r,
- // para_store_num * large_radix * sizeof(DT),
- // NRAM2GDRAM);
- // // imag
- // __memcpy_async(output + Fout_stride + i * large_radix + nfft,
- // nram_para_store.i,
- // para_store_num * large_radix * sizeof(DT),
- // NRAM2GDRAM);
- // real
- __memcpy_async(output + (Fout_stride + i) * para_batch,
- nram_para_store.r, para_batch * sizeof(DT), NRAM2GDRAM,
- para_batch * large_out_stride * sizeof(DT),
- sizeof(DT) * para_batch, large_radix - 1);
- // imag
- __memcpy_async(
- output + (Fout_stride + i) * para_batch + nfft * para_batch,
- nram_para_store.i, para_batch * sizeof(DT), NRAM2GDRAM,
- para_batch * large_out_stride * sizeof(DT),
- sizeof(DT) * para_batch, large_radix - 1);
- }
- }
- // __sync();
- // pipeline: compute-stage
-
- if (repeat_id >= 1 && repeat_id < repeat_num + 1) {
- // MLULOG("pipeline: compute-stage.\n");
- // int i = (repeat_id - 1);
-
- FFT_CPX_T nram_para_load_in = (repeat_id % 2 != 0)
- ? nram_para_load_in_ping
- : nram_para_load_in_pong;
-
- FFT_CPX_T nram_para_load_tw = (repeat_id % 2 != 0)
- ? nram_para_load_tw_ping
- : nram_para_load_tw_pong;
-
- FFT_CPX_T nram_para_store =
- (repeat_id % 2 != 0) ? nram_para_store_ping : nram_para_store_pong;
-
- // __bang_transpose(nram_transpose_load, nram_para_load, large_radix,
- // 2 * para_ldst_num);
-
- // rotation-large
- // __bang_mul(CPX_MUL_RR, nram_para_load_in.r + para_batch,
- // nram_para_load_tw.r, para_batch * (large_radix - 1));
- // __bang_mul(CPX_MUL_II, nram_para_load_in.i + para_batch,
- // nram_para_load_tw.i, para_batch * (large_radix - 1));
- // __bang_mul(CPX_MUL_RI, nram_para_load_in.r + para_batch,
- // nram_para_load_tw.i, para_batch * (large_radix - 1));
- // __bang_mul(CPX_MUL_IR, nram_para_load_in.i + para_batch,
- // nram_para_load_tw.r, para_batch * (large_radix - 1));
-
- // __bang_sub(nram_para_load_in.r + para_batch, CPX_MUL_RR, CPX_MUL_II,
- // para_batch * (large_radix - 1));
- // __bang_add(nram_para_load_in.i + para_batch, CPX_MUL_RI, CPX_MUL_IR,
- // para_batch * (large_radix - 1));
- if (1) {
- for (int i = 1; i < large_radix; i++) {
- // __memcpy(&Fin.r[nram_in_offset],
- // &nram_in_r[butterfly_num * section_num * i],
- // butterfly_num * section_num * sizeof(DT), NRAM2NRAM,
- // butterfly_num * section_num * sizeof(DT),
- // butterfly_num * section_num * radix * sizeof(DT),
- // para_large_butterfly - 1);
-
- // __memcpy(&Fin.i[nram_in_offset],
- // &nram_in_i[butterfly_num * section_num * i],
- // butterfly_num * section_num * sizeof(DT), NRAM2NRAM,
- // butterfly_num * section_num * sizeof(DT),
- // butterfly_num * section_num * radix * sizeof(DT),
- // para_large_butterfly - 1);
-
- __bang_mul_scalar(&CPX_MUL_RR[(i - 1) * para_batch],
- nram_para_load_in.r + para_batch * i,
- nram_para_load_tw.r[(i - 1)], para_batch);
- __bang_mul_scalar(&CPX_MUL_RI[(i - 1) * para_batch],
- nram_para_load_in.r + para_batch * i,
- nram_para_load_tw.i[(i - 1)], para_batch);
- __bang_mul_scalar(&CPX_MUL_IR[(i - 1) * para_batch],
- nram_para_load_in.i + para_batch * i,
- nram_para_load_tw.r[(i - 1)], para_batch);
- __bang_mul_scalar(&CPX_MUL_II[(i - 1) * para_batch],
- nram_para_load_in.i + para_batch * i,
- nram_para_load_tw.i[(i - 1)], para_batch);
-
- // __bang_cycle_mul(&CPX_MUL_RR[(i - 1) * para_batch],
- // nram_para_load_in.r + para_batch * i,
- // &nram_para_load_tw.r[(i - 1) *
- // large_butterfly_num], para_batch, 1);
- // __bang_cycle_mul(&CPX_MUL_RI[(i - 1) * para_batch],
- // nram_para_load_in.r + para_batch * i,
- // &nram_para_load_tw.i[(i - 1) *
- // large_butterfly_num], para_batch, 1);
- // __bang_cycle_mul(&CPX_MUL_IR[(i - 1) * para_batch],
- // nram_para_load_in.i + para_batch * i,
- // &nram_para_load_tw.r[(i - 1) *
- // large_butterfly_num], para_batch, 1);
- // __bang_cycle_mul(&CPX_MUL_II[(i - 1) * para_batch],
- // nram_para_load_in.i + para_batch * i,
- // &nram_para_load_tw.i[(i - 1) *
- // large_butterfly_num], para_batch, 1);
- }
- __bang_sub(nram_para_load_in.r + para_batch, CPX_MUL_RR, CPX_MUL_II,
- para_batch * (large_radix - 1));
- __bang_add(nram_para_load_in.i + para_batch, CPX_MUL_RI, CPX_MUL_IR,
- para_batch * (large_radix - 1));
- }
- // __bang_transpose(nram_transpose_load.r, nram_para_load_in.r,
- // large_radix, para_batch);
- // __bang_transpose(nram_transpose_load.i, nram_para_load_in.i,
- // large_radix, para_batch);
-
- // for (int compute_id = 0; compute_id < para_batch; compute_id++) {
- for (int compute_id = 0; compute_id < para_batch;
- compute_id += para_batch) {
- // load real & imag
-
- radix = small_factors[4];
- small_section_num = small_factors[5];
- small_in_stride = small_factors[7];
- small_stage_count = _small_stage_count;
-
- // __memcpy(nram_in_r,
- // nram_transpose_load + compute_id * large_radix * 2,
- // large_radix * sizeof(DT) * 2, NRAM2NRAM);
-
- // first stage
- // if(0)
- if (ld_dft_radix != radix) {
- ld_dft_radix = radix;
- for (int entry = 0;; entry++) {
- if (dft_table[entry].radix == ld_dft_radix) {
- align_K = K_num * ((radix + K_num - 1) / K_num);
- __memcpy(nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2],
- sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM);
- break;
- }
-
- if (dft_table[entry].radix == -1) {
- break;
- }
- }
- }
+ // DT* nram_transpose_store = nram_in_r;
- switch (radix) {
- default:
- // computeGenericButterflyFirststage(Fout, buffer, twiddles,
- // radix, section_num, butterfly_num, in_stride, 0, dir);
- MLULOG("computeGenericButterflyFirststageMat: %d.\n", radix);
+ // after first stage: [butterfly_num, para_ldst_num, radix]
+ // other in: [para_ldst_num, butterfly_num, radix] ==
+ // [para_ldst_num, large_radix]
+ TRANSPOSE_XYZ2YXZ_PAIR(nram_out_r, nram_out_i, nram_in_r, nram_in_i,
+ small_section_num, para_num, radix, DT)
- computeGenericButterflyFirststageMat(
- nram_out_r, nram_out_i, nram_para_load_in.r,
- nram_para_load_in.i, nram_scratch, nram_dftmtx,
- small_section_num * para_batch,
- small_section_num * para_batch, 1, dir, radix);
- break;
- }
- // for (int j = 0; j < large_radix; j++) {
- // MLULOG("output i: (%f, %f).\n", nram_out_r[j], nram_out_i[j]);
- // }
+ // DT *sram_tw = (DT *)sram_buffer;
+ DT *nram_tw = _nram_tw;
+ value_mul = 8;
- // [radix, small_section_num, para_ldst_num] ->
- // [small_section_num, para_ldst_num, radix] ->
- // [para_ldst_num, small_section_num, radix] ->
- // [small_section_num, radix, para_ldst_num] ==
- // [large_radix, para_ldst_num]
+ for (; small_stage_count > 1; small_stage_count--) {
+ FFT_SWAP_PTR(nram_out_r, nram_in_r);
+ FFT_SWAP_PTR(nram_out_i, nram_in_i);
- small_stage_count--;
- if (small_stage_count == 0) {
- // FFT_SWAP_PTR(nram_out_r, nram_in_r);
- // FFT_SWAP_PTR(nram_out_i, nram_in_i);
- // __bang_transpose(nram_out_r, nram_in_r, para_batch, large_radix);
- // __bang_transpose(nram_out_i, nram_in_i, para_batch, large_radix);
+ // value_mul = (_small_stage_count - small_stage_count + 1) * 4;
+ // // update parameter
+ radix = small_factors[value_mul++];
+ small_section_num = small_factors[value_mul++];
+ small_butterfly_num = small_factors[value_mul++];
+ small_in_stride = small_factors[value_mul++];
+ // copy GDRAM2SRAM
- // nram to gdram
+ if (ld_dft_radix != radix) {
+ ld_dft_radix = radix;
+ for (int entry = 0;; entry++) {
+ if (dft_table[entry].radix == ld_dft_radix) {
+ align_K = K_num * ((radix + K_num - 1) / K_num);
+ __memcpy(
+ nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2],
+ sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM);
+ break;
+ }
- // if (last_stage) {
- // // [2, para_ldst_num, large_radix] -> [para_ldst_num,
- // large_radix,
- // // 2]
- // // DT* nram_transpose_store = nram_in_r;
+ if (dft_table[entry].radix == -1) {
+ break;
+ }
+ }
+ }
- // __bang_transpose(nram_in_r, nram_out_r, 2,
- // max_para_ldst_num * large_radix);
-
- // } else {
- // // [2, para_ldst_num, large_radix] -> [2, para_ldst_num,
- // // large_radix]
- // // TODO(zrg): redundant move
- // __memcpy(nram_in_r, nram_out_r,
- // para_ldst_num * large_radix * sizeof(DT), NRAM2NRAM);
- // __memcpy(nram_in_i,
- // nram_out_i, para_ldst_num * large_radix * sizeof(DT),
- // NRAM2NRAM);
- // }
+ if (sec_count == 0 && compute_id == 0 && repeat_id == 1) {
+ __memcpy(nram_tw, small_twiddles,
+ small_butterfly_num * (radix - 1) * sizeof(DT) * 2,
+ SRAM2NRAM);
+ small_twiddles += small_butterfly_num * (radix - 1) * 2;
+ }
- if (last_stage) {
- __memcpy(nram_transpose_temp.r + compute_id * large_radix * 2,
- nram_out_r, sizeof(DT) * large_radix, NRAM2NRAM,
- sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix,
- para_batch - 1);
-
- __memcpy(nram_transpose_temp.i + compute_id * large_radix * 2,
- nram_out_i, sizeof(DT) * large_radix, NRAM2NRAM,
- sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix,
- para_batch - 1);
-
- __bang_transpose(nram_para_store.r, nram_transpose_temp.r,
- para_batch * 2, large_radix);
- } else {
- __bang_transpose(nram_para_store.r, nram_out_r, para_batch,
- large_radix);
- __bang_transpose(nram_para_store.i, nram_out_i, para_batch,
- large_radix);
- }
+ switch (radix) {
+ default:
+ // computeGenericButterflyOtherstages(Fout, buffer, twiddles,
+ // radix, section_num, butterfly_num, in_stride, 0, dir);
+ computeGenericButterflyOtherstagesMat(
+ nram_out_r, nram_out_i, nram_in_r, nram_in_i,
+ nram_scratch, nram_dftmtx, nram_tw, small_section_num,
+ small_butterfly_num, para_num, small_in_stride, dir,
+ radix);
+ break;
+ }
- continue;
- }
+ nram_tw += small_butterfly_num * (radix - 1) * 2;
+ } // for (stage_count)
- FFT_SWAP_PTR(nram_out_r, nram_in_r);
- FFT_SWAP_PTR(nram_out_i, nram_in_i);
- // DT* nram_transpose_store = nram_in_r;
-
- // for (int para_ldst_id = 0; para_ldst_id < para_ldst_num;
- // para_ldst_id++) {
- // __memcpy(nram_out_r + para_ldst_id * small_section_num * radix,
- // nram_in_r + para_ldst_id * radix, sizeof(DT) * radix,
- // NRAM2NRAM, sizeof(DT) * radix,
- // para_ldst_num * radix * sizeof(DT), small_section_num -
- // 1);
-
- // __memcpy(nram_out_i + para_ldst_id * small_section_num * radix,
- // nram_in_i + para_ldst_id * radix, sizeof(DT) * radix,
- // NRAM2NRAM, sizeof(DT) * radix,
- // para_ldst_num * radix * sizeof(DT), small_section_num -
- // 1);
- // }
+ // for (int j = 0; j < large_radix; j++) {
+ // MLULOG("output i: (%f, %f).\n", nram_out_r[j], nram_out_i[j]);
+ // }
- TRANSPOSE_XYZ2YXZ_PAIR(nram_out_r, nram_out_i, nram_in_r, nram_in_i,
- small_section_num, para_batch, radix, DT)
-
- // TODO(zrg) : add not last-stage
- // if (small_stage_count == 0) {
- // // if last-stage: stride = large_radix * 2
- // // compute_id 0 r
- // // compute_id 0 i
- // // compute_id 1 r
- // // compute_id 1 i
- // // else: stride = large_radix
- // // compute_id 0 r
- // // compute_id 1 i
- // // compute_id 0 r
- // // compute_id 1 i
-
- // // [radix, small_section_num, para_ldst_num] ->
- // // [small_section_num, para_ldst_num, radix] ->
- // // [para_ldst_num, small_section_num, radix] ->
- // // [small_section_num, radix, para_ldst_num] ==
- // // [large_radix, para_ldst_num]
-
- // __memcpy(nram_transpose_temp.r +
- // compute_id * large_radix * (1 + (int)last_stage),
- // nram_out_r, sizeof(DT) * large_radix, NRAM2NRAM,
- // sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix,
- // para_ldst_num - 1);
-
- // __memcpy(nram_transpose_temp.i +
- // compute_id * large_radix * (1 + (int)last_stage),
- // nram_out_i, sizeof(DT) * large_radix, NRAM2NRAM,
- // sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix,
- // para_ldst_num - 1);
-
- // // __memcpy(nram_transpose_temp.r +
- // // compute_id * large_radix * (1 + (int)last_stage),
- // // nram_out_r, large_radix * sizeof(DT), NRAM2NRAM);
- // // __memcpy(nram_transpose_temp.i +
- // // compute_id * large_radix * (1 + (int)last_stage),
- // // nram_out_i, large_radix * sizeof(DT), NRAM2NRAM);
-
- // // __bang_transpose(nram_transpose_temp.r, nram_transpose_temp.r,
- // // max_para_ldst_num * 2, large_radix);
- // continue;
- // }
+ // MLULOG("butterfly id: %d\n", i);
+ // last stage
+ {
+ FFT_SWAP_PTR(nram_out_r, nram_in_r);
+ FFT_SWAP_PTR(nram_out_i, nram_in_i);
+ // copy GDRAM2SRAM
- // DT *sram_tw = (DT *)sram_buffer;
- DT *nram_tw = _nram_tw;
- value_mul = 8;
+ // update parameter
+ radix = small_factors[value_mul++];
+ small_section_num = small_factors[value_mul++];
+ small_butterfly_num = small_factors[value_mul++];
+ small_in_stride = small_factors[value_mul];
- for (; small_stage_count > 1; small_stage_count--) {
- FFT_SWAP_PTR(nram_out_r, nram_in_r);
- FFT_SWAP_PTR(nram_out_i, nram_in_i);
+ if (sec_count == 0 && compute_id == 0 && repeat_id == 1) {
+ __memcpy(nram_tw, small_twiddles,
+ small_butterfly_num * (radix - 1) * sizeof(DT) * 2,
+ SRAM2NRAM);
+ }
- // value_mul = (_small_stage_count - small_stage_count + 1) * 4;
- // // update parameter
- radix = small_factors[value_mul++];
- small_section_num = small_factors[value_mul++];
- small_butterfly_num = small_factors[value_mul++];
- small_in_stride = small_factors[value_mul++];
- // copy GDRAM2SRAM
+ if (ld_dft_radix != radix) {
+ ld_dft_radix = radix;
+ for (int entry = 0;; entry++) {
+ if (dft_table[entry].radix == ld_dft_radix) {
+ align_K = K_num * ((radix + K_num - 1) / K_num);
+ __memcpy(
+ nram_dftmtx, &dft_matrix[dft_table[entry].offset * 2],
+ sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM);
+ break;
+ }
- if (ld_dft_radix != radix) {
- ld_dft_radix = radix;
- for (int entry = 0;; entry++) {
- if (dft_table[entry].radix == ld_dft_radix) {
- align_K = K_num * ((radix + K_num - 1) / K_num);
- __memcpy(nram_dftmtx,
- &dft_matrix[dft_table[entry].offset * 2],
- sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM);
- break;
+ if (dft_table[entry].radix == -1) {
+ break;
+ }
}
-
- if (dft_table[entry].radix == -1) {
+ }
+ switch (radix) {
+ // case 2:
+ // break;
+ // case 3:
+ // computeRadix3ButterflyLaststage(
+ // nram_out_r, nram_out_i, nram_in_r, nram_in_i,
+ // nram_scratch, nram_tw, small_section_num,
+ // small_butterfly_num, small_in_stride, dir);
+ // break;
+ // case 9:
+ // computeRadix9ButterflyLaststage(
+ // nram_out_r, nram_out_i, nram_in_r, nram_in_i,
+ // nram_scratch, nram_tw, small_section_num,
+ // small_butterfly_num, small_in_stride, dir);
+ // break;
+ default:
+ // computeGenericButterflyLaststage(Fout, buffer, twiddles,
+ // radix, section_num, butterfly_num, in_stride, 0, dir);
+ computeGenericButterflyLaststageMat(
+ nram_out_r, nram_out_i, nram_in_r, nram_in_i,
+ nram_scratch, nram_dftmtx, nram_tw, small_section_num,
+ small_butterfly_num, para_num, small_in_stride, dir,
+ radix);
break;
- }
}
- }
-
- if (sec_count == 0 && compute_id == 0 && repeat_id == 1) {
- __memcpy(nram_tw, small_twiddles,
- small_butterfly_num * (radix - 1) * sizeof(DT) * 2,
- SRAM2NRAM);
- small_twiddles += small_butterfly_num * (radix - 1) * 2;
- }
-
- switch (radix) {
- // case 2:
- // // computeRadix2ButterflyOtherstages(Fout, Fin, section_num,
- // // section_num, 1, dir);
- // break;
- // case 3:
- // computeRadix3ButterflyOtherstages(
- // nram_out_r, nram_out_i, nram_in_r, nram_in_i,
- // nram_scratch, nram_tw, small_section_num,
- // small_butterfly_num, small_in_stride, dir);
- // break;
- // case 9:
- // computeRadix9ButterflyOtherstages(
- // nram_out_r, nram_out_i, nram_in_r, nram_in_i,
- // nram_scratch, nram_tw, small_section_num,
- // small_butterfly_num, small_in_stride, dir);
- // break;
- default:
- // computeGenericButterflyOtherstages(Fout, buffer, twiddles,
- // radix, section_num, butterfly_num, in_stride, 0, dir);
- computeGenericButterflyOtherstagesMat(
- nram_out_r, nram_out_i, nram_in_r, nram_in_i, nram_scratch,
- nram_dftmtx, nram_tw, small_section_num,
- small_butterfly_num, para_batch, small_in_stride, dir,
- radix);
- break;
- }
- nram_tw += small_butterfly_num * (radix - 1) * 2;
- } // for (stage_count)
-
- // for (int j = 0; j < large_radix; j++) {
- // MLULOG("output i: (%f, %f).\n", nram_out_r[j], nram_out_i[j]);
- // }
+ // if last-stage: stride = large_radix * 2
+ // compute_id 0 r
+ // compute_id 0 i
+ // compute_id 1 r
+ // compute_id 1 i
+ // else: stride = large_radix
+ // compute_id 0 r
+ // compute_id 1 i
+ // compute_id 0 r
+ // compute_id 1 i
+ // __memcpy(nram_transpose_temp.r +
+ // compute_id * large_radix * (1 + (int)last_stage),
+ // nram_out_r, large_radix * sizeof(DT), NRAM2NRAM);
+ // __memcpy(nram_transpose_temp.i +
+ // compute_id * large_radix * (1 + (int)last_stage),
+ // nram_out_i, large_radix * sizeof(DT), NRAM2NRAM);
- // MLULOG("butterfly id: %d\n", i);
- // last stage
- {
- FFT_SWAP_PTR(nram_out_r, nram_in_r);
- FFT_SWAP_PTR(nram_out_i, nram_in_i);
- // copy GDRAM2SRAM
+ // rotation-large
+ __bang_mul(CPX_MUL_RR, nram_out_r + para_num, nram_para_load_tw.r,
+ para_num * (large_radix - 1));
+ __bang_mul(CPX_MUL_II, nram_out_i + para_num, nram_para_load_tw.i,
+ para_num * (large_radix - 1));
+ __bang_mul(CPX_MUL_RI, nram_out_r + para_num, nram_para_load_tw.i,
+ para_num * (large_radix - 1));
+ __bang_mul(CPX_MUL_IR, nram_out_i + para_num, nram_para_load_tw.r,
+ para_num * (large_radix - 1));
+
+ __bang_sub(nram_out_r + para_num, CPX_MUL_RR, CPX_MUL_II,
+ para_num * (large_radix - 1));
+ __bang_add(nram_out_i + para_num, CPX_MUL_RI, CPX_MUL_IR,
+ para_num * (large_radix - 1));
- // update parameter
- radix = small_factors[value_mul++];
- small_section_num = small_factors[value_mul++];
- small_butterfly_num = small_factors[value_mul++];
- small_in_stride = small_factors[value_mul];
+ if (last_stage) {
+ __memcpy(nram_transpose_temp.r +
+ compute_id * large_radix * (1 + (int)last_stage),
+ nram_out_r, sizeof(DT) * large_radix, NRAM2NRAM,
+ sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix,
+ para_num - 1);
- if (sec_count == 0 && compute_id == 0 && repeat_id == 1) {
- __memcpy(nram_tw, small_twiddles,
- small_butterfly_num * (radix - 1) * sizeof(DT) * 2,
- SRAM2NRAM);
- }
+ __memcpy(nram_transpose_temp.i +
+ compute_id * large_radix * (1 + (int)last_stage),
+ nram_out_i, sizeof(DT) * large_radix, NRAM2NRAM,
+ sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix,
+ para_num - 1);
- if (ld_dft_radix != radix) {
- ld_dft_radix = radix;
- for (int entry = 0;; entry++) {
- if (dft_table[entry].radix == ld_dft_radix) {
- align_K = K_num * ((radix + K_num - 1) / K_num);
- __memcpy(nram_dftmtx,
- &dft_matrix[dft_table[entry].offset * 2],
- sizeof(DT) * 2 * ld_dft_radix * align_K, SRAM2NRAM);
- break;
+ __bang_transpose(nram_para_store_pong.r, nram_transpose_temp.r,
+ para_num * 2, large_radix);
+ } else {
+ if (nram_out_r == nram_para_store_pong.r) {
+ FFT_SWAP_PTR(nram_para_load_in_pong.r, nram_para_store_pong.r)
+ FFT_SWAP_PTR(nram_para_load_in_pong.i, nram_para_store_pong.i)
}
- if (dft_table[entry].radix == -1) {
- break;
- }
+ __bang_transpose(nram_para_store_pong.r, nram_out_r, para_num,
+ large_radix);
+ __bang_transpose(nram_para_store_pong.i, nram_out_i, para_num,
+ large_radix);
}
- }
- switch (radix) {
- // case 2:
- // break;
- // case 3:
- // computeRadix3ButterflyLaststage(
- // nram_out_r, nram_out_i, nram_in_r, nram_in_i,
- // nram_scratch, nram_tw, small_section_num,
- // small_butterfly_num, small_in_stride, dir);
- // break;
- // case 9:
- // computeRadix9ButterflyLaststage(
- // nram_out_r, nram_out_i, nram_in_r, nram_in_i,
- // nram_scratch, nram_tw, small_section_num,
- // small_butterfly_num, small_in_stride, dir);
- // break;
- default:
- // computeGenericButterflyLaststage(Fout, buffer, twiddles,
- // radix, section_num, butterfly_num, in_stride, 0, dir);
- computeGenericButterflyLaststageMat(
- nram_out_r, nram_out_i, nram_in_r, nram_in_i, nram_scratch,
- nram_dftmtx, nram_tw, small_section_num,
- small_butterfly_num, para_batch, small_in_stride, dir,
- radix);
- break;
- }
- if (last_stage) {
- // [2, para_batch, large_radix] -> [large_radix, para_batch, 2]
- __memcpy(nram_transpose_temp.r + compute_id * large_radix * 2,
- nram_out_r, sizeof(DT) * large_radix, NRAM2NRAM,
- sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix,
- para_batch - 1);
-
- __memcpy(nram_transpose_temp.i + compute_id * large_radix * 2,
- nram_out_i, sizeof(DT) * large_radix, NRAM2NRAM,
- sizeof(DT) * large_radix * 2, sizeof(DT) * large_radix,
- para_batch - 1);
-
- __bang_transpose(nram_para_store.r, nram_transpose_temp.r,
- para_batch * 2, large_radix);
- } else {
- __bang_transpose(nram_para_store.r, nram_out_r, para_batch,
- large_radix);
- __bang_transpose(nram_para_store.i, nram_out_i, para_batch,
- large_radix);
+ // __bang_transpose(nram_para_store, nram_transpose_temp.r,
+ // max_para_ldst_num * 2, large_radix);
}
-
- // __bang_transpose(nram_para_store, nram_transpose_temp.r,
- // max_para_ldst_num * 2, large_radix);
}
}
- }
+ // __sync();
- __sync();
+ __sync();
+ FFT_SWAP_PTR(nram_para_load_in_ping.r, nram_para_load_in_pong.r)
+ FFT_SWAP_PTR(nram_para_load_in_ping.i, nram_para_load_in_pong.i)
+ FFT_SWAP_PTR(nram_para_store_ping.r, nram_para_store_pong.r)
+ FFT_SWAP_PTR(nram_para_store_ping.i, nram_para_store_pong.i)
+ }
}
- Fin_stride += large_butterfly_num;
- Fout_stride += large_radix * large_butterfly_num;
+ // Fin_stride += large_butterfly_num;
+ // Fout_stride += large_radix * large_butterfly_num;
}
}
template
-__mlu_func__ void computeLargeButterflyLaststageColumn(
+__mlu_func__ void computeLargeButterflyLaststageBatchPingpongC2R(
DT *output, DT *input, const DT *cur_large_twiddles, const DT *_twiddles,
const DT *dft_matrix, int large_section_num, int large_butterfly_num,
int large_in_stride, void *nram_buf, const int *small_factors, int nfft,
- int dir, int para_batch, int nb) {
- computeLargeButterflyOtherstagesColumn(
+ const int t_start, const int t_end, int dir) {
+ computeLargeButterflyOtherstagesBatchPingpongC2R(
output, input, cur_large_twiddles, _twiddles, dft_matrix,
large_section_num, large_butterfly_num, large_in_stride, nram_buf,
- small_factors, nfft, dir, 1, para_batch, nb);
+ small_factors, nfft, t_start, t_end, dir, 1);
}
diff --git a/kernels/fft/irfft/irfft.h b/kernels/fft/irfft/irfft.h
index ccfd02896..b4f305793 100644
--- a/kernels/fft/irfft/irfft.h
+++ b/kernels/fft/irfft/irfft.h
@@ -33,8 +33,8 @@ mluOpStatus_t setIRFFT1dReserveArea(mluOpHandle_t handle,
const std::string api);
mluOpStatus_t setIRFFT1dReserveArea_v2(mluOpHandle_t handle,
- mluOpFFTPlan_t fft_plan,
- const std::string api);
+ mluOpFFTPlan_t fft_plan,
+ const std::string api);
mluOpStatus_t execIRFFT1d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan,
const void *input, const float scale_factor,
diff --git a/kernels/fft/irfft/irfft_host.cpp b/kernels/fft/irfft/irfft_host.cpp
index 95f510a9b..ecee58a08 100644
--- a/kernels/fft/irfft/irfft_host.cpp
+++ b/kernels/fft/irfft/irfft_host.cpp
@@ -546,8 +546,8 @@ mluOpStatus_t setIRFFT1dReserveArea(mluOpHandle_t handle,
}
mluOpStatus_t setIRFFT1dReserveArea_v2(mluOpHandle_t handle,
- mluOpFFTPlan_t fft_plan,
- const std::string api) {
+ mluOpFFTPlan_t fft_plan,
+ const std::string api) {
mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
VLOG(5) << "Into configure IRFFT1d ReserveArea Addrs (zrg)";