Skip to content

Commit

Permalink
fix ang bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
dglr committed Jul 19, 2024
1 parent 7f745a0 commit c86edf7
Show file tree
Hide file tree
Showing 6 changed files with 745 additions and 397 deletions.
74 changes: 43 additions & 31 deletions kernels/cholesky/cholesky.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "cholesky.h"



mluOpStatus_t MLUOP_WIN_API mluOpGetCholeskyWorkspace(mluOpTensorDescriptor_t input_desc, size_t* size, float** workspace)
{
PARAM_CHECK("mluOpCholesky", input_desc != NULL);
Expand All @@ -18,8 +20,8 @@ mluOpStatus_t MLUOP_WIN_API mluOpGetCholeskyWorkspace(mluOpTensorDescriptor_t in
PARAM_CHECK("mluOpCholesky", dtype == MLUOP_DTYPE_FLOAT || dtype == MLUOP_DTYPE_COMPLEX_FLOAT);

int type_size = (dtype == MLUOP_DTYPE_FLOAT) ? 4 : 8;
int size_a = 0, lda = 0, size_c = 0, ldc = 0;
int batch_size = 1;
long int size_a = 0, lda = 0, size_c = 0, ldc = 0;
long int batch_size = 1;
int dim = input_desc->dim;
if(dim == 2)
{
Expand All @@ -33,17 +35,18 @@ mluOpStatus_t MLUOP_WIN_API mluOpGetCholeskyWorkspace(mluOpTensorDescriptor_t in

if (dtype == MLUOP_DTYPE_FLOAT)
{
*size = 0;
*size = size_a*size_a*sizeof(float)*2*batch_size;
}
else
{
*size = size_a*size_a*sizeof(float)*2*batch_size;
printf("size:%ul\n",(int)(*size));

}
printf("workspace size:%ul\n",(int)(*size));
if(*size>0)
{
CHECK_RETURN("mluOpCholesky",
complex_malloc(*size, workspace));
workspace_malloc(*size, workspace));
}
return MLUOP_STATUS_SUCCESS;
}
Expand Down Expand Up @@ -83,14 +86,13 @@ calculate_body(mluOpHandle_t handle,int batch_size, const mluOpTensorDescriptor_
float* work_space_h;
CNRT_CHECK(cnrtMalloc((void **)&work_space, NB*NB*sizeof(float)));
CNRT_CHECK(cnrtMemset(work_space, 0, NB*NB*sizeof(float)));
work_space_h = (float*)malloc(batch_size*2*lda*lda*sizeof(float));
work_space_h = (float*)malloc(((unsigned long)batch_size)*2*lda*lda*sizeof(float));
PARAM_CHECK("mluOpCholesky", lda >= size_a);
PARAM_CHECK("mluOpCholesky", ldc >= size_c);

cnrtQueue_t queue;
mluOpGetQueue(handle,&queue);


int jb;
const float s_one = 1.0;
const float s_neg_one = -1.0;
Expand All @@ -100,29 +102,30 @@ calculate_body(mluOpHandle_t handle,int batch_size, const mluOpTensorDescriptor_
if(upper == true)
{
CHECK_RETURN("mluOpCholesky",
transpose(batch_size,size_a,size_a,d_input,d_output,handle));
transpose(batch_size,size_a,size_a,d_input,d_output,handle,dtype,workspace));
}
else
{
CNRT_CHECK(cnrtMemcpy(d_output, d_input, type_size*size_a*lda*batch_size, CNRT_MEM_TRANS_DIR_DEV2DEV));
CNRT_CHECK(cnrtMemcpy(d_output, d_input, type_size*size_a*lda*((unsigned long)batch_size), CNRT_MEM_TRANS_DIR_DEV2DEV));
}
}
else
{

CHECK_RETURN("mluOpCholesky",
transpose(batch_size,size_a*size_a,2,d_input,d_output,handle));
transpose(batch_size,size_a*size_a,2,d_input,d_output,handle,MLUOP_DTYPE_FLOAT,workspace));
}

cnrtQueueSync(queue);
int stride = size_a*lda;



if(dtype == MLUOP_DTYPE_FLOAT)
{

int row = is_row_major ? lda : size_a;
int nb = NB;
set_half_zero(batch_size, stride, d_output, lda, lda, handle);
cnrtQueueSync(queue);
for(int j = 0; j < row; j+=nb)
{
jb = std::min(nb, row-j);
Expand Down Expand Up @@ -152,7 +155,9 @@ calculate_body(mluOpHandle_t handle,int batch_size, const mluOpTensorDescriptor_
{
cnrtQueueSync(queue);
CHECK_RETURN("mluOpCholesky",
transpose(batch_size, size_a,size_a,d_output,d_output,handle));
transpose(batch_size, size_a,size_a,d_output,workspace,handle,dtype,workspace));
cnrtQueueSync(queue);
CNRT_CHECK(cnrtMemcpy(d_output, workspace, type_size*size_a*lda*((unsigned long)batch_size), CNRT_MEM_TRANS_DIR_DEV2DEV));
}
}
else
Expand Down Expand Up @@ -195,20 +200,35 @@ calculate_body(mluOpHandle_t handle,int batch_size, const mluOpTensorDescriptor_
OFFSET_ROW(r_start,j+jb,j),OFFSET_ROW(i_start,j+jb,j),lda, handle));
cnrtQueueSync(queue);
}
}

}

CHECK_RETURN("mluOpCholesky",
transpose(batch_size,2,size_a*size_a,d_output,workspace,handle));
transpose(batch_size,2,size_a*size_a,d_output,workspace,handle,MLUOP_DTYPE_FLOAT,workspace));
cnrtQueueSync(queue);
if(batch_size > 16)



if(upper)
{
CNRT_CHECK(cnrtMemcpy(d_output, workspace, type_size*size_a*lda*16, CNRT_MEM_TRANS_DIR_DEV2DEV));
CNRT_CHECK(cnrtMemcpy(d_output+type_size/4*size_a*lda*16, workspace+type_size/4*size_a*lda*16, type_size*size_a*lda*(batch_size-16), CNRT_MEM_TRANS_DIR_DEV2DEV));
cnrtQueueSync(queue);
CHECK_RETURN("mluOpCholesky",
transpose(batch_size, size_a,size_a,workspace,d_output,handle,dtype,workspace));
cnrtQueueSync(queue);
CHECK_RETURN("mluOpCholesky",
conj_complex(batch_size, size_a,size_a,d_output,d_output,handle));
cnrtQueueSync(queue);
}
else
{
CNRT_CHECK(cnrtMemcpy(d_output, workspace, type_size*size_a*lda*batch_size, CNRT_MEM_TRANS_DIR_DEV2DEV));
if(batch_size > 16)
{
CNRT_CHECK(cnrtMemcpy(d_output, workspace, type_size*size_a*lda*16, CNRT_MEM_TRANS_DIR_DEV2DEV));
CNRT_CHECK(cnrtMemcpy(d_output+type_size/4*size_a*lda*16, workspace+type_size/4*size_a*lda*16, type_size*size_a*lda*((unsigned long)batch_size-16), CNRT_MEM_TRANS_DIR_DEV2DEV));
}
else
{
CNRT_CHECK(cnrtMemcpy(d_output, workspace, type_size*size_a*lda*((unsigned long)batch_size), CNRT_MEM_TRANS_DIR_DEV2DEV));
}
}


Expand Down Expand Up @@ -271,19 +291,11 @@ mluOpCholesky(mluOpHandle_t handle,const mluOpTensorDescriptor_t input_desc,floa
ldc = output_desc->dims[2];
}

float* last_addr = d_input+batch_size*size_a*lda*2;
float* last_addr = d_input+((unsigned long)batch_size)*size_a*lda*2;
float* temp_addr = last_addr - 10;


float* work_space_h;
work_space_h = (float*)malloc(100*sizeof(float));
cnrtMemcpy(work_space_h, temp_addr, 10*sizeof(float), CNRT_MEM_TRANS_DIR_DEV2HOST);
printf("last 10 input:\n");
for(int i = 0; i < 10;i++)
{
printf("%8.3f ",work_space_h[i]);
}
printf("\n");



int type_size = (dtype == MLUOP_DTYPE_FLOAT) ? 4 : 8;
Expand All @@ -292,7 +304,7 @@ mluOpCholesky(mluOpHandle_t handle,const mluOpTensorDescriptor_t input_desc,floa
int stride = 2*size_a*lda;
calculate_body(handle, 16, input_desc,d_input, output_desc, d_output, upper, workspace);
cnrtQueueSync(queue);
calculate_body(handle, batch_size-16, input_desc,d_input+16*stride, output_desc, d_output+16*stride, upper, workspace);
calculate_body(handle, ((unsigned long)batch_size)-16, input_desc,d_input+16*stride, output_desc, d_output+16*stride, upper, workspace);
}
else
{
Expand Down
16 changes: 11 additions & 5 deletions kernels/cholesky/cholesky.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,37 +23,43 @@
#define CNB (16)
#define REC_NB (16)
#define POTF_NB ((REC_NB)/4)
#define CREC_NB (8)
#define CREC_NB (16)
#define CPOTF_NB ((CREC_NB)/4)
// #define CPOTF_NB ((CREC_NB))
#define __CNRT_FUNC_TYPE__ CNRT_FUNC_TYPE_UNION1
#define TASK_NUM (4)
#define NB (32)

#define CLUSTER_NUM 1
#define M (TASK_NUM * POTF_NB)
#define M (TASK_NUM * POTF_NB) //POTF边长
#define ZERO 0.0
#define SHARED_MEM_SIZE (((M*POTF_NB/TASK_NUM * 4)+(POTF_NB * POTF_NB)))
#define OFFSET_ROW(A, i, j) A + ((i) * (lda) + (j))
#define OFFSET_B_ROW(B, i, j) B + ((i) * (ldb) + (j))


mluOpStatus_t mlu_spotrf_rectile(int batch, int stride, bool trans, bool uplo, int n, int recnb, float* dA, int ldda, int gbstep, mluOpHandle_t handle);
// void mluOpCholesky(bool trans, bool uplo, int n, float* dA, float* dC, int ldda);

mluOpStatus_t ssyrk(int batch, int stride, bool upper, bool trans,int n, int k, float* d_a, int ldda, float* d_c, int lddc, mluOpHandle_t handle);

mluOpStatus_t sgemm(int batch, bool trans_a, bool trans_b, int m, int n, int k, float alpha, float beta, float* d_a,int lda, int stride_a, float* d_b, int ldb, int stride_b, float* d_c, int ldc, int stride_c, mluOpHandle_t handle);


//side:true->right
// false->left
mluOpStatus_t strsm(int batch, int stride, bool upper, bool trans, int m, int n, float* d_a, int ldda, float* d_b, int lddb, mluOpHandle_t handle);

mluOpStatus_t transpose(int batch, int m, int n,float* d_input,float* d_output, mluOpHandle_t handle);
mluOpStatus_t transpose(int batch, int m, int n,float* d_input,float* d_output, mluOpHandle_t handle,mluOpDataType_t type, float* workspace);

mluOpStatus_t conj_complex(int batch, int m, int n,float* d_input,float* d_output, mluOpHandle_t handle);

mluOpStatus_t mlu_cpotrf_rectile(int batch, int stride, int n, int recnb, float* drA, float* diA, int lda, mluOpHandle_t handle);

mluOpStatus_t cgemm(int batch, bool trans_a, bool trans_b, int m, int n, int k, float alpha, float beta, float* d_ra, float* d_ia, int lda, int stride_a, float* d_rb, float* d_ib, int ldb, int stride_b, float* d_rc, float* d_ic, int ldc, int stride_c, mluOpHandle_t handle);

mluOpStatus_t complex_malloc(size_t size, float** workspace);
mluOpStatus_t workspace_malloc(size_t size, float** workspace);

// mluOpStatus_t complex_set_half_zero(int batch, int stride, float* d_a, int m, int ld);

mluOpStatus_t set_half_zero(int batch,int stride,float* d_a, int lda, int m, mluOpHandle_t handle);

Expand Down
Loading

0 comments on commit c86edf7

Please sign in to comment.