Skip to content

Commit

Permalink
add function get_global_row/col_size() in parallel_2d (#4440)
Browse files Browse the repository at this point in the history
* add function get_global_row/col_size() in parallel_2d

* [pre-commit.ci lite] apply automatic fixes

* add const

---------

Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
  • Loading branch information
pxlxingliang and pre-commit-ci-lite[bot] authored Jun 22, 2024
1 parent f2274d2 commit e6a5c35
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 106 deletions.
54 changes: 31 additions & 23 deletions source/module_basis/module_ao/parallel_2d.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,34 @@
#include "parallel_2d.h"

#include <numeric>
#include <cassert>

#include "module_base/blacs_connector.h"
#include "module_base/scalapack_connector.h"

#include <cassert>
#include <numeric>

bool Parallel_2D::in_this_processor(const int iw1_all, const int iw2_all) const
{
return global2local_row(iw1_all) != -1 && global2local_col(iw2_all) != -1;
}

int Parallel_2D::get_global_row_size() const
{
#ifdef __MPI
return desc[2];
#else
return nrow;
#endif
}

int Parallel_2D::get_global_col_size() const
{
#ifdef __MPI
return desc[3];
#else
return ncol;
#endif
}

#ifdef __MPI
void Parallel_2D::_init_proc_grid(const MPI_Comm comm, const bool mode)
{
Expand All @@ -19,9 +37,15 @@ void Parallel_2D::_init_proc_grid(const MPI_Comm comm, const bool mode)
int num_proc = 0;
MPI_Comm_size(comm, &num_proc);
dim0 = static_cast<int>(std::sqrt(num_proc + 0.5));
while (dim1 = num_proc / dim0, dim0 * dim1 != num_proc) { --dim0; }
while (dim1 = num_proc / dim0, dim0 * dim1 != num_proc)
{
--dim0;
}

if (mode) { std::swap(dim0, dim1); }
if (mode)
{
std::swap(dim0, dim1);
}

// create a 2D Cartesian MPI communicator (row-major by default)
int period[2] = {1, 1};
Expand All @@ -42,7 +66,6 @@ void Parallel_2D::_init_proc_grid(const MPI_Comm comm, const bool mode)
// ***This needs to be verified***
}


void Parallel_2D::_set_dist_info(const int mg, const int ng, const int nb)
{
this->nb = nb;
Expand Down Expand Up @@ -75,28 +98,14 @@ void Parallel_2D::_set_dist_info(const int mg, const int ng, const int nb)
}
}


int Parallel_2D::init(
const int mg,
const int ng,
const int nb,
const MPI_Comm comm,
const bool mode
)
int Parallel_2D::init(const int mg, const int ng, const int nb, const MPI_Comm comm, const bool mode)
{
_init_proc_grid(comm, mode);
_set_dist_info(mg, ng, nb);
return nrow == 0 || ncol == 0;
}


int Parallel_2D::set(
const int mg,
const int ng,
const int nb,
const MPI_Comm comm_2D,
const int blacs_ctxt
)
int Parallel_2D::set(const int mg, const int ng, const int nb, const MPI_Comm comm_2D, const int blacs_ctxt)
{
this->comm_2D = comm_2D;
this->blacs_ctxt = blacs_ctxt;
Expand Down Expand Up @@ -127,4 +136,3 @@ void Parallel_2D::set_serial(const int mg, const int ng)
blacs_ctxt = -1;
#endif
}

78 changes: 51 additions & 27 deletions source/module_basis/module_ao/parallel_2d.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#ifndef _PARALLEL_2D_H_
#define _PARALLEL_2D_H_

#include <vector>
#include <cstdint>
#include <vector>

#ifdef __MPI
#include <mpi.h>
Expand All @@ -12,63 +12,89 @@
/// 2D-block-cyclic parallel distribution of an arbitrary matrix.
class Parallel_2D
{
public:
public:
Parallel_2D() = default;
~Parallel_2D() = default;

/// number of local rows
int get_row_size() const { return nrow; };
int get_row_size() const
{
return nrow;
};

/// number of local columns
int get_col_size() const { return ncol; };
int get_col_size() const
{
return ncol;
};

/// number of global rows
int get_global_row_size() const;

/// number of global columns
int get_global_col_size() const;

/// number of local matrix elements
int64_t get_local_size() const { return nloc; };
int64_t get_local_size() const
{
return nloc;
};

/// get the local index of a global index (row)
int global2local_row(const int igr) const { return global2local_row_[igr]; }
int global2local_row(const int igr) const
{
return global2local_row_[igr];
}

/// get the local index of a global index (col)
int global2local_col(const int igc) const { return global2local_col_[igc]; }
int global2local_col(const int igc) const
{
return global2local_col_[igc];
}

/// get the global index of a local index (row)
int local2global_row(const int ilr) const { return local2global_row_[ilr]; }
int local2global_row(const int ilr) const
{
return local2global_row_[ilr];
}

/// get the global index of a local index (col)
int local2global_col(const int ilc) const { return local2global_col_[ilc]; }
int local2global_col(const int ilc) const
{
return local2global_col_[ilc];
}

/// check whether a global index is in this process
bool in_this_processor(const int iw1_all, const int iw2_all) const;

/// side length of 2d square block
int get_block_size() const { return nb; };
int get_block_size() const
{
return nb;
};

#ifdef __MPI
/**
* @brief Initialize a BLACS grid with the given MPI communicator
* and set up the info of a block-cyclic distribution.
*
*/
int init(
const int mg,
const int ng,
const int nb, // square block is assumed
const MPI_Comm comm,
bool mode = 0
);
int init(const int mg,
const int ng,
const int nb, // square block is assumed
const MPI_Comm comm,
bool mode = 0);

/**
* @brief Set up the info of a block-cyclic distribution using given
* MPI communicator and BLACS context.
*
*/
int set(
const int mg,
const int ng,
const int nb, // square block is assumed
const MPI_Comm comm_2D,
const int blacs_ctxt
);
int set(const int mg,
const int ng,
const int nb, // square block is assumed
const MPI_Comm comm_2D,
const int blacs_ctxt);

/// BLACS context
int blacs_ctxt = -1;
Expand Down Expand Up @@ -104,9 +130,7 @@ class Parallel_2D
/// test parameter
int testpb = 0;


protected:

protected:
/// map from global index to local index
std::vector<int> global2local_row_;
std::vector<int> global2local_col_;
Expand Down
Loading

0 comments on commit e6a5c35

Please sign in to comment.