Skip to content

Commit

Permalink
Minor file reorganization
Browse files Browse the repository at this point in the history
  • Loading branch information
BerengerBerthoul committed Apr 22, 2024
1 parent 96fc77e commit ef93e14
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 143 deletions.
77 changes: 77 additions & 0 deletions pytest_parallel/gather_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from mpi4py import MPI
from _pytest._code.code import (
ExceptionChainRepr,
ReprTraceback,
ReprEntryNative,
ReprFileLocation,
)


def gather_report(mpi_reports, n_sub_rank):
assert len(mpi_reports) == n_sub_rank

report_init = mpi_reports[0]
goutcome = report_init.outcome
glongrepr = report_init.longrepr

collect_longrepr = []
# > We need to rebuild a TestReport object, location can be false # TODO ?
for i_sub_rank, test_report in enumerate(mpi_reports):
if test_report.outcome == "failed":
goutcome = "failed"

if test_report.longrepr:
msg = f"On rank {i_sub_rank} of {n_sub_rank}"
full_msg = f"\n-------------------------------- {msg} --------------------------------"
fake_trace_back = ReprTraceback([ReprEntryNative(full_msg)], None, None)
collect_longrepr.append(
(fake_trace_back, ReprFileLocation(*report_init.location), None)
)
collect_longrepr.append(
(test_report.longrepr, ReprFileLocation(*report_init.location), None)
)

if len(collect_longrepr) > 0:
glongrepr = ExceptionChainRepr(collect_longrepr)

return goutcome, glongrepr


def gather_report_on_local_rank_0(report):
"""
Gather reports from all procs participating in the test on rank 0 of the sub_comm
"""
sub_comm = report.sub_comm
del report.sub_comm # No need to keep it in the report
# Furthermore we need to serialize the report
# and mpi4py does not know how to serialize report.sub_comm
i_sub_rank = sub_comm.Get_rank()
n_sub_rank = sub_comm.Get_size()

if (
report.outcome != "skipped"
): # Skipped test are only known by proc 0 -> no merge required
# Warning: PyTest reports can actually be quite big
request = sub_comm.isend(report, dest=0, tag=i_sub_rank)

if i_sub_rank == 0:
mpi_reports = n_sub_rank * [None]
for _ in range(n_sub_rank):
status = MPI.Status()

mpi_report = sub_comm.recv(
source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status
)
mpi_reports[status.Get_source()] = mpi_report

assert (
None not in mpi_reports
) # should have received from all ranks of `sub_comm`
goutcome, glongrepr = gather_report(mpi_reports, n_sub_rank)

report.outcome = goutcome
report.longrepr = glongrepr

request.wait()

sub_comm.barrier()
142 changes: 1 addition & 141 deletions pytest_parallel/mpi_reporter.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
import numpy as np
import pytest
from _pytest._code.code import (
ExceptionChainRepr,
ReprTraceback,
ReprEntryNative,
ReprFileLocation,
)
from mpi4py import MPI

from .algo import partition, lower_bound
from .utils import get_n_proc_for_test, add_n_procs, run_item_test, mark_original_index
from .utils_mpi import number_of_working_processes, is_dyn_master_process
from .gather_report import gather_report_on_local_rank_0


def mark_skip(item):
Expand All @@ -22,76 +17,6 @@ def mark_skip(item):
item.marker_mpi_skip = True


def gather_report(mpi_reports, n_sub_rank):
assert len(mpi_reports) == n_sub_rank

report_init = mpi_reports[0]
goutcome = report_init.outcome
glongrepr = report_init.longrepr

collect_longrepr = []
# > We need to rebuild a TestReport object, location can be false # TODO ?
for i_sub_rank, test_report in enumerate(mpi_reports):
if test_report.outcome == "failed":
goutcome = "failed"

if test_report.longrepr:
msg = f"On rank {i_sub_rank} of {n_sub_rank}"
full_msg = f"\n-------------------------------- {msg} --------------------------------"
fake_trace_back = ReprTraceback([ReprEntryNative(full_msg)], None, None)
collect_longrepr.append(
(fake_trace_back, ReprFileLocation(*report_init.location), None)
)
collect_longrepr.append(
(test_report.longrepr, ReprFileLocation(*report_init.location), None)
)

if len(collect_longrepr) > 0:
glongrepr = ExceptionChainRepr(collect_longrepr)

return goutcome, glongrepr


def gather_report_on_local_rank_0(report):
"""
Gather reports from all procs participating in the test on rank 0 of the sub_comm
"""
sub_comm = report.sub_comm
del report.sub_comm # No need to keep it in the report
# Furthermore we need to serialize the report
# and mpi4py does not know how to serialize report.sub_comm
i_sub_rank = sub_comm.Get_rank()
n_sub_rank = sub_comm.Get_size()

if (
report.outcome != "skipped"
): # Skipped test are only known by proc 0 -> no merge required
# Warning: PyTest reports can actually be quite big
request = sub_comm.isend(report, dest=0, tag=i_sub_rank)

if i_sub_rank == 0:
mpi_reports = n_sub_rank * [None]
for _ in range(n_sub_rank):
status = MPI.Status()

mpi_report = sub_comm.recv(
source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status
)
mpi_reports[status.Get_source()] = mpi_report

assert (
None not in mpi_reports
) # should have received from all ranks of `sub_comm`
goutcome, glongrepr = gather_report(mpi_reports, n_sub_rank)

report.outcome = goutcome
report.longrepr = glongrepr

request.wait()

sub_comm.barrier()


def filter_and_add_sub_comm(items, global_comm):
i_rank = global_comm.Get_rank()
n_workers = global_comm.Get_size()
Expand Down Expand Up @@ -584,68 +509,3 @@ def pytest_runtest_logreport(self, report):
report.outcome = mpi_report.outcome
report.longrepr = mpi_report.longrepr
report.duration = mpi_report.duration


import socket
import pickle
from . import socket_utils

class ProcessWorker:
def __init__(self, scheduler_ip_address, scheduler_port, test_idx, detach):
self.scheduler_ip_address = scheduler_ip_address
self.scheduler_port = scheduler_port
self.test_idx = test_idx
self.detach = detach

@pytest.hookimpl(tryfirst=True)
def pytest_runtestloop(self, session) -> bool:
comm = MPI.COMM_WORLD
assert len(session.items) == 1, f'INTERNAL FATAL ERROR in pytest_parallel with slurm scheduling: should only have one test per worker'
item = session.items[0]
test_comm_size = get_n_proc_for_test(item)

item.sub_comm = comm
item.test_info = {'test_idx': self.test_idx, 'fatal_error': None}


if comm.Get_size() != test_comm_size: # fatal error, SLURM and MPI do not interoperate correctly
error_info = f'FATAL ERROR in pytest_parallel with slurm scheduling: test `{item.nodeid}`' \
f' uses a `comm` of size {test_comm_size} but was launched with size {comm.Get_size()}.\n' \
f' This generally indicates that `srun` does not interoperate correctly with MPI.'

item.test_info['fatal_error'] = error_info
else: # normal case: the test can be run
nextitem = None
run_item_test(item, nextitem, session)

if not self.detach and comm.Get_rank() == 0: # not detached: proc 0 is expected to send results to scheduling process
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect((self.scheduler_ip_address, self.scheduler_port))
socket_utils.send(s, pickle.dumps(item.test_info))

if item.test_info['fatal_error'] is not None:
assert 0, f'{item.test_info["fatal_error"]}'

return True

@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_makereport(self, item):
"""
We need to hook to pass the test sub-comm to `pytest_runtest_logreport`,
and for that we add the sub-comm to the only argument of `pytest_runtest_logreport`, that is, `report`
We also need to pass `item.test_info` so that we can update it
"""
result = yield
report = result.get_result()
report.sub_comm = item.sub_comm
report.test_info = item.test_info

@pytest.hookimpl(tryfirst=True)
def pytest_runtest_logreport(self, report):
assert report.when in ("setup", "call", "teardown") # only known tags
gather_report_on_local_rank_0(report)
report.test_info.update({report.when: {'outcome' : report.outcome,
'longrepr': report.longrepr,
'duration': report.duration, }})


3 changes: 2 additions & 1 deletion pytest_parallel/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def pytest_configure(config):

else:
from mpi4py import MPI
from .mpi_reporter import SequentialScheduler, StaticScheduler, DynamicScheduler, ProcessWorker
from .mpi_reporter import SequentialScheduler, StaticScheduler, DynamicScheduler
from .process_worker import ProcessWorker
from .utils_mpi import spawn_master_process, should_enable_terminal_reporter

global_comm = MPI.COMM_WORLD
Expand Down
2 changes: 1 addition & 1 deletion pytest_parallel/process_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def submit_items(items_to_run, socket, main_invoke_params, slurm_ntasks, slurm_c
sbatch_cmd = slurm_conf['sub_command'] + ' pytest_slurm/job.sh'

p = subprocess.Popen([sbatch_cmd], shell=True, stdout=subprocess.PIPE)
print('Submitting tests to SLURM...')
print('\nSubmitting tests to SLURM...')
returncode = p.wait()
assert returncode==0, f'Error when submitting to SLURM with `{sbatch_cmd}`'

Expand Down
69 changes: 69 additions & 0 deletions pytest_parallel/process_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest

from mpi4py import MPI

import socket
import pickle
from . import socket_utils
from .utils import get_n_proc_for_test, run_item_test
from .gather_report import gather_report_on_local_rank_0

class ProcessWorker:
def __init__(self, scheduler_ip_address, scheduler_port, test_idx, detach):
self.scheduler_ip_address = scheduler_ip_address
self.scheduler_port = scheduler_port
self.test_idx = test_idx
self.detach = detach

@pytest.hookimpl(tryfirst=True)
def pytest_runtestloop(self, session) -> bool:
comm = MPI.COMM_WORLD
assert len(session.items) == 1, f'INTERNAL FATAL ERROR in pytest_parallel with slurm scheduling: should only have one test per worker, but got {len(session.items)}'
item = session.items[0]
test_comm_size = get_n_proc_for_test(item)

item.sub_comm = comm
item.test_info = {'test_idx': self.test_idx, 'fatal_error': None}


if comm.Get_size() != test_comm_size: # fatal error, SLURM and MPI do not interoperate correctly
error_info = f'FATAL ERROR in pytest_parallel with slurm scheduling: test `{item.nodeid}`' \
f' uses a `comm` of size {test_comm_size} but was launched with size {comm.Get_size()}.\n' \
f' This generally indicates that `srun` does not interoperate correctly with MPI.'

item.test_info['fatal_error'] = error_info
else: # normal case: the test can be run
nextitem = None
run_item_test(item, nextitem, session)

if not self.detach and comm.Get_rank() == 0: # not detached: proc 0 is expected to send results to scheduling process
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect((self.scheduler_ip_address, self.scheduler_port))
socket_utils.send(s, pickle.dumps(item.test_info))

if item.test_info['fatal_error'] is not None:
assert 0, f'{item.test_info["fatal_error"]}'

return True

@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_makereport(self, item):
"""
We need to hook to pass the test sub-comm to `pytest_runtest_logreport`,
and for that we add the sub-comm to the only argument of `pytest_runtest_logreport`, that is, `report`
We also need to pass `item.test_info` so that we can update it
"""
result = yield
report = result.get_result()
report.sub_comm = item.sub_comm
report.test_info = item.test_info

@pytest.hookimpl(tryfirst=True)
def pytest_runtest_logreport(self, report):
assert report.when in ("setup", "call", "teardown") # only known tags
gather_report_on_local_rank_0(report)
report.test_info.update({report.when: {'outcome' : report.outcome,
'longrepr': report.longrepr,
'duration': report.duration, }})


0 comments on commit ef93e14

Please sign in to comment.