From 6c8f72822469b52884d6d1324dd378decb19953f Mon Sep 17 00:00:00 2001 From: Pablo Olivares Date: Mon, 22 Apr 2024 12:04:31 +0200 Subject: [PATCH] Updated checkpoint filename advances #16 --- trainers/base_trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/trainers/base_trainer.py b/trainers/base_trainer.py index cd7ac07..a7f3b0d 100644 --- a/trainers/base_trainer.py +++ b/trainers/base_trainer.py @@ -35,7 +35,7 @@ def __init__(self, model, device): self.scheduler = None self.metrics = [] - def save_checkpoint(self, save_path, epoch, train_losses, val_losses, metric_values): + def save_checkpoint(self, save_path, epoch, train_losses, val_losses, metric_values) -> None: """ Saves the current state of the training process. Args: @@ -56,7 +56,7 @@ def save_checkpoint(self, save_path, epoch, train_losses, val_losses, metric_val if self.scheduler: state['scheduler_state_dict'] = self.scheduler.state_dict() - torch.save(state, os.path.join(save_path, f'checkpoint_epoch_{epoch}.pth')) + torch.save(state, os.path.join(save_path, f'checkpoint_epoch_{epoch+1}.pth')) def load_checkpoint(self, load_path): """ @@ -130,7 +130,6 @@ def train(self, train_loader, num_epochs, valid_loader=None, log_path=None, plot training_epoch_losses = [] validation_epoch_losses = [] metric_values = {metric.name: {'train': [], 'valid': []} for metric in self.metrics} - metric_values['time'] = {'train': [], 'valid': []} start_time = time.time()