Skip to content

Commit

Permalink
Refactor: remove GlobalV in some parallel classes (#4445)
Browse files Browse the repository at this point in the history
* remove globalV in divide_pools

* remove globalv in split_grid(diag)_world

* refactor init_pool

* remove globalv in finalize_mpi

* remove globalv in read_mpi_parameters

* remove globalv in parallel_common

* fix parallel_common_test

* remove globalv in reduce_double_allpool

* remove globalv in parallel_reduce

* remove globalv in parallel_reduce_test

* refactor read_mpi_parameters

* remove globalv in parallel_global_test

* remove globalv in parallel_kpoints

* remove globalv in parallel_kpoints

* update parallel_kpoints_test.cpp

* [pre-commit.ci lite] apply automatic fixes

* fix err

* [pre-commit.ci lite] apply automatic fixes

* restore bcast_string and bcast_bool

* [pre-commit.ci lite] apply automatic fixes

* [pre-commit.ci lite] apply automatic fixes

---------

Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
  • Loading branch information
hongriTianqi and pre-commit-ci-lite[bot] authored Jun 22, 2024
1 parent e6a5c35 commit bdb2178
Show file tree
Hide file tree
Showing 48 changed files with 6,846 additions and 6,179 deletions.
24 changes: 21 additions & 3 deletions source/driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,17 @@ void Driver::reading()
Input_Conv::Convert();

// (3) define the 'DIAGONALIZATION' world in MPI
Parallel_Global::split_diag_world(GlobalV::DIAGO_PROC);
Parallel_Global::split_grid_world(GlobalV::DIAGO_PROC);
Parallel_Global::split_diag_world(GlobalV::DIAGO_PROC,
GlobalV::NPROC,
GlobalV::MY_RANK,
GlobalV::DRANK,
GlobalV::DSIZE,
GlobalV::DCOLOR);
Parallel_Global::split_grid_world(GlobalV::DIAGO_PROC,
GlobalV::NPROC,
GlobalV::MY_RANK,
GlobalV::GRANK,
GlobalV::GSIZE);
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "DRANK", GlobalV::DRANK + 1);
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "DSIZE", GlobalV::DSIZE);
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "DCOLOR", GlobalV::DCOLOR + 1);
Expand All @@ -69,7 +78,16 @@ void Driver::reading()

#ifdef __MPI
// (4) divide the GlobalV::NPROC processors into GlobalV::KPAR for k-points parallelization.
Parallel_Global::init_pools();
Parallel_Global::init_pools(GlobalV::NPROC,
GlobalV::MY_RANK,
GlobalV::NSTOGROUP,
GlobalV::KPAR,
GlobalV::NPROC_IN_STOGROUP,
GlobalV::RANK_IN_STOGROUP,
GlobalV::MY_STOGROUP,
GlobalV::NPROC_IN_POOL,
GlobalV::RANK_IN_POOL,
GlobalV::MY_POOL);
#endif

// (5) Read in parameters about wannier functions.
Expand Down
11 changes: 6 additions & 5 deletions source/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
//==========================================================

#include "driver.h"
#include "fftw3.h"
#include "module_base/global_variable.h"
#include "module_base/parallel_global.h"
#include "module_io/parse_args.h"
#include "fftw3.h"
#ifdef _OPENMP
#include <omp.h>
#endif
Expand All @@ -24,11 +25,11 @@ int main(int argc, char** argv)
read the mpi parameters in the command-line,
initialize the mpi environment.
*/
Parallel_Global::read_mpi_parameters(argc, argv);
Parallel_Global::read_mpi_parameters(argc, argv, GlobalV::NPROC, GlobalV::MY_RANK);
#ifdef _OPENMP
// ref: https://www.fftw.org/fftw3_doc/Usage-of-Multi_002dthreaded-FFTW.html
fftw_init_threads();
fftw_plan_with_nthreads(omp_get_max_threads());
fftw_init_threads();
fftw_plan_with_nthreads(omp_get_max_threads());
#endif

/*
Expand All @@ -44,7 +45,7 @@ int main(int argc, char** argv)
Parallel_Global::finalize_mpi();
#endif
#ifdef _OPENMP
fftw_cleanup_threads();
fftw_cleanup_threads();
#endif

return 0;
Expand Down
36 changes: 19 additions & 17 deletions source/module_base/parallel_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,71 +6,73 @@

#include <cstring>

#include "module_base/global_variable.h"

#ifdef __MPI
void Parallel_Common::bcast_string(std::string &object) // Peize Lin fix bug 2019-03-18
void Parallel_Common::bcast_string(std::string& object) // Peize Lin fix bug 2019-03-18
{
int size = object.size();
MPI_Bcast(&size, 1, MPI_INT, 0, MPI_COMM_WORLD);
char *swap = new char[size + 1];
if (0 == GlobalV::MY_RANK)
char* swap = new char[size + 1];
int my_rank;
MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);
if (0 == my_rank)
strcpy(swap, object.c_str());
MPI_Bcast(swap, size + 1, MPI_CHAR, 0, MPI_COMM_WORLD);
if (0 != GlobalV::MY_RANK)
if (0 != my_rank)
object = static_cast<std::string>(swap);
delete[] swap;
return;
}

void Parallel_Common::bcast_string(std::string *object, const int n) // Peize Lin fix bug 2019-03-18
void Parallel_Common::bcast_string(std::string* object, const int n) // Peize Lin fix bug 2019-03-18
{
for (int i = 0; i < n; i++)
bcast_string(object[i]);
return;
}

void Parallel_Common::bcast_complex_double(std::complex<double> &object)
void Parallel_Common::bcast_complex_double(std::complex<double>& object)
{
MPI_Bcast(&object, 1, MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD);
}

void Parallel_Common::bcast_complex_double(std::complex<double> *object, const int n)
void Parallel_Common::bcast_complex_double(std::complex<double>* object, const int n)
{
MPI_Bcast(object, n, MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD);
}

void Parallel_Common::bcast_double(double &object)
void Parallel_Common::bcast_double(double& object)
{
MPI_Bcast(&object, 1, MPI_DOUBLE, 0, MPI_COMM_WORLD);
}

void Parallel_Common::bcast_double(double *object, const int n)
void Parallel_Common::bcast_double(double* object, const int n)
{
MPI_Bcast(object, n, MPI_DOUBLE, 0, MPI_COMM_WORLD);
}

void Parallel_Common::bcast_int(int &object)
void Parallel_Common::bcast_int(int& object)
{
MPI_Bcast(&object, 1, MPI_INT, 0, MPI_COMM_WORLD);
}

void Parallel_Common::bcast_int(int *object, const int n)
void Parallel_Common::bcast_int(int* object, const int n)
{
MPI_Bcast(object, n, MPI_INT, 0, MPI_COMM_WORLD);
}

void Parallel_Common::bcast_bool(bool &object)
void Parallel_Common::bcast_bool(bool& object)
{
int swap = object;
if (GlobalV::MY_RANK == 0)
int my_rank;
MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);
if (my_rank == 0)
swap = object;
MPI_Bcast(&swap, 1, MPI_INT, 0, MPI_COMM_WORLD);
if (GlobalV::MY_RANK != 0)
if (my_rank != 0)
object = static_cast<bool>(swap);
}

void Parallel_Common::bcast_char(char *object, const int n)
void Parallel_Common::bcast_char(char* object, const int n)
{
MPI_Bcast(object, n, MPI_CHAR, 0, MPI_COMM_WORLD);
}
Expand Down
32 changes: 16 additions & 16 deletions source/module_base/parallel_common.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef PARALLEL_COMMON_H
#define PARALLEL_COMMON_H
#ifndef W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_MODULE_BASE_PARALLEL_COMMON_H
#define W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_MODULE_BASE_PARALLEL_COMMON_H

#ifdef __MPI
#include "mpi.h"
Expand All @@ -9,20 +9,20 @@

namespace Parallel_Common
{
//(1) bcast array
void bcast_complex_double( std::complex<double> *object, const int n);
void bcast_string(std::string *object,const int n);
void bcast_double(double *object,const int n);
void bcast_int(int *object,const int n);
void bcast_char(char *object,const int n);

//(2) bcast single
void bcast_complex_double( std::complex<double> &object);
void bcast_string(std::string &object);
void bcast_double(double &object);
void bcast_int(int &object);
void bcast_bool(bool &object);
//(1) bcast array
void bcast_complex_double(std::complex<double>* object, const int n);
void bcast_string(std::string* object, const int n);
void bcast_double(double* object, const int n);
void bcast_int(int* object, const int n);
void bcast_char(char* object, const int n);

}
//(2) bcast single
void bcast_complex_double(std::complex<double>& object);
void bcast_string(std::string& object);
void bcast_double(double& object);
void bcast_int(int& object);
void bcast_bool(bool& object);

} // namespace Parallel_Common

#endif
Loading

0 comments on commit bdb2178

Please sign in to comment.