Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt SLURM engine for multi-node jobs #212

Merged
merged 9 commits into from
Nov 4, 2024
Merged
2 changes: 1 addition & 1 deletion sisyphus/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def _sis_path(self, path_type=None, task_id=None, abspath=False):

# Add task id as suffix
if task_id is not None:
path += ".%i" % task_id
path += f".{task_id}"

if abspath and not os.path.isabs(path):
path = os.path.join(gs.BASE_DIR, path)
Expand Down
39 changes: 31 additions & 8 deletions sisyphus/simple_linux_utility_for_resource_management_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ def options(self, rqmt):
out.append("--time=%s" % task_time)
out.append("--export=all")

if rqmt.get("multi_node_slots", None):
if rqmt.get("multi_node_slots", 1) > 1:
out.append("--ntasks=%s" % rqmt["multi_node_slots"])
out.append("--nodes=%s" % rqmt["multi_node_slots"])

sbatch_args = rqmt.get("sbatch_args", [])
if isinstance(sbatch_args, str):
Expand Down Expand Up @@ -232,11 +233,15 @@ def submit_helper(self, call, logpath, rqmt, name, task_name, start_id, end_id,
:param int step_size:
"""
name = self.process_task_name(name)
sbatch_call = ["sbatch", "-J", name, "-o", logpath + "/%x.%A.%a", "--mail-type=None"]
out_log_file = f"{logpath}/%x.%A.%a"
if rqmt.get("multi_node_slots", 1) > 1:
out_log_file += ".%t"
sbatch_call = ["sbatch", "-J", name, "--mail-type=None"]
sbatch_call += self.options(rqmt)
sbatch_call += ["-o", f"{out_log_file}.batch"]
albertz marked this conversation as resolved.
Show resolved Hide resolved
sbatch_call += ["-a", f"{start_id}-{end_id}:{step_size}"]
sbatch_call += [f"--wrap=srun -o {out_log_file} {' '.join(call)}"]

sbatch_call += ["-a", "%i-%i:%i" % (start_id, end_id, step_size)]
sbatch_call += ["--wrap=%s" % " ".join(call)]
while True:
try:
out, err, retval = self.system_call(sbatch_call)
Expand Down Expand Up @@ -393,20 +398,38 @@ def get_default_rqmt(self, task):

def init_worker(self, task):
# setup log file by linking to engine logfile
task_id = self.get_task_id(None)
logpath = os.path.relpath(task.path(gs.JOB_LOG, task_id))

# Naming ambiguity: sis "tasks" are what SLURM calls array jobs.
#
# SLURM tasks represent jobs that span multiple nodes at the same time
# (e.g. multi-node multi-GPU trainings consist of one SLURM task per node).
slurm_num_tasks = int(
next(filter(None, (os.getenv(var, None) for var in ["SLURM_NTASKS", "SLURM_NPROCS"])), "1")
)
slurm_task_id = int(os.getenv("SLURM_PROCID", "0"))
array_task_id = self.get_task_id(None)

# keep backwards compatibility: only change output file name for multi-SLURM-task jobs
log_suffix = array_task_id if slurm_num_tasks <= 1 else f"{array_task_id}.{slurm_task_id}"
albertz marked this conversation as resolved.
Show resolved Hide resolved
logpath = os.path.relpath(task.path(gs.JOB_LOG, log_suffix))
if os.path.isfile(logpath):
os.unlink(logpath)

job_id = next(
filter(None, (os.getenv(name, None) for name in ["SLURM_JOB_ID", "SLURM_JOBID", "SLURM_ARRAY_JOB_ID"])), "0"
)
engine_logpath = (
os.path.dirname(logpath)
+ "/engine/"
+ os.getenv("SLURM_JOB_NAME")
+ "."
+ os.getenv("SLURM_ARRAY_JOB_ID")
+ job_id
+ "."
+ os.getenv("SLURM_ARRAY_TASK_ID")
+ os.getenv("SLURM_ARRAY_TASK_ID", "1")
)
if slurm_num_tasks > 1:
engine_logpath += f".{slurm_task_id}"

try:
if os.path.isfile(engine_logpath):
os.link(engine_logpath, logpath)
Expand Down
Loading