Skip to content

Commit

Permalink
[Feature](mluOpExecFFT): add c2r stride
Browse files Browse the repository at this point in the history
  • Loading branch information
squidruge committed Jun 22, 2024
1 parent 1c081e6 commit a3a5ebe
Show file tree
Hide file tree
Showing 4 changed files with 471 additions and 10 deletions.
225 changes: 224 additions & 1 deletion kernels/fft/c2c_fft/c2c_fft_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -627,10 +627,12 @@ mluOpStatus_t setFFT2dReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan,

switch (fft_plan->fft_type) {
case CNFFT_HALF2COMPLEX_HALF:
case CNFFT_COMPLEX_HALF2HALF:
case CNFFT_COMPLEX_HALF2COMPLEX_HALF: {
CPX_TYPE_SIZE = 2 * 2;
} break;
case CNFFT_FLOAT2COMPLEX_FLOAT:
case CNFFT_COMPLEX_FLOAT2FLOAT:
case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT: {
CPX_TYPE_SIZE = 4 * 2;
}; break;
Expand Down Expand Up @@ -735,6 +737,25 @@ mluOpStatus_t setFFT2dReserveArea(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan,
fft_plan->dft_matrix_2d,
CPX_TYPE_SIZE * _n0 * _n0, cnrtMemcpyHostToDev));
}; break;
case CNFFT_COMPLEX_HALF2HALF:
case CNFFT_COMPLEX_FLOAT2FLOAT: {
// C2R
size_t reservespace_offset = 0;
fft_plan->mlu_addrs.dft_matrix =
(uint8_t *)fft_plan->reservespace_addr + reservespace_offset;
reservespace_offset += CPX_TYPE_SIZE * (_n1 / 2 + 1) * _n1;

fft_plan->mlu_addrs.dft_matrix_2d =
(uint8_t *)fft_plan->reservespace_addr + reservespace_offset;
reservespace_offset += CPX_TYPE_SIZE * _n0 * _n0;

CNRT_CHECK(cnrtMemcpy(
fft_plan->mlu_addrs.dft_matrix, fft_plan->dft_matrix,
CPX_TYPE_SIZE * (_n1 / 2 + 1) * _n1, cnrtMemcpyHostToDev));
CNRT_CHECK(cnrtMemcpy(fft_plan->mlu_addrs.dft_matrix_2d,
fft_plan->dft_matrix_2d,
CPX_TYPE_SIZE * _n0 * _n0, cnrtMemcpyHostToDev));
}; break;
default: {
LOG(ERROR) << make_plan_api << ": invalid 2d fft type.";
status = MLUOP_STATUS_NOT_SUPPORTED;
Expand Down Expand Up @@ -934,15 +955,17 @@ static void configureFFT2dWorkspaceAddrs(mluOpHandle_t handle,

switch (fft_plan->fft_type) {
case CNFFT_HALF2COMPLEX_HALF:
case CNFFT_COMPLEX_HALF2HALF:
case CNFFT_COMPLEX_HALF2COMPLEX_HALF: {
CPX_TYPE_SIZE = 2 * 2;
} break;
case CNFFT_FLOAT2COMPLEX_FLOAT:
case CNFFT_COMPLEX_FLOAT2FLOAT:
case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT: {
CPX_TYPE_SIZE = 4 * 2;
}; break;
default: {
LOG(ERROR) << make_plan_api << ": invalid c2c 2d fft type.";
LOG(ERROR) << make_plan_api << ": invalid 2d fft type.";
return;
}
}
Expand Down Expand Up @@ -1786,6 +1809,28 @@ mluOpStatus_t execFFTr2c2d(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan,
return status;
}

mluOpStatus_t execFFTc2r2d(mluOpHandle_t handle, mluOpFFTPlan_t fft_plan,
const float scale_factor, int direction) {
std::string api = "[execFFTc2r2d]";

VLOG(5) << "launch c2r fft2d";
mluOpStatus_t status = MLUOP_STATUS_SUCCESS;
if (fft_plan->fft_strategy == CNFFT_FUNC_TWO_LEVEL_STOCKHAM) {
// CNFFT_FUNC_TWO_LEVEL_STOCKHAM
}

if (fft_plan->fft_strategy == CNFFT_FUNC_MANY_DIST1_2D) {
status =
computeFFT2dMatMulColumnC2R(handle, fft_plan, scale_factor, direction);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);

status =
computeFFT2dMatMulRowC2R(handle, fft_plan, scale_factor, direction);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
}
return status;
}

mluOpStatus_t execFFT1d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan,
const void *input, const float scale_factor,
void *workspace, void *output, int direction) {
Expand Down Expand Up @@ -1850,6 +1895,11 @@ mluOpStatus_t execFFT2d(mluOpHandle_t handle, const mluOpFFTPlan_t fft_plan,
// R2C
status = execFFTr2c2d(handle, fft_plan, scale_factor, direction);
} break;
case CNFFT_COMPLEX_HALF2HALF:
case CNFFT_COMPLEX_FLOAT2FLOAT: {
// R2C
status = execFFTc2r2d(handle, fft_plan, scale_factor, direction);
} break;
case CNFFT_COMPLEX_HALF2COMPLEX_HALF:
case CNFFT_COMPLEX_FLOAT2COMPLEX_FLOAT: {
// C2C
Expand Down Expand Up @@ -2392,3 +2442,176 @@ mluOpStatus_t computeFFT2dMatMulRowR2C(mluOpHandle_t handle,

return status;
}

// in: [2][n0][2][n1][batch]
mluOpStatus_t computeFFT2dMatMulColumnC2R(mluOpHandle_t handle,
mluOpFFTPlan_t fft_plan,
const float scale_factor,
int direction) {
std::string api = "[computeFFT2dMatMulColumnR2C]";
mluOpStatus_t status = MLUOP_STATUS_SUCCESS;

mluOpDataType_t in_e_dtype = fft_plan->execution_dtype;
int batch = fft_plan->batch;
int n0 = fft_plan->n[0];
int n1 = fft_plan->n[1];

void *dft_matrix_addr = fft_plan->mlu_addrs.dft_matrix_2d;
void *in_addr = fft_plan->mlu_addrs.input;
void *out_addr = fft_plan->mlu_addrs.buffer_out;
// void *out_addr = fft_plan->mlu_addrs.output;

// out[n0 * 2][(n1/2+1)*2][batch] = W[n0 * 2][n0] * In[n0][(n1/2+1)*2][batch]
const int m = n0 * 2, k = n0, n = (n1 / 2 + 1) * 2 * batch;

// create descriptor
mluOpTensorDescriptor_t a_desc = nullptr;
mluOpTensorDescriptor_t b_desc = nullptr;
mluOpTensorDescriptor_t c_desc = nullptr;
status = mluOpCreateTensorDescriptor(&a_desc);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
status = mluOpCreateTensorDescriptor(&b_desc);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
status = mluOpCreateTensorDescriptor(&c_desc);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);

// set descriptor
int64_t a_dims[2] = {m, k};
int64_t b_dims[2] = {k, n};
int64_t c_dims[2] = {m, n};

status = mluOpSetTensorDescriptor_v2(a_desc, MLUOP_LAYOUT_ARRAY, in_e_dtype,
2, a_dims);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
status = mluOpSetTensorDescriptorOnchipDataType(a_desc, in_e_dtype);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
status = mluOpSetTensorDescriptor_v2(b_desc, MLUOP_LAYOUT_ARRAY, in_e_dtype,
2, b_dims);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
status = mluOpSetTensorDescriptorOnchipDataType(b_desc, in_e_dtype);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
status = mluOpSetTensorDescriptor_v2(c_desc, MLUOP_LAYOUT_ARRAY, in_e_dtype,
2, c_dims);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
status = mluOpSetTensorDescriptorOnchipDataType(c_desc, in_e_dtype);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);

DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle,
cnnl_handle); // convert to cnnl_handle

// convert to cnnl_tensor_descriptor
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(a_desc, cnnl_a_desc);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(b_desc, cnnl_b_desc);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(c_desc, cnnl_c_desc);

// c_desc->onchip_dtype = MLUOP_DTYPE_FLOAT;
c_desc->onchip_dtype = in_e_dtype;
float alpha = 1.0;
float beta = 0.0;

CALL_CNNL(cnnlMatMul(cnnl_handle, false, false, &alpha, cnnl_a_desc,
dft_matrix_addr, cnnl_b_desc, in_addr, &beta,
cnnl_c_desc, out_addr));

// destroy cnnl descriptor
DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_a_desc);
DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_b_desc);
DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_c_desc);

DESTROY_CNNL_HANDLE(cnnl_handle);
cnrtQueueSync(handle->queue);

cnrtDim3_t k_dim;
cnrtFunctionType_t k_type;
policyFunc(handle, &k_dim, &k_type);

kernelFFTBatchConjMergeC2R(
k_dim, k_type, handle->queue, fft_plan->mlu_addrs.buffer_in,
fft_plan->mlu_addrs.buffer_out, (n1 / 2 + 1) * batch, n0, in_e_dtype);
// kernelFFTBatchConjMergeR2C(
// k_dim, k_type, handle->queue, fft_plan->mlu_addrs.output,
// fft_plan->mlu_addrs.buffer_out, (n1 / 2 + 1) * batch, n0, in_e_dtype);
cnrtQueueSync(handle->queue);

return status;
}

mluOpStatus_t computeFFT2dMatMulRowC2R(mluOpHandle_t handle,
mluOpFFTPlan_t fft_plan,
const float scale_factor,
int direction) {
std::string api = "[computeFFT2dMatMulRowC2R]";
mluOpStatus_t status = MLUOP_STATUS_SUCCESS;

mluOpDataType_t in_e_dtype = fft_plan->execution_dtype;
int batch = fft_plan->batch;
int n0 = fft_plan->n[0];
int n1 = fft_plan->n[1];

printf("%d, %d, %d\n", batch, n0, n1);
void *dft_matrix_addr = fft_plan->mlu_addrs.dft_matrix;
void *in_addr = fft_plan->mlu_addrs.buffer_in;
void *out_addr = fft_plan->mlu_addrs.output;

// out[n0][(n1/2+1)*2][batch] = W[(n1/2+1) * 2][n1] * In[n0][n1][batch]
const int m = n1, k = (n1 / 2 + 1) * 2, n = batch;

// create descriptor
mluOpTensorDescriptor_t a_desc = nullptr;
mluOpTensorDescriptor_t b_desc = nullptr;
mluOpTensorDescriptor_t c_desc = nullptr;
status = mluOpCreateTensorDescriptor(&a_desc);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
status = mluOpCreateTensorDescriptor(&b_desc);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
status = mluOpCreateTensorDescriptor(&c_desc);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);

// set descriptor
int64_t a_dims[2] = {m, k};
int64_t b_dims[3] = {n0, k, n};
int64_t c_dims[3] = {n0, m, n};

status = mluOpSetTensorDescriptor_v2(a_desc, MLUOP_LAYOUT_ARRAY, in_e_dtype,
2, a_dims);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
status = mluOpSetTensorDescriptorOnchipDataType(a_desc, in_e_dtype);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
status = mluOpSetTensorDescriptor_v2(b_desc, MLUOP_LAYOUT_ARRAY, in_e_dtype,
3, b_dims);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
status = mluOpSetTensorDescriptorOnchipDataType(b_desc, in_e_dtype);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
status = mluOpSetTensorDescriptor_v2(c_desc, MLUOP_LAYOUT_ARRAY, in_e_dtype,
3, c_dims);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);
status = mluOpSetTensorDescriptorOnchipDataType(c_desc, in_e_dtype);
INTERNAL_CHECK(api, status == MLUOP_STATUS_SUCCESS);

DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle,
cnnl_handle); // convert to cnnl_handle

// convert to cnnl_tensor_descriptor
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(a_desc, cnnl_a_desc);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(b_desc, cnnl_b_desc);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(c_desc, cnnl_c_desc);

// c_desc->onchip_dtype = MLUOP_DTYPE_FLOAT;
c_desc->onchip_dtype = in_e_dtype;
float alpha = 1.0;
float beta = 0.0;

CALL_CNNL(cnnlBatchMatMulBCast(cnnl_handle, false, false, cnnl_a_desc,
dft_matrix_addr, cnnl_b_desc, in_addr, NULL, 0,
cnnl_c_desc, out_addr));

// destroy cnnl descriptor
DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_a_desc);
DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_b_desc);
DESTROY_CNNL_TENSOR_DESCRIPTOR(cnnl_c_desc);

DESTROY_CNNL_HANDLE(cnnl_handle);
cnrtQueueSync(handle->queue);

return status;
}
Loading

0 comments on commit a3a5ebe

Please sign in to comment.