Skip to content

Commit

Permalink
finish complex batch
Browse files Browse the repository at this point in the history
  • Loading branch information
dglr committed Jun 28, 2024
1 parent 0da4788 commit 7f745a0
Show file tree
Hide file tree
Showing 4 changed files with 751 additions and 2,081 deletions.
253 changes: 105 additions & 148 deletions kernels/cholesky/cholesky.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
#include "cholesky.h"

//dA:输入被分解方阵
//dC:cholesky分解结果方阵
//trans -> false: col major; true: row major
//uplo -> false: lower; true: upper
//ldda:leading dimension

mluOpStatus_t MLUOP_WIN_API mluOpGetCholeskyWorkspace(mluOpTensorDescriptor_t input_desc, size_t* size, float** workspace)
{
PARAM_CHECK("mluOpCholesky", input_desc != NULL);
Expand Down Expand Up @@ -36,11 +30,9 @@ mluOpStatus_t MLUOP_WIN_API mluOpGetCholeskyWorkspace(mluOpTensorDescriptor_t in
batch_size = input_desc->dims[0];
size_a = input_desc->dims[1];
}
printf("fuck you!");

if (dtype == MLUOP_DTYPE_FLOAT)
{
// *size = size_a*size_a*sizeof(float);
*size = 0;
}
else
Expand All @@ -57,30 +49,10 @@ mluOpStatus_t MLUOP_WIN_API mluOpGetCholeskyWorkspace(mluOpTensorDescriptor_t in
}

mluOpStatus_t MLUOP_WIN_API
mluOpCholesky(mluOpHandle_t handle,const mluOpTensorDescriptor_t input_desc,float* d_input, const mluOpTensorDescriptor_t output_desc, float* d_output,bool upper, float* workspace)
calculate_body(mluOpHandle_t handle,int batch_size, const mluOpTensorDescriptor_t input_desc,float* d_input, const mluOpTensorDescriptor_t output_desc, float* d_output,bool upper, float* workspace)
{
PARAM_CHECK("mluOpCholesky", handle != NULL);
PARAM_CHECK("mluOpCholesky", input_desc != NULL);
PARAM_CHECK("mluOpCholesky", d_input != NULL);
PARAM_CHECK("mluOpCholesky", output_desc != NULL);
PARAM_CHECK("mluOpCholesky", d_output != NULL);

PARAM_CHECK("mluOpCholesky", input_desc->dim == 2||input_desc->dim == 3);
PARAM_CHECK("mluOpCholesky", output_desc->dim == input_desc->dim);
PARAM_CHECK("mluOpCholesky", input_desc->dims[0] > 0);
PARAM_CHECK("mluOpCholesky", input_desc->dims[1] > 0);
PARAM_CHECK("mluOpCholesky", output_desc->dims[0] > 0);
PARAM_CHECK("mluOpCholesky", output_desc->dims[1] > 0);

if(input_desc->dim == 3)
{
PARAM_CHECK("mluOpCholesky", input_desc->dims[2] > 0);
PARAM_CHECK("mluOpCholesky", output_desc->dims[2] > 0);
}

mluOpDataType_t dtype = input_desc->dtype;
PARAM_CHECK("mluOpCholesky", dtype == output_desc->dtype);
PARAM_CHECK("mluOpCholesky", dtype == MLUOP_DTYPE_FLOAT || dtype == MLUOP_DTYPE_COMPLEX_FLOAT);
printf("batch_size:%d\n",batch_size);


int recnb = REC_NB;
Expand All @@ -91,7 +63,6 @@ mluOpCholesky(mluOpHandle_t handle,const mluOpTensorDescriptor_t input_desc,floa

int type_size = (dtype == MLUOP_DTYPE_FLOAT) ? 4 : 8;
int size_a = 0, lda = 0, size_c = 0, ldc = 0;
int batch_size = 1;
if(dim == 2)
{
size_a = input_desc->dims[0];
Expand All @@ -101,7 +72,6 @@ mluOpCholesky(mluOpHandle_t handle,const mluOpTensorDescriptor_t input_desc,floa
}
else if(dim == 3)
{
batch_size = input_desc->dims[0];
size_a = input_desc->dims[1];
lda = input_desc->dims[2];
size_c = output_desc->dims[1];
Expand All @@ -113,18 +83,13 @@ mluOpCholesky(mluOpHandle_t handle,const mluOpTensorDescriptor_t input_desc,floa
float* work_space_h;
CNRT_CHECK(cnrtMalloc((void **)&work_space, NB*NB*sizeof(float)));
CNRT_CHECK(cnrtMemset(work_space, 0, NB*NB*sizeof(float)));
work_space_h = (float*)malloc(NB*NB*sizeof(float));
work_space_h = (float*)malloc(batch_size*2*lda*lda*sizeof(float));
PARAM_CHECK("mluOpCholesky", lda >= size_a);
PARAM_CHECK("mluOpCholesky", ldc >= size_c);

cnrtQueue_t queue;
mluOpGetQueue(handle,&queue);
// CNRT_CHECK(cnrtSetDevice(0));
// CNRT_CHECK(cnrtQueueCreate(&queue));

// cnrtNotifier_t start, end;
// CNRT_CHECK(cnrtNotifierCreate(&start));
// CNRT_CHECK(cnrtNotifierCreate(&end));

int jb;
const float s_one = 1.0;
Expand All @@ -150,11 +115,9 @@ mluOpCholesky(mluOpHandle_t handle,const mluOpTensorDescriptor_t input_desc,floa
}

cnrtQueueSync(queue);

//TODO:检查拷贝开销

int stride = size_a*lda;
//printf original matrix


if(dtype == MLUOP_DTYPE_FLOAT)
{

Expand All @@ -168,7 +131,6 @@ mluOpCholesky(mluOpHandle_t handle,const mluOpTensorDescriptor_t input_desc,floa
cnrtQueueSync(queue);
CHECK_RETURN("mluOpCholesky",
mlu_spotrf_rectile(batch_size,stride,is_row_major,false,jb,recnb,OFFSET_ROW(d_output,j,j),lda,j, handle));
// cnrtQueueSync(queue);
if(j+jb < row)
{
CHECK_RETURN("mluOpCholesky",
Expand Down Expand Up @@ -196,14 +158,16 @@ mluOpCholesky(mluOpHandle_t handle,const mluOpTensorDescriptor_t input_desc,floa
else
{
recnb = CREC_NB;
// int nb = NB;
int nb = NB;
int nb = CNB;
int row = lda;
float* r_start = d_output; //实数首地址
float* i_start = d_output + size_a*lda;//虚数首地址
float* r_start = d_output;
float* i_start = d_output + size_a*lda;
stride *= 2;

set_half_zero(batch_size, size_a*lda, r_start, lda, lda, handle);
set_half_zero(batch_size, size_a*lda, i_start, lda, lda, handle);

set_half_zero(batch_size, stride, r_start, lda, lda, handle);
set_half_zero(batch_size, stride, i_start, lda, lda, handle);
cnrtQueueSync(queue);

for(int j = 0; j < row; j+=nb)
{
Expand All @@ -213,7 +177,7 @@ mluOpCholesky(mluOpHandle_t handle,const mluOpTensorDescriptor_t input_desc,floa
cnrtQueueSync(queue);
CHECK_RETURN("mluOpCholesky",
mlu_cpotrf_rectile(batch_size,stride,jb,recnb,r_start+j*lda+j,i_start+j*lda+j,lda, handle));
// cnrtQueueSync(queue);
cnrtQueueSync(queue);
if(j+jb < row)
{
CHECK_RETURN("mluOpCholesky",
Expand All @@ -233,118 +197,111 @@ mluOpCholesky(mluOpHandle_t handle,const mluOpTensorDescriptor_t input_desc,floa
}
}

// printf("after transpose, d_output:\n");
// for(int i = 0; i < 2; i++)
// {
// for(int j = 0; j < lda; j++)
// {
// for(int h = 0; h < lda; h++)
// {
// cnrtMemcpy(work_space_h, d_output+i*lda*lda+j*lda+h, sizeof(float), CNRT_MEM_TRANS_DIR_DEV2HOST);
// printf("%8.3f",*work_space_h);
// }
// printf("\n");
// }
// printf("\n");
// }



printf("before finally, transpose:\n");
cnrtMemcpy(work_space_h, d_output, sizeof(float)*lda*lda*2, CNRT_MEM_TRANS_DIR_DEV2HOST);
printf("real result:\n");
for(int j = 0; j < lda; j++)
{
for(int h = 0; h < lda; h++)
CHECK_RETURN("mluOpCholesky",
transpose(batch_size,2,size_a*size_a,d_output,workspace,handle));
cnrtQueueSync(queue);
if(batch_size > 16)
{
printf("%8.3f",work_space_h[j*lda+h]);
CNRT_CHECK(cnrtMemcpy(d_output, workspace, type_size*size_a*lda*16, CNRT_MEM_TRANS_DIR_DEV2DEV));
CNRT_CHECK(cnrtMemcpy(d_output+type_size/4*size_a*lda*16, workspace+type_size/4*size_a*lda*16, type_size*size_a*lda*(batch_size-16), CNRT_MEM_TRANS_DIR_DEV2DEV));
}
printf("\n");
}
printf("\n");
printf("imag result:\n");
for(int j = 0; j < lda; j++)
{
for(int h = 0; h < lda; h++)
else
{
printf("%8.3f",work_space_h[lda*lda+j*lda+h]);
CNRT_CHECK(cnrtMemcpy(d_output, workspace, type_size*size_a*lda*batch_size, CNRT_MEM_TRANS_DIR_DEV2DEV));
}
printf("\n");


}


// CHECK_RETURN("mluOpCholesky",
// sgemm(batch_size, false,true,row-j-jb,jb,j,-1.0f,1.0f,
// OFFSET_ROW(d_output,j+jb,0),lda,stride,
// OFFSET_ROW(d_output,j,0),lda,stride,
// OFFSET_ROW(d_output,j+jb,j),lda,stride, handle));
// cnrtQueueSync(queue);

// cnrtMemcpy(work_space_h, d_output, sizeof(float)*lda*lda*2, CNRT_MEM_TRANS_DIR_DEV2HOST);
// for(int i = 0; i < 2; i++)
// {
// for(int j = 0; j < lda; j++)
// {
// for(int h = 0; h < lda; h++)
// {
// // cnrtMemcpy(work_space_h, d_output+i*lda*lda+j*lda+h, sizeof(float), CNRT_MEM_TRANS_DIR_DEV2HOST);
// printf("%8.3f",work_space_h[i*lda*lda+j*lda+h]);
// }
// printf("\n");
// }
// printf("\n");
// }

CHECK_RETURN("mluOpCholesky",
transpose(batch_size,2,size_a*size_a,d_output,workspace,handle));
cnrtQueueSync(queue);
CNRT_CHECK(cnrtMemcpy(d_output, workspace, type_size*size_a*lda*batch_size, CNRT_MEM_TRANS_DIR_DEV2DEV));

// printf("after transpose, d_a:\n");

// for(int j = 0; j < lda; j++)
// {
// for(int h = 0; h < lda; h++)
// {
// cnrtMemcpy(work_space_h, d_output+j*lda*2+h*2, sizeof(float), CNRT_MEM_TRANS_DIR_DEV2HOST);
// cnrtMemcpy((work_space_h+1), d_output+j*lda*2+h*2+1, sizeof(float), CNRT_MEM_TRANS_DIR_DEV2HOST);
// printf("%8.3f,%8.3f ",work_space_h[0],work_space_h[1]);
// }
// printf("\n");
// }
cnrtQueueSync(queue);

return MLUOP_STATUS_SUCCESS;
}


mluOpStatus_t MLUOP_WIN_API
mluOpCholesky(mluOpHandle_t handle,const mluOpTensorDescriptor_t input_desc,float* d_input, const mluOpTensorDescriptor_t output_desc, float* d_output,bool upper, float* workspace)
{
PARAM_CHECK("mluOpCholesky", handle != NULL);
PARAM_CHECK("mluOpCholesky", input_desc != NULL);
PARAM_CHECK("mluOpCholesky", d_input != NULL);
PARAM_CHECK("mluOpCholesky", output_desc != NULL);
PARAM_CHECK("mluOpCholesky", d_output != NULL);

PARAM_CHECK("mluOpCholesky", input_desc->dim == 2||input_desc->dim == 3);
PARAM_CHECK("mluOpCholesky", output_desc->dim == input_desc->dim);
PARAM_CHECK("mluOpCholesky", input_desc->dims[0] > 0);
PARAM_CHECK("mluOpCholesky", input_desc->dims[1] > 0);
PARAM_CHECK("mluOpCholesky", output_desc->dims[0] > 0);
PARAM_CHECK("mluOpCholesky", output_desc->dims[1] > 0);

cnrtQueue_t queue;
mluOpGetQueue(handle,&queue);

if(input_desc->dim == 3)
{
PARAM_CHECK("mluOpCholesky", input_desc->dims[2] > 0);
PARAM_CHECK("mluOpCholesky", output_desc->dims[2] > 0);
}


// printf("matrix after calculate:\n");
// for(int i = 0; i < batch_size; i++)
// {
// printf("batch %d:\n",i);
// for(int j = 0; j < size_a; j++)
// {
// for(int k = 0; k < size_a; k++)
// {
// cnrtMemcpy(work_space_h, d_output + i*stride+j*lda+k, sizeof(float), CNRT_MEM_TRANS_DIR_DEV2HOST);
// printf("%.2f ",work_space_h[0]);
// }
// printf("\n");
// }
// }

mluOpDataType_t dtype = input_desc->dtype;
PARAM_CHECK("mluOpCholesky", dtype == output_desc->dtype);
PARAM_CHECK("mluOpCholesky", dtype == MLUOP_DTYPE_FLOAT || dtype == MLUOP_DTYPE_COMPLEX_FLOAT);

int dim = input_desc->dim;
int size_a = 0, lda = 0, size_c = 0, ldc = 0;

int batch_size = 1;
if(dim == 2)
{
size_a = input_desc->dims[0];
lda = input_desc->dims[1];
size_c = output_desc->dims[0];
ldc = output_desc->dims[1];
}
else if(dim == 3)
{
batch_size = input_desc->dims[0];
size_a = input_desc->dims[1];
lda = input_desc->dims[2];
size_c = output_desc->dims[1];
ldc = output_desc->dims[2];
}

float* last_addr = d_input+batch_size*size_a*lda*2;
float* temp_addr = last_addr - 10;


float* work_space_h;
work_space_h = (float*)malloc(100*sizeof(float));
cnrtMemcpy(work_space_h, temp_addr, 10*sizeof(float), CNRT_MEM_TRANS_DIR_DEV2HOST);
printf("last 10 input:\n");
for(int i = 0; i < 10;i++)
{
printf("%8.3f ",work_space_h[i]);
}
printf("\n");


int type_size = (dtype == MLUOP_DTYPE_FLOAT) ? 4 : 8;
if(type_size == 8 && batch_size > 16 && size_a > 2000)
{
int stride = 2*size_a*lda;
calculate_body(handle, 16, input_desc,d_input, output_desc, d_output, upper, workspace);
cnrtQueueSync(queue);
calculate_body(handle, batch_size-16, input_desc,d_input+16*stride, output_desc, d_output+16*stride, upper, workspace);
}
else
{
calculate_body(handle, batch_size, input_desc,d_input, output_desc, d_output, upper, workspace);
}



cnrtQueueSync(queue);

// cnrtMemcpy(work_space_h, work_space, sizeof(float)*NB*NB, CNRT_MEM_TRANS_DIR_DEV2HOST);
//print work_space_h
// printf("work_space:\n");
// for(int i = 0; i < NB; i++)
// {
// for(int j = 0; j < NB; j++)
// {
// printf("%.2f ",work_space_h[i*NB+j]);
// }
// printf("\n");
// }

return MLUOP_STATUS_SUCCESS;
}
Loading

0 comments on commit 7f745a0

Please sign in to comment.