Skip to content

Commit

Permalink
sequential scheduler: quickfix of mpi_tmpdir fixture to run with MPI_…
Browse files Browse the repository at this point in the history
…COMM_NULL, but should change the scheduler so that fixtures are not even called in non-participating tests
  • Loading branch information
BerengerBerthoul authored and sonics committed Aug 29, 2024
1 parent d89cbaa commit fd30676
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 76 deletions.
84 changes: 16 additions & 68 deletions pytest_parallel/mpi_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def create_sub_comm_of_size(global_comm, n_proc, mpi_comm_creation_function):
color = MPI.UNDEFINED
return global_comm.Split(color, key=i_rank)
else:
assert 0, 'unknown MPI communicator creation function'
assert 0, 'Unknown MPI communicator creation function. Available: `MPI_Comm_create`, `MPI_Comm_split`'

def create_sub_comms_for_each_size(global_comm, mpi_comm_creation_function):
i_rank = global_comm.Get_rank()
Expand All @@ -45,36 +45,6 @@ def create_sub_comms_for_each_size(global_comm, mpi_comm_creation_function):
return sub_comms



def filter_and_add_sub_comm__old(items, global_comm):
i_rank = global_comm.Get_rank()
n_workers = global_comm.Get_size()

filtered_items = []
for item in items:
n_proc_test = get_n_proc_for_test(item)

if n_proc_test > n_workers: # not enough procs: will be skipped
if global_comm.Get_rank() == 0:
item.sub_comm = MPI.COMM_SELF
mark_skip(item)
filtered_items += [item]
else:
item.sub_comm = MPI.COMM_NULL # TODO this should not be needed
else:
if i_rank < n_proc_test:
color = 1
else:
color = MPI.UNDEFINED

sub_comm = global_comm.Split(color)

if sub_comm != MPI.COMM_NULL:
item.sub_comm = sub_comm
filtered_items += [item]

return filtered_items

def add_sub_comm(items, global_comm, test_comm_creation, mpi_comm_creation_function):
i_rank = global_comm.Get_rank()
n_rank = global_comm.Get_size()
Expand All @@ -91,73 +61,51 @@ def add_sub_comm(items, global_comm, test_comm_creation, mpi_comm_creation_funct
if n_proc_test > n_rank: # not enough procs: mark as to be skipped
mark_skip(item)
item.sub_comm = MPI.COMM_NULL
#if n_proc_test > n_workers: # not enough procs: will be skipped
# if global_comm.Get_rank() == 0:
# item.sub_comm = MPI.COMM_SELF
# mark_skip(item)
# else:
# item.sub_comm = MPI.COMM_NULL # TODO this should not be needed
else:
if test_comm_creation == 'by_rank':
item.sub_comm = sub_comms[n_proc_test-1]
elif test_comm_creation == 'by_test':
item.sub_comm = create_sub_comm_of_size(global_comm, n_proc_test, mpi_comm_creation_function)
else:
assert 0, 'unknown test MPI communicator creation strategy'

assert 0, 'Unknown test MPI communicator creation strategy. Available: `by_rank`, `by_test`'

class SequentialScheduler:
def __init__(self, global_comm, test_comm_creation='by_rank', mpi_comm_creation_function='MPI_Comm_create', barrier_at_test_start=True, barrier_at_test_end=False):
def __init__(self, global_comm, test_comm_creation='by_rank', mpi_comm_creation_function='MPI_Comm_create', barrier_at_test_start=True, barrier_at_test_end=True):
self.global_comm = global_comm.Dup() # ensure that all communications within the framework are private to the framework
self.test_comm_creation = test_comm_creation
self.mpi_comm_creation_function = mpi_comm_creation_function

self.barrier_at_test_start = barrier_at_test_start
self.barrier_at_test_end = barrier_at_test_end

@pytest.hookimpl(trylast=True)
def pytest_collection_modifyitems(self, config, items):
add_sub_comm(items, self.global_comm, self.test_comm_creation, self.mpi_comm_creation_function)

#@pytest.hookimpl(tryfirst=True)
#def pytest_runtest_protocol(self, item, nextitem):
# #i_rank = self.global_comm.Get_rank()
# #n_proc_test = get_n_proc_for_test(item)
# #if i_rank < n_proc_test:
# # sub_comm = sub_comm_from_ranks(self.global_comm, range(0,n_proc_test))
# #else:
# # sub_comm = MPI.COMM_NULL
# #item.sub_comm = sub_comm
# n_proc_test = get_n_proc_for_test(item)
# #if n_proc_test <= self.global_comm.Get_size():
# #if n_proc_test < self.global_comm.rank:
# item.sub_comm = self.sub_comms[n_proc_test-1]
# #else:
# # item.sub_comm = MPI.COMM_NULL

@pytest.hookimpl(hookwrapper=True, tryfirst=True)
def pytest_runtest_protocol(self, item, nextitem):
if self.barrier_at_test_start:
self.global_comm.barrier()
#print(f'pytest_runtest_protocol beg {MPI.COMM_WORLD.rank=}')
_ = yield
#print(f'pytest_runtest_protocol end {MPI.COMM_WORLD.rank=}')
if self.barrier_at_test_end:
self.global_comm.barrier()

#@pytest.hookimpl(tryfirst=True)
#def pytest_runtest_protocol(self, item, nextitem):
# pass
# #return True
# #if item.sub_comm != MPI.COMM_NULL:
# # _ = yield
# #else:
# # return True
# if self.barrier_at_test_start:
# self.global_comm.barrier()
# print(f'pytest_runtest_protocol beg {MPI.COMM_WORLD.rank=}')
# if item.sub_comm == MPI.COMM_NULL:
# return True # for this hook, `firstresult=True` so returning a non-None will stop other hooks to run

@pytest.hookimpl(hookwrapper=True, tryfirst=True)
@pytest.hookimpl(tryfirst=True)
def pytest_pyfunc_call(self, pyfuncitem):
if pyfuncitem.sub_comm != MPI.COMM_NULL:
_ = yield
else: # the rank does not participate in the test, so do nothing
return True
#print(f'pytest_pyfunc_call {MPI.COMM_WORLD.rank=}')
# This is where the test is normally run.
# Only run the test for the ranks that do participate in the test
if pyfuncitem.sub_comm == MPI.COMM_NULL:
return True # for this hook, `firstresult=True` so returning a non-None will stop other hooks to run

@pytest.hookimpl(hookwrapper=True, tryfirst=True)
def pytest_runtestloop(self, session) -> bool:
Expand Down
21 changes: 13 additions & 8 deletions pytest_parallel/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import pytest
from pathlib import Path
import argparse

#from mpi4py import MPI
#from logger import consoleLogger

# --------------------------------------------------------------------------
def pytest_addoption(parser):
Expand Down Expand Up @@ -164,15 +165,19 @@ def __init__(self, comm):
self.tmp_path = None

def __enter__(self):
rank = self.comm.Get_rank()
self.tmp_dir = tempfile.TemporaryDirectory() if rank == 0 else None
self.tmp_path = Path(self.tmp_dir.name) if rank == 0 else None
return self.comm.bcast(self.tmp_path, root=0)
from mpi4py import MPI
if self.comm != MPI.COMM_NULL: # TODO DEL once non-participating rank do not participate in fixtures either
rank = self.comm.Get_rank()
self.tmp_dir = tempfile.TemporaryDirectory() if rank == 0 else None
self.tmp_path = Path(self.tmp_dir.name) if rank == 0 else None
return self.comm.bcast(self.tmp_path, root=0)

def __exit__(self, type, value, traceback):
self.comm.barrier()
if self.comm.Get_rank() == 0:
self.tmp_dir.cleanup()
from mpi4py import MPI
if self.comm != MPI.COMM_NULL: # TODO DEL once non-participating rank do not participate in fixtures either
self.comm.barrier()
if self.comm.Get_rank() == 0:
self.tmp_dir.cleanup()


@pytest.fixture
Expand Down

0 comments on commit fd30676

Please sign in to comment.