Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt SLURM engine for multi-node jobs #212

Merged
merged 9 commits into from
Nov 4, 2024
Merged
2 changes: 1 addition & 1 deletion sisyphus/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def _sis_path(self, path_type=None, task_id=None, abspath=False):

# Add task id as suffix
if task_id is not None:
path += ".%i" % task_id
path += f".{task_id}"

if abspath and not os.path.isabs(path):
path = os.path.join(gs.BASE_DIR, path)
Expand Down
81 changes: 61 additions & 20 deletions sisyphus/simple_linux_utility_for_resource_management_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ def options(self, rqmt):
out.append("--time=%s" % task_time)
out.append("--export=all")

if rqmt.get("multi_node_slots", None):
if rqmt.get("multi_node_slots", 1) > 1:
out.append("--ntasks=%s" % rqmt["multi_node_slots"])
out.append("--nodes=%s" % rqmt["multi_node_slots"])

sbatch_args = rqmt.get("sbatch_args", [])
if isinstance(sbatch_args, str):
Expand Down Expand Up @@ -232,11 +233,13 @@ def submit_helper(self, call, logpath, rqmt, name, task_name, start_id, end_id,
:param int step_size:
"""
name = self.process_task_name(name)
sbatch_call = ["sbatch", "-J", name, "-o", logpath + "/%x.%A.%a", "--mail-type=None"]
out_log_file = logpath + "/%x.%A.%t.%a"
NeoLegends marked this conversation as resolved.
Show resolved Hide resolved
sbatch_call = ["sbatch", "-J", name, "--mail-type=None"]
sbatch_call += self.options(rqmt)
sbatch_call += ["-o", f"{out_log_file}.batch"]
albertz marked this conversation as resolved.
Show resolved Hide resolved
sbatch_call += ["-a", f"{start_id}-{end_id}:{step_size}"]
sbatch_call += [f"--wrap=srun -o {out_log_file} {' '.join(call)}"]

sbatch_call += ["-a", "%i-%i:%i" % (start_id, end_id, step_size)]
sbatch_call += ["--wrap=%s" % " ".join(call)]
while True:
try:
out, err, retval = self.system_call(sbatch_call)
Expand Down Expand Up @@ -393,28 +396,66 @@ def get_default_rqmt(self, task):

def init_worker(self, task):
# setup log file by linking to engine logfile
task_id = self.get_task_id(None)
logpath = os.path.relpath(task.path(gs.JOB_LOG, task_id))

# Naming ambiguity: sis "tasks" are what SLURM calls array jobs.
#
# SLURM tasks represent jobs that span multiple nodes at the same time
# (e.g. multi-node multi-GPU trainings consist of one SLURM task per node).
slurm_num_tasks = int(
next(filter(None, (os.getenv(var, None) for var in ["SLURM_NTASKS", "SLURM_NPROCS"])), "1")
)
slurm_task_id = int(os.getenv("SLURM_PROCID", "0"))

array_task_id = self.get_task_id(None)
# keep backwards compatibility: only change output file name for multi-SLURM-task jobs
log_suffix = array_task_id if slurm_num_tasks <= 1 else f"{array_task_id}.{slurm_task_id}"
albertz marked this conversation as resolved.
Show resolved Hide resolved
logpath = os.path.relpath(task.path(gs.JOB_LOG, log_suffix))
if os.path.isfile(logpath):
os.unlink(logpath)

engine_logpath = (
os.path.dirname(logpath)
+ "/engine/"
+ os.getenv("SLURM_JOB_NAME")
+ "."
+ os.getenv("SLURM_ARRAY_JOB_ID")
+ "."
+ os.getenv("SLURM_ARRAY_TASK_ID")
job_id = next(
filter(None, (os.getenv(name, None) for name in ["SLURM_JOB_ID", "SLURM_JOBID", "SLURM_ARRAY_JOB_ID"])), "0"
)
try:
if os.path.isfile(engine_logpath):
has_linked_logfile = False
engine_logpath_candidates = [
(
os.path.dirname(logpath)
+ "/engine/"
+ os.getenv("SLURM_JOB_NAME")
+ "."
+ job_id
+ "."
+ str(slurm_task_id)
+ "."
+ os.getenv("SLURM_ARRAY_TASK_ID", "1")
),
(
os.path.dirname(logpath)
+ "/engine/"
+ os.getenv("SLURM_JOB_NAME")
+ "."
+ job_id
+ "."
+ os.getenv("SLURM_ARRAY_TASK_ID", "1")
),
]
for engine_logpath in engine_logpath_candidates:
NeoLegends marked this conversation as resolved.
Show resolved Hide resolved
if not os.path.isfile(engine_logpath):
continue
try:
os.link(engine_logpath, logpath)
else:
logging.warning("Could not find engine logfile: %s Create soft link anyway." % engine_logpath)
has_linked_logfile = True
break
except FileExistsError:
pass

if not has_linked_logfile:
engine_logpath = engine_logpath_candidates[0]
logging.warning("Could not find engine logfile: %s Create soft link anyway." % engine_logpath)
try:
os.symlink(os.path.relpath(engine_logpath, os.path.dirname(logpath)), logpath)
except FileExistsError:
pass
except FileExistsError:
pass

def get_logpath(self, logpath_base, task_name, task_id):
"""Returns log file for the currently running task"""
Expand Down
Loading