Skip to content

Commit

Permalink
[Fix](mluOpCholesky): mv cnnl to cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
dglr committed Dec 1, 2024
1 parent 0938f55 commit 3213104
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 165 deletions.
171 changes: 171 additions & 0 deletions kernels/cholesky/cholesky.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include "cholesky.h"
#include <cstdio>
#include <algorithm>
#include <string>

// calculates the required workspace size for performing the Cholesky
// decomposition on a given matrix or batch of matrices.
mluOpStatus_t MLUOP_WIN_API mluOpGetCholeskyWorkspaceSize(
Expand Down Expand Up @@ -316,3 +318,172 @@ mluOpCholesky(mluOpHandle_t handle, const mluOpTensorDescriptor_t input_desc,
output_desc, d_output, upper, (float*)workspace);
return MLUOP_STATUS_SUCCESS;
}


// m * n
mluOpStatus_t transpose(int batch, int m, int n, float* d_input,
float* d_output, mluOpHandle_t handle,
mluOpDataType_t type, float* workspace) {
if (m == 0) return MLUOP_STATUS_SUCCESS;
cnrtQueue_t queue;
mluOpGetQueue(handle, &queue);

mluOpTensorDescriptor_t trans_input_desc, trans_output_desc;
std::string api_name = "Cholesky";
const int input_dim = 3;

CHECK_RETURN(api_name, mluOpCreateTensorDescriptor(&trans_input_desc));
CHECK_RETURN(api_name, mluOpCreateTensorDescriptor(&trans_output_desc));

int32_t transpose_input_shape[3] = {batch, m, n};
int32_t transpose_output_shape[3] = {batch, n, m};

CHECK_RETURN(api_name,
mluOpSetTensorDescriptor(trans_input_desc, MLUOP_LAYOUT_ARRAY,
type, 3, transpose_input_shape));

CHECK_RETURN(api_name,
mluOpSetTensorDescriptor(trans_output_desc, MLUOP_LAYOUT_ARRAY,
type, 3, transpose_output_shape));

int permute[3] = {0, 2, 1};

DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, cnnl_handle);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(trans_input_desc, cnnl_in_desc);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(trans_output_desc,
cnnl_out_desc);

cnnlTransposeDescriptor_t cnnl_trans_desc = NULL;

CALL_CNNL(cnnlCreateTransposeDescriptor(&cnnl_trans_desc));

CALL_CNNL(cnnlSetTransposeDescriptor(cnnl_trans_desc, input_dim, permute));

size_t size = 0;

CALL_CNNL(cnnlGetTransposeWorkspaceSize(cnnl_handle, cnnl_in_desc,
cnnl_trans_desc, &size));

CALL_CNNL(cnnlTranspose_v2(cnnl_handle, cnnl_trans_desc, cnnl_in_desc,
d_input, cnnl_out_desc, d_output, workspace,
size));
return MLUOP_STATUS_SUCCESS;
}

mluOpStatus_t sgemm(int batch, bool trans_a, bool trans_b, int m, int n, int k,
float alpha, float beta, float* d_a, int lda, int stride_a,
float* d_b, int ldb, int stride_b, float* d_c, int ldc,
int stride_c, mluOpHandle_t handle, float* workspace) {
if (k == 0) return MLUOP_STATUS_SUCCESS;

int32_t batch_size_arr[1] = {batch};
int64_t stride_a_arr[1] = {stride_a};
int64_t stride_b_arr[1] = {stride_b};
int64_t stride_c_arr[1] = {stride_c};

std::string api_name = "Cholesky";

cnrtQueue_t queue;
mluOpGetQueue(handle, &queue);

cnnlStrideBatchMatMulAlgo_t algo;
CALL_CNNL(cnnlStrideBatchMatMulAlgoCreate(&algo));

cnnlStrideBatchMatMulHeuristicResult_t heuristic_result;
CALL_CNNL(cnnlCreateStrideBatchMatMulHeuristicResult(&heuristic_result));

cnnlStrideBatchMatMulDescriptor_t stride_bmm_desc;
CALL_CNNL(cnnlStrideBatchMatMulDescCreate(&stride_bmm_desc));
int32_t allow_tf32 = 0, max_batch_dim = 1;
CALL_CNNL(cnnlSetStrideBatchMatMulDescAttr(stride_bmm_desc,
CNNL_STRIDE_BMM_ALLOW_TF32,
&(allow_tf32), sizeof(int32_t)));
CALL_CNNL(cnnlSetStrideBatchMatMulDescAttr(
stride_bmm_desc, CNNL_STRIDE_BMM_MAX_BATCH_DIM, &(max_batch_dim),
sizeof(int32_t)));

mluOpTensorDescriptor_t matmul_a_desc, matmul_b_desc, matmul_c_desc;

CHECK_RETURN(api_name, mluOpCreateTensorDescriptor(&matmul_a_desc));
CHECK_RETURN(api_name, mluOpCreateTensorDescriptor(&matmul_b_desc));
CHECK_RETURN(api_name, mluOpCreateTensorDescriptor(&matmul_c_desc));

int32_t matmul_a_shape[2] = {batch, stride_a};
int32_t matmul_b_shape[2] = {batch, stride_b};
int32_t matmul_c_shape[2] = {batch, stride_c};

CHECK_RETURN(api_name,
mluOpSetTensorDescriptor(matmul_a_desc, MLUOP_LAYOUT_ARRAY,
MLUOP_DTYPE_FLOAT, 2, matmul_a_shape));
CHECK_RETURN(api_name,
mluOpSetTensorDescriptor(matmul_b_desc, MLUOP_LAYOUT_ARRAY,
MLUOP_DTYPE_FLOAT, 2, matmul_b_shape));
CHECK_RETURN(api_name,
mluOpSetTensorDescriptor(matmul_c_desc, MLUOP_LAYOUT_ARRAY,
MLUOP_DTYPE_FLOAT, 2, matmul_c_shape));

DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, cnnl_handle);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(matmul_a_desc, cnnl_a_desc);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(matmul_b_desc, cnnl_b_desc);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(matmul_c_desc, cnnl_c_desc);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(matmul_c_desc, cnnl_d_desc);

int requested_algo_count = 1, return_algo_count = 0;
size_t workspace_size;

cnnlGetStrideBatchMatMulAlgoHeuristic_v2(
cnnl_handle, stride_bmm_desc, cnnl_a_desc, cnnl_b_desc, cnnl_c_desc,
cnnl_d_desc, trans_a, trans_b, &(alpha), &(beta), m, n, k, lda,
ldb, ldc, ldc, batch_size_arr, stride_a_arr, stride_b_arr, stride_c_arr,
stride_c_arr, nullptr, requested_algo_count, &heuristic_result,
&return_algo_count);

cnnlGetStrideBatchMatMulHeuristicResult(heuristic_result, &algo,
&workspace_size);

if (workspace_size > 0) {
MLULOG("sgemm workspace size: %ld\n", workspace_size);
}

CALL_CNNL(cnnlStrideBatchMatMul_v3(
cnnl_handle, stride_bmm_desc, algo, trans_a, trans_b, m, n, k,
batch_size_arr, &(alpha), cnnl_a_desc, d_a, lda, stride_a_arr,
cnnl_b_desc, d_b, ldb, stride_b_arr, &(beta), cnnl_c_desc, d_c, ldc,
stride_c_arr, workspace, workspace_size, cnnl_d_desc, d_c, ldc,
stride_c_arr));

return MLUOP_STATUS_SUCCESS;
}

mluOpStatus_t conj_complex(int batch, int m, int n, float* d_input,
float* d_output, mluOpHandle_t handle) {
if (m == 0) return MLUOP_STATUS_SUCCESS;
cnrtQueue_t queue;
mluOpGetQueue(handle, &queue);

mluOpTensorDescriptor_t input_desc, output_desc;
std::string api_name = "Cholesky";

CHECK_RETURN(api_name, mluOpCreateTensorDescriptor(&input_desc));
CHECK_RETURN(api_name, mluOpCreateTensorDescriptor(&output_desc));

int32_t input_shape[3] = {batch, m, n};
int32_t output_shape[3] = {batch, m, n};

CHECK_RETURN(api_name, mluOpSetTensorDescriptor(
input_desc, MLUOP_LAYOUT_ARRAY,
MLUOP_DTYPE_COMPLEX_FLOAT, 3, input_shape));

CHECK_RETURN(api_name, mluOpSetTensorDescriptor(
output_desc, MLUOP_LAYOUT_ARRAY,
MLUOP_DTYPE_COMPLEX_FLOAT, 3, output_shape));

DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, cnnl_handle);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(input_desc, cnnl_in_desc);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(output_desc, cnnl_out_desc);

CALL_CNNL(
cnnlConj(cnnl_handle, cnnl_in_desc, d_input, cnnl_out_desc, d_output));

return MLUOP_STATUS_SUCCESS;
}
133 changes: 0 additions & 133 deletions kernels/cholesky/cholesky_union1.mlu
Original file line number Diff line number Diff line change
Expand Up @@ -811,91 +811,6 @@ mluOpStatus_t strsm_rectile(int batch, int stride, bool upper, bool trans,
return MLUOP_STATUS_SUCCESS;
}

mluOpStatus_t sgemm(int batch, bool trans_a, bool trans_b, int m, int n, int k,
float alpha, float beta, float* d_a, int lda, int stride_a,
float* d_b, int ldb, int stride_b, float* d_c, int ldc,
int stride_c, mluOpHandle_t handle, float* workspace) {
if (k == 0) return MLUOP_STATUS_SUCCESS;

int32_t batch_size_arr[1] = {batch};
int64_t stride_a_arr[1] = {stride_a};
int64_t stride_b_arr[1] = {stride_b};
int64_t stride_c_arr[1] = {stride_c};

std::string api_name = "Cholesky";

cnrtQueue_t queue;
mluOpGetQueue(handle, &queue);

cnnlStrideBatchMatMulAlgo_t algo;
CALL_CNNL(cnnlStrideBatchMatMulAlgoCreate(&algo));

cnnlStrideBatchMatMulHeuristicResult_t heuristic_result;
CALL_CNNL(cnnlCreateStrideBatchMatMulHeuristicResult(&heuristic_result));

cnnlStrideBatchMatMulDescriptor_t stride_bmm_desc;
CALL_CNNL(cnnlStrideBatchMatMulDescCreate(&stride_bmm_desc));
int32_t allow_tf32 = 0, max_batch_dim = 1;
CALL_CNNL(cnnlSetStrideBatchMatMulDescAttr(stride_bmm_desc,
CNNL_STRIDE_BMM_ALLOW_TF32,
&(allow_tf32), sizeof(int32_t)));
CALL_CNNL(cnnlSetStrideBatchMatMulDescAttr(
stride_bmm_desc, CNNL_STRIDE_BMM_MAX_BATCH_DIM, &(max_batch_dim),
sizeof(int32_t)));

mluOpTensorDescriptor_t matmul_a_desc, matmul_b_desc, matmul_c_desc;

CHECK_RETURN(api_name, mluOpCreateTensorDescriptor(&matmul_a_desc));
CHECK_RETURN(api_name, mluOpCreateTensorDescriptor(&matmul_b_desc));
CHECK_RETURN(api_name, mluOpCreateTensorDescriptor(&matmul_c_desc));

int32_t matmul_a_shape[2] = {batch, stride_a};
int32_t matmul_b_shape[2] = {batch, stride_b};
int32_t matmul_c_shape[2] = {batch, stride_c};

CHECK_RETURN(api_name,
mluOpSetTensorDescriptor(matmul_a_desc, MLUOP_LAYOUT_ARRAY,
MLUOP_DTYPE_FLOAT, 2, matmul_a_shape));
CHECK_RETURN(api_name,
mluOpSetTensorDescriptor(matmul_b_desc, MLUOP_LAYOUT_ARRAY,
MLUOP_DTYPE_FLOAT, 2, matmul_b_shape));
CHECK_RETURN(api_name,
mluOpSetTensorDescriptor(matmul_c_desc, MLUOP_LAYOUT_ARRAY,
MLUOP_DTYPE_FLOAT, 2, matmul_c_shape));

DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, cnnl_handle);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(matmul_a_desc, cnnl_a_desc);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(matmul_b_desc, cnnl_b_desc);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(matmul_c_desc, cnnl_c_desc);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(matmul_c_desc, cnnl_d_desc);

int requested_algo_count = 1, return_algo_count = 0;
size_t workspace_size;

cnnlGetStrideBatchMatMulAlgoHeuristic_v2(
cnnl_handle, stride_bmm_desc, cnnl_a_desc, cnnl_b_desc, cnnl_c_desc,
cnnl_d_desc, trans_a, trans_b, &(alpha), &(beta), m, n, k, lda,
ldb, ldc, ldc, batch_size_arr, stride_a_arr, stride_b_arr, stride_c_arr,
stride_c_arr, nullptr, requested_algo_count, &heuristic_result,
&return_algo_count);

cnnlGetStrideBatchMatMulHeuristicResult(heuristic_result, &algo,
&workspace_size);

if (workspace_size > 0) {
MLULOG("sgemm workspace size: %ld\n", workspace_size);
}

CALL_CNNL(cnnlStrideBatchMatMul_v3(
cnnl_handle, stride_bmm_desc, algo, trans_a, trans_b, m, n, k,
batch_size_arr, &(alpha), cnnl_a_desc, d_a, lda, stride_a_arr,
cnnl_b_desc, d_b, ldb, stride_b_arr, &(beta), cnnl_c_desc, d_c, ldc,
stride_c_arr, workspace, workspace_size, cnnl_d_desc, d_c, ldc,
stride_c_arr));

return MLUOP_STATUS_SUCCESS;
}

__mlu_global__ void batch_inverse_kernel(int batch, float* d_input,
int ld_input, int stride_input,
float* d_output, int ld_output,
Expand Down Expand Up @@ -1230,52 +1145,4 @@ mluOpStatus_t spotrf_recursion(int batch, int stride, bool trans, bool uplo,
return MLUOP_STATUS_SUCCESS;
}

// m * n
mluOpStatus_t transpose(int batch, int m, int n, float* d_input,
float* d_output, mluOpHandle_t handle,
mluOpDataType_t type, float* workspace) {
if (m == 0) return MLUOP_STATUS_SUCCESS;
cnrtQueue_t queue;
mluOpGetQueue(handle, &queue);

mluOpTensorDescriptor_t trans_input_desc, trans_output_desc;
std::string api_name = "Cholesky";
const int input_dim = 3;

CHECK_RETURN(api_name, mluOpCreateTensorDescriptor(&trans_input_desc));
CHECK_RETURN(api_name, mluOpCreateTensorDescriptor(&trans_output_desc));

int32_t transpose_input_shape[3] = {batch, m, n};
int32_t transpose_output_shape[3] = {batch, n, m};

CHECK_RETURN(api_name,
mluOpSetTensorDescriptor(trans_input_desc, MLUOP_LAYOUT_ARRAY,
type, 3, transpose_input_shape));

CHECK_RETURN(api_name,
mluOpSetTensorDescriptor(trans_output_desc, MLUOP_LAYOUT_ARRAY,
type, 3, transpose_output_shape));

int permute[3] = {0, 2, 1};

DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, cnnl_handle);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(trans_input_desc, cnnl_in_desc);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(trans_output_desc,
cnnl_out_desc);

cnnlTransposeDescriptor_t cnnl_trans_desc = NULL;

CALL_CNNL(cnnlCreateTransposeDescriptor(&cnnl_trans_desc));

CALL_CNNL(cnnlSetTransposeDescriptor(cnnl_trans_desc, input_dim, permute));

size_t size = 0;

CALL_CNNL(cnnlGetTransposeWorkspaceSize(cnnl_handle, cnnl_in_desc,
cnnl_trans_desc, &size));

CALL_CNNL(cnnlTranspose_v2(cnnl_handle, cnnl_trans_desc, cnnl_in_desc,
d_input, cnnl_out_desc, d_output, workspace,
size));
return MLUOP_STATUS_SUCCESS;
}
32 changes: 0 additions & 32 deletions kernels/cholesky/complex_cholesky_union1.mlu
Original file line number Diff line number Diff line change
Expand Up @@ -822,35 +822,3 @@ mluOpStatus_t cpotrf_recursion(int batch, int stride, int n, int recnb,
return MLUOP_STATUS_SUCCESS;
}

mluOpStatus_t conj_complex(int batch, int m, int n, float* d_input,
float* d_output, mluOpHandle_t handle) {
if (m == 0) return MLUOP_STATUS_SUCCESS;
cnrtQueue_t queue;
mluOpGetQueue(handle, &queue);

mluOpTensorDescriptor_t input_desc, output_desc;
std::string api_name = "Cholesky";

CHECK_RETURN(api_name, mluOpCreateTensorDescriptor(&input_desc));
CHECK_RETURN(api_name, mluOpCreateTensorDescriptor(&output_desc));

int32_t input_shape[3] = {batch, m, n};
int32_t output_shape[3] = {batch, m, n};

CHECK_RETURN(api_name, mluOpSetTensorDescriptor(
input_desc, MLUOP_LAYOUT_ARRAY,
MLUOP_DTYPE_COMPLEX_FLOAT, 3, input_shape));

CHECK_RETURN(api_name, mluOpSetTensorDescriptor(
output_desc, MLUOP_LAYOUT_ARRAY,
MLUOP_DTYPE_COMPLEX_FLOAT, 3, output_shape));

DEFINE_CREATE_AND_SET_CNNL_HANDLE(handle, cnnl_handle);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(input_desc, cnnl_in_desc);
DEFINE_CREATE_AND_SET_CNNL_TENSOR_DESCRIPTOR(output_desc, cnnl_out_desc);

CALL_CNNL(
cnnlConj(cnnl_handle, cnnl_in_desc, d_input, cnnl_out_desc, d_output));

return MLUOP_STATUS_SUCCESS;
}

0 comments on commit 3213104

Please sign in to comment.