diff --git a/docs/advanced/input_files/input-main.md b/docs/advanced/input_files/input-main.md index 2327a04bb1..6bc3589b10 100644 --- a/docs/advanced/input_files/input-main.md +++ b/docs/advanced/input_files/input-main.md @@ -569,6 +569,7 @@ These variables are used to control general system parameters. - atomic: the density is starting from the summation of the atomic density of single atoms. - file: the density will be read in from a binary file `charge-density.dat` first. If it does not exist, the charge density will be read in from cube files. Besides, when you do `nspin=1` calculation, you only need the density file SPIN1_CHG.cube. However, if you do `nspin=2` calculation, you also need the density file SPIN2_CHG.cube. The density file should be output with these names if you set out_chg = 1 in INPUT file. + - wfc: the density will be calculated by wavefunctions and occupations. Wavefunctions are read in from binary files `WAVEFUNC*.dat` while occupations are read in from file `istate.info`. - auto: Abacus first attempts to read the density from a file; if not found, it defaults to using atomic density. - **Default**: atomic diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 01d0665db2..78177fbd48 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -480,6 +480,7 @@ OBJS_IO=input_conv.o\ rhog_io.o\ read_exit_file.o\ read_wfc_pw.o\ + read_wfc_to_rho.o\ restart.o\ binstream.o\ to_wannier90.o\ diff --git a/source/module_cell/klist.cpp b/source/module_cell/klist.cpp index bfed79c027..cd214380ec 100644 --- a/source/module_cell/klist.cpp +++ b/source/module_cell/klist.cpp @@ -45,6 +45,20 @@ K_Vectors::~K_Vectors() #endif } +int K_Vectors::get_ik_global(const int& ik, const int& nkstot) +{ + int nkp = nkstot / PARAM.inp.kpar; + int rem = nkstot % PARAM.inp.kpar; + if (GlobalV::MY_POOL < rem) + { + return GlobalV::MY_POOL * nkp + GlobalV::MY_POOL + ik; + } + else + { + return GlobalV::MY_POOL * nkp + rem + ik; + } +} + void K_Vectors::set(const ModuleSymmetry::Symmetry& symm, const std::string& k_file_name, const int& nspin_in, diff --git a/source/module_cell/klist.h b/source/module_cell/klist.h index 3c03e72ec2..edc20446dc 100644 --- a/source/module_cell/klist.h +++ b/source/module_cell/klist.h @@ -108,6 +108,7 @@ class K_Vectors * This function gets the global index of a k-point based on its local index and the process pool ID. * The global index is used when the k-points are distributed among multiple process pools. * + * @param nkstot The total number of k-points. * @param ik The local index of the k-point. * * @return int Returns the global index of the k-point. @@ -116,7 +117,7 @@ class K_Vectors * process pools (KPAR), and adding the remainder if the process pool ID (MY_POOL) is less than the remainder. * @note The function is declared as inline for efficiency. */ - inline int getik_global(const int& ik) const; + static int get_ik_global(const int& ik, const int& nkstot); int get_nks() const { @@ -389,19 +390,4 @@ class K_Vectors */ void print_klists(std::ofstream& fn); }; - -inline int K_Vectors::getik_global(const int& ik) const -{ - int nkp = this->nkstot / GlobalV::KPAR; - int rem = this->nkstot % GlobalV::KPAR; - if (GlobalV::MY_POOL < rem) - { - return GlobalV::MY_POOL * nkp + GlobalV::MY_POOL + ik; - } - else - { - return GlobalV::MY_POOL * nkp + rem + ik; - } -} - #endif // KVECT_H \ No newline at end of file diff --git a/source/module_elecstate/elecstate.cpp b/source/module_elecstate/elecstate.cpp index 3248cad991..8077abd43a 100644 --- a/source/module_elecstate/elecstate.cpp +++ b/source/module_elecstate/elecstate.cpp @@ -206,7 +206,7 @@ void ElecState::calEBand() return; } -void ElecState::init_scf(const int istep, const ModuleBase::ComplexMatrix& strucfac) +void ElecState::init_scf(const int istep, const ModuleBase::ComplexMatrix& strucfac, ModuleSymmetry::Symmetry& symm, const void* wfcpw) { //---------Charge part----------------- // core correction potential. @@ -225,7 +225,7 @@ void ElecState::init_scf(const int istep, const ModuleBase::ComplexMatrix& struc //-------------------------------------------------------------------- if (istep == 0) { - this->charge->init_rho(this->eferm, strucfac, this->bigpw->nbz, this->bigpw->bz); + this->charge->init_rho(this->eferm, strucfac, symm, (const void*)this->klist, wfcpw); this->charge->check_rho(); // check the rho } diff --git a/source/module_elecstate/elecstate.h b/source/module_elecstate/elecstate.h index 38fcf58e28..20a72faae2 100644 --- a/source/module_elecstate/elecstate.h +++ b/source/module_elecstate/elecstate.h @@ -94,7 +94,18 @@ class ElecState return; } - void init_scf(const int istep, const ModuleBase::ComplexMatrix& strucfac); + /** + * @brief Init rho_core, init rho, renormalize rho, init pot + * + * @param istep i-th step + * @param strucfac structure factor + * @param symm symmetry + * @param wfcpw PW basis for wave function if needed + */ + void init_scf(const int istep, + const ModuleBase::ComplexMatrix& strucfac, + ModuleSymmetry::Symmetry& symm, + const void* wfcpw = nullptr); std::string classname = "elecstate"; int iter = 0; ///< scf iteration diff --git a/source/module_elecstate/module_charge/charge.h b/source/module_elecstate/module_charge/charge.h index 71fdc25120..6ea9ac5a20 100644 --- a/source/module_elecstate/module_charge/charge.h +++ b/source/module_elecstate/module_charge/charge.h @@ -7,6 +7,7 @@ #include "module_base/parallel_global.h" #include "module_basis/module_pw/pw_basis.h" #include "module_elecstate/fp_energy.h" +#include "module_cell/module_symmetry/symmetry.h" //a forward declaration of UnitCell class UnitCell; @@ -62,14 +63,19 @@ class Charge /** * @brief Init charge density from file or atomic pseudo-wave-functions - * - * @param eferm_iout fermi energy to be initialized - * @param strucFac [in] structure factor - * @param nbz [in] number of big grids in z direction - * @param bz [in] number of small grids in big grids for z dirction + * + * @param eferm_iout [out] fermi energy to be initialized + * @param strucFac [in] structure factor + * @param symm [in] symmetry + * @param klist [in] k points list if needed + * @param wfcpw [in] PW basis for wave function if needed */ - void init_rho(elecstate::efermi& eferm_iout, const ModuleBase::ComplexMatrix& strucFac, const int& nbz, const int& bz); - + void init_rho(elecstate::efermi& eferm_iout, + const ModuleBase::ComplexMatrix& strucFac, + ModuleSymmetry::Symmetry& symm, + const void* klist = nullptr, + const void* wfcpw = nullptr); + void allocate(const int &nspin_in); void atomic_rho(const int spin_number_need, @@ -81,11 +87,11 @@ class Charge void set_rho_core(const ModuleBase::ComplexMatrix &structure_factor); void set_rho_core_paw(); - void renormalize_rho(void); + void renormalize_rho(); - double sum_rho(void) const; + double sum_rho() const; - void save_rho_before_sum_band(void); + void save_rho_before_sum_band(); // for non-linear core correction void non_linear_core_correction @@ -108,10 +114,8 @@ class Charge /** * @brief init some arrays for mpi_inter_pools, rho_mpi * - * @param nbz number of bigz in big grids - * @param bz number of z for each bigz */ - void init_chgmpi(const int& nbz, const int& bz); + void init_chgmpi(); /** * @brief Sum rho at different pools (k-point parallelism). diff --git a/source/module_elecstate/module_charge/charge_init.cpp b/source/module_elecstate/module_charge/charge_init.cpp index 8aca645cb1..df311a3e48 100644 --- a/source/module_elecstate/module_charge/charge_init.cpp +++ b/source/module_elecstate/module_charge/charge_init.cpp @@ -16,11 +16,16 @@ #include "module_hamilt_pw/hamilt_pwdft/parallel_grid.h" #include "module_io/rho_io.h" #include "module_io/rhog_io.h" +#include "module_io/read_wfc_to_rho.h" #ifdef USE_PAW #include "module_cell/module_paw/paw_cell.h" #endif -void Charge::init_rho(elecstate::efermi& eferm_iout, const ModuleBase::ComplexMatrix& strucFac, const int& nbz, const int& bz) +void Charge::init_rho(elecstate::efermi& eferm_iout, + const ModuleBase::ComplexMatrix& strucFac, + ModuleSymmetry::Symmetry& symm, + const void* klist, + const void* wfcpw) { ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "init_chg", PARAM.inp.init_chg); @@ -195,8 +200,20 @@ void Charge::init_rho(elecstate::efermi& eferm_iout, const ModuleBase::ComplexMa GlobalC::restart.info_load.load_charge_finish = true; } #ifdef __MPI - this->init_chgmpi(nbz, bz); + this->init_chgmpi(); #endif + if (PARAM.inp.init_chg == "wfc") + { + if (wfcpw == nullptr) + { + ModuleBase::WARNING_QUIT("Charge::init_rho", "wfc is only supported for PW-KSDFT."); + } + const ModulePW::PW_Basis_K* pw_wfc = reinterpret_cast(const_cast(wfcpw)); + const K_Vectors* kv = reinterpret_cast(klist); + const int nkstot = kv->get_nkstot(); + const std::vector& isk = kv->isk; + ModuleIO::read_wfc_to_rho(pw_wfc, symm, nkstot, isk, *this); + } } //========================================================== diff --git a/source/module_elecstate/module_charge/charge_mpi.cpp b/source/module_elecstate/module_charge/charge_mpi.cpp index ec926f0761..ee26d5e7b0 100644 --- a/source/module_elecstate/module_charge/charge_mpi.cpp +++ b/source/module_elecstate/module_charge/charge_mpi.cpp @@ -6,7 +6,7 @@ #include "module_elecstate/elecstate_getters.h" #include "module_parameter/parameter.h" #ifdef __MPI -void Charge::init_chgmpi(const int& nbz, const int& bz) +void Charge::init_chgmpi() { if (GlobalV::NPROC_IN_STOGROUP % GlobalV::KPAR == 0) { diff --git a/source/module_elecstate/module_charge/symmetry_rho.cpp b/source/module_elecstate/module_charge/symmetry_rho.cpp index 9979f5e528..d5afdbd4c3 100644 --- a/source/module_elecstate/module_charge/symmetry_rho.cpp +++ b/source/module_elecstate/module_charge/symmetry_rho.cpp @@ -13,7 +13,6 @@ Symmetry_rho::~Symmetry_rho() void Symmetry_rho::begin(const int& spin_now, const Charge& CHR, const ModulePW::PW_Basis* rho_basis, - Parallel_Grid& Pgrid, ModuleSymmetry::Symmetry& symm) const { assert(spin_now < 4); // added by zhengdy-soc @@ -31,7 +30,7 @@ void Symmetry_rho::begin(const int& spin_now, // else //space group, do rho_symm in reciprocal space { rho_basis->real2recip(CHR.rho[spin_now], CHR.rhog[spin_now]); - psymmg(CHR.rhog[spin_now], rho_basis, Pgrid, symm); // need to modify + psymmg(CHR.rhog[spin_now], rho_basis, symm); // need to modify rho_basis->recip2real(CHR.rhog[spin_now], CHR.rho[spin_now]); if (XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5) @@ -39,7 +38,7 @@ void Symmetry_rho::begin(const int& spin_now, // Use std::vector to manage kin_g instead of raw pointer std::vector> kin_g(CHR.ngmc); rho_basis->real2recip(CHR.kin_r[spin_now], kin_g.data()); - psymmg(kin_g.data(), rho_basis, Pgrid, symm); + psymmg(kin_g.data(), rho_basis, symm); rho_basis->recip2real(kin_g.data(), CHR.kin_r[spin_now]); } } @@ -52,7 +51,6 @@ void Symmetry_rho::begin(const int& spin_now, int ngmc, double** kin_r, const ModulePW::PW_Basis* rho_basis, - Parallel_Grid& Pgrid, ModuleSymmetry::Symmetry& symm) const { assert(spin_now < 4); // added by zhengdy-soc @@ -71,7 +69,7 @@ void Symmetry_rho::begin(const int& spin_now, // else //space group, do rho_symm in reciprocal space { rho_basis->real2recip(rho[spin_now], rhog[spin_now]); - psymmg(rhog[spin_now], rho_basis, Pgrid, symm); + psymmg(rhog[spin_now], rho_basis, symm); rho_basis->recip2real(rhog[spin_now], rho[spin_now]); if (XC_Functional::get_func_type() == 3 || XC_Functional::get_func_type() == 5) @@ -79,7 +77,7 @@ void Symmetry_rho::begin(const int& spin_now, // Use std::vector to manage kin_g instead of raw pointer std::vector> kin_g(ngmc); rho_basis->real2recip(kin_r[spin_now], kin_g.data()); - psymmg(kin_g.data(), rho_basis, Pgrid, symm); + psymmg(kin_g.data(), rho_basis, symm); rho_basis->recip2real(kin_g.data(), kin_r[spin_now]); } } diff --git a/source/module_elecstate/module_charge/symmetry_rho.h b/source/module_elecstate/module_charge/symmetry_rho.h index 39ef442687..716cffca5f 100644 --- a/source/module_elecstate/module_charge/symmetry_rho.h +++ b/source/module_elecstate/module_charge/symmetry_rho.h @@ -14,7 +14,6 @@ class Symmetry_rho void begin(const int& spin_now, const Charge& CHR, const ModulePW::PW_Basis* pw, - Parallel_Grid& Pgrid, ModuleSymmetry::Symmetry& symm) const; void begin(const int& spin_now, @@ -23,7 +22,6 @@ class Symmetry_rho int ngmc, double** kin_r, const ModulePW::PW_Basis* pw, - Parallel_Grid& Pgrid, ModuleSymmetry::Symmetry& symm) const; private: @@ -35,7 +33,6 @@ class Symmetry_rho // in reciprocal space: void psymmg(std::complex* rhog_part, const ModulePW::PW_Basis* rho_basis, - Parallel_Grid& Pgrid, ModuleSymmetry::Symmetry& symm) const; #ifdef __MPI void reduce_to_fullrhog(const ModulePW::PW_Basis* rho_basis, diff --git a/source/module_elecstate/module_charge/symmetry_rhog.cpp b/source/module_elecstate/module_charge/symmetry_rhog.cpp index daf89c5212..0ed87bd572 100644 --- a/source/module_elecstate/module_charge/symmetry_rhog.cpp +++ b/source/module_elecstate/module_charge/symmetry_rhog.cpp @@ -4,7 +4,7 @@ #include "module_hamilt_general/module_xc/xc_functional.h" -void Symmetry_rho::psymmg(std::complex* rhog_part, const ModulePW::PW_Basis *rho_basis, Parallel_Grid &Pgrid, ModuleSymmetry::Symmetry &symm) const +void Symmetry_rho::psymmg(std::complex* rhog_part, const ModulePW::PW_Basis *rho_basis, ModuleSymmetry::Symmetry &symm) const { //(1) get fftixy2is and do Allreduce int * fftixy2is = new int [rho_basis->fftnxy]; diff --git a/source/module_elecstate/test/elecstate_base_test.cpp b/source/module_elecstate/test/elecstate_base_test.cpp index 70d62492f3..dd2dff8a0d 100644 --- a/source/module_elecstate/test/elecstate_base_test.cpp +++ b/source/module_elecstate/test/elecstate_base_test.cpp @@ -71,7 +71,11 @@ void Charge::set_rho_core(ModuleBase::ComplexMatrix const&) void Charge::set_rho_core_paw() { } -void Charge::init_rho(elecstate::efermi&, ModuleBase::ComplexMatrix const&, const int&, const int&) +void Charge::init_rho(elecstate::efermi&, + ModuleBase::ComplexMatrix const&, + ModuleSymmetry::Symmetry& symm, + const void*, + const void*) { } void Charge::set_rhopw(ModulePW::PW_Basis*) @@ -390,7 +394,8 @@ TEST_F(ElecStateTest, InitSCF) int istep = 0; ModuleBase::ComplexMatrix strucfac; elecstate->eferm = efermi; - EXPECT_NO_THROW(elecstate->init_scf(istep, strucfac)); + ModuleSymmetry::Symmetry symm; + EXPECT_NO_THROW(elecstate->init_scf(istep, strucfac, symm)); // delete elecstate->pot is done in the destructor of elecstate delete charge; } diff --git a/source/module_elecstate/test/elecstate_pw_test.cpp b/source/module_elecstate/test/elecstate_pw_test.cpp index 39d9f3ebb9..e31908f1b5 100644 --- a/source/module_elecstate/test/elecstate_pw_test.cpp +++ b/source/module_elecstate/test/elecstate_pw_test.cpp @@ -145,7 +145,11 @@ void Charge::set_rho_core(ModuleBase::ComplexMatrix const&) void Charge::set_rho_core_paw() { } -void Charge::init_rho(elecstate::efermi&, ModuleBase::ComplexMatrix const&, const int&, const int&) +void Charge::init_rho(elecstate::efermi&, + ModuleBase::ComplexMatrix const&, + ModuleSymmetry::Symmetry& symm, + const void*, + const void*) { } void Charge::set_rhopw(ModulePW::PW_Basis*) diff --git a/source/module_elecstate/test_mpi/charge_mpi_test.cpp b/source/module_elecstate/test_mpi/charge_mpi_test.cpp index bb4381e510..606446295f 100644 --- a/source/module_elecstate/test_mpi/charge_mpi_test.cpp +++ b/source/module_elecstate/test_mpi/charge_mpi_test.cpp @@ -97,7 +97,7 @@ TEST_F(ChargeMpiTest, reduce_diff_pools1) } double refsum = sum_array(array_rho, nrxx); - charge->init_chgmpi(nz, 1); + charge->init_chgmpi(); charge->reduce_diff_pools(array_rho); double sum = sum_array(array_rho, nrxx); EXPECT_EQ(sum, refsum * GlobalV::KPAR); @@ -150,7 +150,7 @@ TEST_F(ChargeMpiTest, reduce_diff_pools2) } } - charge->init_chgmpi(nz, 1); + charge->init_chgmpi(); charge->reduce_diff_pools(array_rho); double sum = sum_array(array_rho, nrxx); MPI_Allreduce(MPI_IN_PLACE, &sum, 1, MPI_DOUBLE, MPI_SUM, POOL_WORLD); @@ -194,7 +194,7 @@ TEST_F(ChargeMpiTest, rho_mpi) charge->nrxx = nrxx; charge->rho[0] = new double[nrxx]; charge->kin_r[0] = new double[nrxx]; - charge->init_chgmpi(nz, 1); + charge->init_chgmpi(); charge->rho_mpi(); delete[] charge->rho[0]; diff --git a/source/module_esolver/esolver_ks_lcao.cpp b/source/module_esolver/esolver_ks_lcao.cpp index b343b63793..9a8d1d634e 100644 --- a/source/module_esolver/esolver_ks_lcao.cpp +++ b/source/module_esolver/esolver_ks_lcao.cpp @@ -784,7 +784,7 @@ void ESolver_KS_LCAO::hamilt2density(int istep, int iter, double ethr) Symmetry_rho srho; for (int is = 0; is < GlobalV::NSPIN; is++) { - srho.begin(is, *(this->pelec->charge), this->pw_rho, GlobalC::Pgrid, GlobalC::ucell.symm); + srho.begin(is, *(this->pelec->charge), this->pw_rho, GlobalC::ucell.symm); } // 11) compute magnetization, only for spin==2 diff --git a/source/module_esolver/esolver_ks_lcao_tddft.cpp b/source/module_esolver/esolver_ks_lcao_tddft.cpp index 21215b81e5..e312980bee 100644 --- a/source/module_esolver/esolver_ks_lcao_tddft.cpp +++ b/source/module_esolver/esolver_ks_lcao_tddft.cpp @@ -216,7 +216,7 @@ void ESolver_KS_LCAO_TDDFT::hamilt2density(const int istep, const int iter, cons Symmetry_rho srho; for (int is = 0; is < GlobalV::NSPIN; is++) { - srho.begin(is, *(pelec->charge), pw_rho, GlobalC::Pgrid, GlobalC::ucell.symm); + srho.begin(is, *(pelec->charge), pw_rho, GlobalC::ucell.symm); } } diff --git a/source/module_esolver/esolver_ks_lcaopw.cpp b/source/module_esolver/esolver_ks_lcaopw.cpp index 02d47d3e4c..c334f15916 100644 --- a/source/module_esolver/esolver_ks_lcaopw.cpp +++ b/source/module_esolver/esolver_ks_lcaopw.cpp @@ -174,7 +174,7 @@ namespace ModuleESolver Symmetry_rho srho; for (int is = 0; is < GlobalV::NSPIN; is++) { - srho.begin(is, *(this->pelec->charge), this->pw_rhod, GlobalC::Pgrid, GlobalC::ucell.symm); + srho.begin(is, *(this->pelec->charge), this->pw_rhod, GlobalC::ucell.symm); } // compute magnetization, only for LSDA(spin==2) diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index 0d74ebeb20..22d466d93f 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -219,7 +219,7 @@ void ESolver_KS_PW::before_scf(const int istep) } //! calculate the total local pseudopotential in real space - this->pelec->init_scf(istep, this->sf.strucFac); + this->pelec->init_scf(istep, this->sf.strucFac, GlobalC::ucell.symm, (void*)this->pw_wfc); //! output the initial charge density if (PARAM.inp.out_chg[0] == 2) @@ -283,7 +283,7 @@ void ESolver_KS_PW::before_scf(const int istep) Symmetry_rho srho; for (int is = 0; is < GlobalV::NSPIN; is++) { - srho.begin(is, *(this->pelec->charge), this->pw_rhod, GlobalC::Pgrid, GlobalC::ucell.symm); + srho.begin(is, *(this->pelec->charge), this->pw_rhod, GlobalC::ucell.symm); } // liuyu move here 2023-10-09 @@ -405,7 +405,7 @@ void ESolver_KS_PW::hamilt2density(const int istep, const int iter, c Symmetry_rho srho; for (int is = 0; is < GlobalV::NSPIN; is++) { - srho.begin(is, *(this->pelec->charge), this->pw_rhod, GlobalC::Pgrid, GlobalC::ucell.symm); + srho.begin(is, *(this->pelec->charge), this->pw_rhod, GlobalC::ucell.symm); } // compute magnetization, only for LSDA(spin==2) diff --git a/source/module_esolver/esolver_of.cpp b/source/module_esolver/esolver_of.cpp index 006241c7b6..22ece96650 100644 --- a/source/module_esolver/esolver_of.cpp +++ b/source/module_esolver/esolver_of.cpp @@ -120,7 +120,7 @@ void ESolver_OF::before_all_runners(const Input_para& inp, UnitCell& ucell) this->init_elecstate(ucell); // calculate the total local pseudopotential in real space - this->pelec->init_scf(0, sf.strucFac); // atomic_rho, v_of_rho, set_vrs + this->pelec->init_scf(0, sf.strucFac, GlobalC::ucell.symm); // atomic_rho, v_of_rho, set_vrs // liuyu move here 2023-10-09 // D in uspp need vloc, thus behind init_scf() @@ -279,7 +279,7 @@ void ESolver_OF::before_opt(const int istep, UnitCell& ucell) GlobalV::ofs_warning); } - this->pelec->init_scf(istep, sf.strucFac); + this->pelec->init_scf(istep, sf.strucFac, GlobalC::ucell.symm); // calculate ewald energy this->pelec->f_en.ewald_energy = H_Ewald_pw::compute_ewald(ucell, this->pw_rho, sf.strucFac); @@ -287,7 +287,7 @@ void ESolver_OF::before_opt(const int istep, UnitCell& ucell) Symmetry_rho srho; for (int is = 0; is < GlobalV::NSPIN; is++) { - srho.begin(is, *(pelec->charge), this->pw_rho, GlobalC::Pgrid, GlobalC::ucell.symm); + srho.begin(is, *(pelec->charge), this->pw_rho, GlobalC::ucell.symm); } for (int is = 0; is < GlobalV::NSPIN; ++is) diff --git a/source/module_esolver/esolver_sdft_pw.cpp b/source/module_esolver/esolver_sdft_pw.cpp index df399358a9..89dfe97445 100644 --- a/source/module_esolver/esolver_sdft_pw.cpp +++ b/source/module_esolver/esolver_sdft_pw.cpp @@ -197,7 +197,7 @@ void ESolver_SDFT_PW::hamilt2density(int istep, int iter, double ethr) Symmetry_rho srho; for (int is = 0; is < GlobalV::NSPIN; is++) { - srho.begin(is, *(this->pelec->charge), pw_rho, GlobalC::Pgrid, GlobalC::ucell.symm); + srho.begin(is, *(this->pelec->charge), pw_rho, GlobalC::ucell.symm); } this->pelec->f_en.deband = this->pelec->cal_delta_eband(); } diff --git a/source/module_esolver/lcao_before_scf.cpp b/source/module_esolver/lcao_before_scf.cpp index e23d848d78..39dff621a2 100644 --- a/source/module_esolver/lcao_before_scf.cpp +++ b/source/module_esolver/lcao_before_scf.cpp @@ -218,7 +218,7 @@ void ESolver_KS_LCAO::before_scf(const int istep) } #endif // __EXX - this->pelec->init_scf(istep, this->sf.strucFac); + this->pelec->init_scf(istep, this->sf.strucFac, GlobalC::ucell.symm); //! output the initial charge density if (PARAM.inp.out_chg[0] == 2) @@ -338,7 +338,7 @@ void ESolver_KS_LCAO::before_scf(const int istep) Symmetry_rho srho; for (int is = 0; is < GlobalV::NSPIN; is++) { - srho.begin(is, *(this->pelec->charge), this->pw_rho, GlobalC::Pgrid, GlobalC::ucell.symm); + srho.begin(is, *(this->pelec->charge), this->pw_rho, GlobalC::ucell.symm); } // 1. calculate ewald energy. diff --git a/source/module_esolver/lcao_others.cpp b/source/module_esolver/lcao_others.cpp index 77ce1c39ef..14bcd951e3 100644 --- a/source/module_esolver/lcao_others.cpp +++ b/source/module_esolver/lcao_others.cpp @@ -89,7 +89,7 @@ void ESolver_KS_LCAO::others(const int istep) this->beforesolver(istep); // pelec should be initialized before these calculations - this->pelec->init_scf(istep, this->sf.strucFac); + this->pelec->init_scf(istep, this->sf.strucFac, GlobalC::ucell.symm); // self consistent calculations for electronic ground state if (PARAM.inp.calculation == "nscf") { diff --git a/source/module_hamilt_pw/hamilt_pwdft/elecond.cpp b/source/module_hamilt_pw/hamilt_pwdft/elecond.cpp index ebb8769e66..2d6aa95f5e 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/elecond.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/elecond.cpp @@ -97,8 +97,9 @@ void EleCond::jjresponse_ks(const int ik, const int nt, const double dt, const d hamilt::Velocity& velop, double* ct11, double* ct12, double* ct22) { const int nbands = GlobalV::NBANDS; - if (wg(ik, 0) - wg(ik, nbands - 1) < 1e-8 || nbands == 0) + if (wg(ik, 0) - wg(ik, nbands - 1) < 1e-8 || nbands == 0) { return; +} const char transn = 'N'; const char transc = 'C'; const int ndim = 3; @@ -121,7 +122,7 @@ void EleCond::jjresponse_ks(const int ik, const int nt, const double dt, const d #ifdef __MPI MPI_Allreduce(MPI_IN_PLACE, pij.data(), nbands * nbands, MPI_DOUBLE_COMPLEX, MPI_SUM, POOL_WORLD); #endif - if (!gamma_only) + if (!gamma_only) { for (int ib = 0, ijb = 0; ib < nbands; ++ib) { for (int jb = ib + 1; jb < nbands; ++jb, ++ijb) @@ -129,12 +130,13 @@ void EleCond::jjresponse_ks(const int ik, const int nt, const double dt, const d pij2[ijb] += norm(pij[ib * nbands + jb]); } } +} } if (GlobalV::RANK_IN_POOL == 0) { int nkstot = this->p_kv->get_nkstot(); - int ikglobal = this->p_kv->getik_global(ik); + int ikglobal = K_Vectors::get_ik_global(ik, nkstot); std::stringstream ss; ss << GlobalV::global_out_dir << "vmatrix" << ikglobal + 1 << ".dat"; Binstream binpij(ss.str(), "w"); @@ -168,8 +170,9 @@ void EleCond::jjresponse_ks(const int ik, const int nt, const double dt, const d for (int jb = ib + 1; jb < nbands; ++jb, ++ijb) { double ej = enb[jb]; - if (ej - ei > decut) + if (ej - ei > decut) { continue; +} double fj = wg(ik, jb); double tmct = sin((ej - ei) * (it)*dt) * (fi - fj) * pij2[ijb]; tmct11 += tmct; diff --git a/source/module_hamilt_pw/hamilt_pwdft/wavefunc.cpp b/source/module_hamilt_pw/hamilt_pwdft/wavefunc.cpp index a685d1e3a0..5d74857781 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/wavefunc.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/wavefunc.cpp @@ -1,6 +1,5 @@ #include "wavefunc.h" -#include "module_parameter/parameter.h" #include "module_base/memory.h" #include "module_base/timer.h" #include "module_hamilt_lcao/hamilt_lcaodft/wavefunc_in_pw.h" @@ -10,6 +9,7 @@ #include "module_hsolver/kernels/math_kernel_op.h" #include "module_io/read_wfc_pw.h" #include "module_io/winput.h" +#include "module_parameter/parameter.h" #include "module_psi/psi.h" wavefunc::wavefunc() @@ -18,22 +18,22 @@ wavefunc::wavefunc() wavefunc::~wavefunc() { - if(PARAM.inp.test_deconstructor) - { - std::cout << " ~wavefunc()" << std::endl; - } - if(this->irindex != nullptr) - { - delete[] this->irindex; - this->irindex=nullptr; - } + if (PARAM.inp.test_deconstructor) + { + std::cout << " ~wavefunc()" << std::endl; + } + if (this->irindex != nullptr) + { + delete[] this->irindex; + this->irindex = nullptr; + } } psi::Psi>* wavefunc::allocate(const int nkstot, const int nks, const int* ngk, const int npwx_in) { - ModuleBase::TITLE("wavefunc","allocate"); + ModuleBase::TITLE("wavefunc", "allocate"); - this->npwx = npwx_in; + this->npwx = npwx_in; this->nkstot = nkstot; ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "npwx", npwx); @@ -51,69 +51,69 @@ psi::Psi>* wavefunc::allocate(const int nkstot, const int n const int nks2 = nks; - psi::Psi>* psi_out = nullptr; + psi::Psi>* psi_out = nullptr; if (PARAM.inp.calculation == "nscf" && this->mem_saver == 1) { - //initial psi rather than evc - psi_out = new psi::Psi>(1, GlobalV::NBANDS, npwx * GlobalV::NPOL, ngk); - if(PARAM.inp.basis_type=="lcao_in_pw") - { - wanf2[0].create(GlobalV::NLOCAL, npwx * GlobalV::NPOL); - const size_t memory_cost = GlobalV::NLOCAL*(GlobalV::NPOL*npwx) * sizeof(std::complex); - std::cout << " Memory for wanf2 (MB): " << double(memory_cost)/1024.0/1024.0 << std::endl; - ModuleBase::Memory::record("WF::wanf2", memory_cost) ; - } - const size_t memory_cost = GlobalV::NBANDS*(GlobalV::NPOL*npwx) * sizeof(std::complex); - std::cout << " MEMORY FOR PSI (MB) : " << double(memory_cost)/1024.0/1024.0 << std::endl; - ModuleBase::Memory::record("Psi_PW", memory_cost); - } - else if(PARAM.inp.basis_type!="pw") + // initial psi rather than evc + psi_out = new psi::Psi>(1, GlobalV::NBANDS, npwx * GlobalV::NPOL, ngk); + if (PARAM.inp.basis_type == "lcao_in_pw") + { + wanf2[0].create(GlobalV::NLOCAL, npwx * GlobalV::NPOL); + const size_t memory_cost = GlobalV::NLOCAL * (GlobalV::NPOL * npwx) * sizeof(std::complex); + std::cout << " Memory for wanf2 (MB): " << double(memory_cost) / 1024.0 / 1024.0 << std::endl; + ModuleBase::Memory::record("WF::wanf2", memory_cost); + } + const size_t memory_cost = GlobalV::NBANDS * (GlobalV::NPOL * npwx) * sizeof(std::complex); + std::cout << " MEMORY FOR PSI (MB) : " << double(memory_cost) / 1024.0 / 1024.0 << std::endl; + ModuleBase::Memory::record("Psi_PW", memory_cost); + } + else if (PARAM.inp.basis_type != "pw") { if ((PARAM.inp.basis_type == "lcao" || PARAM.inp.basis_type == "lcao_in_pw") || winput::out_spillage == 2) { // for lcao_in_pw - if(this->wanf2 != nullptr) - { - delete[] this->wanf2; - } - this->wanf2 = new ModuleBase::ComplexMatrix [nks2]; + if (this->wanf2 != nullptr) + { + delete[] this->wanf2; + } + this->wanf2 = new ModuleBase::ComplexMatrix[nks2]; - for (int ik = 0; ik < nks2; ik++) - { - this->wanf2[ik].create(GlobalV::NLOCAL, npwx * GlobalV::NPOL); - } + for (int ik = 0; ik < nks2; ik++) + { + this->wanf2[ik].create(GlobalV::NLOCAL, npwx * GlobalV::NPOL); + } - const size_t memory_cost = nks2 * GlobalV::NLOCAL*(npwx * GlobalV::NPOL) * sizeof(std::complex); - std::cout << " Memory for wanf2 (MB): " << double(memory_cost)/1024.0/1024.0 << std::endl; - ModuleBase::Memory::record("WF::wanf2", memory_cost) ; + const size_t memory_cost = nks2 * GlobalV::NLOCAL * (npwx * GlobalV::NPOL) * sizeof(std::complex); + std::cout << " Memory for wanf2 (MB): " << double(memory_cost) / 1024.0 / 1024.0 << std::endl; + ModuleBase::Memory::record("WF::wanf2", memory_cost); } } else { // initial psi rather than evc psi_out = new psi::Psi>(nks2, GlobalV::NBANDS, npwx * GlobalV::NPOL, ngk); - const size_t memory_cost = nks2 * GlobalV::NBANDS*(GlobalV::NPOL*npwx) * sizeof(std::complex); - std::cout << " MEMORY FOR PSI (MB) : " << double(memory_cost)/1024.0/1024.0 << std::endl; - ModuleBase::Memory::record("Psi_PW", memory_cost); + const size_t memory_cost = nks2 * GlobalV::NBANDS * (GlobalV::NPOL * npwx) * sizeof(std::complex); + std::cout << " MEMORY FOR PSI (MB) : " << double(memory_cost) / 1024.0 / 1024.0 << std::endl; + ModuleBase::Memory::record("Psi_PW", memory_cost); } return psi_out; - //showMemStats(); + // showMemStats(); } //=================================================================== // This routine computes an estimate of the start_ wavefunctions // from superposition of atomic wavefunctions or random wave functions. //=================================================================== -void wavefunc::wfcinit(psi::Psi> *psi_in, ModulePW::PW_Basis_K *wfc_basis) +void wavefunc::wfcinit(psi::Psi>* psi_in, ModulePW::PW_Basis_K* wfc_basis) { - ModuleBase::TITLE("wavefunc","wfcinit"); + ModuleBase::TITLE("wavefunc", "wfcinit"); ModuleBase::timer::tick("wavefunc", "wfcinit"); if (PARAM.inp.basis_type == "pw") { - if (this->irindex != nullptr) - { - delete[] this->irindex; - } + if (this->irindex != nullptr) + { + delete[] this->irindex; + } this->irindex = new int[wfc_basis->fftnxy]; wfc_basis->getfftixy2is(this->irindex); #if defined(__CUDA) || defined(__ROCM) @@ -123,59 +123,61 @@ void wavefunc::wfcinit(psi::Psi> *psi_in, ModulePW::PW_Basi } #endif } - ModuleBase::timer::tick("wavefunc","wfcinit"); + ModuleBase::timer::tick("wavefunc", "wfcinit"); return; } -int wavefunc::get_starting_nw()const +int wavefunc::get_starting_nw() const { if (init_wfc == "file") { return GlobalV::NBANDS; } - else if (init_wfc.substr(0,6) == "atomic") + else if (init_wfc.substr(0, 6) == "atomic") { if (GlobalC::ucell.natomwfc >= GlobalV::NBANDS) { - if(PARAM.inp.test_wf) - { - GlobalV::ofs_running << " Start wave functions are all pseudo atomic wave functions." << std::endl; - } + if (PARAM.inp.test_wf) + { + GlobalV::ofs_running << " Start wave functions are all pseudo atomic wave functions." << std::endl; + } } else { - if(PARAM.inp.test_wf) - { - GlobalV::ofs_running << " Start wave functions are atomic + " - << GlobalV::NBANDS - GlobalC::ucell.natomwfc - << " random wave functions." << std::endl; - } + if (PARAM.inp.test_wf) + { + GlobalV::ofs_running << " Start wave functions are atomic + " + << GlobalV::NBANDS - GlobalC::ucell.natomwfc << " random wave functions." + << std::endl; + } } - return std::max(GlobalC::ucell.natomwfc, GlobalV::NBANDS); + return std::max(GlobalC::ucell.natomwfc, GlobalV::NBANDS); } else if (init_wfc == "random") { - if(PARAM.inp.test_wf) - { - GlobalV::ofs_running << " Start wave functions are all random." << std::endl; - } + if (PARAM.inp.test_wf) + { + GlobalV::ofs_running << " Start wave functions are all random." << std::endl; + } return GlobalV::NBANDS; } else { - throw std::runtime_error("wavefunc::get_starting_nw. Don't know what to do! Please Check source code! "+ModuleBase::GlobalFunc::TO_STRING(__FILE__)+" line "+ModuleBase::GlobalFunc::TO_STRING(__LINE__)); // Peize Lin change 2019-05-01 - //ModuleBase::WARNING_QUIT("get_starting_nw","Don't know what to do! Please Check source code!"); + throw std::runtime_error("wavefunc::get_starting_nw. Don't know what to do! Please Check source code! " + + ModuleBase::GlobalFunc::TO_STRING(__FILE__) + " line " + + ModuleBase::GlobalFunc::TO_STRING(__LINE__)); // Peize Lin change 2019-05-01 + // ModuleBase::WARNING_QUIT("get_starting_nw","Don't know what to do! Please Check source code!"); } } namespace hamilt { -void diago_PAO_in_pw_k2(const int &ik, - psi::Psi> &wvf, - ModulePW::PW_Basis_K *wfc_basis, - wavefunc *p_wf, - hamilt::Hamilt> *phm_in) +void diago_PAO_in_pw_k2(const int& ik, + psi::Psi>& wvf, + ModulePW::PW_Basis_K* wfc_basis, + wavefunc* p_wf, + hamilt::Hamilt>* phm_in) { ModuleBase::TITLE("wavefunc", "diago_PAO_in_pw_k2"); @@ -187,59 +189,53 @@ void diago_PAO_in_pw_k2(const int &ik, { ModuleBase::ComplexMatrix wfcatom(nbands, nbasis); std::stringstream filename; - filename << GlobalV::global_readin_dir << "WAVEFUNC" << ik + 1 << ".dat"; - bool result = ModuleIO::read_wfc_pw(filename.str(), wfc_basis, ik, p_wf->nkstot, wfcatom); + int ik_tot = K_Vectors::get_ik_global(ik, p_wf->nkstot); + filename << GlobalV::global_readin_dir << "WAVEFUNC" << ik_tot + 1 << ".dat"; + ModuleIO::read_wfc_pw(filename.str(), wfc_basis, ik, p_wf->nkstot, wfcatom); - if (result) - { - std::vector> s_wfcatom(nbands * nbasis); - castmem_z2c_h2h_op()(cpu_ctx, cpu_ctx, s_wfcatom.data(), wfcatom.c, nbands * nbasis); + std::vector> s_wfcatom(nbands * nbasis); + castmem_z2c_h2h_op()(cpu_ctx, cpu_ctx, s_wfcatom.data(), wfcatom.c, nbands * nbasis); - if (GlobalV::KS_SOLVER == "cg") + if (GlobalV::KS_SOLVER == "cg") + { + std::vector etfile(nbands, 0.0); + if (phm_in != nullptr) { - std::vector etfile(nbands, 0.0); - if (phm_in != nullptr) - { - hsolver::DiagoIterAssist>::diagH_subspace_init(phm_in, - s_wfcatom.data(), - wfcatom.nr, - wfcatom.nc, - wvf, - etfile.data()); - return; - } - else - { - ModuleBase::WARNING_QUIT("wavefunc", "Psi does not exist!"); - } + hsolver::DiagoIterAssist>::diagH_subspace_init(phm_in, + s_wfcatom.data(), + wfcatom.nr, + wfcatom.nc, + wvf, + etfile.data()); + return; } - - assert(nbands <= wfcatom.nr); - for (int ib = 0; ib < nbands; ib++) + else { - for (int ig = 0; ig < nbasis; ig++) - { - wvf(ib, ig) = s_wfcatom[ib * nbasis + ig]; - } + ModuleBase::WARNING_QUIT("wavefunc", "Psi does not exist!"); } - return; } - else + + assert(nbands <= wfcatom.nr); + for (int ib = 0; ib < nbands; ib++) { - p_wf->init_wfc = "atomic+random"; + for (int ig = 0; ig < nbasis; ig++) + { + wvf(ib, ig) = s_wfcatom[ib * nbasis + ig]; + } } + return; } const int starting_nw = p_wf->get_starting_nw(); - if (starting_nw == 0) - { - return; - } + if (starting_nw == 0) + { + return; + } assert(starting_nw > 0); std::vector etatom(starting_nw, 0.0); - //special case here! use Psi(k-1) for the initialization of Psi(k) - //this method should be tested. + // special case here! use Psi(k-1) for the initialization of Psi(k) + // this method should be tested. /*if(PARAM.inp.calculation == "nscf" && GlobalC::ucell.natomwfc == 0 && ik>0) { //this is memsaver case @@ -255,90 +251,91 @@ void diago_PAO_in_pw_k2(const int &ik, } */ - if( p_wf->init_wfc=="random" || ( p_wf->init_wfc.substr(0,6)=="atomic" && GlobalC::ucell.natomwfc == 0 )) - { - p_wf->random(wvf.get_pointer(),0,nbands,ik, wfc_basis); - - if(GlobalV::KS_SOLVER=="cg") //xiaohui add 2013-09-02 - { - if(phm_in!= nullptr) - { - hsolver::DiagoIterAssist>::diagH_subspace(phm_in, wvf, wvf, etatom.data()); - return; - } - else - { - ModuleBase::WARNING_QUIT("wavefunc","Hamiltonian does not exist!"); - } - } - } - else if(p_wf->init_wfc.substr(0,6)=="atomic") - { - ModuleBase::ComplexMatrix wfcatom(starting_nw, nbasis);//added by zhengdy-soc - if(PARAM.inp.test_wf)ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "starting_nw", starting_nw); - - p_wf->atomic_wfc( - ik, - current_nbasis, - GlobalC::ucell.lmax_ppwf, - wfc_basis, - wfcatom, - GlobalC::ppcell.tab_at, - GlobalV::NQX, - GlobalV::DQ); - - if( p_wf->init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc )//added by qianrui 2021-5-16 - { - p_wf->atomicrandom(wfcatom,0,starting_nw,ik, wfc_basis); - } - - //==================================================== - // If not enough atomic wfc are available, complete - // with random wfcs - //==================================================== - p_wf->random(wfcatom.c, GlobalC::ucell.natomwfc, nbands, ik, wfc_basis); - - // (7) Diago with cg method. - std::vector> s_wfcatom(starting_nw * nbasis); - castmem_z2c_h2h_op()(cpu_ctx, cpu_ctx, s_wfcatom.data(), wfcatom.c, starting_nw * nbasis); - - //if(GlobalV::DIAGO_TYPE == "cg") xiaohui modify 2013-09-02 - if(GlobalV::KS_SOLVER=="cg") //xiaohui add 2013-09-02 - { - if(phm_in!= nullptr) - { - hsolver::DiagoIterAssist>::diagH_subspace_init(phm_in, - s_wfcatom.data(), - wfcatom.nr, - wfcatom.nc, - wvf, - etatom.data()); - return; - } - else - { - ModuleBase::WARNING_QUIT("wavefunc","Psi does not exist!"); - //this diagonalization method is obsoleted now - //GlobalC::hm.diagH_subspace(ik ,starting_nw, nbands, wfcatom, wfcatom, etatom.data()); - } - } - - assert(nbands <= wfcatom.nr); - for (int ib=0; ibinit_wfc == "random" || (p_wf->init_wfc.substr(0, 6) == "atomic" && GlobalC::ucell.natomwfc == 0)) + { + p_wf->random(wvf.get_pointer(), 0, nbands, ik, wfc_basis); + + if (GlobalV::KS_SOLVER == "cg") // xiaohui add 2013-09-02 + { + if (phm_in != nullptr) + { + hsolver::DiagoIterAssist>::diagH_subspace(phm_in, wvf, wvf, etatom.data()); + return; + } + else + { + ModuleBase::WARNING_QUIT("wavefunc", "Hamiltonian does not exist!"); + } + } + } + else if (p_wf->init_wfc.substr(0, 6) == "atomic") + { + ModuleBase::ComplexMatrix wfcatom(starting_nw, nbasis); // added by zhengdy-soc + if (PARAM.inp.test_wf) { + ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "starting_nw", starting_nw); +} + + p_wf->atomic_wfc(ik, + current_nbasis, + GlobalC::ucell.lmax_ppwf, + wfc_basis, + wfcatom, + GlobalC::ppcell.tab_at, + GlobalV::NQX, + GlobalV::DQ); + + if (p_wf->init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16 + { + p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis); + } + + //==================================================== + // If not enough atomic wfc are available, complete + // with random wfcs + //==================================================== + p_wf->random(wfcatom.c, GlobalC::ucell.natomwfc, nbands, ik, wfc_basis); + + // (7) Diago with cg method. + std::vector> s_wfcatom(starting_nw * nbasis); + castmem_z2c_h2h_op()(cpu_ctx, cpu_ctx, s_wfcatom.data(), wfcatom.c, starting_nw * nbasis); + + // if(GlobalV::DIAGO_TYPE == "cg") xiaohui modify 2013-09-02 + if (GlobalV::KS_SOLVER == "cg") // xiaohui add 2013-09-02 + { + if (phm_in != nullptr) + { + hsolver::DiagoIterAssist>::diagH_subspace_init(phm_in, + s_wfcatom.data(), + wfcatom.nr, + wfcatom.nc, + wvf, + etatom.data()); + return; + } + else + { + ModuleBase::WARNING_QUIT("wavefunc", "Psi does not exist!"); + // this diagonalization method is obsoleted now + // GlobalC::hm.diagH_subspace(ik ,starting_nw, nbands, wfcatom, wfcatom, etatom.data()); + } + } + + assert(nbands <= wfcatom.nr); + for (int ib = 0; ib < nbands; ib++) + { + for (int ig = 0; ig < nbasis; ig++) + { + wvf(ib, ig) = s_wfcatom[ib * nbasis + ig]; + } + } + } } -void diago_PAO_in_pw_k2(const int &ik, - psi::Psi> &wvf, - ModulePW::PW_Basis_K *wfc_basis, - wavefunc *p_wf, - hamilt::Hamilt> *phm_in) +void diago_PAO_in_pw_k2(const int& ik, + psi::Psi>& wvf, + ModulePW::PW_Basis_K* wfc_basis, + wavefunc* p_wf, + hamilt::Hamilt>* phm_in) { ModuleBase::TITLE("wavefunc", "diago_PAO_in_pw_k2"); @@ -350,44 +347,38 @@ void diago_PAO_in_pw_k2(const int &ik, { ModuleBase::ComplexMatrix wfcatom(nbands, nbasis); std::stringstream filename; - filename << GlobalV::global_readin_dir << "WAVEFUNC" << ik + 1 << ".dat"; - bool result = ModuleIO::read_wfc_pw(filename.str(), wfc_basis, ik, p_wf->nkstot, wfcatom); + int ik_tot = K_Vectors::get_ik_global(ik, p_wf->nkstot); + filename << GlobalV::global_readin_dir << "WAVEFUNC" << ik_tot + 1 << ".dat"; + ModuleIO::read_wfc_pw(filename.str(), wfc_basis, ik, p_wf->nkstot, wfcatom); - if (result) + if (GlobalV::KS_SOLVER == "cg") { - if (GlobalV::KS_SOLVER == "cg") + std::vector etfile(nbands, 0.0); + if (phm_in != nullptr) { - std::vector etfile(nbands, 0.0); - if (phm_in != nullptr) - { - hsolver::DiagoIterAssist>::diagH_subspace_init(phm_in, - wfcatom.c, - wfcatom.nr, - wfcatom.nc, - wvf, - etfile.data()); - return; - } - else - { - ModuleBase::WARNING_QUIT("wavefunc", "Hamiltonian does not exist!"); - } + hsolver::DiagoIterAssist>::diagH_subspace_init(phm_in, + wfcatom.c, + wfcatom.nr, + wfcatom.nc, + wvf, + etfile.data()); + return; } - - assert(nbands <= wfcatom.nr); - for (int ib = 0; ib < nbands; ib++) + else { - for (int ig = 0; ig < nbasis; ig++) - { - wvf(ib, ig) = wfcatom(ib, ig); - } + ModuleBase::WARNING_QUIT("wavefunc", "Hamiltonian does not exist!"); } - return; } - else + + assert(nbands <= wfcatom.nr); + for (int ib = 0; ib < nbands; ib++) { - p_wf->init_wfc = "atomic+random"; + for (int ig = 0; ig < nbasis; ig++) + { + wvf(ib, ig) = wfcatom(ib, ig); + } } + return; } // special case here! use Psi(k-1) for the initialization of Psi(k) @@ -408,10 +399,10 @@ void diago_PAO_in_pw_k2(const int &ik, */ const int starting_nw = p_wf->get_starting_nw(); - if (starting_nw == 0) - { - return; - } + if (starting_nw == 0) + { + return; + } assert(starting_nw > 0); std::vector etatom(starting_nw, 0.0); @@ -435,22 +426,21 @@ void diago_PAO_in_pw_k2(const int &ik, else if (p_wf->init_wfc.substr(0, 6) == "atomic") { ModuleBase::ComplexMatrix wfcatom(starting_nw, nbasis); // added by zhengdy-soc - if (PARAM.inp.test_wf) - { - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "starting_nw", starting_nw); - } - - p_wf->atomic_wfc(ik, - current_nbasis, - GlobalC::ucell.lmax_ppwf, - wfc_basis, - wfcatom, - GlobalC::ppcell.tab_at, - GlobalV::NQX, - GlobalV::DQ); - - if (p_wf->init_wfc == "atomic+random" - && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16 + if (PARAM.inp.test_wf) + { + ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "starting_nw", starting_nw); + } + + p_wf->atomic_wfc(ik, + current_nbasis, + GlobalC::ucell.lmax_ppwf, + wfc_basis, + wfcatom, + GlobalC::ppcell.tab_at, + GlobalV::NQX, + GlobalV::DQ); + + if (p_wf->init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16 { p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis); } @@ -462,33 +452,33 @@ void diago_PAO_in_pw_k2(const int &ik, p_wf->random(wfcatom.c, GlobalC::ucell.natomwfc, nbands, ik, wfc_basis); // (7) Diago with cg method. - //if(GlobalV::DIAGO_TYPE == "cg") xiaohui modify 2013-09-02 - if(GlobalV::KS_SOLVER=="cg") //xiaohui add 2013-09-02 - { - if(phm_in!= nullptr) - { - hsolver::DiagoIterAssist>::diagH_subspace_init(phm_in, - wfcatom.c, - wfcatom.nr, - wfcatom.nc, - wvf, - etatom.data()); - return; - } - else - { - ModuleBase::WARNING_QUIT("wavefunc","Hamiltonian does not exist!"); - } - } - - assert(nbands <= wfcatom.nr); - for (int ib=0; ib>::diagH_subspace_init(phm_in, + wfcatom.c, + wfcatom.nr, + wfcatom.nc, + wvf, + etatom.data()); + return; + } + else + { + ModuleBase::WARNING_QUIT("wavefunc", "Hamiltonian does not exist!"); + } + } + + assert(nbands <= wfcatom.nr); + for (int ib = 0; ib < nbands; ib++) + { + for (int ig = 0; ig < nbasis; ig++) + { + wvf(ib, ig) = wfcatom(ib, ig); + } + } } } @@ -503,7 +493,6 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_CPU* ctx, diago_PAO_in_pw_k2(ik, wvf, wfc_basis, p_wf, phm_in); } - template <> void diago_PAO_in_pw_k2(const base_device::DEVICE_CPU* ctx, const int& ik, @@ -515,7 +504,6 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_CPU* ctx, diago_PAO_in_pw_k2(ik, wvf, wfc_basis, p_wf, phm_in); } - #if ((defined __CUDA) || (defined __ROCM)) template <> void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx, @@ -532,67 +520,60 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx, const int current_nbasis = wfc_basis->npwk[ik]; int starting_nw = nbands; - bool result = false; ModuleBase::ComplexMatrix wfcatom(nbands, nbasis); if (p_wf->init_wfc == "file") { std::stringstream filename; - filename << GlobalV::global_readin_dir << "WAVEFUNC" << ik + 1 << ".dat"; - result = ModuleIO::read_wfc_pw(filename.str(), wfc_basis, ik, p_wf->nkstot, wfcatom); - if (!result) - { - p_wf->init_wfc = "atomic+random"; - } + int ik_tot = K_Vectors::get_ik_global(ik, p_wf->nkstot); + filename << GlobalV::global_readin_dir << "WAVEFUNC" << ik_tot + 1 << ".dat"; + ModuleIO::read_wfc_pw(filename.str(), wfc_basis, ik, p_wf->nkstot, wfcatom); } - if (!result) - { - starting_nw = p_wf->get_starting_nw(); - if (starting_nw == 0) - return; - assert(starting_nw > 0); - wfcatom.create(starting_nw, nbasis); // added by zhengdy-soc - if (PARAM.inp.test_wf) - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "starting_nw", starting_nw); - - if (p_wf->init_wfc.substr(0, 6) == "atomic") - { - p_wf->atomic_wfc(ik, - current_nbasis, - GlobalC::ucell.lmax_ppwf, - wfc_basis, - wfcatom, - GlobalC::ppcell.tab_at, - GlobalV::NQX, - GlobalV::DQ); - if (p_wf->init_wfc == "atomic+random" - && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16 - { - p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis); - } + starting_nw = p_wf->get_starting_nw(); + if (starting_nw == 0) + return; + assert(starting_nw > 0); + wfcatom.create(starting_nw, nbasis); // added by zhengdy-soc + if (PARAM.inp.test_wf) + ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "starting_nw", starting_nw); - //==================================================== - // If not enough atomic wfc are available, complete - // with random wfcs - //==================================================== - p_wf->random(wfcatom.c, GlobalC::ucell.natomwfc, nbands, ik, wfc_basis); - } - else if (p_wf->init_wfc == "random") + if (p_wf->init_wfc.substr(0, 6) == "atomic") + { + p_wf->atomic_wfc(ik, + current_nbasis, + GlobalC::ucell.lmax_ppwf, + wfc_basis, + wfcatom, + GlobalC::ppcell.tab_at, + GlobalV::NQX, + GlobalV::DQ); + if (p_wf->init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16 { - p_wf->random(wfcatom.c, 0, nbands, ik, wfc_basis); + p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis); } + + //==================================================== + // If not enough atomic wfc are available, complete + // with random wfcs + //==================================================== + p_wf->random(wfcatom.c, GlobalC::ucell.natomwfc, nbands, ik, wfc_basis); + } + else if (p_wf->init_wfc == "random") + { + p_wf->random(wfcatom.c, 0, nbands, ik, wfc_basis); } - std::complex *c_wfcatom = nullptr; - if (GlobalV::KS_SOLVER!="bpcg") { + std::complex* c_wfcatom = nullptr; + if (GlobalV::KS_SOLVER != "bpcg") + { // store wfcatom on the GPU resmem_cd_op()(gpu_ctx, c_wfcatom, wfcatom.nr * wfcatom.nc); castmem_z2c_h2d_op()(gpu_ctx, cpu_ctx, c_wfcatom, wfcatom.c, wfcatom.nr * wfcatom.nc); } - if(GlobalV::KS_SOLVER=="cg") //xiaohui add 2013-09-02 + if (GlobalV::KS_SOLVER == "cg") // xiaohui add 2013-09-02 { // (7) Diago with cg method. - if(phm_in!= nullptr) + if (phm_in != nullptr) { std::vector etatom(starting_nw, 0.0); hsolver::DiagoIterAssist, base_device::DEVICE_GPU>::diagH_subspace_init(phm_in, @@ -604,8 +585,8 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx, } else { - //this diagonalization method is obsoleted now - //GlobalC::hm.diagH_subspace(ik ,starting_nw, nbands, wfcatom, wfcatom, etatom.data()); + // this diagonalization method is obsoleted now + // GlobalC::hm.diagH_subspace(ik ,starting_nw, nbands, wfcatom, wfcatom, etatom.data()); } } else if (GlobalV::KS_SOLVER == "dav" || GlobalV::KS_SOLVER == "dav_subspace") @@ -619,10 +600,12 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx, &wvf(0, 0), nbasis); } - else if (GlobalV::KS_SOLVER=="bpcg") { - castmem_z2c_h2d_op()(gpu_ctx, cpu_ctx, &wvf(0,0), wfcatom.c, wfcatom.nr * wfcatom.nc); + else if (GlobalV::KS_SOLVER == "bpcg") + { + castmem_z2c_h2d_op()(gpu_ctx, cpu_ctx, &wvf(0, 0), wfcatom.c, wfcatom.nr * wfcatom.nc); } - if (GlobalV::KS_SOLVER!="bpcg") { + if (GlobalV::KS_SOLVER != "bpcg") + { delmem_cd_op()(gpu_ctx, c_wfcatom); } } @@ -641,68 +624,61 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx, const int current_nbasis = wfc_basis->npwk[ik]; int starting_nw = nbands; - bool result = false; ModuleBase::ComplexMatrix wfcatom(nbands, nbasis); if (p_wf->init_wfc == "file") { std::stringstream filename; - filename << GlobalV::global_readin_dir << "WAVEFUNC" << ik + 1 << ".dat"; - result = ModuleIO::read_wfc_pw(filename.str(), wfc_basis, ik, p_wf->nkstot, wfcatom); - if (!result) - { - p_wf->init_wfc = "atomic+random"; - } + int ik_tot = K_Vectors::get_ik_global(ik, p_wf->nkstot); + filename << GlobalV::global_readin_dir << "WAVEFUNC" << ik_tot + 1 << ".dat"; + ModuleIO::read_wfc_pw(filename.str(), wfc_basis, ik, p_wf->nkstot, wfcatom); } - if (!result) + starting_nw = p_wf->get_starting_nw(); + if (starting_nw == 0) + return; + assert(starting_nw > 0); + wfcatom.create(starting_nw, nbasis); // added by zhengdy-soc + if (PARAM.inp.test_wf) + ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "starting_nw", starting_nw); + if (p_wf->init_wfc.substr(0, 6) == "atomic") { - starting_nw = p_wf->get_starting_nw(); - if (starting_nw == 0) - return; - assert(starting_nw > 0); - wfcatom.create(starting_nw, nbasis); // added by zhengdy-soc - if (PARAM.inp.test_wf) - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "starting_nw", starting_nw); - if (p_wf->init_wfc.substr(0, 6) == "atomic") - { - p_wf->atomic_wfc(ik, - current_nbasis, - GlobalC::ucell.lmax_ppwf, - wfc_basis, - wfcatom, - GlobalC::ppcell.tab_at, - GlobalV::NQX, - GlobalV::DQ); - if (p_wf->init_wfc == "atomic+random" - && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16 - { - p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis); - } - - //==================================================== - // If not enough atomic wfc are available, complete - // with random wfcs - //==================================================== - p_wf->random(wfcatom.c, GlobalC::ucell.natomwfc, nbands, ik, wfc_basis); - } - else if (p_wf->init_wfc == "random") + p_wf->atomic_wfc(ik, + current_nbasis, + GlobalC::ucell.lmax_ppwf, + wfc_basis, + wfcatom, + GlobalC::ppcell.tab_at, + GlobalV::NQX, + GlobalV::DQ); + if (p_wf->init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16 { - p_wf->random(wfcatom.c, 0, nbands, ik, wfc_basis); + p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis); } + + //==================================================== + // If not enough atomic wfc are available, complete + // with random wfcs + //==================================================== + p_wf->random(wfcatom.c, GlobalC::ucell.natomwfc, nbands, ik, wfc_basis); + } + else if (p_wf->init_wfc == "random") + { + p_wf->random(wfcatom.c, 0, nbands, ik, wfc_basis); } - std::complex *z_wfcatom = nullptr; - if (GlobalV::KS_SOLVER!="bpcg") { + std::complex* z_wfcatom = nullptr; + if (GlobalV::KS_SOLVER != "bpcg") + { // store wfcatom on the GPU resmem_zd_op()(gpu_ctx, z_wfcatom, wfcatom.nr * wfcatom.nc); syncmem_z2z_h2d_op()(gpu_ctx, cpu_ctx, z_wfcatom, wfcatom.c, wfcatom.nr * wfcatom.nc); } - if(GlobalV::KS_SOLVER=="cg") //xiaohui add 2013-09-02 - { - // (7) Diago with cg method. - if(phm_in!= nullptr) - { - std::vector etatom(starting_nw, 0.0); + if (GlobalV::KS_SOLVER == "cg") // xiaohui add 2013-09-02 + { + // (7) Diago with cg method. + if (phm_in != nullptr) + { + std::vector etatom(starting_nw, 0.0); hsolver::DiagoIterAssist, base_device::DEVICE_GPU>::diagH_subspace_init(phm_in, z_wfcatom, wfcatom.nr, @@ -710,16 +686,16 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx, wvf, etatom.data()); } - else - { - //this diagonalization method is obsoleted now - //GlobalC::hm.diagH_subspace(ik ,starting_nw, nbands, wfcatom, wfcatom, etatom.data()); - } - } + else + { + // this diagonalization method is obsoleted now + // GlobalC::hm.diagH_subspace(ik ,starting_nw, nbands, wfcatom, wfcatom, etatom.data()); + } + } else if (GlobalV::KS_SOLVER == "dav" || GlobalV::KS_SOLVER == "dav_subspace") { - assert(nbands <= wfcatom.nr); - // replace by haozhihan 2022-11-23 + assert(nbands <= wfcatom.nr); + // replace by haozhihan 2022-11-23 hsolver::matrixSetToAnother, base_device::DEVICE_GPU>()(gpu_ctx, nbands, z_wfcatom, @@ -727,50 +703,53 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx, &wvf(0, 0), nbasis); } - else if(GlobalV::KS_SOLVER=="bpcg") { - syncmem_z2z_h2d_op()(gpu_ctx, cpu_ctx, &wvf(0,0), wfcatom.c, wfcatom.nr * wfcatom.nc); + else if (GlobalV::KS_SOLVER == "bpcg") + { + syncmem_z2z_h2d_op()(gpu_ctx, cpu_ctx, &wvf(0, 0), wfcatom.c, wfcatom.nr * wfcatom.nc); } - if (GlobalV::KS_SOLVER!="bpcg") { + if (GlobalV::KS_SOLVER != "bpcg") + { delmem_zd_op()(gpu_ctx, z_wfcatom); } } #endif -}//namespace hamilt - +} // namespace hamilt -//LiuXh add a new function here, -//20180515 +// LiuXh add a new function here, +// 20180515 void wavefunc::init_after_vc(const int nks) { - ModuleBase::TITLE("wavefunc","init"); - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"npwx",this->npwx); + ModuleBase::TITLE("wavefunc", "init"); + ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "npwx", this->npwx); assert(this->npwx > 0); assert(nks > 0); assert(GlobalV::NBANDS > 0); const int nks2 = nks; - const int nbasis = this->npwx * GlobalV::NPOL; + const int nbasis = this->npwx * GlobalV::NPOL; - if((PARAM.inp.basis_type=="lcao" || PARAM.inp.basis_type=="lcao_in_pw") || winput::out_spillage==2) - { - if(wanf2 != nullptr) {delete[] wanf2; -} - this->wanf2 = new ModuleBase::ComplexMatrix [nks2]; - for (int ik = 0; ik < nks2; ik++) - { - this->wanf2[ik].create(GlobalV::NLOCAL, nbasis); - } - } + if ((PARAM.inp.basis_type == "lcao" || PARAM.inp.basis_type == "lcao_in_pw") || winput::out_spillage == 2) + { + if (wanf2 != nullptr) + { + delete[] wanf2; + } + this->wanf2 = new ModuleBase::ComplexMatrix[nks2]; + for (int ik = 0; ik < nks2; ik++) + { + this->wanf2[ik].create(GlobalV::NLOCAL, nbasis); + } + } - if(PARAM.inp.test_wf) + if (PARAM.inp.test_wf) { - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"psi allocation","Done"); - if(PARAM.inp.basis_type=="lcao" || PARAM.inp.basis_type=="lcao_in_pw") + ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "psi allocation", "Done"); + if (PARAM.inp.basis_type == "lcao" || PARAM.inp.basis_type == "lcao_in_pw") { - ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running,"wanf2 allocation","Done"); + ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "wanf2 allocation", "Done"); } } diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp index b1e7b4fbaf..39f4697b56 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_wf.cpp @@ -4,7 +4,7 @@ #include #include "module_base/memory.h" -#include "time.h" +#include //---------Temporary------------------------------------ #include "module_base/complexmatrix.h" @@ -277,7 +277,7 @@ void Init_Sto_Orbitals_Ecut(Stochastic_WF& stowf, for (int ik = 0; ik < nks; ++ik) { - const int iktot = kv.getik_global(ik); + const int iktot = K_Vectors::get_ik_global(ik, nkstot); const int npw = wfcpw.npwk[ik]; int* ig2ixyz = new int[npw]; diff --git a/source/module_hsolver/hsolver.cpp b/source/module_hsolver/hsolver.cpp index d5d1289836..a0b54ed9d9 100644 --- a/source/module_hsolver/hsolver.cpp +++ b/source/module_hsolver/hsolver.cpp @@ -24,7 +24,7 @@ double set_diagethr_ks(const std::string basis_type, { if (std::abs(res_diag_ethr - 1.0e-2) < 1.0e-6) { - if (init_chg_in == "file") + if (init_chg_in == "file" || init_chg_in == "wfc") { //====================================================== // if you think that the starting potential is good diff --git a/source/module_hsolver/test/hsolver_supplementary_mock.h b/source/module_hsolver/test/hsolver_supplementary_mock.h index 8b5b1156dc..8cee0698af 100644 --- a/source/module_hsolver/test/hsolver_supplementary_mock.h +++ b/source/module_hsolver/test/hsolver_supplementary_mock.h @@ -43,7 +43,7 @@ void ElecState::print_eigenvalue(std::ofstream& ofs) return; } -void ElecState::init_scf(const int istep, const ModuleBase::ComplexMatrix& strucfac) +void ElecState::init_scf(const int istep, const ModuleBase::ComplexMatrix& strucfac, ModuleSymmetry::Symmetry&, const void*) { return; } diff --git a/source/module_io/CMakeLists.txt b/source/module_io/CMakeLists.txt index 42f1cefde6..bb91e72db0 100644 --- a/source/module_io/CMakeLists.txt +++ b/source/module_io/CMakeLists.txt @@ -17,6 +17,7 @@ list(APPEND objects rhog_io.cpp read_exit_file.cpp read_wfc_pw.cpp + read_wfc_to_rho.cpp restart.cpp binstream.cpp write_wfc_pw.cpp diff --git a/source/module_io/get_pchg_lcao.cpp b/source/module_io/get_pchg_lcao.cpp index 1147d4b0dd..3e8de3d5cf 100644 --- a/source/module_io/get_pchg_lcao.cpp +++ b/source/module_io/get_pchg_lcao.cpp @@ -328,7 +328,6 @@ void IState_Charge::begin(Gint_k& gk, ngmc, nullptr, rho_pw, - *Pgrid, ucell_in->symm); } diff --git a/source/module_io/get_pchg_pw.h b/source/module_io/get_pchg_pw.h index 8dad37a95e..1dfc797b32 100644 --- a/source/module_io/get_pchg_pw.h +++ b/source/module_io/get_pchg_pw.h @@ -185,7 +185,6 @@ void get_pchg_pw(const std::vector& bands_to_print, ngmc, nullptr, pw_rhod, - Pgrid, ucell->symm); } diff --git a/source/module_io/read_input.cpp b/source/module_io/read_input.cpp index 6ffd98648a..c73aef3a8b 100644 --- a/source/module_io/read_input.cpp +++ b/source/module_io/read_input.cpp @@ -481,9 +481,35 @@ void ReadInput::add_item(const Input_Item& item) this->input_lists.push_back(make_pair(item.label, item)); } } + bool find_str(const std::vector& strings, const std::string& strToFind) { auto it = std::find(strings.begin(), strings.end(), strToFind); return it != strings.end(); -}; +} + +std::string nofound_str(std::vector init_chgs, const std::string& str) +{ + std::string warningstr = "The parameter "; + warningstr.append(str); + warningstr.append(" must be "); + for(int i = 0; i < init_chgs.size(); i++) + { + warningstr.append("'"); + warningstr.append(init_chgs[i]); + warningstr.append("'"); + if(i < init_chgs.size() - 2) + { + warningstr.append(", "); + } + else if(i == init_chgs.size() - 2) + { + warningstr.append(" or "); + } + } + warningstr.append("!"); + + return warningstr; +} + } // namespace ModuleIO \ No newline at end of file diff --git a/source/module_io/read_input.h b/source/module_io/read_input.h index 71e8a553ce..b42efb2a95 100644 --- a/source/module_io/read_input.h +++ b/source/module_io/read_input.h @@ -140,6 +140,8 @@ bool convert_bool(std::string str); bool find_str(const std::vector& strings, const std::string& strToFind); // convert to directory format std::string to_dir(const std::string& str); +// return a warning string if the string is not found in the vector +std::string nofound_str(std::vector init_chgs, const std::string& str); } // namespace ModuleIO diff --git a/source/module_io/read_input_item_elec_stru.cpp b/source/module_io/read_input_item_elec_stru.cpp index cf98201b86..b7f48bc735 100644 --- a/source/module_io/read_input_item_elec_stru.cpp +++ b/source/module_io/read_input_item_elec_stru.cpp @@ -77,19 +77,16 @@ void ReadInput::item_elec_stru() { if (!find_str(pw_solvers, ks_solver)) { - ModuleBase::WARNING_QUIT("ReadInput", - "ks_solver must be cg, dav, bpcg " - "or dav_subspace for pw basis."); + const std::string warningstr = "For PW basis: " + nofound_str(pw_solvers, "ks_solver"); + ModuleBase::WARNING_QUIT("ReadInput", warningstr); } } else if (para.input.basis_type == "lcao") { if (!find_str(lcao_solvers, ks_solver)) { - ModuleBase::WARNING_QUIT("ReadInput", - "ks_solver must be genelpa, lapack, scalapack_gvx, " - "cusolver, pexsi or " - "cg_in_lcao for lcao basis."); + const std::string warningstr = "For LCAO basis: " + nofound_str(lcao_solvers, "ks_solver"); + ModuleBase::WARNING_QUIT("ReadInput", warningstr); } if (ks_solver == "cg_in_lcao") { @@ -167,7 +164,8 @@ void ReadInput::item_elec_stru() const std::vector basis_types = {"pw", "lcao_in_pw", "lcao"}; if (!find_str(basis_types, para.input.basis_type)) { - ModuleBase::WARNING_QUIT("ReadInput", "basis_type should be pw, lcao_in_pw, or lcao"); + const std::string warningstr = nofound_str(basis_types, "basis_type"); + ModuleBase::WARNING_QUIT("ReadInput", warningstr); } }; this->add_item(item); diff --git a/source/module_io/read_input_item_relax.cpp b/source/module_io/read_input_item_relax.cpp index 1dc98c270c..c9545576c2 100644 --- a/source/module_io/read_input_item_relax.cpp +++ b/source/module_io/read_input_item_relax.cpp @@ -15,7 +15,8 @@ void ReadInput::item_relax() const std::vector relax_methods = {"cg", "bfgs", "sd", "cg_bfgs"}; if (!find_str(relax_methods, para.input.relax_method)) { - ModuleBase::WARNING_QUIT("ReadInput", "relax_method must be cg, bfgs, sd or cg_bfgs."); + const std::string warningstr = nofound_str(relax_methods, "relax_method"); + ModuleBase::WARNING_QUIT("ReadInput", warningstr); } }; this->add_item(item); diff --git a/source/module_io/read_input_item_system.cpp b/source/module_io/read_input_item_system.cpp index 9fda7bc0da..055ff1f0b9 100644 --- a/source/module_io/read_input_item_system.cpp +++ b/source/module_io/read_input_item_system.cpp @@ -81,7 +81,8 @@ void ReadInput::item_system() "gen_bessel"}; if (!find_str(callist, calculation)) { - ModuleBase::WARNING_QUIT("ReadInput", "check 'calculation' !"); + const std::string warningstr = nofound_str(callist, "calculation"); + ModuleBase::WARNING_QUIT("ReadInput", warningstr); } if (calculation == "get_pchg" || calculation == "get_wf") { @@ -111,9 +112,8 @@ void ReadInput::item_system() const std::vector esolver_types = { "ksdft", "sdft", "ofdft", "tddft", "lj", "dp", "lr", "ks-lr" }; if (!find_str(esolver_types, para.input.esolver_type)) { - ModuleBase::WARNING_QUIT("ReadInput", - "esolver_type should be ksdft, sdft, " - "ofdft, tddft, lr, ks-lr, lj or dp."); + const std::string warningstr = nofound_str(esolver_types, "esolver_type"); + ModuleBase::WARNING_QUIT("ReadInput", warningstr); } if (para.input.esolver_type == "dp") { @@ -516,9 +516,11 @@ void ReadInput::item_system() } }; item.check_value = [](const Input_Item& item, const Parameter& para) { - if (para.input.init_chg != "atomic" && para.input.init_chg != "file" && para.input.init_chg != "auto") + const std::vector init_chgs = {"atomic", "file", "wfc", "auto"}; + if (!find_str(init_chgs, para.input.init_chg)) { - ModuleBase::WARNING_QUIT("ReadInput", "init_chg should be 'atomic', 'file' or 'auto'"); + const std::string warningstr = nofound_str(init_chgs, "init_chg"); + ModuleBase::WARNING_QUIT("ReadInput", warningstr); } }; this->add_item(item); diff --git a/source/module_io/read_wfc_pw.cpp b/source/module_io/read_wfc_pw.cpp index 43706a087c..020f00f11e 100644 --- a/source/module_io/read_wfc_pw.cpp +++ b/source/module_io/read_wfc_pw.cpp @@ -8,7 +8,7 @@ #include "module_base/timer.h" #include "module_base/vector3.h" -bool ModuleIO::read_wfc_pw(const std::string& filename, +void ModuleIO::read_wfc_pw(const std::string& filename, const ModulePW::PW_Basis_K* pw_wfc, const int& ik, const int& nkstot, @@ -17,6 +17,43 @@ bool ModuleIO::read_wfc_pw(const std::string& filename, ModuleBase::TITLE("ModuleIO", "read_wfc_pw"); ModuleBase::timer::tick("ModuleIO", "read_wfc_pw"); + Binstream rfs; + std::ifstream ifs; + bool error = false; + int size = 0; + std::string msg = ""; + std::string filetype = filename.substr(filename.length() - 3, 3); + + // whether can open the file + if (filetype == "txt") + { + ifs.open(filename); + if (!ifs) + { + error = true; + msg = "Can't open file " + filename; + } + } + else if (filetype == "dat") + { + rfs.open(filename, "r"); + if (!rfs) + { + error = true; + msg = "Can't open file " + filename; + } + } + else + { + error = true; + msg = "Unknown file type " + filetype; + } + + if (error) + { + ModuleBase::WARNING_QUIT("ModuleIO::read_wfc_pw", msg); + } + const int nx = pw_wfc->nx; const int ny = pw_wfc->ny; const int nz = pw_wfc->nz; @@ -46,68 +83,7 @@ bool ModuleIO::read_wfc_pw(const std::string& filename, npwtot *= GlobalV::NPOL; - Binstream rfs; - std::ifstream ifs; - bool error = false; - int size = 0; - std::string msg = ""; - std::string filetype = filename.substr(filename.length() - 3, 3); - - // whether can open the file - if (GlobalV::RANK_IN_POOL == 0) - { - if (filetype == "txt") - { - ifs.open(filename); - if (!ifs) - { - error = true; - msg = "Can't open file " + filename; - } - } - else if (filetype == "dat") - { - rfs.open(filename, "r"); - if (!rfs) - { - error = true; - msg = "Can't open file " + filename; - } - } - else - { - error = true; - msg = "Unknown file type " + filetype; - } - } - -#ifdef __MPI - // bcast msg to root process - int ip = 0; - if (error) - { - ip = GlobalV::MY_RANK; - size = msg.size(); - } - MPI_Allreduce(MPI_IN_PLACE, &ip, 1, MPI_INT, MPI_MAX, MPI_COMM_WORLD); - MPI_Bcast(&size, 1, MPI_INT, ip, MPI_COMM_WORLD); - std::vector swap(size + 1); - if (error) - { - strcpy(swap.data(), msg.c_str()); - } - MPI_Bcast(swap.data(), size + 1, MPI_CHAR, ip, MPI_COMM_WORLD); - msg = static_cast(swap.data()); - - - MPI_Bcast(&error, 1, MPI_C_BOOL, ip, MPI_COMM_WORLD); -#endif - - if (error) - { - ModuleBase::WARNING("ModuleIO::read_wfc_pw", msg); - return false; - } + // read in some information int ikstot_in, nkstot_in, npwtot_in, nbands_in; @@ -361,5 +337,5 @@ bool ModuleIO::read_wfc_pw(const std::string& filename, } ModuleBase::timer::tick("ModuleIO", "read_wfc_pw"); - return true; -} + return; +} \ No newline at end of file diff --git a/source/module_io/read_wfc_pw.h b/source/module_io/read_wfc_pw.h index 0835429d89..e1630277bf 100644 --- a/source/module_io/read_wfc_pw.h +++ b/source/module_io/read_wfc_pw.h @@ -1,9 +1,10 @@ #ifndef READ_WFC_PW_H #define READ_WFC_PW_H -#include - #include "module_basis/module_pw/pw_basis_k.h" +#include "module_elecstate/module_charge/charge.h" + +#include namespace ModuleIO { @@ -16,15 +17,12 @@ namespace ModuleIO * @param ik k index * @param nkstot total number of k points * @param wfc wave functions - * @return true if read successfully - * @return false if read failed */ -bool read_wfc_pw(const std::string& filedir, +void read_wfc_pw(const std::string& filedir, const ModulePW::PW_Basis_K* pw_wfc, const int& ik, const int& nkstot, ModuleBase::ComplexMatrix& wfc); - } // namespace ModuleIO #endif diff --git a/source/module_io/read_wfc_to_rho.cpp b/source/module_io/read_wfc_to_rho.cpp new file mode 100644 index 0000000000..f3fa367f78 --- /dev/null +++ b/source/module_io/read_wfc_to_rho.cpp @@ -0,0 +1,116 @@ +#include "read_wfc_to_rho.h" + +#include "read_wfc_pw.h" +#include "module_base/timer.h" +#include "module_hamilt_pw/hamilt_pwdft/global.h" +#include "module_elecstate/module_charge/symmetry_rho.h" + +void ModuleIO::read_wfc_to_rho(const ModulePW::PW_Basis_K* pw_wfc, + ModuleSymmetry::Symmetry& symm, + const int nkstot, + const std::vector& isk, + Charge& chg) +{ + ModuleBase::TITLE("ModuleIO", "read_wfc_pw_to_rho"); + ModuleBase::timer::tick("ModuleIO", "read_wfc_pw_to_rho"); + + const int kpar = GlobalV::KPAR; + const int my_pool = GlobalV::MY_POOL; + const int my_rank = GlobalV::MY_RANK; + const int nbands = GlobalV::NBANDS; + const int nspin = GlobalV::NSPIN; + + const int npwk_max = pw_wfc->npwk_max; + const int nrxx = pw_wfc->nrxx; + for (int is = 0; is < nspin; ++is) + { + ModuleBase::GlobalFunc::ZEROS(chg.rho[is], nrxx); + } + + ModuleBase::ComplexMatrix wfc_tmp(nbands, npwk_max); + std::vector> rho_tmp(nrxx); + + // read occupation numbers + ModuleBase::matrix wg_tmp(nkstot, nbands); + if (my_rank == 0) + { + std::string filename = GlobalV::global_readin_dir + "istate.info"; + std::ifstream ifs(filename); + std::string useless; + for (int ik_tot = 0; ik_tot < nkstot; ++ik_tot) + { + ifs >> useless; + getline(ifs, useless); + for(int ib = 0; ib < nbands; ++ib) + { + ifs >> useless >> useless >> wg_tmp(ik_tot, ib); + } + } + } + +#ifdef __MPI + MPI_Bcast(wg_tmp.c, nkstot * nbands, MPI_DOUBLE, 0, MPI_COMM_WORLD); +#endif + + auto get_ikstot = [&](int ik) { + int nkp = nkstot / kpar; + int rem = nkstot % kpar; + int ikstot; + if (my_pool < rem) + { + ikstot = my_pool * nkp + my_pool + ik; + } + else + { + ikstot = my_pool * nkp + rem + ik; + } + return ikstot; + }; + for (int ik = 0; ik < pw_wfc->nks; ++ik) + { + int is = 0; + if (nspin == 2) + { + is = isk[ik]; + } + const int ikstot = get_ikstot(ik); + std::stringstream filename; + filename << GlobalV::global_readin_dir << "WAVEFUNC" << ikstot + 1 << ".dat"; + ModuleIO::read_wfc_pw(filename.str(), pw_wfc, ik, nkstot, wfc_tmp); + for (int ib = 0; ib < nbands; ++ib) + { + const std::complex* wfc_ib = wfc_tmp.c + ib * npwk_max; + pw_wfc->recip2real(wfc_ib, rho_tmp.data(), ik); + + const double w1 = wg_tmp(ikstot, ib) / pw_wfc->omega; + + if (w1 != 0.0) + { +#ifdef _OPENMP +#pragma omp parallel for +#endif + for (int ir = 0; ir < nrxx; ir++) + { + chg.rho[is][ir] += w1 * std::norm(rho_tmp[ir]); + } + } + } + } + +#ifdef __MPI + chg.init_chgmpi(); + for (int is = 0; is < nspin; ++is) + { + chg.reduce_diff_pools(chg.rho[is]); + } +#endif + + // Since rho is calculated by psi^2, it is not symmetric. We need to rearrange it. + Symmetry_rho srho; + for (int is = 0; is < nspin; is++) + { + srho.begin(is, chg, chg.rhopw, GlobalC::ucell.symm); + } + + ModuleBase::timer::tick("ModuleIO", "read_wfc_pw_to_rho"); +} diff --git a/source/module_io/read_wfc_to_rho.h b/source/module_io/read_wfc_to_rho.h new file mode 100644 index 0000000000..3e8492d679 --- /dev/null +++ b/source/module_io/read_wfc_to_rho.h @@ -0,0 +1,28 @@ +#ifndef READ_WFC_TO_RHO_H +#define READ_WFC_TO_RHO_H + +#include "module_basis/module_pw/pw_basis_k.h" +#include "module_elecstate/module_charge/charge.h" + +#include + +namespace ModuleIO +{ +/** + * @brief read wave functions and occupation numbers to charge density + * + * @param pw_wfc pw basis for wave functions + * @param symm symmetry + * @param nkstot total number of k points + * @param isk k index to spin index + * @param chg charge density + */ +void read_wfc_to_rho(const ModulePW::PW_Basis_K* pw_wfc, + ModuleSymmetry::Symmetry& symm, + const int nkstot, + const std::vector& isk, + Charge& chg); + +} // namespace ModuleIO + +#endif diff --git a/source/module_io/test/CMakeLists.txt b/source/module_io/test/CMakeLists.txt index e687c99b38..155d7e96e4 100644 --- a/source/module_io/test/CMakeLists.txt +++ b/source/module_io/test/CMakeLists.txt @@ -169,6 +169,19 @@ add_test(NAME read_wfc_pw_test_parallel WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} ) +AddTest( + TARGET read_wfc_to_rho_test + LIBS parameter base ${math_libs} device planewave psi + SOURCES read_wfc_to_rho_test.cpp ../read_wfc_pw.cpp ../read_wfc_to_rho.cpp ../binstream.cpp ../../module_basis/module_pw/test/test_tool.cpp + ../../module_elecstate/module_charge/charge_mpi.cpp ../write_wfc_pw.cpp +) + +add_test(NAME read_wfc_to_rho_parallel + COMMAND mpirun -np 4 ./read_wfc_to_rho_test + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} +) + + AddTest( TARGET numerical_basis_test LIBS parameter base ${math_libs} device numerical_atomic_orbitals container orb diff --git a/source/module_io/test/read_wfc_pw_test.cpp b/source/module_io/test/read_wfc_pw_test.cpp index 833c3e89e7..78867db78b 100644 --- a/source/module_io/test/read_wfc_pw_test.cpp +++ b/source/module_io/test/read_wfc_pw_test.cpp @@ -27,10 +27,12 @@ class ReadWfcPwTest : public ::testing::Test } virtual void TearDown() { - if (wfcpw != nullptr) + if (wfcpw != nullptr) { delete wfcpw; - if (kvec_d != nullptr) +} + if (kvec_d != nullptr) { delete[] kvec_d; +} } }; @@ -50,9 +52,8 @@ TEST_F(ReadWfcPwTest, ReadWfcPw) GlobalV::NBANDS = 8; const int nbasis = wfcpw->npwk[0]; ModuleBase::ComplexMatrix wfcatom(GlobalV::NBANDS, nbasis); - bool result = ModuleIO::read_wfc_pw(filename, wfcpw, 0, nkstot, wfcatom); + ModuleIO::read_wfc_pw(filename, wfcpw, 0, nkstot, wfcatom); - EXPECT_TRUE(result); if (GlobalV::NPROC_IN_POOL == 1) { EXPECT_DOUBLE_EQ(wfcatom(0, 0).real(), -0.017953720885562179); @@ -117,71 +118,30 @@ TEST_F(ReadWfcPwTest, NotFoundFile) ModuleBase::ComplexMatrix wfcatom(GlobalV::NBANDS, wfcpw->npwk[0]); + if(GlobalV::RANK_IN_POOL == 0) + { // dat file std::string filename = "notfound.dat"; - GlobalV::ofs_warning.open("test_read_wfc_pw.txt"); - bool result = ModuleIO::read_wfc_pw(filename, wfcpw, 0, nkstot, wfcatom); - GlobalV::ofs_warning.close(); + testing::internal::CaptureStdout(); + EXPECT_EXIT(ModuleIO::read_wfc_pw(filename, wfcpw, 0, nkstot, wfcatom), ::testing::ExitedWithCode(0), ""); + std::string output = testing::internal::GetCapturedStdout(); + EXPECT_THAT(output,testing::HasSubstr("Can't open file notfound.dat")); - std::ifstream ifs_running("test_read_wfc_pw.txt"); - std::stringstream ss; - ss << ifs_running.rdbuf(); - std::string file_content = ss.str(); - ifs_running.close(); - - std::string expected_content = " ModuleIO::read_wfc_pw warning : Can't open file notfound.dat\n"; - - EXPECT_FALSE(result); - if (GlobalV::RANK_IN_POOL == 0) - EXPECT_EQ(file_content, expected_content); -#ifdef __MPI - MPI_Barrier(MPI_COMM_WORLD); -#endif - std::remove("test_read_wfc_pw.txt"); // txt file filename = "notfound.txt"; - GlobalV::ofs_warning.open("test_read_wfc_pw.txt"); - result = ModuleIO::read_wfc_pw(filename, wfcpw, 0, nkstot, wfcatom); - GlobalV::ofs_warning.close(); - - ifs_running.open("test_read_wfc_pw.txt"); - std::stringstream ss2; - ss2 << ifs_running.rdbuf(); - file_content = ss2.str(); - ifs_running.close(); - - expected_content = " ModuleIO::read_wfc_pw warning : Can't open file notfound.txt\n"; - - EXPECT_FALSE(result); - if (GlobalV::RANK_IN_POOL == 0) - EXPECT_EQ(file_content, expected_content); -#ifdef __MPI - MPI_Barrier(MPI_COMM_WORLD); -#endif - std::remove("test_read_wfc_pw.txt"); + testing::internal::CaptureStdout(); + EXPECT_EXIT(ModuleIO::read_wfc_pw(filename, wfcpw, 0, nkstot, wfcatom), ::testing::ExitedWithCode(0), ""); + output = testing::internal::GetCapturedStdout(); + EXPECT_THAT(output,testing::HasSubstr("Can't open file notfound.txt")); // other file filename = "notfound"; - GlobalV::ofs_warning.open("test_read_wfc_pw.txt"); - result = ModuleIO::read_wfc_pw(filename, wfcpw, 0, nkstot, wfcatom); - GlobalV::ofs_warning.close(); - - ifs_running.open("test_read_wfc_pw.txt"); - std::stringstream ss3; - ss3 << ifs_running.rdbuf(); - file_content = ss3.str(); - ifs_running.close(); - - expected_content = " ModuleIO::read_wfc_pw warning : Unknown file type und\n"; - - EXPECT_FALSE(result); - if (GlobalV::RANK_IN_POOL == 0) - EXPECT_EQ(file_content, expected_content); -#ifdef __MPI - MPI_Barrier(MPI_COMM_WORLD); -#endif - std::remove("test_read_wfc_pw.txt"); + testing::internal::CaptureStdout(); + EXPECT_EXIT(ModuleIO::read_wfc_pw(filename, wfcpw, 0, nkstot, wfcatom), ::testing::ExitedWithCode(0), ""); + output = testing::internal::GetCapturedStdout(); + EXPECT_THAT(output,testing::HasSubstr("Unknown file type und")); + } } // Test the read_wfc_pw function when nbands is inconsistent diff --git a/source/module_io/test/read_wfc_to_rho_test.cpp b/source/module_io/test/read_wfc_to_rho_test.cpp new file mode 100644 index 0000000000..7433e11500 --- /dev/null +++ b/source/module_io/test/read_wfc_to_rho_test.cpp @@ -0,0 +1,304 @@ +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#undef __LCAO + +#define private public +#include "module_elecstate/module_charge/charge.h" +#include "module_parameter/parameter.h" +#undef private + +#ifdef __MPI +#include "module_base/parallel_global.h" +#include "module_basis/module_pw/test/test_tool.h" +#include "mpi.h" +#endif +#include "module_cell/klist.h" +#include "module_cell/unitcell.h" +#include "module_elecstate/elecstate_getters.h" +#include "module_elecstate/module_charge/symmetry_rho.h" +#include "module_hamilt_general/module_xc/xc_functional.h" +#include "module_hamilt_pw/hamilt_pwdft/parallel_grid.h" +#include "module_io/read_wfc_to_rho.h" +#include "module_io/write_wfc_pw.h" +#include "module_psi/psi.h" + +Parallel_Grid::Parallel_Grid() +{ +} +Parallel_Grid::~Parallel_Grid() +{ +} +Charge::Charge() +{ +} +Charge::~Charge() +{ +} +UnitCell::UnitCell() +{ +} +UnitCell::~UnitCell() +{ +} +Magnetism::Magnetism() +{ +} +Magnetism::~Magnetism() +{ +} +int elecstate::get_xc_func_type() +{ + return 0; +} +int XC_Functional::get_func_type() +{ + return 0; +} +K_Vectors::K_Vectors() +{ +} +K_Vectors::~K_Vectors() +{ +} +Symmetry_rho::Symmetry_rho() +{ +} +Symmetry_rho::~Symmetry_rho() +{ +} +void Symmetry_rho::begin(const int& spin_now, + const Charge& CHR, + const ModulePW::PW_Basis* rho_basis, + ModuleSymmetry::Symmetry& symm) const +{ + return; +} + +int K_Vectors::get_ik_global(const int& ik, const int& nkstot) +{ + int nkp = nkstot / PARAM.inp.kpar; + int rem = nkstot % PARAM.inp.kpar; + if (GlobalV::MY_POOL < rem) + { + return GlobalV::MY_POOL * nkp + GlobalV::MY_POOL + ik; + } + else + { + return GlobalV::MY_POOL * nkp + rem + ik; + } +} + +namespace GlobalC +{ +UnitCell ucell; +Parallel_Grid Pgrid; +} // namespace GlobalC + +/** + * - Tested Functions: + * - write_wfc_pw() + * - read_wfc_to_rho() + */ + +class ReadWfcRhoTest : public ::testing::Test +{ + protected: + ModulePW::PW_Basis_K* wfcpw = nullptr; + ModulePW::PW_Basis* rhopw = nullptr; + K_Vectors* kv = nullptr; + psi::Psi>* psi = nullptr; + Charge chg; + + virtual void SetUp() + { + wfcpw = new ModulePW::PW_Basis_K; + rhopw = new ModulePW::PW_Basis; + kv = new K_Vectors; + GlobalV::NBANDS = 4; + GlobalV::NSPIN = 1; + PARAM.input.out_wfc_pw = 2; + } + virtual void TearDown() + { + delete wfcpw; + delete rhopw; + delete kv; + } +}; + +// Test the read_wfc_pw function +TEST_F(ReadWfcRhoTest, ReadWfcRho) +{ + // Init K_Vectors + const int my_pool = GlobalV::MY_POOL; + const int nbands = GlobalV::NBANDS; + const int nks = 2; + const int nkstot = GlobalV::KPAR * nks; + kv->set_nkstot(nkstot); + kv->set_nks(nks); + kv->isk = {0, 0}; + const double shift = my_pool * 0.1; + kv->kvec_d = {ModuleBase::Vector3(shift, shift, shift), + ModuleBase::Vector3(0.5 + shift, 0.5 + shift, 0.5 + shift)}; + + // Init the pw basis +#ifdef __MPI + wfcpw->initmpi(GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, POOL_WORLD); + rhopw->initmpi(GlobalV::NPROC_IN_POOL, GlobalV::RANK_IN_POOL, POOL_WORLD); +#endif + rhopw->initgrids(5.3233, ModuleBase::Matrix3(-0.5, 0.0, 0.5, 0.0, 0.5, 0.5, -0.5, 0.5, 0.0), 80); + rhopw->initparameters(false, 80); + rhopw->setuptransform(); + rhopw->collect_local_pw(); + + wfcpw->initgrids(5.3233, ModuleBase::Matrix3(-0.5, 0.0, 0.5, 0.0, 0.5, 0.5, -0.5, 0.5, 0.0), 80); + wfcpw->initparameters(false, 20, nks, kv->kvec_d.data()); + wfcpw->setuptransform(); + wfcpw->collect_local_pw(); + kv->kvec_c.clear(); + for (int ik = 0; ik < nks; ++ik) + { + kv->kvec_c.push_back(wfcpw->kvec_c[ik]); + } + kv->ngk = {wfcpw->npwk[0], wfcpw->npwk[1]}; + kv->wk = {1.0, 1.0}; + + // Init wg + ModuleBase::matrix wg(nkstot, nbands); + wg.fill_out(1.0); + if (GlobalV::MY_RANK == 0) + { + std::ofstream ofs("istate.info"); + for (int ik = 0; ik < nkstot; ++ik) + { + ofs << "BAND Energy(ev) Occupation Kpoint" << std::endl; + for (int ib = 0; ib < nbands; ++ib) + { + ofs << " " << ib + 1 << " 0.0000000 " << 1.0 << std::endl; + } + ofs << std::endl; + } + ofs.close(); + } + + // Init Psi + psi = new psi::Psi>(nks, nbands, wfcpw->npwk_max, wfcpw->npwk); + std::complex* ptr = psi->get_pointer(); + for (int i = 0; i < nks * nbands * wfcpw->npwk_max; i++) + { + ptr[i] = std::complex((i + my_pool * 100) / 100.0, (i + my_pool) / 100.0); + } + + // Init charge + chg.rho = new double*[1]; + chg._space_rho = new double[rhopw->nrxx]; + chg.rho[0] = chg._space_rho; + ModuleBase::GlobalFunc::ZEROS(chg.rho[0], rhopw->nrxx); + chg.rhopw = rhopw; + chg.nrxx = rhopw->nrxx; + // set charge_ref + Charge chg_ref; + chg_ref.rho = new double*[1]; + chg_ref._space_rho = new double[rhopw->nrxx]; + chg_ref.rho[0] = chg_ref._space_rho; + ModuleBase::GlobalFunc::ZEROS(chg_ref.rho[0], rhopw->nrxx); + std::vector> rho_tmp(rhopw->nrxx); + chg_ref.nrxx = rhopw->nrxx; + + for (int ik = 0; ik < nks; ++ik) + { + for (int ib = 0; ib < nbands; ++ib) + { + const std::complex* wfc_ib = ptr + ik * nbands * wfcpw->npwk_max + ib * wfcpw->npwk_max; + wfcpw->recip2real(wfc_ib, rho_tmp.data(), ik); + + const double w1 = wg(ik, ib) / wfcpw->omega; + + for (int ir = 0; ir < rhopw->nrxx; ir++) + { + chg_ref.rho[0][ir] += w1 * std::norm(rho_tmp[ir]); + } + } + } + +#ifdef __MPI + chg_ref.init_chgmpi(); + chg_ref.reduce_diff_pools(chg_ref.rho[0]); +#endif + + // Write the wave functions to file + ModuleIO::write_wfc_pw("WAVEFUNC", *psi, *kv, wfcpw); + + // Read the wave functions to charge density + ModuleIO::read_wfc_to_rho(wfcpw, GlobalC::ucell.symm, nkstot, kv->isk, chg); + + // compare the charge density + for (int ir = 0; ir < rhopw->nrxx; ++ir) + { + EXPECT_NEAR(chg.rho[0][ir], chg_ref.rho[0][ir], 1e-8); + } + // std::cout.precision(16); + // std::cout< ref = {8207.849135313403, 35.34776105132742, 8207.849135313403, 35.34776105132742}; + EXPECT_NEAR(chg.rho[0][0], ref[GlobalV::MY_RANK], 1e-8); + // for (int ip = 0; ip < GlobalV::NPROC; ++ip) + // { + // if (GlobalV::MY_RANK == ip) + // { + // std::cout.precision(16); + // std::cout << GlobalV::MY_RANK << " " << chg.rho[0][0] << std::endl; + // } + // MPI_Barrier(MPI_COMM_WORLD); + // } + } + + delete[] chg.rho; + delete[] chg._space_rho; + delete[] chg_ref.rho; + delete[] chg_ref._space_rho; + delete psi; + if (GlobalV::MY_RANK == 0) + { + remove("istate.info"); + remove("WAVEFUNC1.dat"); + remove("WAVEFUNC2.dat"); + if (GlobalV::KPAR > 1) + { + remove("WAVEFUNC3.dat"); + remove("WAVEFUNC4.dat"); + } + } +} + +int main(int argc, char** argv) +{ +#ifdef __MPI + setupmpi(argc, argv, GlobalV::NPROC, GlobalV::MY_RANK); + PARAM.input.kpar = (GlobalV::NPROC > 1) ? 2 : 1; + GlobalV::KPAR = PARAM.input.kpar; + PARAM.input.bndpar = 1; + Parallel_Global::divide_pools(GlobalV::NPROC, + GlobalV::MY_RANK, + PARAM.inp.bndpar, + GlobalV::KPAR, + GlobalV::NPROC_IN_STOGROUP, + GlobalV::RANK_IN_STOGROUP, + GlobalV::MY_STOGROUP, + GlobalV::NPROC_IN_POOL, + GlobalV::RANK_IN_POOL, + GlobalV::MY_POOL); +#endif + + testing::InitGoogleTest(&argc, argv); + int result = RUN_ALL_TESTS(); + +#ifdef __MPI + finishmpi(); +#endif + return result; +} \ No newline at end of file diff --git a/source/module_io/test/write_istate_info_test.cpp b/source/module_io/test/write_istate_info_test.cpp index c8aab2885c..f09c945053 100644 --- a/source/module_io/test/write_istate_info_test.cpp +++ b/source/module_io/test/write_istate_info_test.cpp @@ -90,7 +90,7 @@ TEST_F(IstateInfoTest, OutIstateInfoS1) str, testing::HasSubstr("BAND Energy(ev) Occupation Kpoint = 100")); EXPECT_THAT(str, testing::HasSubstr("(0.99 0.99 0.99)")); - EXPECT_THAT(str, testing::HasSubstr("4 2.04085 0")); + EXPECT_THAT(str, testing::HasSubstr("4 2.0408547 0")); ifs.close(); remove("istate.info"); } diff --git a/source/module_io/write_istate_info.cpp b/source/module_io/write_istate_info.cpp index 7df99f9d2b..93ac500a29 100644 --- a/source/module_io/write_istate_info.cpp +++ b/source/module_io/write_istate_info.cpp @@ -41,6 +41,7 @@ void ModuleIO::write_istate_info(const ModuleBase::matrix &ekb,const ModuleBase: << " " << kv.kvec_d[ik].z << ")" << std::endl; for (int ib = 0; ib < GlobalV::NBANDS; ib++) { + ofsi2.precision(16); ofsi2 << std::setw(6) << ib + 1 << std::setw(25) << ekb(ik, ib) * ModuleBase::Ry_to_eV << std::setw(25) << wg(ik, ib) << std::endl; diff --git a/source/module_io/write_wfc_pw.cpp b/source/module_io/write_wfc_pw.cpp index 59a3d94bad..b33559af05 100644 --- a/source/module_io/write_wfc_pw.cpp +++ b/source/module_io/write_wfc_pw.cpp @@ -24,9 +24,9 @@ void ModuleIO::write_wfc_pw(const std::string& fn, for(int ik = 0; ik < nkstot ; ++ik) { std::stringstream wfss; - if(PARAM.inp.out_wfc_pw==1) + if(PARAM.inp.out_wfc_pw==1) { wfss<" <" <" <" <