Skip to content

Commit

Permalink
Feature: add new init_chg method with wavefunctions (#5082)
Browse files Browse the repository at this point in the history
* Feature: add init_chg: wfc

* add symmetry for init_chg: wfc

* [pre-commit.ci lite] apply automatic fixes

* fix compile failure

* [pre-commit.ci lite] apply automatic fixes

* fix bugs

* add tests for read_wfc_to_rho

* [pre-commit.ci lite] apply automatic fixes

* optimize codes

* [pre-commit.ci lite] apply automatic fixes

* fix bugs

---------

Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
  • Loading branch information
Qianruipku and pre-commit-ci-lite[bot] authored Sep 14, 2024
1 parent 548ac3c commit c020e20
Show file tree
Hide file tree
Showing 46 changed files with 1,080 additions and 631 deletions.
1 change: 1 addition & 0 deletions docs/advanced/input_files/input-main.md
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ These variables are used to control general system parameters.

- atomic: the density is starting from the summation of the atomic density of single atoms.
- file: the density will be read in from a binary file `charge-density.dat` first. If it does not exist, the charge density will be read in from cube files. Besides, when you do `nspin=1` calculation, you only need the density file SPIN1_CHG.cube. However, if you do `nspin=2` calculation, you also need the density file SPIN2_CHG.cube. The density file should be output with these names if you set out_chg = 1 in INPUT file.
- wfc: the density will be calculated by wavefunctions and occupations. Wavefunctions are read in from binary files `WAVEFUNC*.dat` while occupations are read in from file `istate.info`.
- auto: Abacus first attempts to read the density from a file; if not found, it defaults to using atomic density.
- **Default**: atomic

Expand Down
1 change: 1 addition & 0 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ OBJS_IO=input_conv.o\
rhog_io.o\
read_exit_file.o\
read_wfc_pw.o\
read_wfc_to_rho.o\
restart.o\
binstream.o\
to_wannier90.o\
Expand Down
14 changes: 14 additions & 0 deletions source/module_cell/klist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,20 @@ K_Vectors::~K_Vectors()
#endif
}

int K_Vectors::get_ik_global(const int& ik, const int& nkstot)
{
int nkp = nkstot / PARAM.inp.kpar;
int rem = nkstot % PARAM.inp.kpar;
if (GlobalV::MY_POOL < rem)
{
return GlobalV::MY_POOL * nkp + GlobalV::MY_POOL + ik;
}
else
{
return GlobalV::MY_POOL * nkp + rem + ik;
}
}

void K_Vectors::set(const ModuleSymmetry::Symmetry& symm,
const std::string& k_file_name,
const int& nspin_in,
Expand Down
18 changes: 2 additions & 16 deletions source/module_cell/klist.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class K_Vectors
* This function gets the global index of a k-point based on its local index and the process pool ID.
* The global index is used when the k-points are distributed among multiple process pools.
*
* @param nkstot The total number of k-points.
* @param ik The local index of the k-point.
*
* @return int Returns the global index of the k-point.
Expand All @@ -116,7 +117,7 @@ class K_Vectors
* process pools (KPAR), and adding the remainder if the process pool ID (MY_POOL) is less than the remainder.
* @note The function is declared as inline for efficiency.
*/
inline int getik_global(const int& ik) const;
static int get_ik_global(const int& ik, const int& nkstot);

int get_nks() const
{
Expand Down Expand Up @@ -389,19 +390,4 @@ class K_Vectors
*/
void print_klists(std::ofstream& fn);
};

inline int K_Vectors::getik_global(const int& ik) const
{
int nkp = this->nkstot / GlobalV::KPAR;
int rem = this->nkstot % GlobalV::KPAR;
if (GlobalV::MY_POOL < rem)
{
return GlobalV::MY_POOL * nkp + GlobalV::MY_POOL + ik;
}
else
{
return GlobalV::MY_POOL * nkp + rem + ik;
}
}

#endif // KVECT_H
4 changes: 2 additions & 2 deletions source/module_elecstate/elecstate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ void ElecState::calEBand()
return;
}

void ElecState::init_scf(const int istep, const ModuleBase::ComplexMatrix& strucfac)
void ElecState::init_scf(const int istep, const ModuleBase::ComplexMatrix& strucfac, ModuleSymmetry::Symmetry& symm, const void* wfcpw)
{
//---------Charge part-----------------
// core correction potential.
Expand All @@ -225,7 +225,7 @@ void ElecState::init_scf(const int istep, const ModuleBase::ComplexMatrix& struc
//--------------------------------------------------------------------
if (istep == 0)
{
this->charge->init_rho(this->eferm, strucfac, this->bigpw->nbz, this->bigpw->bz);
this->charge->init_rho(this->eferm, strucfac, symm, (const void*)this->klist, wfcpw);
this->charge->check_rho(); // check the rho
}

Expand Down
13 changes: 12 additions & 1 deletion source/module_elecstate/elecstate.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,18 @@ class ElecState
return;
}

void init_scf(const int istep, const ModuleBase::ComplexMatrix& strucfac);
/**
* @brief Init rho_core, init rho, renormalize rho, init pot
*
* @param istep i-th step
* @param strucfac structure factor
* @param symm symmetry
* @param wfcpw PW basis for wave function if needed
*/
void init_scf(const int istep,
const ModuleBase::ComplexMatrix& strucfac,
ModuleSymmetry::Symmetry& symm,
const void* wfcpw = nullptr);
std::string classname = "elecstate";

int iter = 0; ///< scf iteration
Expand Down
30 changes: 17 additions & 13 deletions source/module_elecstate/module_charge/charge.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "module_base/parallel_global.h"
#include "module_basis/module_pw/pw_basis.h"
#include "module_elecstate/fp_energy.h"
#include "module_cell/module_symmetry/symmetry.h"

//a forward declaration of UnitCell
class UnitCell;
Expand Down Expand Up @@ -62,14 +63,19 @@ class Charge

/**
* @brief Init charge density from file or atomic pseudo-wave-functions
*
* @param eferm_iout fermi energy to be initialized
* @param strucFac [in] structure factor
* @param nbz [in] number of big grids in z direction
* @param bz [in] number of small grids in big grids for z dirction
*
* @param eferm_iout [out] fermi energy to be initialized
* @param strucFac [in] structure factor
* @param symm [in] symmetry
* @param klist [in] k points list if needed
* @param wfcpw [in] PW basis for wave function if needed
*/
void init_rho(elecstate::efermi& eferm_iout, const ModuleBase::ComplexMatrix& strucFac, const int& nbz, const int& bz);

void init_rho(elecstate::efermi& eferm_iout,
const ModuleBase::ComplexMatrix& strucFac,
ModuleSymmetry::Symmetry& symm,
const void* klist = nullptr,
const void* wfcpw = nullptr);

void allocate(const int &nspin_in);

void atomic_rho(const int spin_number_need,
Expand All @@ -81,11 +87,11 @@ class Charge
void set_rho_core(const ModuleBase::ComplexMatrix &structure_factor);
void set_rho_core_paw();

void renormalize_rho(void);
void renormalize_rho();

double sum_rho(void) const;
double sum_rho() const;

void save_rho_before_sum_band(void);
void save_rho_before_sum_band();

// for non-linear core correction
void non_linear_core_correction
Expand All @@ -108,10 +114,8 @@ class Charge
/**
* @brief init some arrays for mpi_inter_pools, rho_mpi
*
* @param nbz number of bigz in big grids
* @param bz number of z for each bigz
*/
void init_chgmpi(const int& nbz, const int& bz);
void init_chgmpi();

/**
* @brief Sum rho at different pools (k-point parallelism).
Expand Down
21 changes: 19 additions & 2 deletions source/module_elecstate/module_charge/charge_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@
#include "module_hamilt_pw/hamilt_pwdft/parallel_grid.h"
#include "module_io/rho_io.h"
#include "module_io/rhog_io.h"
#include "module_io/read_wfc_to_rho.h"
#ifdef USE_PAW
#include "module_cell/module_paw/paw_cell.h"
#endif

void Charge::init_rho(elecstate::efermi& eferm_iout, const ModuleBase::ComplexMatrix& strucFac, const int& nbz, const int& bz)
void Charge::init_rho(elecstate::efermi& eferm_iout,
const ModuleBase::ComplexMatrix& strucFac,
ModuleSymmetry::Symmetry& symm,
const void* klist,
const void* wfcpw)
{
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "init_chg", PARAM.inp.init_chg);

Expand Down Expand Up @@ -195,8 +200,20 @@ void Charge::init_rho(elecstate::efermi& eferm_iout, const ModuleBase::ComplexMa
GlobalC::restart.info_load.load_charge_finish = true;
}
#ifdef __MPI
this->init_chgmpi(nbz, bz);
this->init_chgmpi();
#endif
if (PARAM.inp.init_chg == "wfc")
{
if (wfcpw == nullptr)
{
ModuleBase::WARNING_QUIT("Charge::init_rho", "wfc is only supported for PW-KSDFT.");
}
const ModulePW::PW_Basis_K* pw_wfc = reinterpret_cast<ModulePW::PW_Basis_K*>(const_cast<void*>(wfcpw));
const K_Vectors* kv = reinterpret_cast<const K_Vectors*>(klist);
const int nkstot = kv->get_nkstot();
const std::vector<int>& isk = kv->isk;
ModuleIO::read_wfc_to_rho(pw_wfc, symm, nkstot, isk, *this);
}
}

//==========================================================
Expand Down
2 changes: 1 addition & 1 deletion source/module_elecstate/module_charge/charge_mpi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "module_elecstate/elecstate_getters.h"
#include "module_parameter/parameter.h"
#ifdef __MPI
void Charge::init_chgmpi(const int& nbz, const int& bz)
void Charge::init_chgmpi()
{
if (GlobalV::NPROC_IN_STOGROUP % GlobalV::KPAR == 0)
{
Expand Down
10 changes: 4 additions & 6 deletions source/module_elecstate/module_charge/symmetry_rho.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ Symmetry_rho::~Symmetry_rho()
void Symmetry_rho::begin(const int& spin_now,
const Charge& CHR,
const ModulePW::PW_Basis* rho_basis,
Parallel_Grid& Pgrid,
ModuleSymmetry::Symmetry& symm) const
{
assert(spin_now < 4); // added by zhengdy-soc
Expand All @@ -31,15 +30,15 @@ void Symmetry_rho::begin(const int& spin_now,
// else //space group, do rho_symm in reciprocal space
{
rho_basis->real2recip(CHR.rho[spin_now], CHR.rhog[spin_now]);
psymmg(CHR.rhog[spin_now], rho_basis, Pgrid, symm); // need to modify
psymmg(CHR.rhog[spin_now], rho_basis, symm); // need to modify
rho_basis->recip2real(CHR.rhog[spin_now], CHR.rho[spin_now]);

if (XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5)
{
// Use std::vector to manage kin_g instead of raw pointer
std::vector<std::complex<double>> kin_g(CHR.ngmc);
rho_basis->real2recip(CHR.kin_r[spin_now], kin_g.data());
psymmg(kin_g.data(), rho_basis, Pgrid, symm);
psymmg(kin_g.data(), rho_basis, symm);
rho_basis->recip2real(kin_g.data(), CHR.kin_r[spin_now]);
}
}
Expand All @@ -52,7 +51,6 @@ void Symmetry_rho::begin(const int& spin_now,
int ngmc,
double** kin_r,
const ModulePW::PW_Basis* rho_basis,
Parallel_Grid& Pgrid,
ModuleSymmetry::Symmetry& symm) const
{
assert(spin_now < 4); // added by zhengdy-soc
Expand All @@ -71,15 +69,15 @@ void Symmetry_rho::begin(const int& spin_now,
// else //space group, do rho_symm in reciprocal space
{
rho_basis->real2recip(rho[spin_now], rhog[spin_now]);
psymmg(rhog[spin_now], rho_basis, Pgrid, symm);
psymmg(rhog[spin_now], rho_basis, symm);
rho_basis->recip2real(rhog[spin_now], rho[spin_now]);

if (XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5)
{
// Use std::vector to manage kin_g instead of raw pointer
std::vector<std::complex<double>> kin_g(ngmc);
rho_basis->real2recip(kin_r[spin_now], kin_g.data());
psymmg(kin_g.data(), rho_basis, Pgrid, symm);
psymmg(kin_g.data(), rho_basis, symm);
rho_basis->recip2real(kin_g.data(), kin_r[spin_now]);
}
}
Expand Down
3 changes: 0 additions & 3 deletions source/module_elecstate/module_charge/symmetry_rho.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ class Symmetry_rho
void begin(const int& spin_now,
const Charge& CHR,
const ModulePW::PW_Basis* pw,
Parallel_Grid& Pgrid,
ModuleSymmetry::Symmetry& symm) const;

void begin(const int& spin_now,
Expand All @@ -23,7 +22,6 @@ class Symmetry_rho
int ngmc,
double** kin_r,
const ModulePW::PW_Basis* pw,
Parallel_Grid& Pgrid,
ModuleSymmetry::Symmetry& symm) const;

private:
Expand All @@ -35,7 +33,6 @@ class Symmetry_rho
// in reciprocal space:
void psymmg(std::complex<double>* rhog_part,
const ModulePW::PW_Basis* rho_basis,
Parallel_Grid& Pgrid,
ModuleSymmetry::Symmetry& symm) const;
#ifdef __MPI
void reduce_to_fullrhog(const ModulePW::PW_Basis* rho_basis,
Expand Down
2 changes: 1 addition & 1 deletion source/module_elecstate/module_charge/symmetry_rhog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "module_hamilt_general/module_xc/xc_functional.h"


void Symmetry_rho::psymmg(std::complex<double>* rhog_part, const ModulePW::PW_Basis *rho_basis, Parallel_Grid &Pgrid, ModuleSymmetry::Symmetry &symm) const
void Symmetry_rho::psymmg(std::complex<double>* rhog_part, const ModulePW::PW_Basis *rho_basis, ModuleSymmetry::Symmetry &symm) const
{
//(1) get fftixy2is and do Allreduce
int * fftixy2is = new int [rho_basis->fftnxy];
Expand Down
9 changes: 7 additions & 2 deletions source/module_elecstate/test/elecstate_base_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ void Charge::set_rho_core(ModuleBase::ComplexMatrix const&)
void Charge::set_rho_core_paw()
{
}
void Charge::init_rho(elecstate::efermi&, ModuleBase::ComplexMatrix const&, const int&, const int&)
void Charge::init_rho(elecstate::efermi&,
ModuleBase::ComplexMatrix const&,
ModuleSymmetry::Symmetry& symm,
const void*,
const void*)
{
}
void Charge::set_rhopw(ModulePW::PW_Basis*)
Expand Down Expand Up @@ -390,7 +394,8 @@ TEST_F(ElecStateTest, InitSCF)
int istep = 0;
ModuleBase::ComplexMatrix strucfac;
elecstate->eferm = efermi;
EXPECT_NO_THROW(elecstate->init_scf(istep, strucfac));
ModuleSymmetry::Symmetry symm;
EXPECT_NO_THROW(elecstate->init_scf(istep, strucfac, symm));
// delete elecstate->pot is done in the destructor of elecstate
delete charge;
}
Expand Down
6 changes: 5 additions & 1 deletion source/module_elecstate/test/elecstate_pw_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,11 @@ void Charge::set_rho_core(ModuleBase::ComplexMatrix const&)
void Charge::set_rho_core_paw()
{
}
void Charge::init_rho(elecstate::efermi&, ModuleBase::ComplexMatrix const&, const int&, const int&)
void Charge::init_rho(elecstate::efermi&,
ModuleBase::ComplexMatrix const&,
ModuleSymmetry::Symmetry& symm,
const void*,
const void*)
{
}
void Charge::set_rhopw(ModulePW::PW_Basis*)
Expand Down
6 changes: 3 additions & 3 deletions source/module_elecstate/test_mpi/charge_mpi_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ TEST_F(ChargeMpiTest, reduce_diff_pools1)
}
double refsum = sum_array(array_rho, nrxx);

charge->init_chgmpi(nz, 1);
charge->init_chgmpi();
charge->reduce_diff_pools(array_rho);
double sum = sum_array(array_rho, nrxx);
EXPECT_EQ(sum, refsum * GlobalV::KPAR);
Expand Down Expand Up @@ -150,7 +150,7 @@ TEST_F(ChargeMpiTest, reduce_diff_pools2)
}
}

charge->init_chgmpi(nz, 1);
charge->init_chgmpi();
charge->reduce_diff_pools(array_rho);
double sum = sum_array(array_rho, nrxx);
MPI_Allreduce(MPI_IN_PLACE, &sum, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD);
Expand Down Expand Up @@ -194,7 +194,7 @@ TEST_F(ChargeMpiTest, rho_mpi)
charge->nrxx = nrxx;
charge->rho[0] = new double[nrxx];
charge->kin_r[0] = new double[nrxx];
charge->init_chgmpi(nz, 1);
charge->init_chgmpi();
charge->rho_mpi();

delete[] charge->rho[0];
Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ void ESolver_KS_LCAO<TK, TR>::hamilt2density(int istep, int iter, double ethr)
Symmetry_rho srho;
for (int is = 0; is < GlobalV::NSPIN; is++)
{
srho.begin(is, *(this->pelec->charge), this->pw_rho, GlobalC::Pgrid, GlobalC::ucell.symm);
srho.begin(is, *(this->pelec->charge), this->pw_rho, GlobalC::ucell.symm);
}

// 11) compute magnetization, only for spin==2
Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_ks_lcao_tddft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ void ESolver_KS_LCAO_TDDFT::hamilt2density(const int istep, const int iter, cons
Symmetry_rho srho;
for (int is = 0; is < GlobalV::NSPIN; is++)
{
srho.begin(is, *(pelec->charge), pw_rho, GlobalC::Pgrid, GlobalC::ucell.symm);
srho.begin(is, *(pelec->charge), pw_rho, GlobalC::ucell.symm);
}
}

Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_ks_lcaopw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ namespace ModuleESolver
Symmetry_rho srho;
for (int is = 0; is < GlobalV::NSPIN; is++)
{
srho.begin(is, *(this->pelec->charge), this->pw_rhod, GlobalC::Pgrid, GlobalC::ucell.symm);
srho.begin(is, *(this->pelec->charge), this->pw_rhod, GlobalC::ucell.symm);
}

// compute magnetization, only for LSDA(spin==2)
Expand Down
Loading

0 comments on commit c020e20

Please sign in to comment.