Skip to content

Commit

Permalink
cleanup models, sorted
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jan 7, 2025
1 parent 4f2cb47 commit e420e38
Showing 1 changed file with 39 additions and 9 deletions.
48 changes: 39 additions & 9 deletions users/zeyer/sis_tools/cleanup_unused_train_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,32 +50,43 @@ def main():
from sisyphus import gs
from i6_core.returnn.training import ReturnnTrainingJob
from i6_experiments.users.zeyer.utils import job_aliases_from_log
from i6_experiments.users.zeyer.utils.set_insert_order import SetInsertOrder
from returnn.util.basic import human_bytes_size

# HACK: Replace the set() by SetInsertOrder() to make the order deterministic.
graph.graph._targets = SetInsertOrder()

gs.WARNING_ABSPATH = False
gs.GRAPH_WORKER = 1 # makes the order deterministic, easier to reason about

sisyphus.logging_format.add_coloring_to_logging()
logging.basicConfig(format="[%(asctime)s] %(levelname)s: %(message)s", level=args.log_level)

print("Loading Sisyphus configs...")
config_manager.load_configs(args.config)

print("Checking active train jobs of the Sisyphus graph...")
active_train_job_paths = set()
for job in graph.graph.jobs():
if not isinstance(job, ReturnnTrainingJob):
continue
# print("active train job:", job._sis_path())
job_path = job._sis_path()
if os.path.isdir(job_path):
active_train_job_paths.add(job._sis_path())
active_train_job_paths.add(job_path)
aliases = job_aliases_from_log.get_job_aliases(job_path)
print("Active train job:", aliases[0] if aliases else job)
else:
print("Active train job not created yet:", job)
print("Num active train jobs:", len(active_train_job_paths))

print("Now checking all train jobs in work dir...")
total_model_size_to_remove = 0
total_train_job_count = 0
total_train_job_with_models_to_remove_count = 0
train_job_with_models_to_remove = []
unused_train_jobs = {} # key: alias (or basename as fallback), value: job path filename
model_fns_to_remove = []
found_active_count = 0 # as a sanity check
found_active_fns = set() # as a sanity check.
for basename in os.listdir("work/i6_core/returnn/training"):
if not basename.startswith("ReturnnTrainingJob."):
continue
Expand All @@ -97,13 +108,22 @@ def main():
pass

if fn in active_train_job_paths:
print("Active train job:", alias or basename)
found_active_count += 1
found_active_fns.add(fn)
continue

model_dir = fn + "/output/models"
if not os.path.isdir(model_dir):
continue # can happen when there was an early error, e.g. at file creation
# First collect all, and then go through them in sorted order below.
# We do this because here the listdir order is totally arbitrary
# (due to FS, but sorting by hash also would not help),
# and to inspect the output, it's much more helpful when this is sorted in some way.
unused_train_jobs[alias or basename] = fn

print("Collecting model checkpoint files to remove...")
# Now go sorted.
for name, fn in sorted(unused_train_jobs.items()):
model_dir = fn + "/output/models"
model_count = 0
model_size = 0
with os.scandir(model_dir) as it:
Expand All @@ -117,14 +137,24 @@ def main():
model_count += 1
if model_count == 0:
continue
print("Unused train job:", alias or basename, "model size:", human_bytes_size(model_size))
print("Unused train job:", name, "model size:", human_bytes_size(model_size))
total_model_size_to_remove += model_size
total_train_job_with_models_to_remove_count += 1
train_job_with_models_to_remove.append(name)

print("Total train job count:", total_train_job_count)
print("Total train job with models to remove count:", total_train_job_with_models_to_remove_count)
print("Total train job with models to remove count:", len(train_job_with_models_to_remove))
print("List of train jobs with models to remove:")
for alias in train_job_with_models_to_remove:
print(f" {alias}")
if not train_job_with_models_to_remove:
print(" (none)")
print("Can remove total model size:", human_bytes_size(total_model_size_to_remove))
assert found_active_count == len(active_train_job_paths), (found_active_count, len(active_train_job_paths))
if len(found_active_fns) != len(active_train_job_paths):
print("ERROR: Did not find some active jobs:")
for fn in active_train_job_paths:
if fn not in found_active_fns:
print(" ", fn)
raise Exception("Did not find some active jobs.")

if args.mode == "remove":
for fn in model_fns_to_remove:
Expand Down

0 comments on commit e420e38

Please sign in to comment.