-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
96fc77e
commit ef93e14
Showing
5 changed files
with
150 additions
and
143 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from mpi4py import MPI | ||
from _pytest._code.code import ( | ||
ExceptionChainRepr, | ||
ReprTraceback, | ||
ReprEntryNative, | ||
ReprFileLocation, | ||
) | ||
|
||
|
||
def gather_report(mpi_reports, n_sub_rank): | ||
assert len(mpi_reports) == n_sub_rank | ||
|
||
report_init = mpi_reports[0] | ||
goutcome = report_init.outcome | ||
glongrepr = report_init.longrepr | ||
|
||
collect_longrepr = [] | ||
# > We need to rebuild a TestReport object, location can be false # TODO ? | ||
for i_sub_rank, test_report in enumerate(mpi_reports): | ||
if test_report.outcome == "failed": | ||
goutcome = "failed" | ||
|
||
if test_report.longrepr: | ||
msg = f"On rank {i_sub_rank} of {n_sub_rank}" | ||
full_msg = f"\n-------------------------------- {msg} --------------------------------" | ||
fake_trace_back = ReprTraceback([ReprEntryNative(full_msg)], None, None) | ||
collect_longrepr.append( | ||
(fake_trace_back, ReprFileLocation(*report_init.location), None) | ||
) | ||
collect_longrepr.append( | ||
(test_report.longrepr, ReprFileLocation(*report_init.location), None) | ||
) | ||
|
||
if len(collect_longrepr) > 0: | ||
glongrepr = ExceptionChainRepr(collect_longrepr) | ||
|
||
return goutcome, glongrepr | ||
|
||
|
||
def gather_report_on_local_rank_0(report): | ||
""" | ||
Gather reports from all procs participating in the test on rank 0 of the sub_comm | ||
""" | ||
sub_comm = report.sub_comm | ||
del report.sub_comm # No need to keep it in the report | ||
# Furthermore we need to serialize the report | ||
# and mpi4py does not know how to serialize report.sub_comm | ||
i_sub_rank = sub_comm.Get_rank() | ||
n_sub_rank = sub_comm.Get_size() | ||
|
||
if ( | ||
report.outcome != "skipped" | ||
): # Skipped test are only known by proc 0 -> no merge required | ||
# Warning: PyTest reports can actually be quite big | ||
request = sub_comm.isend(report, dest=0, tag=i_sub_rank) | ||
|
||
if i_sub_rank == 0: | ||
mpi_reports = n_sub_rank * [None] | ||
for _ in range(n_sub_rank): | ||
status = MPI.Status() | ||
|
||
mpi_report = sub_comm.recv( | ||
source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status | ||
) | ||
mpi_reports[status.Get_source()] = mpi_report | ||
|
||
assert ( | ||
None not in mpi_reports | ||
) # should have received from all ranks of `sub_comm` | ||
goutcome, glongrepr = gather_report(mpi_reports, n_sub_rank) | ||
|
||
report.outcome = goutcome | ||
report.longrepr = glongrepr | ||
|
||
request.wait() | ||
|
||
sub_comm.barrier() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import pytest | ||
|
||
from mpi4py import MPI | ||
|
||
import socket | ||
import pickle | ||
from . import socket_utils | ||
from .utils import get_n_proc_for_test, run_item_test | ||
from .gather_report import gather_report_on_local_rank_0 | ||
|
||
class ProcessWorker: | ||
def __init__(self, scheduler_ip_address, scheduler_port, test_idx, detach): | ||
self.scheduler_ip_address = scheduler_ip_address | ||
self.scheduler_port = scheduler_port | ||
self.test_idx = test_idx | ||
self.detach = detach | ||
|
||
@pytest.hookimpl(tryfirst=True) | ||
def pytest_runtestloop(self, session) -> bool: | ||
comm = MPI.COMM_WORLD | ||
assert len(session.items) == 1, f'INTERNAL FATAL ERROR in pytest_parallel with slurm scheduling: should only have one test per worker, but got {len(session.items)}' | ||
item = session.items[0] | ||
test_comm_size = get_n_proc_for_test(item) | ||
|
||
item.sub_comm = comm | ||
item.test_info = {'test_idx': self.test_idx, 'fatal_error': None} | ||
|
||
|
||
if comm.Get_size() != test_comm_size: # fatal error, SLURM and MPI do not interoperate correctly | ||
error_info = f'FATAL ERROR in pytest_parallel with slurm scheduling: test `{item.nodeid}`' \ | ||
f' uses a `comm` of size {test_comm_size} but was launched with size {comm.Get_size()}.\n' \ | ||
f' This generally indicates that `srun` does not interoperate correctly with MPI.' | ||
|
||
item.test_info['fatal_error'] = error_info | ||
else: # normal case: the test can be run | ||
nextitem = None | ||
run_item_test(item, nextitem, session) | ||
|
||
if not self.detach and comm.Get_rank() == 0: # not detached: proc 0 is expected to send results to scheduling process | ||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | ||
s.connect((self.scheduler_ip_address, self.scheduler_port)) | ||
socket_utils.send(s, pickle.dumps(item.test_info)) | ||
|
||
if item.test_info['fatal_error'] is not None: | ||
assert 0, f'{item.test_info["fatal_error"]}' | ||
|
||
return True | ||
|
||
@pytest.hookimpl(hookwrapper=True) | ||
def pytest_runtest_makereport(self, item): | ||
""" | ||
We need to hook to pass the test sub-comm to `pytest_runtest_logreport`, | ||
and for that we add the sub-comm to the only argument of `pytest_runtest_logreport`, that is, `report` | ||
We also need to pass `item.test_info` so that we can update it | ||
""" | ||
result = yield | ||
report = result.get_result() | ||
report.sub_comm = item.sub_comm | ||
report.test_info = item.test_info | ||
|
||
@pytest.hookimpl(tryfirst=True) | ||
def pytest_runtest_logreport(self, report): | ||
assert report.when in ("setup", "call", "teardown") # only known tags | ||
gather_report_on_local_rank_0(report) | ||
report.test_info.update({report.when: {'outcome' : report.outcome, | ||
'longrepr': report.longrepr, | ||
'duration': report.duration, }}) | ||
|
||
|