Skip to content

Commit

Permalink
[Feature](mluOpExecFFT): update test for r2c
Browse files Browse the repository at this point in the history
  • Loading branch information
squidruge committed Jun 21, 2024
1 parent 4b24314 commit e1efa5e
Showing 1 changed file with 136 additions and 1 deletion.
137 changes: 136 additions & 1 deletion test/mlu_op_gtest/pb_gtest/src/zoo/fft/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ void FftExecutor::workspaceMalloc() {
GTEST_CHECK(reservespace_addr_ = mlu_runtime_.allocate(reservespace_size_));
workspace_.push_back(reservespace_addr_);
}

// interface_timer_.start();
/* reserve space is the compiling time process before FFT execution */

MLUOP_CHECK(mluOpSetFFTReserveArea(handle_, fft_plan_, reservespace_addr_));
// interface_timer_.stop();
if (workspace_size_ > 0) {
Expand Down Expand Up @@ -103,7 +105,9 @@ void FftExecutor::cpuCompute() {
}

#define TEST_C2C1D_FP32 0
#define TEST_C2C2D_FP32 1
#define TEST_C2C2D_FP32 0
#define TEST_C2C2D_STRIDE_FP32 0
#define TEST_R2C2D_STRIDE_FP32 1

#if TEST_C2C1D_FP32
int n[1];
Expand Down Expand Up @@ -170,6 +174,137 @@ void FftExecutor::cpuCompute() {
fftwf_execute(fft);
fftwf_destroy_plan(fft);

#endif

#if TEST_C2C2D_STRIDE_FP32

fftwf_plan fft;

fftwf_complex *fftw_out = ((fftwf_complex *)cpu_fp32_output_[0]);
fftwf_complex *fftw_in = ((fftwf_complex *)cpu_fp32_input_[0]);

int n[2];
n[0] = parser_->getProtoNode()->fft_param().n()[0];
n[1] = parser_->getProtoNode()->fft_param().n()[1];
int howmany = count / (n[0] * n[1]);
int *inembed = n;
int *onembed = n;
// int istride = howmany;
// int ostride = howmany;
// int idist = 1;
// int odist = 1;
int istride = 1;
int ostride = 1;
int idist = (n[0] * n[1]);
int odist = (n[0] * n[1]);
// input[b * idist + (y * inembed[1] + x) * istride]
// output[b * odist + (y * onembed[1] + x) * ostride]

// for(int i = 0; i <6; i ++) {
// for(int j = 0; j <2; j ++) {

// printf("(%f, %f) ", ((float *)fftw_in)[(i*2+j)*2],((float
// *)fftw_in)[(i*2+j)*2+1]);
// }
// printf("\n");
// }

fft = fftwf_plan_many_dft(2, n, howmany, fftw_in, inembed, istride, idist,
fftw_out, onembed, ostride, odist, FFTW_FORWARD,
FFTW_ESTIMATE); // Setup fftw plan for fft
printf("fftw:\n");
printf("howmany: %d\n", howmany);
printf("n[0]: %d\n", n[0]);
printf("n[1]: %d\n", n[1]);

fftwf_execute(fft);

// for(int i = 0; i <6; i ++) {
// for(int j = 0; j <2; j ++) {

// printf("(%f, %f) ", ((float *)fftw_out)[(i*2+j)*2],((float
// *)fftw_out)[(i*2+j)*2+1]);
// }
// printf("\n");
// }

fftwf_destroy_plan(fft);

#endif

#if TEST_R2C2D_STRIDE_FP32

fftwf_plan fft;

fftwf_complex *fftw_out = ((fftwf_complex *)cpu_fp32_output_[0]);
float *fftw_in = ((float *)cpu_fp32_input_[0]);

int n[2];
n[0] = parser_->getProtoNode()->fft_param().n()[0];
n[1] = parser_->getProtoNode()->fft_param().n()[1];
int howmany = count / (n[0] * n[1]);
int inembed[2] = {n[0], n[1]};
int onembed[2] = {n[0], n[1] / 2 + 1};
// onembed[1] = n[1]/2 +1;
// int istride = howmany;
// int ostride = howmany;
// int idist = 1;
// int odist = 1;
int istride = 1;
int ostride = 1;
int idist = (n[0] * n[1]);
int odist = (n[0] * (n[1] / 2 + 1));
// input[b * idist + (y * inembed[1] + x) * istride]
// output[b * odist + (y * onembed[1] + x) * ostride]

// for(int i = 0; i <6; i ++) {
// for(int j = 0; j <2; j ++) {

// printf("(%f, %f) ", ((float *)fftw_in)[(i*2+j)*2],((float
// *)fftw_in)[(i*2+j)*2+1]);
// }
// printf("\n");
// }

fft = fftwf_plan_many_dft_r2c(2, n, howmany, fftw_in, inembed, istride, idist,
fftw_out, onembed, ostride, odist,
FFTW_ESTIMATE); // Setup fftw plan for fft
// printf("fftw:\n");
// printf("howmany: %d\n", howmany);
// printf("n[0]: %d\n", n[0]);
// printf("n[1]: %d\n", n[1]);

fftwf_execute(fft);
// fftwf_execute_dft_r2c(fft, fftw_in, fftw_out);
// for(int i = 0; i <6; i ++) {
// for(int j = 0; j <2; j ++) {

// printf("(%f, %f) ", ((float *)fftw_out)[(i*2+j)*2],((float
// *)fftw_out)[(i*2+j)*2+1]);
// }
// printf("\n");
// }

// for (int i = 0; i < n[0]; i++) {
// int ld = (n[1]/2+1)*howmany;
// for (int j = 0; j < ld; j++) {
// printf("[%d][%d]: (%f, %f) ",i, j, ((float*)fftw_out)[(i * (ld) + j)
// *2], ((float*)fftw_out)[((i * (ld) + j) *2 + 1)]);
// }
// printf("\n");
// }
// for (int i = 0; i < howmany; i++) {
// // int ld = (n[1]/2+1)*howmany;
// int ld = (n[1]/2+1)*n[0];
// for (int j = 0; j < ld; j++) {
// printf("[%d][%d]: (%f, %f) ",i, j, ((float*)fftw_out)[(i * (ld) + j)
// *2], ((float*)fftw_out)[((i * (ld) + j) *2 + 1)]);
// }
// printf("\n");
// }

fftwf_destroy_plan(fft);

#endif
}

Expand Down

0 comments on commit e1efa5e

Please sign in to comment.