Skip to content

Commit

Permalink
[Feature](mluOpExecFFT): update c2c 1d stride
Browse files Browse the repository at this point in the history
  • Loading branch information
squidruge committed Jun 24, 2024
1 parent d5c541f commit 2ef4dfe
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 179 deletions.
155 changes: 48 additions & 107 deletions kernels/fft/c2c_fft/c2c_fft_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -974,47 +974,40 @@ static void configureFFT1dWorkspaceAddrs_v2(mluOpHandle_t handle,
mluOpFFTPlan_t fft_plan,
void *input, void *workspace,
void *output) {
VLOG(5) << "Into configure FFT1d Workspace Addrs (zrg)";
VLOG(5) << "Into configure FFT1d Workspace Addrs";
const std::string make_plan_api = "[configureFFT1dWorkspaceAddrs_v2]";
size_t workspace_size = 0;
size_t reservespace_size = 0;

size_t CPX_TYPE_SIZE = 0;
// c2c
mluOpDataType_t in_c_dtype = fft_plan->input_dtype;
mluOpDataType_t out_c_dtype = fft_plan->output_dtype;

switch (fft_plan->fft_type) {
case CNFFT_COMPLEX_HALF2COMPLEX_HALF: {
CPX_TYPE_SIZE = 2 * 2;
} break;
case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT: {
CPX_TYPE_SIZE = 4 * 2;
}; break;
default: {
LOG(ERROR) << make_plan_api << ": invalid c2c 1d fft type.";
return;
// return MLUOP_STATUS_BAD_PARAM;
}
}
size_t in_c_dtype_size = mluOpDataTypeBytes(in_c_dtype);
size_t out_c_dtype_size = mluOpDataTypeBytes(out_c_dtype);

int batch = fft_plan->batch;
int nfft = fft_plan->n[0];

size_t buffer_size = batch * sizeof(CPX_TYPE_SIZE) * nfft;
size_t twiddles_size = sizeof(CPX_TYPE_SIZE) * nfft * 2;
size_t buffer_size = batch * in_c_dtype * nfft;

// mlu_addrs
// fft_plan->mlu_addrs.input = workspace;
// fft_plan->mlu_addrs.output = fft_plan->mlu_addrs.input + buffer_size;
// fft_plan->mlu_addrs.buffer = fft_plan->mlu_addrs.output + buffer_size;
size_t offset = 0;
fft_plan->mlu_addrs.buffer_buf = (uint8_t *)workspace + offset;
offset += buffer_size;

fft_plan->mlu_addrs.input = input;
fft_plan->mlu_addrs.output = output;
// fft_plan->mlu_addrs.buffer_in = (uint8_t *)workspace;
// fft_plan->mlu_addrs.buffer_out = (uint8_t *)workspace + buffer_size;
// fft_plan->mlu_addrs.buffer_buf = (uint8_t *)workspace + 2 * buffer_size;

fft_plan->mlu_addrs.buffer_buf = (uint8_t *)workspace;
if (fft_plan->is_input_contiguous || fft_plan->is_batch_contiguous) {
fft_plan->mlu_addrs.input = input;
} else {
fft_plan->mlu_addrs.input = (uint8_t *)workspace + offset;
offset += batch * in_c_dtype * nfft;
}

// fft_plan->mlu_addrs.twiddles = mlu_runtime_.allocate(reservespace_size);
if (fft_plan->is_output_contiguous || fft_plan->is_batch_contiguous) {
fft_plan->mlu_addrs.output = output;
} else {
fft_plan->mlu_addrs.output = (uint8_t *)workspace + offset;
offset += batch * out_c_dtype * nfft;
}
}

// zrg
Expand Down Expand Up @@ -1198,7 +1191,8 @@ static void configureFFT1dWorkspaceAddrs(mluOpHandle_t handle,
// output : in input_contiguous_addr
static mluOpStatus_t makeFFT1dContiguousInput(mluOpHandle_t handle,
mluOpFFTPlan_t fft_plan,
const void *input) {
const void *input,
void *input_contiguous) {
std::string api = "[mluOpExecFFT]";
VLOG(5) << "into makeFFT1dContiguousInput";
auto status = MLUOP_STATUS_SUCCESS;
Expand All @@ -1216,8 +1210,7 @@ static mluOpStatus_t makeFFT1dContiguousInput(mluOpHandle_t handle,
dims, strides);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);

status = mluOpContiguous(handle, input_desc, input,
fft_plan->matmul_addrs.input_contiguous_addr);
status = mluOpContiguous(handle, input_desc, input, input_contiguous);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);

status = mluOpDestroyTensorDescriptor(input_desc);
Expand Down Expand Up @@ -1697,7 +1690,8 @@ static mluOpStatus_t transposeFFT1dOutput(mluOpHandle_t handle,

static mluOpStatus_t makeFFT1dContiguousOutput(mluOpHandle_t handle,
mluOpFFTPlan_t fft_plan,
void *output) {
void *output,
void *output_contiguous) {
std::string api = "[mluOpExecFFT]";
VLOG(5) << "into makeFFT1dContiguousOutput";
mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
Expand All @@ -1723,7 +1717,8 @@ static mluOpStatus_t makeFFT1dContiguousOutput(mluOpHandle_t handle,
out_c_dtype, out_dim_num, dims, strides);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);

void *copy_src_addr = fft_plan->matmul_addrs.output_contiguous_addr;
// void *copy_src_addr = fft_plan->matmul_addrs.output_contiguous_addr;
void *copy_src_addr = output_contiguous;
DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle,
cnnl_handle); // convert to cnnl_handle
// convert to cnnl_tensor_descriptor
Expand All @@ -1742,83 +1737,17 @@ static mluOpStatus_t makeFFT1dContiguousOutput(mluOpHandle_t handle,
return status;
}

// only for CNFFT_FUNC_COOLEY_TUKEY and CNFFT_FUNC_STOCKHAM
// input : matmul real result in matmul_re_mul_re_addr
// matmul imag result in matmul_re_mul_im_addr
// workspace: internal_workspace_addr
// output : output real result in output_contiguous_addr
mluOpStatus_t execFFTc2c1d(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan,
const float scale_factor, int direction) {
std::string api = "[execFFTc2c1d]";

VLOG(5) << "launch c2c fft1d";
// TODO(niyuming) luanch merge kernel
// int core_num = handle->core_num_per_cluster;
mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
// cnrtFunctionType_t k_type = CNRT_FUNC_TYPE_UNION1;
// int task_type = mluop::runtime::getJobLimitCapability(handle);
// int task_num = 1;

// switch (task_type) {
// default:
// task_num = core_num;
// break;
// case (int)CNRT_FUNC_TYPE_UNION2:
// task_num = core_num * 2;
// break;
// case (int)CNRT_FUNC_TYPE_UNION4:
// task_num = core_num * 4;
// break;
// case (int)CNRT_FUNC_TYPE_UNION8:
// task_num = core_num * 8;
// break;
// case (int)CNRT_FUNC_TYPE_UNION16:
// task_num = core_num * 16;
// break;
// }
// int task_num = core_num * 4;
// unsigned int dimx = task_num;
// cnrtDim3_t k_dim = {dimx, 1, 1};
// cnrtFunctionType_t k_type = (cnrtFunctionType_t)dimx;
// kernelFFT1dButterflyRow(k_dim, k_type, handle->queue, fft_plan, direction,
// FFT_IFFT);

// int core_num = handle->core_num_per_cluster;
// cnrtFunctionType_t k_type = CNRT_FUNC_TYPE_UNION1;
// int task_type = mluop::runtime::getJobLimitCapability(handle);
// int task_num = 1;

// switch (task_type) {
// default:
// task_num = core_num;
// break;
// case (int)CNRT_FUNC_TYPE_UNION2:
// task_num = core_num * 2;
// break;
// case (int)CNRT_FUNC_TYPE_UNION4:
// task_num = core_num * 4;
// break;
// case (int)CNRT_FUNC_TYPE_UNION8:
// task_num = core_num * 8;
// break;
// case (int)CNRT_FUNC_TYPE_UNION16:
// task_num = core_num * 16;
// break;
// }

// unsigned int dimx = task_num;
// cnrtDim3_t k_dim = {dimx, 1, 1};
// k_type = (cnrtFunctionType_t)dimx;
// // std::cout<<"\n\nkernelFFT1dButterflyRow\n\n"<<std::endl;
// // std::cout<<"\n\ntask_num\n\n"<<task_num<<std::endl;
// kernelFFT1dButterflyRow(k_dim, k_type, handle->queue, fft_plan, direction,
// FFT_IFFT);

// VLOG(5) << "launch mrege rfft1d output";

cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFunc(handle, &k_dim, &k_type);
if (fft_plan->istride == 1) {
if (!fft_plan->is_batch_contiguous) {
kernelFFT1dButterflyRow(k_dim, k_type, handle->queue, fft_plan, direction,
FFT_IFFT);
} else {
Expand Down Expand Up @@ -1957,7 +1886,8 @@ mluOpStatus_t execFFT1d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan,
configureFFT1dMatmulWorkspaceAddrs(handle, fft_plan, (void *)input,
workspace, output);

status = makeFFT1dContiguousInput(handle, fft_plan, input);
status = makeFFT1dContiguousInput(
handle, fft_plan, input, fft_plan->matmul_addrs.input_contiguous_addr);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);

status = padFFT1dContiguousInput(handle, fft_plan);
Expand All @@ -1979,16 +1909,27 @@ mluOpStatus_t execFFT1d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan,
status = transposeFFT1dOutput(handle, fft_plan);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);

status = makeFFT1dContiguousOutput(handle, fft_plan, output);
status = makeFFT1dContiguousOutput(
handle, fft_plan, output,
fft_plan->matmul_addrs.output_contiguous_addr);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
} else {
// direction: 0(forward) 1(backward)

configureFFT1dWorkspaceAddrs_v2(handle, fft_plan, (void *)input, workspace,
output);
status = execFFTc2c1d(handle, fft_plan, scale_factor, direction);
if (!fft_plan->is_input_contiguous && !fft_plan->is_batch_contiguous) {
status = makeFFT1dContiguousInput(handle, fft_plan, input,
fft_plan->mlu_addrs.input);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
}

status = execFFTc2c1d(handle, fft_plan, scale_factor, direction);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);

if (!fft_plan->is_output_contiguous && !fft_plan->is_batch_contiguous) {
status = makeFFT1dContiguousOutput(handle, fft_plan, output,
fft_plan->mlu_addrs.output);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
}
}

return status;
Expand Down
55 changes: 18 additions & 37 deletions kernels/fft/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,8 @@ mluOpStatus_t MLUOP_WIN_API fftFactor(const int _n, int *facbuf,
}

mluOpStatus_t MLUOP_WIN_API fftTwoStepFactor(mluOpFFTPlan_t fft_plan,
const int _n, int *facbuf) {
const int _n, int *facbuf,
int is_row_major) {
int n = _n;
// if ((facbuf == NULL) || (n <= 0))
// {
Expand All @@ -683,10 +684,8 @@ mluOpStatus_t MLUOP_WIN_API fftTwoStepFactor(mluOpFFTPlan_t fft_plan,
int large_radix = 1;
int small_factors_offset = 22 * 5;

int row_major = (fft_plan->istride == 1);

while (n > 1) {
if (row_major) {
if (is_row_major) {
switch (_n) {
case (32 * 17):
if (n % 32 == 0) {
Expand Down Expand Up @@ -1234,45 +1233,24 @@ mluOpAllocateC2C1D(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan,
size_t workspace_size = 0;
size_t reservespace_size = 0;

size_t CPX_TYPE_SIZE = 0;

switch (fft_plan->fft_type) {
case CNFFT_COMPLEX_HALF2COMPLEX_HALF: {
CPX_TYPE_SIZE = 2 * 2;
} break;
case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT: {
CPX_TYPE_SIZE = 4 * 2;
}; break;
default: {
LOG(ERROR) << make_plan_api << ": invalid c2c 1d fft type.";
return MLUOP_STATUS_BAD_PARAM;
}
}
mluOpDataType_t in_c_dtype = fft_plan->input_dtype;
size_t in_c_dtype_size = mluOpDataTypeBytes(in_c_dtype);

int batch = fft_plan->batch;

size_t buffer_size = batch * sizeof(CPX_TYPE_SIZE) * nfft;

workspace_size = buffer_size * 3;
size_t buffer_size = batch * in_c_dtype_size * nfft;

// reservespace_size = batch * sizeof(mluOpFFTPlan_t) + sizeof(int) *
// (FFT_MAXFACTORS) /* factors */
// + sizeof(CPX_TYPE_SIZE) * nfft * 2 /* twiddles
// */
// );
workspace_size = buffer_size;
workspace_size += (fft_plan->is_input_contiguous) ? 0 : buffer_size;
workspace_size += (fft_plan->is_output_contiguous) ? 0 : buffer_size;

size_t twiddles_size = sizeof(CPX_TYPE_SIZE) * nfft * 2;
size_t twiddles_size = in_c_dtype_size * nfft * 2;
reservespace_size = sizeof(int) * (FFT_MAXFACTORS) /* factors */
+ twiddles_size * 2 + DFT_TABLE_SIZE * 2; /* twiddles */

fft_plan->workspace_size = workspace_size;
fft_plan->reservespace_size = reservespace_size;

// std::cout << "workspace_size: " << workspace_size << "bytes" << std::endl;
// std::cout << "reservespace_size: " << reservespace_size << "bytes" <<
// std::endl; CNAME(openfft_generate_twiddles)(st->twiddles, st->factors,
// nfft, st->dir);

return MLUOP_STATUS_SUCCESS;
}

Expand Down Expand Up @@ -1410,17 +1388,20 @@ mluOpStatus_t MLUOP_WIN_API mluOpMakeFFTPlanC2C1D(
const int rank, const int *n) {
// reservespace_addr_ = mlu_runtime_.allocate(reservespace_size_)
// st = CNAME(openfft_allocate_c2c_plan_1d)(nfft, fin, fout, dir);
fft_plan->is_batch_contiguous =
(fft_plan->idist == 1 && fft_plan->odist == 1);

// std::cout<< "mluOpAllocateC2C1D"<<std::endl;
mluOpAllocateC2C1D(handle, fft_plan, input_desc, output_desc, n[0]);
// std::cout<< "mluOpAllocateC2C1D"<<std::endl;
fftTwoStepFactor(fft_plan, n[0], fft_plan->factors);
int is_row_major = !fft_plan->is_batch_contiguous;
fftTwoStepFactor(fft_plan, n[0], fft_plan->factors, is_row_major);

switch (fft_plan->fft_type) {
case CNFFT_FLOAT2COMPLEX_FLOAT:
case CNFFT_COMPLEX_FLOAT2FLOAT:
case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT:
if (fft_plan->istride == 1) {
if (!fft_plan->is_batch_contiguous) {
fftGenerateTwiddles<float>(fft_plan->twiddles, fft_plan->twiddles_end,
fft_plan->factors, n[0], FFT_FORWARD);
fftGenerateTwiddles<float>(fft_plan->twiddles_inv,
Expand All @@ -1444,7 +1425,7 @@ mluOpStatus_t MLUOP_WIN_API mluOpMakeFFTPlanC2C1D(
case CNFFT_COMPLEX_HALF2HALF:
case CNFFT_COMPLEX_HALF2COMPLEX_HALF:

if (fft_plan->istride == 1) {
if (!fft_plan->is_batch_contiguous) {
fftGenerateTwiddles<float>(fft_plan->twiddles, fft_plan->twiddles_end,
fft_plan->factors, n[0], FFT_FORWARD);
fftGenerateTwiddles<float>(fft_plan->twiddles_inv,
Expand Down Expand Up @@ -1531,8 +1512,8 @@ mluOpStatus_t MLUOP_WIN_API mluOpMakeFFTPlanC2C2D(
}

if (fft_plan->fft_strategy == CNFFT_FUNC_TWO_LEVEL_STOCKHAM) {
fftTwoStepFactor(fft_plan, n[1], fft_plan->factors);
fftTwoStepFactor(fft_plan, n[0], fft_plan->factors_2d);
fftTwoStepFactor(fft_plan, n[1], fft_plan->factors, 1);
fftTwoStepFactor(fft_plan, n[0], fft_plan->factors_2d, 0);

switch (fft_plan->fft_type) {
case CNFFT_FLOAT2COMPLEX_FLOAT:
Expand Down
1 change: 1 addition & 0 deletions kernels/fft/fft.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ struct mluOpFFTStruct {
int prime;
bool is_input_contiguous;
bool is_output_contiguous;
bool is_batch_contiguous;
size_t reservespace_size;
size_t workspace_size;
FFTType fft_type; // types of fft
Expand Down
Loading

0 comments on commit 2ef4dfe

Please sign in to comment.