Skip to content

Commit

Permalink
Refactor: remove the inheritance in Gint (#4462)
Browse files Browse the repository at this point in the history
* remove inheritance in Gint

* fix cuda

* fix format

* [pre-commit.ci lite] apply automatic fixes

---------

Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
  • Loading branch information
maki49 and pre-commit-ci-lite[bot] authored Jun 22, 2024
1 parent 8680d17 commit f89e17d
Show file tree
Hide file tree
Showing 9 changed files with 1,513 additions and 1,306 deletions.
976 changes: 563 additions & 413 deletions source/module_hamilt_lcao/module_gint/gint.cpp

Large diffs are not rendered by default.

356 changes: 171 additions & 185 deletions source/module_hamilt_lcao/module_gint/gint.h

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion source/module_hamilt_lcao/module_gint/gint_force_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void gint_fvl_gamma_gpu(hamilt::HContainer<double>* dm,
double* force_in,
double* stress_in,
double dr,
double* rcut,
const double* rcut,
const int isforce,
const int isstress,
const Grid_Technique& gridt,
Expand Down
22 changes: 11 additions & 11 deletions source/module_hamilt_lcao/module_gint/gint_force_gpu.h
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
#ifndef GINT_FORCE_GPU_H
#define GINT_FORCE_GPU_H
#ifndef W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_MODULE_HAMILT_LCAO_MODULE_GINT_GINT_FORCE_GPU_H
#define W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_MODULE_HAMILT_LCAO_MODULE_GINT_GINT_FORCE_GPU_H

#include "module_hamilt_lcao/module_gint/gint.h"
#include "module_hamilt_lcao/module_gint/grid_technique.h"
namespace GintKernel
{
void gint_fvl_gamma_gpu(hamilt::HContainer<double>* dm,
const double* vlocal,
double* force_in,
double* stress_in,
double dr,
double* rcut,
const int isforce,
const int isstress,
const Grid_Technique& gridt,
const UnitCell& ucell);
const double* vlocal,
double* force_in,
double* stress_in,
double dr,
const double* rcut,
const int isforce,
const int isstress,
const Grid_Technique& gridt,
const UnitCell& ucell);

/**
* @brief GPU task generator for forces.
Expand Down
149 changes: 82 additions & 67 deletions source/module_hamilt_lcao/module_gint/gint_gamma_env.cpp
Original file line number Diff line number Diff line change
@@ -1,87 +1,102 @@
#include "gint_gamma.h"
#include "grid_technique.h"
#include "module_base/timer.h"
#include "module_base/ylm.h"
#include "module_basis/module_ao/ORB_read.h"
#include "module_hamilt_pw/hamilt_pwdft/global.h"
#include "module_base/ylm.h"
#include "module_base/timer.h"

void Gint_Gamma::cal_env(const double* wfc, double* rho,UnitCell &ucell)
void Gint_Gamma::cal_env(const double* wfc, double* rho, UnitCell& ucell)
{
ModuleBase::TITLE("Grid_Integral","cal_env");
ModuleBase::TITLE("Grid_Integral", "cal_env");

// it's a uniform grid to save orbital values, so the delta_r is a constant.
const double delta_r = this->dr_uniform;
const int max_size = this->gridt->max_atom;
const int LD_pool = max_size*ucell.nwmax;
const double delta_r = this->gridt->dr_uniform;
const int max_size = this->gridt->max_atom;
const int LD_pool = max_size * ucell.nwmax;

if(max_size!=0)
{
const int nbx = this->gridt->nbx;
const int nby = this->gridt->nby;
const int nbz_start = this->gridt->nbzp_start;
const int nbz = this->gridt->nbzp;
const int ncyz = this->ny*this->nplane; // mohan add 2012-03-25
if (max_size != 0)
{
const int nbx = this->gridt->nbx;
const int nby = this->gridt->nby;
const int nbz_start = this->gridt->nbzp_start;
const int nbz = this->gridt->nbzp;
const int ncyz = this->ny * this->nplane; // mohan add 2012-03-25

for(int grid_index = 0; grid_index < this->nbxx; grid_index++)
for (int grid_index = 0; grid_index < this->nbxx; grid_index++)
{

// get the value: how many atoms has orbital value on this grid.
const int size = this->gridt->how_many_atoms[ grid_index ];
if(size==0) continue;
// get the value: how many atoms has orbital value on this grid.
const int size = this->gridt->how_many_atoms[grid_index];
if (size == 0)
continue;

int * block_iw, * block_index, * block_size;
bool** cal_flag;
Gint_Tools::get_block_info(*this->gridt, this->bxyz, size, grid_index, block_iw, block_index, block_size, cal_flag);
int *block_iw, *block_index, *block_size;
bool** cal_flag;
Gint_Tools::get_block_info(*this->gridt,
this->bxyz,
size,
grid_index,
block_iw,
block_index,
block_size,
cal_flag);

//evaluate psi on grids
Gint_Tools::Array_Pool<double> psir_ylm(this->bxyz, LD_pool);
Gint_Tools::cal_psir_ylm(*this->gridt,
this->bxyz,
size, grid_index, delta_r,
block_index, block_size,
cal_flag,
psir_ylm.ptr_2D);
// evaluate psi on grids
Gint_Tools::Array_Pool<double> psir_ylm(this->bxyz, LD_pool);
Gint_Tools::cal_psir_ylm(*this->gridt,
this->bxyz,
size,
grid_index,
delta_r,
block_index,
block_size,
cal_flag,
psir_ylm.ptr_2D);

int* vindex = Gint_Tools::get_vindex(this->bxyz, this->bx, this->by, this->bz,
this->nplane, this->gridt->start_ind[grid_index], ncyz);
int* vindex = Gint_Tools::get_vindex(this->bxyz,
this->bx,
this->by,
this->bz,
this->nplane,
this->gridt->start_ind[grid_index],
ncyz);

for (int ia1=0; ia1<size; ia1++)
{
const int mcell_index1 = this->gridt->bcell_start[grid_index] + ia1;
const int iat = this->gridt->which_atom[mcell_index1];
const int T1 = ucell.iat2it[iat];
Atom *atom1 = &ucell.atoms[T1];
const int I1 = ucell.iat2ia[iat];
// get the start index of local orbitals.
const int start1 = ucell.itiaiw2iwt(T1, I1, 0);
for (int ib=0; ib<this->bxyz; ib++)
{
if(cal_flag[ib][ia1])
{
int iw1_lo = this->gridt->trace_lo[start1];
double* psi1 = &psir_ylm.ptr_2D[ib][block_index[ia1]];
double tmp = 0.0;
for (int iw=0; iw< atom1->nw; ++iw, ++iw1_lo)
{
tmp += psi1[iw] * wfc[iw1_lo];
}//iw
rho[ vindex[ib] ] += tmp;
}// cal_flag
}//ib
}// ia1
for (int ia1 = 0; ia1 < size; ia1++)
{
const int mcell_index1 = this->gridt->bcell_start[grid_index] + ia1;
const int iat = this->gridt->which_atom[mcell_index1];
const int T1 = ucell.iat2it[iat];
Atom* atom1 = &ucell.atoms[T1];
const int I1 = ucell.iat2ia[iat];
// get the start index of local orbitals.
const int start1 = ucell.itiaiw2iwt(T1, I1, 0);
for (int ib = 0; ib < this->bxyz; ib++)
{
if (cal_flag[ib][ia1])
{
int iw1_lo = this->gridt->trace_lo[start1];
double* psi1 = &psir_ylm.ptr_2D[ib][block_index[ia1]];
double tmp = 0.0;
for (int iw = 0; iw < atom1->nw; ++iw, ++iw1_lo)
{
tmp += psi1[iw] * wfc[iw1_lo];
} // iw
rho[vindex[ib]] += tmp;
} // cal_flag
} // ib
} // ia1

delete[] vindex;
delete[] block_iw;
delete[] block_index;
delete[] block_size;
for(int ib=0; ib<this->bxyz; ++ib)
{
delete[] cal_flag[ib];
}
delete[] cal_flag;
}
}
delete[] vindex;
delete[] block_iw;
delete[] block_index;
delete[] block_size;
for (int ib = 0; ib < this->bxyz; ++ib)
{
delete[] cal_flag[ib];
}
delete[] cal_flag;
}
}

return;
}

Loading

0 comments on commit f89e17d

Please sign in to comment.