From 8c9debeb447cf7eed857482c5ad21bacb70b1649 Mon Sep 17 00:00:00 2001 From: Andreas Date: Tue, 9 Jul 2024 15:12:57 +0200 Subject: [PATCH] fix plot() in FairseqHydraTrain() --- fairseq/training.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fairseq/training.py b/fairseq/training.py index b965be40..5142e7a6 100755 --- a/fairseq/training.py +++ b/fairseq/training.py @@ -343,8 +343,8 @@ def plot(self): i = 0 while i < len(lines): line = lines[i] - if "begin validation on" in line or "end of epoch" in line: - epoch_dict = eval(lines[i + 1][lines[i + 1].index("{") :]) + if "[train][INFO]" in line or "[valid][INFO]" in line: + epoch_dict = eval(line[line.index("{") :]) try: epoch = int(epoch_dict["epoch"]) losses = {k: {epoch: float(v)} for k, v in epoch_dict.items() if k.endswith("_loss")} @@ -355,12 +355,12 @@ def plot(self): continue if "train_lr" in epoch_dict: learning_rates[epoch] = float(epoch_dict["train_lr"]) - if "begin validation on" in line: + if "[valid][INFO]" in line: for k in losses.keys(): valid_loss[k].update(losses[k]) for k in accuracy.keys(): valid_accuracy[k].update(accuracy[k]) - else: + else: # [train][INFO] is in line for k in losses.keys(): train_loss[k].update(losses[k]) for k in accuracy.keys():