Skip to content

Commit

Permalink
AverageTorchCheckpointsJob (#471)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz authored Jan 3, 2024
1 parent 11fb91b commit d48aa82
Showing 1 changed file with 42 additions and 1 deletion.
43 changes: 42 additions & 1 deletion returnn/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import shutil
import subprocess as sp
from typing import Dict, Iterable, List, Optional, Union
from typing import Dict, Sequence, Iterable, List, Optional, Union

from sisyphus import *

Expand Down Expand Up @@ -916,3 +916,44 @@ def run(self):

# The env override is needed if this job is run locally on a node with a GPU installed
sp.check_call(args, env={"CUDA_VISIBLE_DEVICES": ""})


class AverageTorchCheckpointsJob(Job):
"""
average Torch model checkpoints
"""

def __init__(
self,
*,
checkpoints: Sequence[Union[tk.Path, PtCheckpoint]],
returnn_python_exe: tk.Path,
returnn_root: tk.Path,
):
"""
:param checkpoints: input checkpoints
:param returnn_python_exe: file path to the executable for running returnn (python binary or .sh)
:param returnn_root: file path to the RETURNN repository root folder
"""
self.checkpoints = [ckpt if isinstance(ckpt, PtCheckpoint) else PtCheckpoint(ckpt) for ckpt in checkpoints]
self.returnn_python_exe = returnn_python_exe
self.returnn_root = returnn_root

self.out_checkpoint = PtCheckpoint(self.output_path("model/average.pt"))

self.rqmt = {"cpu": 1, "time": 0.5, "mem": 5}

def tasks(self):
yield Task("run", rqmt=self.rqmt)

def run(self):
os.makedirs(os.path.dirname(self.out_checkpoint.path.get_path()), exist_ok=True)
args = [
self.returnn_python_exe.get_path(),
os.path.join(self.returnn_root.get_path(), "tools/torch_avg_checkpoints.py"),
"--checkpoints",
*[ckpt.path.get_path() for ckpt in self.checkpoints],
"--output_path",
self.out_checkpoint.path.get_path(),
]
sp.check_call(args)

0 comments on commit d48aa82

Please sign in to comment.