Skip to content

Commit

Permalink
[Fix](mluOpCholesky): add new memcpy
Browse files Browse the repository at this point in the history
  • Loading branch information
dglr committed Dec 4, 2024
1 parent 3213104 commit a9c9db6
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 30 deletions.
28 changes: 11 additions & 17 deletions kernels/cholesky/cholesky.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,8 @@ calculate_body(mluOpHandle_t handle, int batch_size,
transpose(batch_size, size_a, size_a, d_input, d_output,
handle, dtype, workspace));
} else {
CNRT_CHECK(cnrtMemcpy(d_output, d_input,
type_size * size_a * lda * ((uint64_t)batch_size),
CNRT_MEM_TRANS_DIR_DEV2DEV));
KernelMyCnrtMemcpy1D(d_input, d_output,
size_a * lda * ((uint64_t)batch_size), queue, 0);
}
} else {
CHECK_RETURN("mluOpCholesky",
Expand Down Expand Up @@ -165,9 +164,8 @@ calculate_body(mluOpHandle_t handle, int batch_size,
transpose(batch_size, size_a, size_a, d_output, workspace,
handle, dtype, workspace));
cnrtQueueSync(queue);
CNRT_CHECK(cnrtMemcpy(d_output, workspace,
type_size * size_a * lda * ((uint64_t)batch_size),
CNRT_MEM_TRANS_DIR_DEV2DEV));
KernelMyCnrtMemcpy1D(workspace, d_output,
size_a * lda * ((uint64_t)batch_size), queue, 0);
}
} else {
recnb = CRECNB;
Expand Down Expand Up @@ -232,18 +230,14 @@ calculate_body(mluOpHandle_t handle, int batch_size,
cnrtQueueSync(queue);
} else {
if (batch_size > 16) {
CNRT_CHECK(cnrtMemcpy(d_output, workspace,
type_size * size_a * lda * 16,
CNRT_MEM_TRANS_DIR_DEV2DEV));
CNRT_CHECK(
cnrtMemcpy(d_output + type_size / 4 * size_a * lda * 16,
workspace + type_size / 4 * size_a * lda * 16,
type_size * size_a * lda * ((uint64_t)batch_size - 16),
CNRT_MEM_TRANS_DIR_DEV2DEV));
KernelMyCnrtMemcpy1D(workspace,
d_output, size_a * lda * 16, queue, 0);
KernelMyCnrtMemcpy1D(workspace + type_size / 4 * size_a * lda * 16,
d_output + type_size / 4 * size_a * lda * 16,
size_a * lda * ((uint64_t)batch_size - 16), queue, 0);
} else {
CNRT_CHECK(cnrtMemcpy(d_output, workspace,
type_size * size_a * lda * ((uint64_t)batch_size),
CNRT_MEM_TRANS_DIR_DEV2DEV));
KernelMyCnrtMemcpy1D(workspace,
d_output, size_a * lda * ((uint64_t)batch_size), queue, 0);
}
}
}
Expand Down
9 changes: 9 additions & 0 deletions kernels/cholesky/cholesky.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,13 @@ mluOpStatus_t cherk(int batch, int stride, int n, int k, float* rd_a,
float* id_a, int lda, float* rd_c, float* id_c, int ldc,
mluOpHandle_t handle, float* workspace);


mluOpStatus_t KernelMyCnrtMemcpy3D(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
int batch, int m, int n, float *dA, int ldda,
int stride_a, float *dB, int lddb, int stride_b, int mode);

mluOpStatus_t MLUOP_WIN_API KernelMyCnrtMemcpy1D(
float *dA, float *dB, int n, cnrtQueue_t queue, int mode);

#endif
64 changes: 59 additions & 5 deletions kernels/cholesky/cholesky_union1.mlu
Original file line number Diff line number Diff line change
Expand Up @@ -1074,11 +1074,16 @@ mluOpStatus_t strsm(int batch, int stride, bool upper, bool trans, int m, int n,

cnrtQueueSync(queue);

for (int i = 0; i < batch; i++) {
CNRT_CHECK(cnrtMemcpy2D(d_b + i * stride, ldb * sizeof(float),
temp_result + i * m * n, m * sizeof(float),
m * sizeof(float), n, CNRT_MEM_TRANS_DIR_DEV2DEV));
}
cnrtDim3_t dim1;
dim1.y = 1;
dim1.z = 1;
dim1.x = 4 * batch;

cnrtFunctionType_t func_type1 = CNRT_FUNC_TYPE_UNION1;

KernelMyCnrtMemcpy3D(dim1, func_type1, queue, batch, n, m,
temp_result, m, m * n, d_b, ldb, stride, 0);

cnrtQueueSync(queue);

return MLUOP_STATUS_SUCCESS;
Expand Down Expand Up @@ -1146,3 +1151,52 @@ mluOpStatus_t spotrf_recursion(int batch, int stride, bool trans, bool uplo,
}


__mlu_entry__ void MLUKernelMyCnrtMemcpy3D(int batch, int m, int n, float *dA,
int ldda, int stride_a, float *dB,
int lddb, int stride_b, int mode) {
int id, batch_id, tx;

if (batch > 1) {
id = taskId;
batch_id = id / 4;
if (batch_id >= batch) return;
tx = taskId % 4;
dA += batch_id * stride_a;
dB += batch_id * stride_b;
// taskdim = TaskUnion1;
} else {
id = taskId;
batch_id = 0;
// taskdim = taskDim;
tx = taskId;
}
if (tx == 0)
__memcpy(dB, dA, n * sizeof(float), GDRAM2GDRAM, lddb * sizeof(float),
ldda * sizeof(float), m - 1);
}

// dA: src, dB: dst
mluOpStatus_t MLUOP_WIN_API KernelMyCnrtMemcpy3D(
cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue,
int batch, int m, int n, float *dA, int ldda,
int stride_a, float *dB, int lddb, int stride_b, int mode) {
KERNEL_CHECK(MLUKernelMyCnrtMemcpy3D<<<k_dim, k_type, queue>>>(
batch, m, n, dA, ldda, stride_a, dB, lddb, stride_b, mode));

return MLUOP_STATUS_SUCCESS;
}

// dA: src, dB: dst
mluOpStatus_t MLUOP_WIN_API KernelMyCnrtMemcpy1D(
float *dA, float *dB, int n, cnrtQueue_t queue, int mode) {
cnrtDim3_t dim1;
dim1.y = 1;
dim1.z = 1;
dim1.x = 4;

cnrtFunctionType_t func_type1 = CNRT_FUNC_TYPE_UNION1;
KERNEL_CHECK(MLUKernelMyCnrtMemcpy3D<<<dim1, func_type1, queue>>>(
1, 1, n, dA, n, 0, dB, n, 0, mode));

return MLUOP_STATUS_SUCCESS;
}
21 changes: 13 additions & 8 deletions kernels/cholesky/complex_cholesky_union1.mlu
Original file line number Diff line number Diff line change
Expand Up @@ -664,14 +664,19 @@ mluOpStatus_t cgemm_real(int batch, bool trans_a, bool trans_b, int m, int n,
int copy_lda = k;
int copy_stride_a = m * k;

for (int i = 0; i < batch; i++) {
CNRT_CHECK(cnrtMemcpy2D(copy_ra + i * m * k, k * sizeof(float),
d_ra + i * stride_a, lda * sizeof(float),
k * sizeof(float), m, CNRT_MEM_TRANS_DIR_DEV2DEV));
CNRT_CHECK(cnrtMemcpy2D(copy_ia + i * m * k, k * sizeof(float),
d_ia + i * stride_a, lda * sizeof(float),
k * sizeof(float), m, CNRT_MEM_TRANS_DIR_DEV2DEV));
}
cnrtDim3_t dim1;
dim1.y = 1;
dim1.z = 1;
dim1.x = 4 * batch;

cnrtFunctionType_t func_type1 = CNRT_FUNC_TYPE_UNION1;


KernelMyCnrtMemcpy3D(dim1, func_type1, queue, batch, m, k,
d_ra, lda, stride_a, copy_ra, k, m * k, 0);

KernelMyCnrtMemcpy3D(dim1, func_type1, queue, batch, m, k,
d_ia, lda, stride_a, copy_ia, k, m * k, 0);

float *r_c, *i_c;
r_c = d_rc;
Expand Down

0 comments on commit a9c9db6

Please sign in to comment.