Skip to content

Commit

Permalink
Making things more robust
Browse files Browse the repository at this point in the history
  • Loading branch information
BerengerBerthoul committed Mar 26, 2024
1 parent 6a43cdb commit 7ecaa23
Show file tree
Hide file tree
Showing 24 changed files with 66 additions and 34 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
68 changes: 44 additions & 24 deletions pytest_parallel/mpi_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,25 +599,40 @@ def pytest_runtest_logreport(self, report):
from pathlib import Path

class ProcessWorker:
def __init__(self, scheduler_ip_address, scheduler_port, test_idx):
def __init__(self, scheduler_ip_address, scheduler_port, test_idx, detach):
self.scheduler_ip_address = scheduler_ip_address
self.scheduler_port = scheduler_port
self.test_idx = test_idx
self.detach = detach

@pytest.hookimpl(tryfirst=True)
def pytest_runtestloop(self, session) -> bool:
comm = MPI.COMM_WORLD
item = session.items[self.test_idx]
test_comm_size = get_n_proc_for_test(item)

item.sub_comm = comm
item.test_info = {'test_idx': self.test_idx}
nextitem = None
run_item_test(item, nextitem, session)
item.test_info = {'test_idx': self.test_idx, 'fatal_error': None}


if comm.Get_size() != test_comm_size: # fatal error, SLURM and MPI do not interoperate correctly
error_info = f'FATAL ERROR in pytest_parallel with slurm scheduling: test `{item.nodeid}`' \
f' uses a `comm` of size {test_comm_size} but was launched with size {comm.Get_size()}.\n' \
f' This generally indicates that `srun` does not interoperate correctly with MPI.'

item.test_info['fatal_error'] = error_info
else: # normal case: the test can be run
nextitem = None
run_item_test(item, nextitem, session)

if comm.Get_rank() == 0:
if not self.detach and comm.Get_rank() == 0: # not detached: proc 0 is expected to send results to scheduling process
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect((self.scheduler_ip_address, self.scheduler_port))
socket_utils.send(s, pickle.dumps(item.test_info))

if item.test_info['fatal_error'] is not None:
assert 0, f'{item.test_info["fatal_error"]}'

return True

@pytest.hookimpl(hookwrapper=True)
Expand Down Expand Up @@ -649,7 +664,7 @@ def replace_sub_strings(s, subs, replacement):
def remove_exotic_chars(s):
return replace_sub_strings(str(s), ['[',']','/', ':'], '_')

def submit_items(items_to_run, socket, main_invoke_params, slurm_options):
def submit_items(items_to_run, socket, main_invoke_params, slurm_additional_cmds, slurm_options):
# Find IP our address
r = subprocess.run(['hostname','-I'], stdout=subprocess.PIPE)
assert r.returncode==0, f'SLURM scheduler: error getting IP address of {socket.gethostname()} with `hostname -I`'
Expand All @@ -671,39 +686,39 @@ def submit_items(items_to_run, socket, main_invoke_params, slurm_options):
items = sorted(items_to_run, key=lambda item: item.n_proc, reverse=True)

# launch srun for each item
cmds = f'WORKER_FLAGS="--_worker --_scheduler_ip_address={SCHEDULER_IP_ADDRESS} --_scheduler_port={port}"\n'
cmds += f'INVOKE_FLAGS="{main_invoke_params}"\n\n'
worker_flags=f"--_worker --_scheduler_ip_address={SCHEDULER_IP_ADDRESS} --_scheduler_port={port}"
cmds = ''
for item in items:
test_idx = item.original_index
test_out_file_base = f'pytest_slurm/{remove_exotic_chars(item.nodeid)}'
cmd = f'srun --exclusive --ntasks={item.n_proc} -l'
cmd += f' python3 -u -m pytest $WORKER_FLAGS $INVOKE_FLAGS --_test_idx={test_idx}'
cmd += f' > {test_idx}.out 2> {test_idx}.err'
#test_out_file_base = f'pytest_parallel_slurm/{remove_exotic_chars(item.nodeid)}'
#cmd += f' > {test_out_file_base}.out 2> {test_out_file_base}.err'
cmd += f' python3 -u -m pytest {worker_flags} {main_invoke_params} --_test_idx={test_idx}'
cmd += f' > {test_out_file_base} 2>&1'
cmd += ' &\n' # launch everything in parallel
cmds += cmd
cmds += 'wait\n'

pytest_slurm = f'''#!/bin/bash
#SBATCH --job-name=pytest_parallel
#SBATCH --output=pytest_parallel_slurm/slurm.%j.out
#SBATCH --error=pytest_parallel_slurm/slurm.%j.err
#SBATCH --output=pytest_slurm/slurm.%j.out
#SBATCH --error=pytest_slurm/slurm.%j.err
{slurm_header}
{slurm_additional_cmds}
{cmds}
'''
Path('pytest_parallel_slurm').mkdir(exist_ok=True)
with open('pytest_parallel_slurm/job.sh','w') as f:
Path('pytest_slurm').mkdir(exist_ok=True)
with open('pytest_slurm/job.sh','w') as f:
f.write(pytest_slurm)

## submit SLURM job
sbatch_cmd = 'sbatch --parsable pytest_parallel_slurm/job.sh'
sbatch_cmd = 'sbatch --parsable pytest_slurm/job.sh'
p = subprocess.Popen([sbatch_cmd], shell=True, stdout=subprocess.PIPE)
print('Submitting tests to SLURM...')
returncode = p.wait()
slurm_job_id = int(p.stdout.read())
assert returncode==0, f'Error when submitting to SLURM with `{sbatch_cmd}`'
slurm_job_id = int(p.stdout.read())
print(f'SLURM job {slurm_job_id} has been submitted')
return slurm_job_id

Expand All @@ -714,6 +729,8 @@ def receive_items(items, session, socket, n_item_to_recv):
msg = socket_utils.recv(conn)
test_info = pickle.loads(msg) # the worker is supposed to have send a dict with the correct structured information
test_idx = test_info['test_idx']
if test_info['fatal_error'] is not None:
assert 0, f'{test_info["fatal_error"]}'
item = items[test_idx]
item.sub_comm = MPI.COMM_NULL
item.info = test_info
Expand All @@ -724,10 +741,12 @@ def receive_items(items, session, socket, n_item_to_recv):
n_item_to_recv -= 1

class ProcessScheduler:
def __init__(self, main_invoke_params, slurm_ntasks, slurm_options):
self.main_invoke_params = main_invoke_params
self.slurm_ntasks = slurm_ntasks
self.slurm_options = slurm_options
def __init__(self, main_invoke_params, slurm_ntasks, slurm_options, slurm_additional_cmds, detach):
self.main_invoke_params = main_invoke_params
self.slurm_ntasks = slurm_ntasks
self.slurm_options = slurm_options
self.slurm_additional_cmds = slurm_additional_cmds
self.detach = detach

self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # TODO close at the end
self.slurm_job_id = None
Expand Down Expand Up @@ -773,8 +792,9 @@ def pytest_runtestloop(self, session) -> bool:
# schedule tests to run
n_item_to_receive = len(items_to_run)
if n_item_to_receive > 0:
self.slurm_job_id = submit_items(items_to_run, self.socket, self.main_invoke_params, self.slurm_options)
receive_items(session.items, session, self.socket, n_item_to_receive)
self.slurm_job_id = submit_items(items_to_run, self.socket, self.main_invoke_params, self.slurm_additional_cmds, self.slurm_options)
if not self.detach: # The job steps are supposed to send their reports
receive_items(session.items, session, self.socket, n_item_to_receive)

return True

Expand Down
32 changes: 22 additions & 10 deletions pytest_parallel/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ def pytest_addoption(parser):
'--scheduler',
dest='scheduler',
type='choice',
choices=['sequential', 'static', 'dynamic', 'shell', 'slurm'],
choices=['sequential', 'static', 'dynamic', 'slurm'],
default='sequential',
help='Method used by pytest_parallel to schedule tests',
)

parser.addoption('--slurm_options', dest='slurm_options', type=str, help='list of SLURM options e.g. "--time=00:30:00 --qos=my_queue --n_tasks=4"')
parser.addoption('--slurm-options', dest='slurm_options', type=str, help='list of SLURM options e.g. "--time=00:30:00 --qos=my_queue --n_tasks=4"')
parser.addoption('--slurm-additional-cmds', dest='slurm_additional_cmds', type=str, help='list of commands to pass to SLURM job e.g. "source my_env.sh"')
parser.addoption('--detach', dest='detach', action='store_true', help='Detach SLURM jobs: do not send reports to the scheduling process (useful to launch slurm job.sh separately)')

# Private to SLURM scheduler
parser.addoption('--_worker', dest='_worker', action='store_true', help='Internal pytest_parallel option')
Expand All @@ -45,7 +47,7 @@ def parse_slurm_options(opt_str):
assert ntasks_val[0]==' ' or ntasks_val[0]=='=', 'pytest_parallel SLURM scheduler: parsing error for `--ntasks`'
return ntasks, opts

assert 0, 'pytest_parallel SLURM scheduler: you need specify --ntasks in slurm_options'
assert 0, 'pytest_parallel SLURM scheduler: you need specify `--ntasks` in `--slurm-options`'

# --------------------------------------------------------------------------
@pytest.hookimpl(trylast=True)
Expand All @@ -55,11 +57,13 @@ def pytest_configure(config):
# Get options and check dependent/incompatible options
scheduler = config.getoption('scheduler')
slurm_options = config.getoption('slurm_options')
slurm_additional_cmds = config.getoption('slurm_additional_cmds')
slurm_worker = config.getoption('_worker')
## !slurm => !slurm_worker
detach = config.getoption('detach')
if scheduler != 'slurm':
assert not slurm_worker
assert not slurm_options
assert not slurm_additional_cmds

if scheduler == 'sequential':
plugin = SequentialScheduler(global_comm)
Expand All @@ -73,18 +77,26 @@ def pytest_configure(config):
scheduler_ip_address = config.getoption('_scheduler_ip_address')
scheduler_port = config.getoption('_scheduler_port')
test_idx = config.getoption('_test_idx')
plugin = ProcessWorker(scheduler_ip_address, scheduler_port, test_idx)
plugin = ProcessWorker(scheduler_ip_address, scheduler_port, test_idx, detach)
else: # scheduler
assert global_comm.Get_size() == 1, 'pytest_parallel usage error: \
when scheduling with SLURM, \
do not launch the scheduling itself in parallel \
(do NOT use `mpirun -np n pytest...`)'

# List of all invoke options except slurm options
main_invoke_params = ' '.join(config.invocation_params.args)
main_invoke_params = ''.join( main_invoke_params.split(f'--slurm_options={slurm_options}') )
## reconstruct complete invoke string
quoted_invoke_params = []
for arg in config.invocation_params.args:
if ' ' in arg and not '--slurm-options' in arg:
quoted_invoke_params.append("'"+arg+"'")
else:
quoted_invoke_params.append(arg)
main_invoke_params = ' '.join(quoted_invoke_params)
## pull `--slurm-options` appart for special treatement
main_invoke_params = ''.join( main_invoke_params.split(f'--slurm-options={slurm_options}') )
slurm_ntasks, slurm_options = parse_slurm_options(slurm_options)
plugin = ProcessScheduler(main_invoke_params, slurm_ntasks, slurm_options)
plugin = ProcessScheduler(main_invoke_params, slurm_ntasks, slurm_options, slurm_additional_cmds, detach)
else:
assert 0

Expand All @@ -100,9 +112,9 @@ def pytest_configure(config):
@pytest.fixture
def comm(request):
'''
Only return a previous MPI Communicator (build at prepare step )
Returns the MPI Communicator created by pytest_parallel
'''
return request.node.sub_comm # TODO clean
return request.node.sub_comm


# --------------------------------------------------------------------------
Expand Down

0 comments on commit 7ecaa23

Please sign in to comment.