From 833cd8435a98b660be3aab563c2c24a2f7d0405c Mon Sep 17 00:00:00 2001 From: "Vinciguerra, Armando" Date: Mon, 14 Oct 2024 14:52:48 -0400 Subject: [PATCH 1/3] First pass at scan, implemented linear and ring algorithms --- mpp/shmemx.h4 | 14 +++ mpp/shmemx_c_func.h4 | 10 ++ src/collectives.c | 208 ++++++++++++++++++++++++++++++++++++++++ src/collectives_c.c4 | 61 ++++++++++++ src/shmem_collectives.h | 75 +++++++++++++++ src/shmem_env_defs.h | 2 + src/shmem_team.h | 3 +- 7 files changed, 372 insertions(+), 1 deletion(-) diff --git a/mpp/shmemx.h4 b/mpp/shmemx.h4 index debb5f430..92b385197 100644 --- a/mpp/shmemx.h4 +++ b/mpp/shmemx.h4 @@ -68,6 +68,20 @@ static inline void shmemx_ibget(shmem_ctx_t ctx, $2 *target, const $2 *source, }')dnl SHMEM_CXX_DEFINE_FOR_RMA(`SHMEM_CXX_IBGET') +define(`SHMEM_CXX_SUM_EXSCAN', +`static inline int shmemx_sum_exscan(shmem_team_t team, $2* dest, const $2* source, + size_t nelems) { + return shmemx_$1_sum_exscan(team, dest, source, nelems); +}')dnl +SHMEM_CXX_DEFINE_FOR_COLL_SUM_PROD(`SHMEM_CXX_SUM_EXSCAN') + +define(`SHMEM_CXX_SUM_INSCAN', +`static inline int shmemx_sum_inscan(shmem_team_t team, $2* dest, const $2* source, + size_t nelems) { + return shmemx_$1_sum_inscan(team, dest, source, nelems); +}')dnl +SHMEM_CXX_DEFINE_FOR_COLL_SUM_PROD(`SHMEM_CXX_SUM_INSCAN') + /* C11 Generic Macros */ #elif (defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L && !defined(SHMEM_INTERNAL_INCLUDE)) diff --git a/mpp/shmemx_c_func.h4 b/mpp/shmemx_c_func.h4 index 7eb969a9c..948146156 100644 --- a/mpp/shmemx_c_func.h4 +++ b/mpp/shmemx_c_func.h4 @@ -76,6 +76,16 @@ SH_PAD(`$1') ptrdiff_t tst, ptrdiff_t sst, SH_PAD(`$1') size_t bsize, size_t nblocks, int pe)')dnl SHMEM_DECLARE_FOR_SIZES(`SHMEM_C_CTX_IBGET_N') +define(`SHMEM_C_EXSCAN', +`SHMEM_FUNCTION_ATTRIBUTES int SHPRE()shmemx_$1_$4_exscan(shmem_team_t team, $2 *dest, const $2 *source, size_t nelems);')dnl + +SHMEM_BIND_C_COLL_SUM_PROD(`SHMEM_C_EXSCAN', `sum') + +define(`SHMEM_C_INSCAN', +`SHMEM_FUNCTION_ATTRIBUTES int SHPRE()shmemx_$1_$4_inscan(shmem_team_t team, $2 *dest, const $2 *source, size_t nelems);')dnl + +SHMEM_BIND_C_COLL_SUM_PROD(`SHMEM_C_INSCAN', `sum') + /* Performance Counter Query Routines */ SHMEM_FUNCTION_ATTRIBUTES void SHPRE()shmemx_pcntr_get_issued_write(shmem_ctx_t ctx, uint64_t *cntr_value); SHMEM_FUNCTION_ATTRIBUTES void SHPRE()shmemx_pcntr_get_issued_read(shmem_ctx_t ctx, uint64_t *cntr_value); diff --git a/src/collectives.c b/src/collectives.c index ee51f869e..d69c0d80e 100644 --- a/src/collectives.c +++ b/src/collectives.c @@ -18,6 +18,7 @@ #define SHMEM_INTERNAL_INCLUDE #include "shmem.h" +#include "shmemx.h" #include "shmem_internal.h" #include "shmem_collectives.h" #include "shmem_internal_op.h" @@ -25,6 +26,7 @@ coll_type_t shmem_internal_barrier_type = AUTO; coll_type_t shmem_internal_bcast_type = AUTO; coll_type_t shmem_internal_reduce_type = AUTO; +coll_type_t shmem_internal_scan_type = AUTO; coll_type_t shmem_internal_collect_type = AUTO; coll_type_t shmem_internal_fcollect_type = AUTO; long *shmem_internal_barrier_all_psync; @@ -206,6 +208,16 @@ shmem_internal_collectives_init(void) } else { RAISE_WARN_MSG("Ignoring bad reduction algorithm '%s'\n", type); } + } + if (shmem_internal_params.SCAN_ALGORITHM_provided) { + type = shmem_internal_params.SCAN_ALGORITHM; + if (0 == strcmp(type, "auto")) { + shmem_internal_scan_type = AUTO; + } else if (0 == strcmp(type, "linear")) { + shmem_internal_scan_type = LINEAR; + } else { + RAISE_WARN_MSG("Ignoring bad scan algorithm '%s'\n", type); + } } if (shmem_internal_params.COLLECT_ALGORITHM_provided) { type = shmem_internal_params.COLLECT_ALGORITHM; @@ -971,6 +983,202 @@ shmem_internal_op_to_all_recdbl_sw(void *target, const void *source, size_t coun } +/***************************************** + * + * SCAN + * + *****************************************/ +void +shmem_internal_scan_linear(void *target, const void *source, size_t count, size_t type_size, + int PE_start, int PE_stride, int PE_size, + void *pWrk, long *pSync, + shm_internal_op_t op, shm_internal_datatype_t datatype, int scantype) +{ + + /* scantype is 0 for inscan and 1 for exscan */ + + long zero = 0, one = 1; + long completion = 0; + + + if (count == 0) return; + + int pe, i; + + if (PE_start == shmem_internal_my_pe) { + + + /* initialize target buffer. The put + will flush any atomic cache value that may currently + exist. */ + if(scantype) + { + /* Exclude own value for EXSCAN */ + //Create an array of size (count * type_size) of zeroes + uint8_t *zeroes = (uint8_t *) calloc(count, type_size); + shmem_internal_put_nb(SHMEM_CTX_DEFAULT, target, zeroes, count * type_size, + shmem_internal_my_pe, &completion); + shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); + shmem_internal_quiet(SHMEM_CTX_DEFAULT); + free(zeroes); + } + + + /* Send contribution to all */ + for (pe = PE_start + PE_stride*scantype, i = scantype ; + i < PE_size ; + i++, pe += PE_stride) { + + shmem_internal_put_nb(SHMEM_CTX_DEFAULT, target, source, count * type_size, + pe, &completion); + shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); + shmem_internal_fence(SHMEM_CTX_DEFAULT); + + } + + for (pe = PE_start + PE_stride, i = 1 ; + i < PE_size ; + i++, pe += PE_stride) { + shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), pe); + } + + /* Wait for others to acknowledge initialization */ + SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, PE_size - 1); + + /* reset pSync */ + shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &zero, sizeof(zero), shmem_internal_my_pe); + SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, 0); + + + /* Let everyone know sending can start */ + for (pe = PE_start + PE_stride, i = 1 ; + i < PE_size ; + i++, pe += PE_stride) { + shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), pe); + } + + + } else { + + /* wait for clear to intialization */ + SHMEM_WAIT(pSync, 0); + + /* reset pSync */ + shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &zero, sizeof(zero), shmem_internal_my_pe); + SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, 0); + + /* Send contribution to all pes larger than itself */ + for (pe = shmem_internal_my_pe + PE_stride*scantype, i = shmem_internal_my_pe + scantype ; + i < PE_size; + i++, pe += PE_stride) { + + shmem_internal_atomicv(SHMEM_CTX_DEFAULT, target, source, count * type_size, + pe, op, datatype, &completion); + shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); + shmem_internal_fence(SHMEM_CTX_DEFAULT); + + } + + shmem_internal_atomic(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), + PE_start, SHM_INTERNAL_SUM, SHM_INTERNAL_LONG); + + SHMEM_WAIT(pSync, 0); + + /* reset pSync */ + shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &zero, sizeof(zero), shmem_internal_my_pe); + SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, 0); + + } + +} + + +void +shmem_internal_scan_ring(void *target, const void *source, size_t count, size_t type_size, + int PE_start, int PE_stride, int PE_size, + void *pWrk, long *pSync, + shm_internal_op_t op, shm_internal_datatype_t datatype, int scantype) +{ + + /* scantype is 0 for inscan and 1 for exscan */ + + long zero = 0, one = 1; + long completion = 0; + + + if (count == 0) return; + + int pe, i; + + if (PE_start == shmem_internal_my_pe) { + + + /* initialize target buffer. The put + will flush any atomic cache value that may currently + exist. */ + if(scantype) + { + /* Exclude own value for EXSCAN */ + //Create an array of size (count * type_size) of zeroes + uint8_t *zeroes = (uint8_t *) calloc(count, type_size); + shmem_internal_put_nb(SHMEM_CTX_DEFAULT, target, zeroes, count * type_size, + shmem_internal_my_pe, &completion); + free(zeroes); + } + + shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); + shmem_internal_quiet(SHMEM_CTX_DEFAULT); + + /* Send contribution to all */ + for (pe = PE_start + PE_stride*scantype, i = scantype ; + i < PE_size ; + i++, pe += PE_stride) { + + shmem_internal_put_nb(SHMEM_CTX_DEFAULT, target, source, count * type_size, + pe, &completion); + shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); + shmem_internal_fence(SHMEM_CTX_DEFAULT); + } + + /* Let next pe know that it's safe to send to us */ + if(shmem_internal_my_pe + PE_stride < PE_size) + shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), shmem_internal_my_pe + PE_stride); + + /* Wait for others to acknowledge sending data */ + SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, PE_size - 1); + + /* reset pSync */ + shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &zero, sizeof(zero), shmem_internal_my_pe); + SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, 0); + + } else { + /* wait for clear to send */ + SHMEM_WAIT(pSync, 0); + + /* reset pSync */ + shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &zero, sizeof(zero), shmem_internal_my_pe); + SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, 0); + + /* Send contribution to all pes larger than itself */ + for (pe = shmem_internal_my_pe + PE_stride*scantype, i = shmem_internal_my_pe + scantype ; + i < PE_size; + i++, pe += PE_stride) { + + shmem_internal_atomicv(SHMEM_CTX_DEFAULT, target, source, count * type_size, + pe, op, datatype, &completion); + shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); + shmem_internal_fence(SHMEM_CTX_DEFAULT); + } + + /* Let next pe know that it's safe to send to us */ + if(shmem_internal_my_pe + PE_stride < PE_size) + shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), shmem_internal_my_pe + PE_stride); + + shmem_internal_atomic(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), + PE_start, SHM_INTERNAL_SUM, SHM_INTERNAL_LONG); + } + +} /***************************************** * * COLLECT (variable size) diff --git a/src/collectives_c.c4 b/src/collectives_c.c4 index 70c8876b5..eb6d20d21 100644 --- a/src/collectives_c.c4 +++ b/src/collectives_c.c4 @@ -28,6 +28,7 @@ include(shmem_bind_c.m4)dnl #define SHMEM_INTERNAL_INCLUDE #include "shmem.h" +#include "shmemx.h" #include "shmem_internal.h" #include "shmem_comm.h" #include "shmem_collectives.h" @@ -81,6 +82,18 @@ SHMEM_BIND_C_COLL_SUM_PROD(`SHMEM_PROF_DEF_REDUCE', `prod', `SHM_INTERNAL_PROD') SHMEM_BIND_C_COLL_MIN_MAX(`SHMEM_PROF_DEF_REDUCE', `min', `SHM_INTERNAL_MIN') SHMEM_BIND_C_COLL_MIN_MAX(`SHMEM_PROF_DEF_REDUCE', `max', `SHM_INTERNAL_MAX') +define(`SHMEM_PROF_DEF_EXSCAN', +`#pragma weak shmem_$1_$4_exscan = pshmem_$1_$4_exscan +#define shmem_$1_$4_exscan pshmem_$1_$4_exscan')dnl +dnl +SHMEM_BIND_C_COLL_SUM_PROD(`SHMEM_PROF_DEF_EXSCAN', `sum', `SHM_INTERNAL_SUM') + +define(`SHMEM_PROF_DEF_INSCAN', +`#pragma weak shmem_$1_$4_inscan = pshmem_$1_$4_inscan +#define shmem_$1_$4_inscan pshmem_$1_$4_inscan')dnl +dnl +SHMEM_BIND_C_COLL_SUM_PROD(`SHMEM_PROF_DEF_INSCAN', `sum', `SHM_INTERNAL_SUM') + define(`SHMEM_PROF_DEF_BCAST', `#pragma weak shmem_$1_broadcast = pshmem_$1_broadcast #define shmem_$1_broadcast pshmem_$1_broadcast')dnl @@ -279,6 +292,54 @@ SHMEM_BIND_C_COLL_SUM_PROD(`SHMEM_DEF_REDUCE', `prod', `SHM_INTERNAL_PROD') SHMEM_BIND_C_COLL_MIN_MAX(`SHMEM_DEF_REDUCE', `min', `SHM_INTERNAL_MIN') SHMEM_BIND_C_COLL_MIN_MAX(`SHMEM_DEF_REDUCE', `max', `SHM_INTERNAL_MAX') +#define SHMEM_DEF_EXSCAN(STYPE,TYPE,ITYPE,SOP,IOP) \ + int SHMEM_FUNCTION_ATTRIBUTES \ + shmemx_##STYPE##_##SOP##_exscan(shmem_team_t team, TYPE *dest, \ + const TYPE *source, \ + size_t nelems) \ + { \ + SHMEM_ERR_CHECK_INITIALIZED(); \ + SHMEM_ERR_CHECK_TEAM_VALID(team); \ + SHMEM_ERR_CHECK_SYMMETRIC(dest, sizeof(TYPE)*nelems); \ + SHMEM_ERR_CHECK_SYMMETRIC(source, sizeof(TYPE)*nelems); \ + SHMEM_ERR_CHECK_OVERLAP(dest, source, sizeof(TYPE)*nelems, \ + sizeof(TYPE)*nelems, 1, 1); \ + TYPE *pWrk = NULL; \ + \ + shmem_internal_team_t *myteam = (shmem_internal_team_t *)team; \ + long *psync = shmem_internal_team_choose_psync(myteam, SCAN); \ + shmem_internal_exscan(dest, source, nelems, sizeof(TYPE), \ + myteam->start, myteam->stride, myteam->size, pWrk, \ + psync, IOP, ITYPE); \ + shmem_internal_team_release_psyncs(myteam, SCAN); \ + return 0; \ + } +SHMEM_BIND_C_COLL_SUM_PROD(`SHMEM_DEF_EXSCAN', `sum', `SHM_INTERNAL_SUM') + +#define SHMEM_DEF_INSCAN(STYPE,TYPE,ITYPE,SOP,IOP) \ + int SHMEM_FUNCTION_ATTRIBUTES \ + shmemx_##STYPE##_##SOP##_inscan(shmem_team_t team, TYPE *dest, \ + const TYPE *source, \ + size_t nelems) \ + { \ + SHMEM_ERR_CHECK_INITIALIZED(); \ + SHMEM_ERR_CHECK_TEAM_VALID(team); \ + SHMEM_ERR_CHECK_SYMMETRIC(dest, sizeof(TYPE)*nelems); \ + SHMEM_ERR_CHECK_SYMMETRIC(source, sizeof(TYPE)*nelems); \ + SHMEM_ERR_CHECK_OVERLAP(dest, source, sizeof(TYPE)*nelems, \ + sizeof(TYPE)*nelems, 1, 1); \ + TYPE *pWrk = NULL; \ + \ + shmem_internal_team_t *myteam = (shmem_internal_team_t *)team; \ + long *psync = shmem_internal_team_choose_psync(myteam, SCAN); \ + shmem_internal_inscan(dest, source, nelems, sizeof(TYPE), \ + myteam->start, myteam->stride, myteam->size, pWrk, \ + psync, IOP, ITYPE); \ + shmem_internal_team_release_psyncs(myteam, SCAN); \ + return 0; \ + } +SHMEM_BIND_C_COLL_SUM_PROD(`SHMEM_DEF_INSCAN', `sum', `SHM_INTERNAL_SUM') + void SHMEM_FUNCTION_ATTRIBUTES shmem_broadcast32(void *target, const void *source, size_t nlong, int PE_root, int PE_start, int logPE_stride, int PE_size, diff --git a/src/shmem_collectives.h b/src/shmem_collectives.h index 6409c5178..42ae8d3eb 100644 --- a/src/shmem_collectives.h +++ b/src/shmem_collectives.h @@ -37,6 +37,7 @@ extern long *shmem_internal_sync_all_psync; extern coll_type_t shmem_internal_barrier_type; extern coll_type_t shmem_internal_bcast_type; extern coll_type_t shmem_internal_reduce_type; +extern coll_type_t shmem_internal_scan_type; extern coll_type_t shmem_internal_collect_type; extern coll_type_t shmem_internal_fcollect_type; @@ -237,6 +238,80 @@ shmem_internal_op_to_all(void *target, const void *source, size_t count, } } +void shmem_internal_scan_linear(void *target, const void *source, size_t count, size_t type_size, + int PE_start, int PE_stride, int PE_size, + void *pWrk, long *pSync, + shm_internal_op_t op, shm_internal_datatype_t datatype, int scantype); + +void shmem_internal_scan_ring(void *target, const void *source, size_t count, size_t type_size, + int PE_start, int PE_stride, int PE_size, + void *pWrk, long *pSync, + shm_internal_op_t op, shm_internal_datatype_t datatype, int scantype); + +static inline +void +shmem_internal_exscan(void *target, const void *source, size_t count, + size_t type_size, int PE_start, int PE_stride, + int PE_size, void *pWrk, long *pSync, + shm_internal_op_t op, + shm_internal_datatype_t datatype) +{ + shmem_internal_assert(type_size > 0); + + switch (shmem_internal_scan_type) { + case AUTO: + shmem_internal_scan_linear(target, source, count, type_size, + PE_start, PE_stride, PE_size, + pWrk, pSync, op, datatype, 1); + break; + case LINEAR: + shmem_internal_scan_linear(target, source, count, type_size, + PE_start, PE_stride, PE_size, + pWrk, pSync, op, datatype, 1); + break; + case RING: + shmem_internal_scan_ring(target, source, count, type_size, + PE_start, PE_stride, PE_size, + pWrk, pSync, op, datatype, 1); + break; + default: + RAISE_ERROR_MSG("Illegal exscan type (%d)\n", + shmem_internal_scan_type); + } +} + + +static inline +void +shmem_internal_inscan(void *target, const void *source, size_t count, + size_t type_size, int PE_start, int PE_stride, + int PE_size, void *pWrk, long *pSync, + shm_internal_op_t op, + shm_internal_datatype_t datatype) +{ + shmem_internal_assert(type_size > 0); + + switch (shmem_internal_scan_type) { + case AUTO: + shmem_internal_scan_linear(target, source, count, type_size, + PE_start, PE_stride, PE_size, + pWrk, pSync, op, datatype, 0); + break; + case LINEAR: + shmem_internal_scan_linear(target, source, count, type_size, + PE_start, PE_stride, PE_size, + pWrk, pSync, op, datatype, 0); + break; + case RING: + shmem_internal_scan_ring(target, source, count, type_size, + PE_start, PE_stride, PE_size, + pWrk, pSync, op, datatype, 0); + break; + default: + RAISE_ERROR_MSG("Illegal exscan type (%d)\n", + shmem_internal_scan_type); + } +} void shmem_internal_collect_linear(void *target, const void *source, size_t len, int PE_start, int PE_stride, int PE_size, long *pSync); diff --git a/src/shmem_env_defs.h b/src/shmem_env_defs.h index 6118509d0..b786aaa24 100644 --- a/src/shmem_env_defs.h +++ b/src/shmem_env_defs.h @@ -63,6 +63,8 @@ SHMEM_INTERNAL_ENV_DEF(BCAST_ALGORITHM, string, "auto", SHMEM_INTERNAL_ENV_CAT_C "Algorithm for broadcast. Options are auto, linear, tree") SHMEM_INTERNAL_ENV_DEF(REDUCE_ALGORITHM, string, "auto", SHMEM_INTERNAL_ENV_CAT_COLLECTIVES, "Algorithm for reductions. Options are auto, linear, tree, recdbl") +SHMEM_INTERNAL_ENV_DEF(SCAN_ALGORITHM, string, "auto", SHMEM_INTERNAL_ENV_CAT_COLLECTIVES, + "Algorithm for scan. Options are linear, ring") SHMEM_INTERNAL_ENV_DEF(COLLECT_ALGORITHM, string, "auto", SHMEM_INTERNAL_ENV_CAT_COLLECTIVES, "Algorithm for collect. Options are auto, linear") SHMEM_INTERNAL_ENV_DEF(FCOLLECT_ALGORITHM, string, "auto", SHMEM_INTERNAL_ENV_CAT_COLLECTIVES, diff --git a/src/shmem_team.h b/src/shmem_team.h index 195730864..6ea197c3d 100644 --- a/src/shmem_team.h +++ b/src/shmem_team.h @@ -38,7 +38,8 @@ enum shmem_internal_team_op_t { BCAST, REDUCE, COLLECT, - ALLTOALL + ALLTOALL, + SCAN }; typedef enum shmem_internal_team_op_t shmem_internal_team_op_t; From c301f999ddf06ff1a46ae55a7c8c08e65557905b Mon Sep 17 00:00:00 2001 From: "Vinciguerra, Armando" Date: Mon, 14 Oct 2024 14:58:22 -0400 Subject: [PATCH 2/3] Adding ring value to san type --- src/collectives.c | 2 ++ src/shmem_team.h | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/collectives.c b/src/collectives.c index d69c0d80e..4a6a11774 100644 --- a/src/collectives.c +++ b/src/collectives.c @@ -215,6 +215,8 @@ shmem_internal_collectives_init(void) shmem_internal_scan_type = AUTO; } else if (0 == strcmp(type, "linear")) { shmem_internal_scan_type = LINEAR; + } else if (0 == strcmp(type, "ring")) { + shmem_internal_scan_type = RING; } else { RAISE_WARN_MSG("Ignoring bad scan algorithm '%s'\n", type); } diff --git a/src/shmem_team.h b/src/shmem_team.h index 6ea197c3d..5f3144132 100644 --- a/src/shmem_team.h +++ b/src/shmem_team.h @@ -39,7 +39,7 @@ enum shmem_internal_team_op_t { REDUCE, COLLECT, ALLTOALL, - SCAN + SCAN }; typedef enum shmem_internal_team_op_t shmem_internal_team_op_t; From 9dc152c91fe3ed24b78f9b50f7f4dadd952bb465 Mon Sep 17 00:00:00 2001 From: "Vinciguerra, Armando" Date: Wed, 16 Oct 2024 15:34:07 -0400 Subject: [PATCH 3/3] Added code for special case of in place scan --- mpp/shmemx.h4 | 15 ++- src/collectives.c | 214 +++++++++++++++++++++++----------------- src/collectives_c.c4 | 8 +- src/shmem_collectives.h | 48 ++++----- 4 files changed, 165 insertions(+), 120 deletions(-) diff --git a/mpp/shmemx.h4 b/mpp/shmemx.h4 index 92b385197..6729ddf5e 100644 --- a/mpp/shmemx.h4 +++ b/mpp/shmemx.h4 @@ -38,7 +38,7 @@ include(shmemx_c_func.h4)dnl /* SHMEMX constant(s) are included in MAX_HINTS value in shmem-def.h */ #define SHMEMX_MALLOC_NO_BARRIER (1l<<2) -/* C++ overloaded declarations */ +/* C overloaded declarations */ #ifdef __cplusplus } /* extern "C" */ @@ -119,6 +119,19 @@ SHMEM_BIND_C11_RMA(`SHMEM_C11_GEN_IBGET', `, \') \ uint64_t*: shmemx_signal_add \ )(__VA_ARGS__) +define(`SHMEM_C11_GEN_EXSCAN', ` $2*: shmemx_$1_sum_exscan')dnl +#define shmemx_sum_exscan(...) \ + _Generic(SHMEM_C11_TYPE_EVAL_PTR(SHMEM_C11_ARG1(__VA_ARGS__)), \ +SHMEM_BIND_C11_RMA(`SHMEM_C11_GEN_EXSCAN', `, \') \ + )(__VA_ARGS__) + +define(`SHMEM_C11_GEN_INSCAN', ` $2*: shmemx_$1_sum_inscan')dnl +#define shmemx_sum_inscan(...) \ + _Generic(SHMEM_C11_TYPE_EVAL_PTR(SHMEM_C11_ARG1(__VA_ARGS__)), \ +SHMEM_BIND_C11_RMA(`SHMEM_C11_GEN_INSCAN', `, \') \ + )(__VA_ARGS__) + + #endif /* C11 */ #endif /* SHMEMX_H */ diff --git a/src/collectives.c b/src/collectives.c index 4a6a11774..5461aa540 100644 --- a/src/collectives.c +++ b/src/collectives.c @@ -998,99 +998,115 @@ shmem_internal_scan_linear(void *target, const void *source, size_t count, size_ { /* scantype is 0 for inscan and 1 for exscan */ - - long zero = 0, one = 1; + long zero = 0, one = 1; long completion = 0; + int free_source = 0; if (count == 0) return; - - int pe, i; + + int pe, i; + + /* In-place scan: copy source data to a temporary buffer so we can use + * the symmetric buffer to accumulate scan data. */ + if (target == source) { + void *tmp = malloc(count * type_size); + + if (NULL == tmp) + RAISE_ERROR_MSG("Unable to allocate %zub temporary buffer\n", count*type_size); + + shmem_internal_copy_self(tmp, target, count * type_size); + free_source = 1; + source = tmp; + + shmem_internal_sync(PE_start, PE_stride, PE_size, pSync + 2); + } if (PE_start == shmem_internal_my_pe) { - - /* initialize target buffer. The put + + /* initialize target buffer. The put will flush any atomic cache value that may currently exist. */ - if(scantype) - { - /* Exclude own value for EXSCAN */ - //Create an array of size (count * type_size) of zeroes - uint8_t *zeroes = (uint8_t *) calloc(count, type_size); - shmem_internal_put_nb(SHMEM_CTX_DEFAULT, target, zeroes, count * type_size, + if(scantype) + { + /* Exclude own value for EXSCAN */ + //Create an array of size (count * type_size) of zeroes + uint8_t *zeroes = (uint8_t *) calloc(count, type_size); + shmem_internal_put_nb(SHMEM_CTX_DEFAULT, target, zeroes, count * type_size, shmem_internal_my_pe, &completion); - shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); - shmem_internal_quiet(SHMEM_CTX_DEFAULT); - free(zeroes); - } - - + shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); + shmem_internal_quiet(SHMEM_CTX_DEFAULT); + free(zeroes); + } + + /* Send contribution to all */ for (pe = PE_start + PE_stride*scantype, i = scantype ; i < PE_size ; i++, pe += PE_stride) { - - shmem_internal_put_nb(SHMEM_CTX_DEFAULT, target, source, count * type_size, - pe, &completion); - shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); - shmem_internal_fence(SHMEM_CTX_DEFAULT); - + + shmem_internal_put_nb(SHMEM_CTX_DEFAULT, target, source, count * type_size, + pe, &completion); + shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); + shmem_internal_fence(SHMEM_CTX_DEFAULT); + } - - for (pe = PE_start + PE_stride, i = 1 ; + + for (pe = PE_start + PE_stride, i = 1 ; i < PE_size ; i++, pe += PE_stride) { - shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), pe); - } - - /* Wait for others to acknowledge initialization */ + shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), pe); + } + + /* Wait for others to acknowledge initialization */ SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, PE_size - 1); - - /* reset pSync */ + + /* reset pSync */ shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &zero, sizeof(zero), shmem_internal_my_pe); SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, 0); - - - /* Let everyone know sending can start */ - for (pe = PE_start + PE_stride, i = 1 ; + + + /* Let everyone know sending can start */ + for (pe = PE_start + PE_stride, i = 1 ; i < PE_size ; i++, pe += PE_stride) { - shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), pe); - } - - + shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), pe); + } } else { - - /* wait for clear to intialization */ + + /* wait for clear to intialization */ SHMEM_WAIT(pSync, 0); /* reset pSync */ shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &zero, sizeof(zero), shmem_internal_my_pe); SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, 0); - /* Send contribution to all pes larger than itself */ - for (pe = shmem_internal_my_pe + PE_stride*scantype, i = shmem_internal_my_pe + scantype ; + /* Send contribution to all pes larger than itself */ + for (pe = shmem_internal_my_pe + PE_stride*scantype, i = shmem_internal_my_pe + scantype ; i < PE_size; i++, pe += PE_stride) { - shmem_internal_atomicv(SHMEM_CTX_DEFAULT, target, source, count * type_size, + shmem_internal_atomicv(SHMEM_CTX_DEFAULT, target, source, count * type_size, pe, op, datatype, &completion); - shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); - shmem_internal_fence(SHMEM_CTX_DEFAULT); - + shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); + shmem_internal_fence(SHMEM_CTX_DEFAULT); + } - - shmem_internal_atomic(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), + + shmem_internal_atomic(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), PE_start, SHM_INTERNAL_SUM, SHM_INTERNAL_LONG); - - SHMEM_WAIT(pSync, 0); - - /* reset pSync */ + + SHMEM_WAIT(pSync, 0); + + /* reset pSync */ shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &zero, sizeof(zero), shmem_internal_my_pe); SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, 0); } + + if (free_source) + free((void *)source); } @@ -1103,48 +1119,61 @@ shmem_internal_scan_ring(void *target, const void *source, size_t count, size_t { /* scantype is 0 for inscan and 1 for exscan */ - - long zero = 0, one = 1; + long zero = 0, one = 1; long completion = 0; + int free_source = 0; + + /* In-place scan: copy source data to a temporary buffer so we can use + * the symmetric buffer to accumulate scan data. */ + if (target == source) { + void *tmp = malloc(count * type_size); + + if (NULL == tmp) + RAISE_ERROR_MSG("Unable to allocate %zub temporary buffer\n", count*type_size); + + shmem_internal_copy_self(tmp, target, count * type_size); + free_source = 1; + source = tmp; + + shmem_internal_sync(PE_start, PE_stride, PE_size, pSync + 2); + } if (count == 0) return; - - int pe, i; + + int pe, i; if (PE_start == shmem_internal_my_pe) { - - /* initialize target buffer. The put + /* initialize target buffer. The put will flush any atomic cache value that may currently exist. */ - if(scantype) - { - /* Exclude own value for EXSCAN */ - //Create an array of size (count * type_size) of zeroes - uint8_t *zeroes = (uint8_t *) calloc(count, type_size); - shmem_internal_put_nb(SHMEM_CTX_DEFAULT, target, zeroes, count * type_size, + if(scantype) + { + /* Exclude own value for EXSCAN */ + //Create an array of size (count * type_size) of zeroes + uint8_t *zeroes = (uint8_t *) calloc(count, type_size); + shmem_internal_put_nb(SHMEM_CTX_DEFAULT, target, zeroes, count * type_size, shmem_internal_my_pe, &completion); - free(zeroes); - } - - shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); - shmem_internal_quiet(SHMEM_CTX_DEFAULT); - + shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); + shmem_internal_quiet(SHMEM_CTX_DEFAULT); + free(zeroes); + } + /* Send contribution to all */ for (pe = PE_start + PE_stride*scantype, i = scantype ; i < PE_size ; i++, pe += PE_stride) { - - shmem_internal_put_nb(SHMEM_CTX_DEFAULT, target, source, count * type_size, - pe, &completion); - shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); - shmem_internal_fence(SHMEM_CTX_DEFAULT); + + shmem_internal_put_nb(SHMEM_CTX_DEFAULT, target, source, count * type_size, + pe, &completion); + shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); + shmem_internal_fence(SHMEM_CTX_DEFAULT); } - - /* Let next pe know that it's safe to send to us */ - if(shmem_internal_my_pe + PE_stride < PE_size) - shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), shmem_internal_my_pe + PE_stride); + + /* Let next pe know that it's safe to send to us */ + if(shmem_internal_my_pe + PE_stride < PE_size) + shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), shmem_internal_my_pe + PE_stride); /* Wait for others to acknowledge sending data */ SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, PE_size - 1); @@ -1161,24 +1190,27 @@ shmem_internal_scan_ring(void *target, const void *source, size_t count, size_t shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &zero, sizeof(zero), shmem_internal_my_pe); SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, 0); - /* Send contribution to all pes larger than itself */ - for (pe = shmem_internal_my_pe + PE_stride*scantype, i = shmem_internal_my_pe + scantype ; + /* Send contribution to all pes larger than itself */ + for (pe = shmem_internal_my_pe + PE_stride*scantype, i = shmem_internal_my_pe + scantype ; i < PE_size; i++, pe += PE_stride) { - shmem_internal_atomicv(SHMEM_CTX_DEFAULT, target, source, count * type_size, + shmem_internal_atomicv(SHMEM_CTX_DEFAULT, target, source, count * type_size, pe, op, datatype, &completion); - shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); - shmem_internal_fence(SHMEM_CTX_DEFAULT); + shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion); + shmem_internal_fence(SHMEM_CTX_DEFAULT); } - - /* Let next pe know that it's safe to send to us */ - if(shmem_internal_my_pe + PE_stride < PE_size) - shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), shmem_internal_my_pe + PE_stride); - + + /* Let next pe know that it's safe to send to us */ + if(shmem_internal_my_pe + PE_stride < PE_size) + shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), shmem_internal_my_pe + PE_stride); + shmem_internal_atomic(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), PE_start, SHM_INTERNAL_SUM, SHM_INTERNAL_LONG); } + + if (free_source) + free((void *)source); } /***************************************** diff --git a/src/collectives_c.c4 b/src/collectives_c.c4 index eb6d20d21..2287fff70 100644 --- a/src/collectives_c.c4 +++ b/src/collectives_c.c4 @@ -83,14 +83,14 @@ SHMEM_BIND_C_COLL_MIN_MAX(`SHMEM_PROF_DEF_REDUCE', `min', `SHM_INTERNAL_MIN') SHMEM_BIND_C_COLL_MIN_MAX(`SHMEM_PROF_DEF_REDUCE', `max', `SHM_INTERNAL_MAX') define(`SHMEM_PROF_DEF_EXSCAN', -`#pragma weak shmem_$1_$4_exscan = pshmem_$1_$4_exscan -#define shmem_$1_$4_exscan pshmem_$1_$4_exscan')dnl +`#pragma weak shmemx_$1_$4_exscan = pshmemx_$1_$4_exscan +#define shmemx_$1_$4_exscan pshmemx_$1_$4_exscan')dnl dnl SHMEM_BIND_C_COLL_SUM_PROD(`SHMEM_PROF_DEF_EXSCAN', `sum', `SHM_INTERNAL_SUM') define(`SHMEM_PROF_DEF_INSCAN', -`#pragma weak shmem_$1_$4_inscan = pshmem_$1_$4_inscan -#define shmem_$1_$4_inscan pshmem_$1_$4_inscan')dnl +`#pragma weak shmemx_$1_$4_inscan = pshmemx_$1_$4_inscan +#define shmemx_$1_$4_inscan pshmemx_$1_$4_inscan')dnl dnl SHMEM_BIND_C_COLL_SUM_PROD(`SHMEM_PROF_DEF_INSCAN', `sum', `SHM_INTERNAL_SUM') diff --git a/src/shmem_collectives.h b/src/shmem_collectives.h index 42ae8d3eb..3bbe36108 100644 --- a/src/shmem_collectives.h +++ b/src/shmem_collectives.h @@ -242,12 +242,12 @@ void shmem_internal_scan_linear(void *target, const void *source, size_t count, int PE_start, int PE_stride, int PE_size, void *pWrk, long *pSync, shm_internal_op_t op, shm_internal_datatype_t datatype, int scantype); - + void shmem_internal_scan_ring(void *target, const void *source, size_t count, size_t type_size, int PE_start, int PE_stride, int PE_size, void *pWrk, long *pSync, shm_internal_op_t op, shm_internal_datatype_t datatype, int scantype); - + static inline void shmem_internal_exscan(void *target, const void *source, size_t count, @@ -260,19 +260,19 @@ shmem_internal_exscan(void *target, const void *source, size_t count, switch (shmem_internal_scan_type) { case AUTO: - shmem_internal_scan_linear(target, source, count, type_size, - PE_start, PE_stride, PE_size, - pWrk, pSync, op, datatype, 1); + shmem_internal_scan_linear(target, source, count, type_size, + PE_start, PE_stride, PE_size, + pWrk, pSync, op, datatype, 1); break; - case LINEAR: - shmem_internal_scan_linear(target, source, count, type_size, - PE_start, PE_stride, PE_size, - pWrk, pSync, op, datatype, 1); + case LINEAR: + shmem_internal_scan_linear(target, source, count, type_size, + PE_start, PE_stride, PE_size, + pWrk, pSync, op, datatype, 1); break; - case RING: - shmem_internal_scan_ring(target, source, count, type_size, - PE_start, PE_stride, PE_size, - pWrk, pSync, op, datatype, 1); + case RING: + shmem_internal_scan_ring(target, source, count, type_size, + PE_start, PE_stride, PE_size, + pWrk, pSync, op, datatype, 1); break; default: RAISE_ERROR_MSG("Illegal exscan type (%d)\n", @@ -293,19 +293,19 @@ shmem_internal_inscan(void *target, const void *source, size_t count, switch (shmem_internal_scan_type) { case AUTO: - shmem_internal_scan_linear(target, source, count, type_size, - PE_start, PE_stride, PE_size, - pWrk, pSync, op, datatype, 0); + shmem_internal_scan_linear(target, source, count, type_size, + PE_start, PE_stride, PE_size, + pWrk, pSync, op, datatype, 0); break; - case LINEAR: - shmem_internal_scan_linear(target, source, count, type_size, - PE_start, PE_stride, PE_size, - pWrk, pSync, op, datatype, 0); + case LINEAR: + shmem_internal_scan_linear(target, source, count, type_size, + PE_start, PE_stride, PE_size, + pWrk, pSync, op, datatype, 0); break; - case RING: - shmem_internal_scan_ring(target, source, count, type_size, - PE_start, PE_stride, PE_size, - pWrk, pSync, op, datatype, 0); + case RING: + shmem_internal_scan_ring(target, source, count, type_size, + PE_start, PE_stride, PE_size, + pWrk, pSync, op, datatype, 0); break; default: RAISE_ERROR_MSG("Illegal exscan type (%d)\n",