From a9c9db62eb6644de8095383ea05ef57f4d804164 Mon Sep 17 00:00:00 2001 From: dglr <2398621969@qq.com> Date: Thu, 5 Dec 2024 03:23:09 +0800 Subject: [PATCH] [Fix](mluOpCholesky): add new memcpy --- kernels/cholesky/cholesky.cpp | 28 ++++----- kernels/cholesky/cholesky.h | 9 +++ kernels/cholesky/cholesky_union1.mlu | 64 ++++++++++++++++++-- kernels/cholesky/complex_cholesky_union1.mlu | 21 ++++--- 4 files changed, 92 insertions(+), 30 deletions(-) diff --git a/kernels/cholesky/cholesky.cpp b/kernels/cholesky/cholesky.cpp index 754aa00a0..f39929f41 100644 --- a/kernels/cholesky/cholesky.cpp +++ b/kernels/cholesky/cholesky.cpp @@ -110,9 +110,8 @@ calculate_body(mluOpHandle_t handle, int batch_size, transpose(batch_size, size_a, size_a, d_input, d_output, handle, dtype, workspace)); } else { - CNRT_CHECK(cnrtMemcpy(d_output, d_input, - type_size * size_a * lda * ((uint64_t)batch_size), - CNRT_MEM_TRANS_DIR_DEV2DEV)); + KernelMyCnrtMemcpy1D(d_input, d_output, + size_a * lda * ((uint64_t)batch_size), queue, 0); } } else { CHECK_RETURN("mluOpCholesky", @@ -165,9 +164,8 @@ calculate_body(mluOpHandle_t handle, int batch_size, transpose(batch_size, size_a, size_a, d_output, workspace, handle, dtype, workspace)); cnrtQueueSync(queue); - CNRT_CHECK(cnrtMemcpy(d_output, workspace, - type_size * size_a * lda * ((uint64_t)batch_size), - CNRT_MEM_TRANS_DIR_DEV2DEV)); + KernelMyCnrtMemcpy1D(workspace, d_output, + size_a * lda * ((uint64_t)batch_size), queue, 0); } } else { recnb = CRECNB; @@ -232,18 +230,14 @@ calculate_body(mluOpHandle_t handle, int batch_size, cnrtQueueSync(queue); } else { if (batch_size > 16) { - CNRT_CHECK(cnrtMemcpy(d_output, workspace, - type_size * size_a * lda * 16, - CNRT_MEM_TRANS_DIR_DEV2DEV)); - CNRT_CHECK( - cnrtMemcpy(d_output + type_size / 4 * size_a * lda * 16, - workspace + type_size / 4 * size_a * lda * 16, - type_size * size_a * lda * ((uint64_t)batch_size - 16), - CNRT_MEM_TRANS_DIR_DEV2DEV)); + KernelMyCnrtMemcpy1D(workspace, + d_output, size_a * lda * 16, queue, 0); + KernelMyCnrtMemcpy1D(workspace + type_size / 4 * size_a * lda * 16, + d_output + type_size / 4 * size_a * lda * 16, + size_a * lda * ((uint64_t)batch_size - 16), queue, 0); } else { - CNRT_CHECK(cnrtMemcpy(d_output, workspace, - type_size * size_a * lda * ((uint64_t)batch_size), - CNRT_MEM_TRANS_DIR_DEV2DEV)); + KernelMyCnrtMemcpy1D(workspace, + d_output, size_a * lda * ((uint64_t)batch_size), queue, 0); } } } diff --git a/kernels/cholesky/cholesky.h b/kernels/cholesky/cholesky.h index b0fa28529..05e8b92d4 100644 --- a/kernels/cholesky/cholesky.h +++ b/kernels/cholesky/cholesky.h @@ -111,4 +111,13 @@ mluOpStatus_t cherk(int batch, int stride, int n, int k, float* rd_a, float* id_a, int lda, float* rd_c, float* id_c, int ldc, mluOpHandle_t handle, float* workspace); + +mluOpStatus_t KernelMyCnrtMemcpy3D( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + int batch, int m, int n, float *dA, int ldda, + int stride_a, float *dB, int lddb, int stride_b, int mode); + +mluOpStatus_t MLUOP_WIN_API KernelMyCnrtMemcpy1D( + float *dA, float *dB, int n, cnrtQueue_t queue, int mode); + #endif diff --git a/kernels/cholesky/cholesky_union1.mlu b/kernels/cholesky/cholesky_union1.mlu index af4459ea0..e2ec17a7f 100644 --- a/kernels/cholesky/cholesky_union1.mlu +++ b/kernels/cholesky/cholesky_union1.mlu @@ -1074,11 +1074,16 @@ mluOpStatus_t strsm(int batch, int stride, bool upper, bool trans, int m, int n, cnrtQueueSync(queue); - for (int i = 0; i < batch; i++) { - CNRT_CHECK(cnrtMemcpy2D(d_b + i * stride, ldb * sizeof(float), - temp_result + i * m * n, m * sizeof(float), - m * sizeof(float), n, CNRT_MEM_TRANS_DIR_DEV2DEV)); - } + cnrtDim3_t dim1; + dim1.y = 1; + dim1.z = 1; + dim1.x = 4 * batch; + + cnrtFunctionType_t func_type1 = CNRT_FUNC_TYPE_UNION1; + + KernelMyCnrtMemcpy3D(dim1, func_type1, queue, batch, n, m, + temp_result, m, m * n, d_b, ldb, stride, 0); + cnrtQueueSync(queue); return MLUOP_STATUS_SUCCESS; @@ -1146,3 +1151,52 @@ mluOpStatus_t spotrf_recursion(int batch, int stride, bool trans, bool uplo, } +__mlu_entry__ void MLUKernelMyCnrtMemcpy3D(int batch, int m, int n, float *dA, + int ldda, int stride_a, float *dB, + int lddb, int stride_b, int mode) { + int id, batch_id, tx; + + if (batch > 1) { + id = taskId; + batch_id = id / 4; + if (batch_id >= batch) return; + tx = taskId % 4; + dA += batch_id * stride_a; + dB += batch_id * stride_b; + // taskdim = TaskUnion1; + } else { + id = taskId; + batch_id = 0; + // taskdim = taskDim; + tx = taskId; + } + if (tx == 0) + __memcpy(dB, dA, n * sizeof(float), GDRAM2GDRAM, lddb * sizeof(float), + ldda * sizeof(float), m - 1); +} + +// dA: src, dB: dst +mluOpStatus_t MLUOP_WIN_API KernelMyCnrtMemcpy3D( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + int batch, int m, int n, float *dA, int ldda, + int stride_a, float *dB, int lddb, int stride_b, int mode) { + KERNEL_CHECK(MLUKernelMyCnrtMemcpy3D<<>>( + batch, m, n, dA, ldda, stride_a, dB, lddb, stride_b, mode)); + + return MLUOP_STATUS_SUCCESS; +} + +// dA: src, dB: dst +mluOpStatus_t MLUOP_WIN_API KernelMyCnrtMemcpy1D( + float *dA, float *dB, int n, cnrtQueue_t queue, int mode) { + cnrtDim3_t dim1; + dim1.y = 1; + dim1.z = 1; + dim1.x = 4; + + cnrtFunctionType_t func_type1 = CNRT_FUNC_TYPE_UNION1; + KERNEL_CHECK(MLUKernelMyCnrtMemcpy3D<<>>( + 1, 1, n, dA, n, 0, dB, n, 0, mode)); + + return MLUOP_STATUS_SUCCESS; +} diff --git a/kernels/cholesky/complex_cholesky_union1.mlu b/kernels/cholesky/complex_cholesky_union1.mlu index e5e8b0dfd..462ace56e 100644 --- a/kernels/cholesky/complex_cholesky_union1.mlu +++ b/kernels/cholesky/complex_cholesky_union1.mlu @@ -664,14 +664,19 @@ mluOpStatus_t cgemm_real(int batch, bool trans_a, bool trans_b, int m, int n, int copy_lda = k; int copy_stride_a = m * k; - for (int i = 0; i < batch; i++) { - CNRT_CHECK(cnrtMemcpy2D(copy_ra + i * m * k, k * sizeof(float), - d_ra + i * stride_a, lda * sizeof(float), - k * sizeof(float), m, CNRT_MEM_TRANS_DIR_DEV2DEV)); - CNRT_CHECK(cnrtMemcpy2D(copy_ia + i * m * k, k * sizeof(float), - d_ia + i * stride_a, lda * sizeof(float), - k * sizeof(float), m, CNRT_MEM_TRANS_DIR_DEV2DEV)); - } + cnrtDim3_t dim1; + dim1.y = 1; + dim1.z = 1; + dim1.x = 4 * batch; + + cnrtFunctionType_t func_type1 = CNRT_FUNC_TYPE_UNION1; + + + KernelMyCnrtMemcpy3D(dim1, func_type1, queue, batch, m, k, + d_ra, lda, stride_a, copy_ra, k, m * k, 0); + + KernelMyCnrtMemcpy3D(dim1, func_type1, queue, batch, m, k, + d_ia, lda, stride_a, copy_ia, k, m * k, 0); float *r_c, *i_c; r_c = d_rc;