Skip to content

Commit

Permalink
tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
BerengerBerthoul committed Mar 15, 2024
1 parent 9c6cc33 commit 61e37dc
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 68 deletions.
75 changes: 46 additions & 29 deletions pytest_parallel/mpi_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,12 +596,13 @@ def pytest_runtest_logreport(self, report):
import socket
import pickle
from . import socket_utils
from pathlib import Path

class ProcessWorker:
def __init__(self, n_working_procs, scheduler_ip_address, scheduler_port, test_idx):
self.n_working_procs = n_working_procs
def __init__(self, scheduler_ip_address, scheduler_port, slurm_ntasks, test_idx):
self.scheduler_ip_address = scheduler_ip_address
self.scheduler_port = scheduler_port
self.slurm_ntasks = slurm_ntasks
self.test_idx = test_idx

@pytest.hookimpl(tryfirst=True)
Expand Down Expand Up @@ -640,59 +641,67 @@ def pytest_runtest_logreport(self, report):



LOCALHOST = '127.0.0.1'
def submit_items(items_to_run, socket, n_working_procs, main_invoke_params):
def replace_sub_strings(s, subs, replacement):
res = s
for sub in subs:
res = res.replace(sub,replacement)
return res

def remove_exotic_chars(s):
return replace_sub_strings(str(s), ['[',']','/', ':'], '_')

def submit_items(items_to_run, socket, main_invoke_params, slurm_ntasks, slurm_options):
# setup master's socket
SCHEDULER_IP_ADDRESS='10.33.240.8' # spiro07-clu
socket.bind((SCHEDULER_IP_ADDRESS, 0)) # 0: let the OS choose an available port
socket.listen()
port = socket.getsockname()[1]


# sort item by comm size
items = sorted(items_to_run, key=lambda item: item.n_proc, reverse=True)

# launch srun for each item
cmds = ''

cmds = f'WORKER_FLAGS=--_worker --_scheduler_ip_address={SCHEDULER_IP_ADDRESS} --_scheduler_port={port} --_slurm_ntasks={slurm_ntasks}\n\n'
for item in items:
test_idx = item.original_index
test_out_file_base = f'pytest_parallel_slurm/{remove_exotic_chars(item.nodeid)}'
cmd = f'srun --exclusive --ntasks={item.n_proc} -l'
cmd += f' python3 -u -m pytest {main_invoke_params}'
cmd += f' --_worker --_scheduler_ip_address={SCHEDULER_IP_ADDRESS}'
cmd += f' --_scheduler_port={port} --_test_idx={test_idx}'
cmd += f' python3 -u -m pytest $WORKER_FLAGS {main_invoke_params} --_test_idx={test_idx}\n'

#cmd += f' python3 -u ~/dev/pytest_parallel/slurm/worker.py {SCHEDULER_IP_ADDRESS} {port} {test_idx}'
cmd += f' > out_{test_idx}.txt 2> err_{test_idx}.txt'
cmd += f' > {test_out_file_base}.out 2> {test_out_file_base}.err'
cmd += ' &' # launch everything in parallel
cmds += cmd + '\n'
cmds += 'wait\n'

pytest_slurm = f'''#!/bin/bash
#SBATCH --job-name=pytest_par
#SBATCH --time 00:30:00
#SBATCH --job-name=pytest_parallel
#SBATCH --ntasks={slurm_ntasks}
#SBATCH --time 00:10:00
#SBATCH --qos=co_short_std
#SBATCH --ntasks={n_working_procs}
##SBATCH --nodes=2-2
##SBATCH --nodes=4-4
#SBATCH --nodes=1-1
#SBATCH --output=slurm.%j.out
#SBATCH --error=slurm.%j.err
#SBATCH --output=pytest_parallel_slurm/slurm.%j.out
#SBATCH --error=pytest_parallel_slurm/slurm.%j.err
#source /scratchm/sonics/dist/source.sh --env maia --compiler gcc@12 --mpi intel-oneapi
module load socle-cfd/6.0-intel2220-impi
export PYTHONPATH=/stck/bberthou/dev/pytest_parallel:$PYTHONPATH
export PYTEST_PLUGINS=pytest_parallel.plugin
{cmds}
wait
'''

with open('pytest_slurm.sh','w') as f:
Path('pytest_parallel_slurm').mkdir(exist_ok=True)
with open('pytest_parallel_slurm/job.sh','w') as f:
f.write(pytest_slurm)

## submit SLURM job
sbatch_cmd = 'sbatch pytest_slurm.sh'
p = subprocess.Popen([sbatch_cmd], shell=True)
sbatch_cmd = 'sbatch --parsable pytest_parallel_slurm/job.sh'
p = subprocess.Popen([sbatch_cmd], shell=True, stdout=subprocess.PIPE)
print('Submitting tests to SLURM...')
returncode = p.wait()
slurm_job_id = int(p.stdout.read())
assert returncode==0, f'Error when submitting to SLURM with `{sbatch_cmd}`'
print(f'SLURM job {slurm_job_id} has been submitted')
return slurm_job_id

def receive_items(items, session, socket):
n = len(items)
Expand All @@ -713,11 +722,13 @@ def receive_items(items, session, socket):
n -= 1

class ProcessScheduler:
def __init__(self, n_working_procs, main_invoke_params):
self.n_working_procs = n_working_procs
self.current_item_requests = []
def __init__(self, main_invoke_params, slurm_ntasks, slurm_options):
self.main_invoke_params = main_invoke_params
self.slurm_ntasks = slurm_ntasks
self.slurm_options = slurm_options

self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # TODO close at the end
self.slurm_job_id = None

@pytest.hookimpl(tryfirst=True)
def pytest_pyfunc_call(self, pyfuncitem):
Expand Down Expand Up @@ -747,7 +758,7 @@ def pytest_runtestloop(self, session) -> bool:
add_n_procs(session.items)

# isolate skips
has_enough_procs = lambda item: item.n_proc <= self.n_working_procs
has_enough_procs = lambda item: item.n_proc <= self.slurm_ntasks
items_to_run, items_to_skip = partition(session.items, has_enough_procs)

# run skipped
Expand All @@ -758,11 +769,17 @@ def pytest_runtestloop(self, session) -> bool:
run_item_test(item, nextitem, session)

# schedule tests to run
submit_items(items_to_run, self.socket, self.n_working_procs, self.main_invoke_params)
self.slurm_job_id = submit_items(items_to_run, self.socket, self.main_invoke_params, self.slurm_ntasks, self.slurm_options)
receive_items(session.items, session, self.socket)

return True

@pytest.hookimpl()
def pytest_keyboard_interrupt(excinfo):
if excinfo.slurm_job_id is not None:
print(f'Calling `scancel {excinfo.slurm_job_id}`')
subprocess.run(['scancel',str(excinfo.slurm_job_id)])

@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_makereport(self, item):
"""
Expand Down
49 changes: 37 additions & 12 deletions pytest_parallel/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,47 @@ def pytest_addoption(parser):
type='choice',
choices=['sequential', 'static', 'dynamic', 'shell', 'slurm'],
default='sequential',
help='Method used by pytest_parallel to schedule tests',
)

parser.addoption('--max_n_proc', dest='max_n_proc', type=int)
parser.addoption('--slurm_options', dest='slurm_options', type=str, help='list of SLURM options e.g. "--time=00:30:00 --qos=my_queue --n_tasks=4"')

# Private to SLURM scheduler
parser.addoption('--_worker', dest='_worker', action='store_true')
parser.addoption('--_scheduler_ip_address', dest='_scheduler_ip_address', type=str)
parser.addoption('--_scheduler_port', dest='_scheduler_port', type=int)
parser.addoption('--_test_idx' , dest='_test_idx' , type=int)
parser.addoption('--_worker', dest='_worker', action='store_true', help='Internal pytest_parallel option')
parser.addoption('--_scheduler_ip_address', dest='_scheduler_ip_address', type=str, help='Internal pytest_parallel option')
parser.addoption('--_scheduler_port', dest='_scheduler_port', type=int, help='Internal pytest_parallel option')
parser.addoption('--_test_idx' , dest='_test_idx' , type=int, help='Internal pytest_parallel option')
parser.addoption('--_slurm_ntasks', dest='_slurm_ntasks', type=int)

# create SBATCH header
def parse_slurm_options(opt_str):
opts = opt_str.split()
for opt in opts:
if '--ntasks' in opt:
assert opt[0:len('--ntasks')] == '--ntasks', 'pytest_parallel SLURM scheduler: parsing error for `--ntasks`'
ntasks_val = opt[len('--ntasks'):]
assert ntasks_val[0]==' ' or ntasks_val[0]=='=', 'pytest_parallel SLURM scheduler: parsing error for `--ntasks`'
try:
ntasks = int(ntasks_val[1:])
except ValueError:
assert ntasks_val[0]==' ' or ntasks_val[0]=='=', 'pytest_parallel SLURM scheduler: parsing error for `--ntasks`'
return ntasks, opts

assert 0, 'pytest_parallel SLURM scheduler: you need specify --ntasks in slurm_options'

# --------------------------------------------------------------------------
@pytest.hookimpl(trylast=True)
def pytest_configure(config):
global_comm = MPI.COMM_WORLD

# Get options and check dependent/incompatible options
scheduler = config.getoption('scheduler')
slurm_worker = config.getoption('_worker') # only meaningful if scheduler == 'slurm'
slurm_options = config.getoption('slurm_options')
slurm_worker = config.getoption('_worker')
## !slurm => !slurm_worker
if scheduler != 'slurm':
assert not slurm_worker
assert not slurm_options

if scheduler == 'sequential':
plugin = SequentialScheduler(global_comm)
Expand All @@ -46,22 +70,23 @@ def pytest_configure(config):
inter_comm = spawn_master_process(global_comm)
plugin = DynamicScheduler(global_comm, inter_comm)
elif scheduler == 'slurm':
if config.getoption('_worker'):
n_working_procs = config.getoption('max_n_proc')
if slurm_worker:
scheduler_ip_address = config.getoption('_scheduler_ip_address')
scheduler_port = config.getoption('_scheduler_port')
test_idx = config.getoption('_test_idx')
plugin = ProcessWorker(n_working_procs, scheduler_ip_address, scheduler_port, test_idx)
slurm_ntasks = config.getoption('slurm_ntasks')
plugin = ProcessWorker(slurm_ntasks, scheduler_ip_address, scheduler_port, test_idx)
else: # scheduler
assert global_comm.Get_size() == 1, 'pytest_parallel usage error: \
when scheduling with SLURM, \
do not launch the scheduling itself in parallel \
(do NOT use `mpirun -np n pytest...`)'

n_working_procs = int(config.getoption('max_n_proc'))

# List of all invoke options except slurm options
main_invoke_params = ' '.join(config.invocation_params.args)
plugin = ProcessScheduler(n_working_procs, main_invoke_params)
main_invoke_params = ''.join( main_invoke_params.split(f'--slurm_options={slurm_options}') )
slurm_ntasks, slurm_options = parse_slurm_options(slurm_options)
plugin = ProcessScheduler(main_invoke_params, slurm_ntasks, slurm_options)
else:
assert 0

Expand Down
19 changes: 0 additions & 19 deletions pytest_slurm.sh

This file was deleted.

File renamed without changes.
11 changes: 4 additions & 7 deletions run.sh → slurm/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@
#SBATCH --job-name=pytest_par
#SBATCH --time 00:30:00
#SBATCH --qos=co_short_std
#SBATCH --ntasks=88
#SBATCH --nodes=2-2
#SBATCH --ntasks=1
##SBATCH --nodes=2-2
#SBATCH --output=slurm.%j.out
#SBATCH --error=slurm.%j.err

#echo $TOTO
whoami
echo "before 88"
srun --exclusive --ntasks=88 -l hostname &
echo "after 88"
wait
echo "after wait"
#srun --exclusive --ntasks=1 -l hostname
nproc --all
3 changes: 2 additions & 1 deletion test/test_pytest_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def run_pytest_parallel_test(test_name, n_workers, scheduler, capfd, suffix=""):


param_scheduler = (
["sequential", "static", "dynamic"]
["static", "dynamic"]
#["sequential"]
if sys.platform != "win32"
else ["sequential", "static"]
)
Expand Down

0 comments on commit 61e37dc

Please sign in to comment.