From f03925010d11f038d0ee45710f8165632502dc3b Mon Sep 17 00:00:00 2001 From: Critsium Date: Wed, 23 Oct 2024 13:40:08 -0400 Subject: [PATCH] Feature: Porting abacus to DSP hardware (mtblas part) (#5301) * Link mtblas library * Add mtblas gemm kernel usage * Finish memory_op on dsp * Update CMakeLists * Add compilation script * Fix warnings * Fix install script * Initialize DSP hardware * Replace gemm in math_kernel * Fix CMakeLists Bug * Fix bugs #1 * Fix bug 2 * Fix link to shared library error * Stop use gemm_mt globally * Modify op usage * Fix bug * Fix template usage * Fix compilation * Replace all dav_subspace gemm kernels --------- Co-authored-by: Mohan Chen --- CMakeLists.txt | 12 ++++ install_dsp.sh | 10 +++ source/module_base/blas_connector.cpp | 45 +++++++++++-- .../module_base/kernels/dsp/dsp_connector.h | 66 +++++++++++++++++++ .../module_base/module_device/memory_op.cpp | 15 +++++ source/module_base/module_device/types.h | 1 + source/module_esolver/esolver_ks_pw.cpp | 13 +++- source/module_hsolver/diago_dav_subspace.cpp | 42 ++++++++++-- .../module_hsolver/kernels/math_kernel_op.cpp | 28 ++++++++ .../module_hsolver/kernels/math_kernel_op.h | 30 +++++++++ 10 files changed, 251 insertions(+), 11 deletions(-) create mode 100644 install_dsp.sh create mode 100644 source/module_base/kernels/dsp/dsp_connector.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 637aa95d3e..62dfd41073 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,6 +39,7 @@ option(ENABLE_RAPIDJSON "Enable rapid-json usage." OFF) option(ENABLE_CNPY "Enable cnpy usage." OFF) option(ENABLE_PEXSI "Enable support for PEXSI." OFF) option(ENABLE_CUSOLVERMP "Enable cusolvermp." OFF) +option(USE_DSP "Enable DSP usage." OFF) # enable json support if(ENABLE_RAPIDJSON) @@ -119,6 +120,12 @@ elseif(ENABLE_LCAO AND NOT ENABLE_MPI) set(ABACUS_BIN_NAME abacus_serial) endif() +if (USE_DSP) + set(USE_ELPA OFF) + set(ENABLE_LCAO OFF) + set(ABACUS_BIN_NAME abacus_dsp) +endif() + list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) if(ENABLE_COVERAGE) @@ -240,6 +247,11 @@ if(ENABLE_MPI) list(APPEND math_libs MPI::MPI_CXX) endif() +if (USE_DSP) + target_link_libraries(${ABACUS_BIN_NAME} ${DIR_MTBLAS_LIBRARY}) + add_compile_definitions(__DSP) +endif() + find_package(Threads REQUIRED) target_link_libraries(${ABACUS_BIN_NAME} Threads::Threads) diff --git a/install_dsp.sh b/install_dsp.sh new file mode 100644 index 0000000000..7ae2f48ffa --- /dev/null +++ b/install_dsp.sh @@ -0,0 +1,10 @@ +CXX=mpicxx \ + cmake -B build \ + -DUSE_DSP=ON \ + -DENABLE_LCAO=OFF \ + -DFFTW3_DIR=/vol8/appsoftware/fftw/ \ + -DFFTW3_LIBRARY=/vol8/appsoftware/fftw/lib/libfftw3.so \ + -DFFTW3_OMP_LIBRARY=/vol8/appsoftware/fftw/lib/libfftw3_omp.so \ + -DFFTW3_FLOAT_LIBRARY=/vol8/appsoftware/fftw/lib/libfftw3f.so \ + -DLAPACK_DIR=/vol8/appsoftware/openblas/0.3.21/lib \ + -DDIR_MTBLAS_LIBRARY=/vol8/home/dptech_zyz1/develop/packages/libmtblas_abacus.so \ No newline at end of file diff --git a/source/module_base/blas_connector.cpp b/source/module_base/blas_connector.cpp index 8da2b802fa..075e4df297 100644 --- a/source/module_base/blas_connector.cpp +++ b/source/module_base/blas_connector.cpp @@ -1,5 +1,9 @@ #include "blas_connector.h" +#ifdef __DSP +#include "module_base/kernels/dsp/dsp_connector.h" +#endif + void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type) { if (device_type == base_device::AbacusDevice_t::CpuDevice) { @@ -64,6 +68,7 @@ float BlasConnector::dot( const int n, const float *X, const int incX, const flo { if (device_type == base_device::AbacusDevice_t::CpuDevice) { return sdot_(&n, X, &incX, Y, &incY); + return sdot_(&n, X, &incX, Y, &incY); } } @@ -71,6 +76,7 @@ double BlasConnector::dot( const int n, const double *X, const int incX, const d { if (device_type == base_device::AbacusDevice_t::CpuDevice) { return ddot_(&n, X, &incX, Y, &incY); + return ddot_(&n, X, &incX, Y, &incY); } } @@ -83,7 +89,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons sgemm_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc); -} + } + #ifdef __DSP + else if (device_type == base_device::AbacusDevice_t::DspDevice){ + sgemm_mt_(&transb, &transa, &n, &m, &k, + &alpha, b, &ldb, a, &lda, + &beta, c, &ldc); + } + #endif } void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k, @@ -94,7 +107,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons dgemm_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc); -} + } + #ifdef __DSP + else if (device_type == base_device::AbacusDevice_t::DspDevice){ + dgemm_mt_(&transb, &transa, &n, &m, &k, + &alpha, b, &ldb, a, &lda, + &beta, c, &ldc); + } + #endif } void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k, @@ -105,7 +125,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons cgemm_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc); -} + } + #ifdef __DSP + else if (device_type == base_device::AbacusDevice_t::DspDevice) { + cgemm_mt_(&transb, &transa, &n, &m, &k, + &alpha, b, &ldb, a, &lda, + &beta, c, &ldc); + } + #endif } void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k, @@ -116,7 +143,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons zgemm_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc); -} + } + #ifdef __DSP + else if (device_type == base_device::AbacusDevice_t::DspDevice) { + zgemm_mt_(&transb, &transa, &n, &m, &k, + &alpha, b, &ldb, a, &lda, + &beta, c, &ldc); + } + #endif } void BlasConnector::gemv(const char trans, const int m, const int n, @@ -152,6 +186,7 @@ float BlasConnector::nrm2( const int n, const float *X, const int incX, base_dev { if (device_type == base_device::AbacusDevice_t::CpuDevice) { return snrm2_( &n, X, &incX ); + return snrm2_( &n, X, &incX ); } } @@ -160,6 +195,7 @@ double BlasConnector::nrm2( const int n, const double *X, const int incX, base_d { if (device_type == base_device::AbacusDevice_t::CpuDevice) { return dnrm2_( &n, X, &incX ); + return dnrm2_( &n, X, &incX ); } } @@ -168,6 +204,7 @@ double BlasConnector::nrm2( const int n, const std::complex *X, const in { if (device_type == base_device::AbacusDevice_t::CpuDevice) { return dznrm2_( &n, X, &incX ); + return dznrm2_( &n, X, &incX ); } } diff --git a/source/module_base/kernels/dsp/dsp_connector.h b/source/module_base/kernels/dsp/dsp_connector.h new file mode 100644 index 0000000000..2d3075fcd1 --- /dev/null +++ b/source/module_base/kernels/dsp/dsp_connector.h @@ -0,0 +1,66 @@ +#ifndef DSP_CONNECTOR_H +#define DSP_CONNECTOR_H +#ifdef __DSP + +// Base dsp functions +void dspInitHandle(int id); +void dspDestoryHandle(); +void *malloc_ht(size_t bytes); +void free_ht(void* ptr); + + +// mtblas functions + +void sgemm_mt_(const char *transa, const char *transb, + const int *m, const int *n, const int *k, + const float *alpha, const float *a, const int *lda, + const float *b, const int *ldb, const float *beta, + float *c, const int *ldc); + +void dgemm_mt_(const char *transa, const char *transb, + const int *m, const int *n, const int *k, + const double *alpha,const double *a, const int *lda, + const double *b, const int *ldb, const double *beta, + double *c, const int *ldc); + +void zgemm_mt_(const char *transa, const char *transb, + const int *m, const int *n, const int *k, + const std::complex *alpha, const std::complex *a, const int *lda, + const std::complex *b, const int *ldb, const std::complex *beta, + std::complex *c, const int *ldc); + +void cgemm_mt_(const char *transa, const char *transb, + const int *m, const int *n, const int *k, + const std::complex *alpha, const std::complex *a, const int *lda, + const std::complex *b, const int *ldb, const std::complex *beta, + std::complex *c, const int *ldc); + + +void sgemm_mth_(const char *transa, const char *transb, + const int *m, const int *n, const int *k, + const float *alpha, const float *a, const int *lda, + const float *b, const int *ldb, const float *beta, + float *c, const int *ldc); + +void dgemm_mth_(const char *transa, const char *transb, + const int *m, const int *n, const int *k, + const double *alpha,const double *a, const int *lda, + const double *b, const int *ldb, const double *beta, + double *c, const int *ldc); + +void zgemm_mth_(const char *transa, const char *transb, + const int *m, const int *n, const int *k, + const std::complex *alpha, const std::complex *a, const int *lda, + const std::complex *b, const int *ldb, const std::complex *beta, + std::complex *c, const int *ldc); + +void cgemm_mth_(const char *transa, const char *transb, + const int *m, const int *n, const int *k, + const std::complex *alpha, const std::complex *a, const int *lda, + const std::complex *b, const int *ldb, const std::complex *beta, + std::complex *c, const int *ldc); + +//#define zgemm_ zgemm_mt + +#endif +#endif \ No newline at end of file diff --git a/source/module_base/module_device/memory_op.cpp b/source/module_base/module_device/memory_op.cpp index 1edc05b8fd..625b535051 100644 --- a/source/module_base/module_device/memory_op.cpp +++ b/source/module_base/module_device/memory_op.cpp @@ -2,6 +2,9 @@ #include "module_base/memory.h" #include "module_base/tool_threading.h" +#ifdef __DSP +#include "module_base/kernels/dsp/dsp_connector.h" +#endif #include #include @@ -18,9 +21,17 @@ struct resize_memory_op { if (arr != nullptr) { +#ifdef __DSP + free_ht(arr); +#else free(arr); +#endif } +#ifdef __DSP + arr = (FPTYPE*)malloc_ht(sizeof(FPTYPE) * size); +#else arr = (FPTYPE*)malloc(sizeof(FPTYPE) * size); +#endif std::string record_string; if (record_in != nullptr) { @@ -92,7 +103,11 @@ struct delete_memory_op { void operator()(const base_device::DEVICE_CPU* dev, FPTYPE* arr) { +#ifdef __DSP + free_ht(arr); +#else free(arr); +#endif } }; diff --git a/source/module_base/module_device/types.h b/source/module_base/module_device/types.h index dfa960a1e3..153b6ab8ca 100644 --- a/source/module_base/module_device/types.h +++ b/source/module_base/module_device/types.h @@ -12,6 +12,7 @@ enum AbacusDevice_t UnKnown, CpuDevice, GpuDevice, + DspDevice }; } // namespace base_device diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index cd9dd4ce66..bf6c0bc450 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -49,6 +49,10 @@ #include #include +#ifdef __DSP +#include "module_base/kernels/dsp/dsp_connector.h" +#endif + namespace ModuleESolver { @@ -67,6 +71,10 @@ ESolver_KS_PW::ESolver_KS_PW() container::kernels::createGpuSolverHandle(); } #endif +#ifdef __DSP + std::cout << " ** Initializing DSP Hardware..." << std::endl; + dspInitHandle(GlobalV::MY_RANK % 4); +#endif } template @@ -92,7 +100,10 @@ ESolver_KS_PW::~ESolver_KS_PW() #endif delete reinterpret_cast*>(this->kspw_psi); } - +#ifdef __DSP + std::cout << " ** Closing DSP Hardware..." << std::endl; + dspDestoryHandle(); +#endif if (PARAM.inp.precision == "single") { delete reinterpret_cast, Device>*>(this->__kspw_psi); diff --git a/source/module_hsolver/diago_dav_subspace.cpp b/source/module_hsolver/diago_dav_subspace.cpp index 1bfd0a73a1..7d298be7ac 100644 --- a/source/module_hsolver/diago_dav_subspace.cpp +++ b/source/module_hsolver/diago_dav_subspace.cpp @@ -181,7 +181,12 @@ int Diago_DavSubspace::diag_once(const HPsiFunc& hpsi_func, // updata eigenvectors of Hamiltonian setmem_complex_op()(this->ctx, psi_in, 0, n_band * psi_in_dmax); - gemm_op()(this->ctx, +#ifdef __DSP + gemm_op_mt() +#else + gemm_op() +#endif + (this->ctx, 'N', 'N', this->dim, @@ -262,7 +267,12 @@ void Diago_DavSubspace::cal_grad(const HPsiFunc& hpsi_func, } } - gemm_op()(this->ctx, +#ifdef __DSP + gemm_op_mt() +#else + gemm_op() +#endif + (this->ctx, 'N', 'N', this->dim, @@ -302,7 +312,12 @@ void Diago_DavSubspace::cal_grad(const HPsiFunc& hpsi_func, delmem_real_op()(this->ctx, e_temp_hd); } - gemm_op()(this->ctx, +#ifdef __DSP + gemm_op_mt() +#else + gemm_op() +#endif + (this->ctx, 'N', 'N', this->dim, @@ -386,7 +401,12 @@ void Diago_DavSubspace::cal_elem(const int& dim, { ModuleBase::timer::tick("Diago_DavSubspace", "cal_elem"); - gemm_op()(this->ctx, +#ifdef __DSP + gemm_op_mt() +#else + gemm_op() +#endif + (this->ctx, 'C', 'N', nbase + notconv, @@ -401,7 +421,12 @@ void Diago_DavSubspace::cal_elem(const int& dim, &hcc[nbase * this->nbase_x], this->nbase_x); - gemm_op()(this->ctx, +#ifdef __DSP + gemm_op_mt() +#else + gemm_op() +#endif + (this->ctx, 'C', 'N', nbase + notconv, @@ -603,7 +628,12 @@ void Diago_DavSubspace::refresh(const int& dim, { ModuleBase::timer::tick("Diago_DavSubspace", "refresh"); - gemm_op()(this->ctx, +#ifdef __DSP + gemm_op_mt() +#else + gemm_op() +#endif + (this->ctx, 'N', 'N', this->dim, diff --git a/source/module_hsolver/kernels/math_kernel_op.cpp b/source/module_hsolver/kernels/math_kernel_op.cpp index 3ad19bd4cc..02deb41696 100644 --- a/source/module_hsolver/kernels/math_kernel_op.cpp +++ b/source/module_hsolver/kernels/math_kernel_op.cpp @@ -281,6 +281,30 @@ struct gemm_op } }; +#ifdef __DSP +template +struct gemm_op_mt +{ + void operator()(const base_device::DEVICE_CPU* /*ctx*/, + const char& transa, + const char& transb, + const int& m, + const int& n, + const int& k, + const T* alpha, + const T* a, + const int& lda, + const T* b, + const int& ldb, + const T* beta, + T* c, + const int& ldc) + { + BlasConnector::gemm(transb, transa, n, m, k, *alpha, b, ldb, a, lda, *beta, c, ldc, base_device::AbacusDevice_t::DspDevice); + } +}; +#endif + template struct matrixTranspose_op { @@ -372,4 +396,8 @@ template struct matrixTranspose_op; template struct matrixSetToAnother; template struct constantvector_addORsub_constantVector_op; #endif +#ifdef __DSP +template struct gemm_op_mt, base_device::DEVICE_CPU>; +template struct gemm_op_mt, base_device::DEVICE_CPU>; +#endif } // namespace hsolver \ No newline at end of file diff --git a/source/module_hsolver/kernels/math_kernel_op.h b/source/module_hsolver/kernels/math_kernel_op.h index a23c9c329f..0daf0e5718 100644 --- a/source/module_hsolver/kernels/math_kernel_op.h +++ b/source/module_hsolver/kernels/math_kernel_op.h @@ -264,6 +264,36 @@ template struct gemm_op { const T *beta, T *c, const int &ldc); }; +#ifdef __DSP +// compute C = alpha * op(A) * op(B) + beta * C on DSP Hardware +template struct gemm_op_mt { + /// @brief C = alpha * op(A) * op(B) + beta * C + /// + /// Input Parameters + /// \param d : the type of computing device + /// \param transa : whether to transpose matrix A + /// \param transb : whether to transpose matrix B + /// \param m : first dimension of matrix mulplication + /// \param n : second dimension of matrix mulplication + /// \param k : third dimension of matrix mulplication + /// \param alpha : input constant alpha + /// \param a : input matrix A + /// \param lda : leading dimention of A + /// \param b : input matrix B + /// \param ldb : leading dimention of A + /// \param beta : input constant beta + /// \param c : input matrix C + /// \param ldc : leading dimention of C + /// + /// Output Parameters + /// \param c : output matrix C + void operator()(const Device *d, const char &transa, const char &transb, + const int &m, const int &n, const int &k, const T *alpha, + const T *a, const int &lda, const T *b, const int &ldb, + const T *beta, T *c, const int &ldc); +}; +#endif + template struct matrixTranspose_op { /// @brief transpose the input matrix ///