Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MR descriptor list #1117

Merged
merged 7 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ jobs:
steps:
- name: Checking OS version
run: |
echo "OS_NAME=$(lsb_release -si)-$(ls_release -sr)" >> $GITHUB_ENV
echo "OS_NAME=$(lsb_release -si)-$(lsb_release -sr)" >> $GITHUB_ENV
- uses: actions/checkout@v4
with:
fetch-depth: 0
Expand Down Expand Up @@ -229,7 +229,7 @@ jobs:
cd build
make check TESTS= -j
${{ matrix.env_setup }}
SHMEM_DEBUG=1 SHMEM_INFO=1 make VERBOSE=1 TEST_RUNNER="${SOS_PM} -np 2" check
SHMEM_DEBUG=1 SHMEM_INFO=1 SHMEM_OFI_PROVIDER=sockets make VERBOSE=1 TEST_RUNNER="${SOS_PM} -np 2" check
cat modules/tests-sos/test/unit/hello.log
- name: Test RPM (${{ matrix.rpm_build }})
if: ${{ matrix.rpm_build }}
Expand Down Expand Up @@ -488,7 +488,7 @@ jobs:
steps:
- name: Checking OS version
run: |
echo "OS_NAME=$(lsb_release -si)-$(ls_release -sr)" >> $GITHUB_ENV
echo "OS_NAME=$(lsb_release -si)-$(lsb_release -sr)" >> $GITHUB_ENV
- uses: actions/checkout@v4
with:
fetch-depth: 0
Expand Down Expand Up @@ -582,7 +582,7 @@ jobs:
steps:
- name: Checking OS version
run: |
echo "OS_NAME=$(lsb_release -si)-$(ls_release -sr)" >> $GITHUB_ENV
echo "OS_NAME=$(lsb_release -si)-$(lsb_release -sr)" >> $GITHUB_ENV
- uses: actions/checkout@v4
with:
fetch-depth: 0
Expand Down Expand Up @@ -663,7 +663,7 @@ jobs:
steps:
- name: Checking OS version
run: |
echo "OS_NAME=$(lsb_release -si)-$(ls_release -sr)" >> $GITHUB_ENV
echo "OS_NAME=$(lsb_release -si)-$(lsb_release -sr)" >> $GITHUB_ENV
- uses: actions/checkout@v4
with:
fetch-depth: 0
Expand Down
14 changes: 14 additions & 0 deletions src/transport_ofi.c
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ uint64_t* shmem_transport_ofi_external_heap_keys;
uint8_t** shmem_transport_ofi_external_heap_addrs;
#endif

/* List of MR descriptors: current support is for heap, data, and one external heap */
struct fid_mr* shmem_transport_ofi_mrfd_list[3];
uint64_t shmem_transport_ofi_max_poll;
long shmem_transport_ofi_put_poll_limit;
long shmem_transport_ofi_get_poll_limit;
Expand Down Expand Up @@ -707,6 +709,8 @@ int ofi_mr_reg_bind(uint64_t flags)
}
#endif /* ENABLE_MR_RMA_EVENT */
#endif /* ENABLE_TARGET_CNTR */
shmem_transport_ofi_mrfd_list[0] = shmem_transport_ofi_target_mrfd;
shmem_transport_ofi_mrfd_list[1] = NULL;

#else
/* Register separate data and heap segments using keys 0 and 1,
Expand Down Expand Up @@ -770,6 +774,10 @@ int ofi_mr_reg_bind(uint64_t flags)
}
#endif /* ENABLE_MR_RMA_EVENT */
#endif /* ENABLE_TARGET_CNTR */

shmem_transport_ofi_mrfd_list[0] = shmem_transport_ofi_target_data_mrfd;
shmem_transport_ofi_mrfd_list[1] = shmem_transport_ofi_target_heap_mrfd;

#endif

return ret;
Expand Down Expand Up @@ -812,8 +820,14 @@ int allocate_recv_cntr_mr(void)
if (shmem_external_heap_pre_initialized) {
ret = ofi_mr_reg_external_heap();
OFI_CHECK_RETURN_STR(ret, "OFI MR registration with HMEM failed");
shmem_transport_ofi_mrfd_list[2] = shmem_transport_ofi_external_heap_mrfd;
} else {
shmem_transport_ofi_mrfd_list[2] = NULL;
}
#else
shmem_transport_ofi_mrfd_list[2] = NULL;
#endif

ret = ofi_mr_reg_bind(flags);
OFI_CHECK_RETURN_STR(ret, "OFI MR registration failed");

Expand Down
97 changes: 76 additions & 21 deletions src/transport_ofi.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ extern uint64_t* shmem_transport_ofi_external_heap_keys;
extern uint8_t** shmem_transport_ofi_external_heap_addrs;
#endif

extern struct fid_mr* shmem_transport_ofi_mrfd_list[3];
extern uint64_t shmem_transport_ofi_max_poll;
extern long shmem_transport_ofi_put_poll_limit;
extern long shmem_transport_ofi_get_poll_limit;
Expand Down Expand Up @@ -124,6 +125,51 @@ extern pthread_mutex_t shmem_transport_ofi_progress_lock;
} while (0)


#ifdef ENABLE_MR_SCALABLE
static inline
int shmem_transport_ofi_get_mr_desc_index(const void *addr) {
int ret = -1;
#ifdef ENABLE_REMOTE_VIRTUAL_ADDRESSING
ret = 0;
#else
if ((void*) addr >= shmem_internal_data_base &&
(uint8_t*) addr < (uint8_t*) shmem_internal_data_base + shmem_internal_data_length) {
ret = 0;
} else if ((void*) addr >= shmem_internal_heap_base &&
(uint8_t*) addr < (uint8_t*) shmem_internal_heap_base + shmem_internal_heap_length) {
ret = 1;
} else {
ret = -1;
}
#endif /* ENABLE_REMOTE_VIRTUAL_ADDRESSING */
return ret;
}
#else
static inline
int shmem_transport_ofi_get_mr_desc_index(const void *addr) {
int ret = -1;
if ((void*) addr >= shmem_internal_data_base &&
(uint8_t*) addr < (uint8_t*) shmem_internal_data_base + shmem_internal_data_length) {
ret = 0;
} else if ((void*) addr >= shmem_internal_heap_base &&
(uint8_t*) addr < (uint8_t*) shmem_internal_heap_base + shmem_internal_heap_length) {
ret = 1;
}
#ifdef USE_FI_HMEM
else if (shmem_external_heap_pre_initialized) {
if ((void*) addr >= shmem_external_heap_base &&
(uint8_t*) addr < (uint8_t*) shmem_external_heap_base + shmem_external_heap_length) {
ret = 2;
}
}
#endif /* USE_FI_HMEM */
else {
ret = -1;
}
return ret;
}
#endif

#ifdef ENABLE_MR_SCALABLE
static inline
void shmem_transport_ofi_get_mr(const void *addr, int dest_pe,
Expand Down Expand Up @@ -229,6 +275,13 @@ extern fi_addr_t *addr_table;
#define GET_DEST(dest) ((fi_addr_t)(dest))
#endif

#ifdef USE_FI_HMEM
#define GET_MR_DESC(index) ((index == -1) ? NULL : (void *) shmem_transport_ofi_mrfd_list[index])
#define GET_MR_DESC_ADDR(index) ((index == -1) ? NULL : (void **) &shmem_transport_ofi_mrfd_list[index])
#else
#define GET_MR_DESC(index) NULL
#define GET_MR_DESC_ADDR(index) NULL
#endif

struct shmem_transport_ofi_frag_t {
shmem_free_list_item_t item;
Expand Down Expand Up @@ -611,7 +664,8 @@ void shmem_transport_ofi_put_large(shmem_transport_ctx_t* ctx, void *target, con

do {
ret = fi_write(ctx->ep,
frag_source, frag_len, NULL,
frag_source, frag_len,
GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(source)),
GET_DEST(dst), frag_target,
key, NULL);
} while (try_again(ctx, ret, &polled));
Expand Down Expand Up @@ -652,7 +706,7 @@ void shmem_transport_put_nb(shmem_transport_ctx_t* ctx, void *target, const void
const struct fi_rma_iov rma_iov = { .addr = (uint64_t) addr, .len = len, .key = key };
const struct fi_msg_rma msg = {
.msg_iov = &msg_iov,
.desc = NULL,
.desc = GET_MR_DESC_ADDR(shmem_transport_ofi_get_mr_desc_index(source)),
.iov_count = 1,
.addr = GET_DEST(dst),
.rma_iov = &rma_iov,
Expand Down Expand Up @@ -700,7 +754,7 @@ void shmem_transport_put_signal_nbi(shmem_transport_ctx_t* ctx, void *target, co
};
const struct fi_msg_rma msg = {
.msg_iov = &msg_iov,
.desc = NULL,
.desc = GET_MR_DESC_ADDR(shmem_transport_ofi_get_mr_desc_index(source)),
.iov_count = 1,
.addr = GET_DEST(dst),
.rma_iov = &rma_iov,
Expand Down Expand Up @@ -730,7 +784,7 @@ void shmem_transport_put_signal_nbi(shmem_transport_ctx_t* ctx, void *target, co
};
struct fi_msg_rma msg = {
.msg_iov = &msg_iov,
.desc = NULL,
.desc = GET_MR_DESC_ADDR(shmem_transport_ofi_get_mr_desc_index(source)),
.iov_count = 1,
.addr = GET_DEST(dst),
.rma_iov = &rma_iov,
Expand Down Expand Up @@ -796,7 +850,7 @@ void shmem_transport_put_signal_nbi(shmem_transport_ctx_t* ctx, void *target, co
};
const struct fi_msg_atomic msg_signal = {
.msg_iov = &msg_iov_signal,
.desc = NULL,
.desc = GET_MR_DESC_ADDR(shmem_transport_ofi_get_mr_desc_index((void *) &signal)),
.iov_count = 1,
.addr = GET_DEST(dst),
.rma_iov = &rma_iov_signal,
Expand Down Expand Up @@ -860,7 +914,7 @@ void shmem_transport_get(shmem_transport_ctx_t* ctx, void *target, const void *s
ret = fi_read(ctx->ep,
target,
len,
NULL,
GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(target)),
GET_DEST(dst),
(uint64_t) addr,
key,
Expand All @@ -881,7 +935,8 @@ void shmem_transport_get(shmem_transport_ctx_t* ctx, void *target, const void *s

do {
ret = fi_read(ctx->ep,
frag_target, frag_len, NULL,
frag_target, frag_len,
GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(target)),
GET_DEST(dst), frag_source,
key, NULL);
} while (try_again(ctx, ret, &polled));
Expand Down Expand Up @@ -964,7 +1019,7 @@ void shmem_transport_cswap_nbi(shmem_transport_ctx_t* ctx, void *target, const
const struct fi_rma_ioc rmav= { .addr = (uint64_t) addr, .count = 1, .key = key };
const struct fi_msg_atomic msg = {
.msg_iov = &sourcev,
.desc = NULL,
.desc = GET_MR_DESC_ADDR(shmem_transport_ofi_get_mr_desc_index(source)),
.iov_count = 1,
.addr = GET_DEST(dst),
.rma_iov = &rmav,
Expand All @@ -985,7 +1040,7 @@ void shmem_transport_cswap_nbi(shmem_transport_ctx_t* ctx, void *target, const
NULL,
1,
&resultv,
NULL,
GET_MR_DESC_ADDR(shmem_transport_ofi_get_mr_desc_index(dest)),
1,
FI_INJECT); /* FI_DELIVERY_COMPLETE is not required as
it is implied for fetch atomicmsgs */
Expand Down Expand Up @@ -1023,11 +1078,11 @@ void shmem_transport_cswap(shmem_transport_ctx_t* ctx, void *target, const void
ret = fi_compare_atomic(ctx->ep,
source,
1,
NULL,
GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(source)),
operand,
NULL,
dest,
NULL,
GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(dest)),
GET_DEST(dst),
(uint64_t) addr,
key,
Expand Down Expand Up @@ -1062,11 +1117,11 @@ void shmem_transport_mswap(shmem_transport_ctx_t* ctx, void *target, const void
ret = fi_compare_atomic(ctx->ep,
source,
1,
NULL,
GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(source)),
mask,
NULL,
dest,
NULL,
GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(dest)),
GET_DEST(dst),
(uint64_t) addr,
key,
Expand Down Expand Up @@ -1170,7 +1225,7 @@ void shmem_transport_atomicv(shmem_transport_ctx_t* ctx, void *target, const voi
const struct fi_rma_ioc rma_iov = { .addr = (uint64_t) addr, .count = len, .key = key };
const struct fi_msg_atomic msg = {
.msg_iov = &msg_iov,
.desc = NULL,
.desc = GET_MR_DESC_ADDR(shmem_transport_ofi_get_mr_desc_index(source)),
.iov_count = 1,
.addr = GET_DEST(dst),
.rma_iov = &rma_iov,
Expand Down Expand Up @@ -1198,7 +1253,7 @@ void shmem_transport_atomicv(shmem_transport_ctx_t* ctx, void *target, const voi
(void *)((char *)source +
(sent*SHMEM_Dtsize[dt])),
chunksize,
NULL,
GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(source)),
GET_DEST(dst),
((uint64_t) addr +
(sent*SHMEM_Dtsize[dt])),
Expand Down Expand Up @@ -1238,7 +1293,7 @@ void shmem_transport_fetch_atomic_nbi(shmem_transport_ctx_t* ctx, void *target,
const struct fi_rma_ioc rmav= { .addr = (uint64_t) addr, .count = 1, .key = key };
const struct fi_msg_atomic msg = {
.msg_iov = &sourcev,
.desc = NULL,
.desc = GET_MR_DESC_ADDR(shmem_transport_ofi_get_mr_desc_index(source)),
.iov_count = 1,
.addr = GET_DEST(dst),
.rma_iov = &rmav,
Expand All @@ -1256,7 +1311,7 @@ void shmem_transport_fetch_atomic_nbi(shmem_transport_ctx_t* ctx, void *target,
ret = fi_fetch_atomicmsg(ctx->ep,
&msg,
&resultv,
NULL,
GET_MR_DESC_ADDR(shmem_transport_ofi_get_mr_desc_index(dest)),
1,
FI_INJECT); /* FI_DELIVERY_COMPLETE is not required as it's
implied for fetch atomicmsgs */
Expand Down Expand Up @@ -1295,9 +1350,9 @@ void shmem_transport_fetch_atomic(shmem_transport_ctx_t* ctx, void *target,
ret = fi_fetch_atomic(ctx->ep,
source,
1,
NULL,
GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(source)),
dest,
NULL,
GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(dest)),
GET_DEST(dst),
(uint64_t) addr,
key,
Expand Down Expand Up @@ -1353,8 +1408,8 @@ void shmem_transport_atomic_fetch(shmem_transport_ctx_t* ctx, void *target,
shmem_transport_fetch_atomic_nbi(ctx, (void *) source, (const void *) &dummy,
target, len, pe, FI_SUM, datatype);
#else
shmem_transport_fetch_atomic(ctx, (void *) source, (const void *) NULL,
target, len, pe, FI_ATOMIC_READ, datatype);
shmem_transport_fetch_atomic_nbi(ctx, (void *) source, (const void *) NULL,
target, len, pe, FI_ATOMIC_READ, datatype);
davidozog marked this conversation as resolved.
Show resolved Hide resolved
#endif
}

Expand Down
Loading