Skip to content

Commit

Permalink
[Feature](mluOpExecFFT): c2r other with bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
squidruge committed Jun 28, 2024
1 parent 6b0b6ee commit 08df5ca
Show file tree
Hide file tree
Showing 7 changed files with 524 additions and 2,947 deletions.
2 changes: 1 addition & 1 deletion kernels/fft/c2c_fft/c2c_fft_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2112,7 +2112,7 @@ mluOpStatus_t execIRFFT1d_v2(mluOpHandle_t handle,
// fft_plan->mlu_addrs.input);
// INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
// }

status = execFFTc2r1d(handle, fft_plan, scale_factor);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);

Expand Down
65 changes: 43 additions & 22 deletions kernels/fft/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -603,13 +603,13 @@ mluOpStatus_t MLUOP_WIN_API fftFactor(const int _n, int *facbuf,
}
break;

case (32 * 17):
if (n % 32 == 0) {
r = 32;
} else if ((n % 17) == 0) {
r = 17;
}
break;
case (32 * 17):
if (n % 32 == 0) {
r = 32;
} else if ((n % 17) == 0) {
r = 17;
}
break;

case 600:
if (n % 30 == 0) {
Expand Down Expand Up @@ -726,19 +726,19 @@ mluOpStatus_t MLUOP_WIN_API fftTwoStepFactor(mluOpFFTPlan_t fft_plan,
if (is_row_major) {
switch (_n) {
case (32 * 17):
r = 32 * 17;
r = 32 * 17;
break;

case (200):
r = 200;
r = 200;
break;

case (600):
r = 600;
r = 600;
break;

case (256):
r = 256;
r = 256;
break;

case 1024:
Expand Down Expand Up @@ -868,7 +868,29 @@ mluOpStatus_t MLUOP_WIN_API fftTwoStepFactor(mluOpFFTPlan_t fft_plan,
}
}
n /= r;
in_stride = _n / r;

switch (fft_plan->fft_type) {
// r2c
case CNFFT_HALF2COMPLEX_HALF:
case CNFFT_FLOAT2COMPLEX_FLOAT:
case CNFFT_COMPLEX_HALF2HALF:
case CNFFT_COMPLEX_FLOAT2FLOAT: {
if ((n * r) != _n) {
in_stride = (((out_stride / 2) + 1) * section_num) / r;

} else {
in_stride = _n / r;
}
}

case CNFFT_COMPLEX_HALF2COMPLEX_HALF:
case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT: {
in_stride = _n / r;
}
default:
break;
}

section_num = n;
stage_num++;

Expand Down Expand Up @@ -991,7 +1013,7 @@ mluOpStatus_t MLUOP_WIN_API calParallelNumLowBound(mluOpFFTPlan_t fft_plan,
switch (fft_plan->fft_type) {
// r2c
case CNFFT_HALF2COMPLEX_HALF:
case CNFFT_FLOAT2COMPLEX_FLOAT:
case CNFFT_FLOAT2COMPLEX_FLOAT:
case CNFFT_COMPLEX_HALF2HALF:
case CNFFT_COMPLEX_FLOAT2FLOAT:
case CNFFT_COMPLEX_HALF2COMPLEX_HALF:
Expand Down Expand Up @@ -2226,13 +2248,12 @@ mluOpStatus_t MLUOP_WIN_API mluOpSetFFTReserveArea(mluOpHandle_t handle,
case CNFFT_COMPLEX_HALF2COMPLEX_HALF:
case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT: {
if (fft_plan->rank == 1) {
if(fft_plan->prime) {
status = setFFT1dReserveArea(handle, fft_plan, api);
}else{
status = setFFT1dReserveArea_v2(handle, fft_plan, api);
if (fft_plan->prime) {
status = setFFT1dReserveArea(handle, fft_plan, api);
} else {
status = setFFT1dReserveArea_v2(handle, fft_plan, api);
}


} else if (fft_plan->rank == 2) {
// status = setFFT1dReserveArea(handle, fft_plan, api);
status = setFFT2dReserveArea(handle, fft_plan, api);
Expand All @@ -2244,10 +2265,10 @@ status = setFFT1dReserveArea_v2(handle, fft_plan, api);
case CNFFT_COMPLEX_HALF2HALF:
case CNFFT_COMPLEX_FLOAT2FLOAT: {
if (fft_plan->rank == 1) {
if(fft_plan->prime) {
status = setIRFFT1dReserveArea(handle, fft_plan, api);
}else{
status = setIRFFT1dReserveArea_v2(handle, fft_plan, api);
if (fft_plan->prime) {
status = setIRFFT1dReserveArea(handle, fft_plan, api);
} else {
status = setIRFFT1dReserveArea_v2(handle, fft_plan, api);
}
} else if (fft_plan->rank == 2) {
// status = setFFT1dReserveArea(handle, fft_plan, api);
Expand Down
41 changes: 20 additions & 21 deletions kernels/fft/fft_optm_device/fft_c2c_stockham_nram.h
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ __mlu_func__ void computeLargeButterflyFirststageBatchPingpong(
small_butterfly_num, para_num, small_in_stride, dir,
radix);
} else {
// TODO(zrg):check
// TODO(zrg): check
computeGenericButterflyLaststageMat(
nram_para_store_pong,
nram_para_store_pong + max_para_ldst_num * large_radix,
Expand Down Expand Up @@ -1117,25 +1117,24 @@ __mlu_func__ void computeLargeButterflyOtherstages(
int para_load_num = (max_para_ldst_num > (large_butterfly_num - i))
? (large_butterfly_num - i)
: max_para_ldst_num;
if (para_load_num != 1) {
__memcpy_async(nram_para_load_in.r, input + Fin_stride + i,
sizeof(DT) * para_load_num, GDRAM2NRAM,
sizeof(DT) * para_load_num,
large_in_stride * sizeof(DT), large_radix - 1);
__memcpy_async(nram_para_load_in.i, input + nfft + Fin_stride + i,
sizeof(DT) * para_load_num, GDRAM2NRAM,
sizeof(DT) * para_load_num,
large_in_stride * sizeof(DT), large_radix - 1);
__memcpy_async(nram_para_load_tw.r, cur_large_twiddles + i,
sizeof(DT) * para_load_num, SRAM2NRAM,
sizeof(DT) * para_load_num,
large_out_stride * sizeof(DT), large_radix - 2);
__memcpy_async(
nram_para_load_tw.i,
cur_large_twiddles + large_butterfly_num * (large_radix - 1) + i,
sizeof(DT) * para_load_num, SRAM2NRAM, sizeof(DT) * para_load_num,
large_out_stride * sizeof(DT), large_radix - 2);
}

__memcpy_async(nram_para_load_in.r, input + Fin_stride + i,
sizeof(DT) * para_load_num, GDRAM2NRAM,
sizeof(DT) * para_load_num, large_in_stride * sizeof(DT),
large_radix - 1);
__memcpy_async(nram_para_load_in.i, input + nfft + Fin_stride + i,
sizeof(DT) * para_load_num, GDRAM2NRAM,
sizeof(DT) * para_load_num, large_in_stride * sizeof(DT),
large_radix - 1);
__memcpy_async(nram_para_load_tw.r, cur_large_twiddles + i,
sizeof(DT) * para_load_num, SRAM2NRAM,
sizeof(DT) * para_load_num,
large_out_stride * sizeof(DT), large_radix - 2);
__memcpy_async(
nram_para_load_tw.i,
cur_large_twiddles + large_butterfly_num * (large_radix - 1) + i,
sizeof(DT) * para_load_num, SRAM2NRAM, sizeof(DT) * para_load_num,
large_out_stride * sizeof(DT), large_radix - 2);
}

// pipeline: store-stage
Expand Down Expand Up @@ -1763,7 +1762,7 @@ __mlu_func__ void computeLargeButterflyOtherstagesBatchPingpong(
// pipeline: load-stage

if (repeat_id < repeat_num) {
if (para_num != 1) {
if (1) {
__memcpy_async(
nram_para_load_in_ping.r,
input_batch + sec_count * large_butterfly_num + butterfly_id,
Expand Down
105 changes: 44 additions & 61 deletions kernels/fft/fft_optm_device/fft_c2r_stockham_gdram.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ __mlu_func__ void computeMutiStageOnchipC2R(DT *input, DT *output, int *factors,
taskId, repeat_num, remain_num, t_len, t_start, t_end);

int radix, section_num, butterfly_num, in_stride, stage_count, value_mul,
small_factors_offset;
out_stride, small_factors_offset;

int *small_factors;
int last_stage;
Expand All @@ -78,25 +78,30 @@ __mlu_func__ void computeMutiStageOnchipC2R(DT *input, DT *output, int *factors,
const int nfft = factors[1];

// first stage
radix = factors[5 + 0];
section_num = factors[5 + 1];
in_stride = factors[5 + 3];
small_factors_offset = factors[5 + 4];
radix = factors[5 * _stage_count + 0];
section_num = factors[5 * _stage_count + 1];
butterfly_num = factors[5 * _stage_count + 3];
out_stride = factors[5 * _stage_count + 3];
in_stride = butterfly_num;
small_factors_offset = factors[5 * _stage_count + 4];

for (int loop_stage = 2; loop_stage < _stage_count; loop_stage++) {
cur_radix = factors[5 * loop_stage];
butterfly_num_stage = factors[5 * loop_stage + 2];
twiddles += (cur_radix - 1) * (butterfly_num_stage / 2);
}

// small_factors = factors + small_factors_offset;

stage_count = _stage_count;
last_stage = (stage_count == 1);


stage_count = 1;
last_stage = (_stage_count == 1);

if (__is_mpu()) {
__memcpy_async(sram_factors, factors, FFT_MAXFACTORS * sizeof(int),
GDRAM2SRAM);
if (twiddles_size) {
__memcpy_async(sram_twiddles, twiddles, twiddles_size * sizeof(DT),
GDRAM2SRAM);

}

const dft_table_entry *dft_table_gdram =
Expand Down Expand Up @@ -166,19 +171,19 @@ __mlu_func__ void computeMutiStageOnchipC2R(DT *input, DT *output, int *factors,
}

// sram_large_tw
stage_count--;
if (stage_count == 0) {
// continue;

if (stage_count == _stage_count) {
return;
}
stage_count++;

// if (__is_mpu()) {
// return;
// }

// sram_large_tw
value_mul = 10;
for (; stage_count > 1; stage_count--) {
for (; stage_count < _stage_count; stage_count++) {
// fft_swap_ptr<DT>(&buffer, &output);
// FFT_SWAP_PTR(buffer, output);
FFT_SWAP_PTR(buffer, output);
Expand All @@ -198,43 +203,32 @@ __mlu_func__ void computeMutiStageOnchipC2R(DT *input, DT *output, int *factors,
FFT_SWAP_PTR(odd_extra_buffer, output);
}

// value_mul = (_stage_count - stage_count + 1) * 5;
value_mul = (_stage_count - stage_count + 1) * 5;

// update parameter
radix = factors[value_mul++];
section_num = factors[value_mul++];
butterfly_num = factors[value_mul++];
in_stride = factors[value_mul++];
small_factors_offset = factors[value_mul++];

radix = factors[value_mul];
section_num = factors[value_mul + 1];
butterfly_num = factors[value_mul + 2];
in_stride = butterfly_num;
out_stride = factors[value_mul + 3];
small_factors_offset = factors[value_mul + 4];
twiddles -= (radix - 1) * (butterfly_num / 2 + 1);

small_factors = factors + small_factors_offset;

if (__is_ipu()) {
// MLULOG("other stage radix: %d \n", radix);

if (repeat_num > 0 || taskId < remain_num) {
if (6000 / radix > repeat_num && 0) {
for (int t = t_start; t < t_end; t++) {
DT *output_batch = output + t * (nfft << 1);
DT *buffer_batch = buffer + t * (nfft << 1);

computeLargeButterflyOtherstages<DT>(
output_batch, buffer_batch, (DT *)twiddles, _twiddles,
sram_dftmtx, section_num, butterfly_num, in_stride,
(void *)nram_buf, small_factors, nfft, direction, 0);

// __sync();
}
} else {
computeLargeButterflyOtherstagesBatchPingpong<DT>(
output, buffer, (DT *)twiddles, _twiddles, sram_dftmtx,
section_num, butterfly_num, in_stride, (void *)nram_buf,
small_factors, nfft, t_start, t_end, direction, 0);
}
computeLargeButterflyOtherstagesBatchPingpongC2R<DT>(
output, buffer, (DT *)twiddles, _twiddles, sram_dftmtx, section_num,
butterfly_num, in_stride, (void *)nram_buf, small_factors, nfft,
t_start, t_end, direction, 0);
}
}
twiddles += butterfly_num * (radix - 1) * 2; // 2 for complex
} // for (stage_count)
twiddles += (butterfly_num / 2 - 1) * (radix - 1) * 2; // 2 for complex
} // for (stage_count)

// __mlu_shared__ DT *sram_tw[2048]; // radix-1024
// __mlu_shared__ DT *sram_tw[64]; // radix-1024
Expand All @@ -248,32 +242,21 @@ __mlu_func__ void computeMutiStageOnchipC2R(DT *input, DT *output, int *factors,
FFT_SWAP_PTR(buffer, output);

// update parameter
radix = factors[value_mul++];
section_num = factors[value_mul++];
butterfly_num = factors[value_mul++];
in_stride = factors[value_mul++];
small_factors_offset = factors[value_mul];
radix = factors[5];
section_num = factors[6];
butterfly_num = factors[7];
in_stride = butterfly_num;
out_stride = factors[8];
small_factors_offset = factors[9];

small_factors = factors + small_factors_offset;

if (__is_ipu()) {
if (repeat_num > 0 || taskId < remain_num) {
if (0) {
for (int t = t_start; t < t_end; t++) {
DT *output_batch = output + t * (nfft << 1);
DT *buffer_batch = buffer + t * (nfft << 1);

computeLargeButterflyLaststage<DT>(
output_batch, buffer_batch, (DT *)twiddles, _twiddles,
sram_dftmtx, section_num, butterfly_num, in_stride,
(void *)nram_buf, small_factors, nfft, direction);
}
} else {
computeLargeButterflyLaststageBatchPingpong(
output, buffer, (DT *)twiddles, _twiddles, sram_dftmtx,
section_num, butterfly_num, in_stride, (void *)nram_buf,
small_factors, nfft, t_start, t_end, direction);
}
computeLargeButterflyLaststageBatchPingpongC2R(
output, buffer, (DT *)twiddles, _twiddles, sram_dftmtx, section_num,
butterfly_num, in_stride, (void *)nram_buf, small_factors, nfft,
t_start, t_end, direction);
}
}
}
Expand Down
Loading

0 comments on commit 08df5ca

Please sign in to comment.