diff --git a/trainers/base_trainer.py b/trainers/base_trainer.py index cd7ac07..a7f3b0d 100644 --- a/trainers/base_trainer.py +++ b/trainers/base_trainer.py @@ -35,7 +35,7 @@ def __init__(self, model, device): self.scheduler = None self.metrics = [] - def save_checkpoint(self, save_path, epoch, train_losses, val_losses, metric_values): + def save_checkpoint(self, save_path, epoch, train_losses, val_losses, metric_values) -> None: """ Saves the current state of the training process. Args: @@ -56,7 +56,7 @@ def save_checkpoint(self, save_path, epoch, train_losses, val_losses, metric_val if self.scheduler: state['scheduler_state_dict'] = self.scheduler.state_dict() - torch.save(state, os.path.join(save_path, f'checkpoint_epoch_{epoch}.pth')) + torch.save(state, os.path.join(save_path, f'checkpoint_epoch_{epoch+1}.pth')) def load_checkpoint(self, load_path): """ @@ -130,7 +130,6 @@ def train(self, train_loader, num_epochs, valid_loader=None, log_path=None, plot training_epoch_losses = [] validation_epoch_losses = [] metric_values = {metric.name: {'train': [], 'valid': []} for metric in self.metrics} - metric_values['time'] = {'train': [], 'valid': []} start_time = time.time()