Skip to content

Commit

Permalink
Updated checkpoint filename advances #16
Browse files Browse the repository at this point in the history
  • Loading branch information
pab1s committed Apr 22, 2024
1 parent 0a8eb95 commit 6c8f728
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, model, device):
self.scheduler = None
self.metrics = []

def save_checkpoint(self, save_path, epoch, train_losses, val_losses, metric_values):
def save_checkpoint(self, save_path, epoch, train_losses, val_losses, metric_values) -> None:
"""
Saves the current state of the training process.
Args:
Expand All @@ -56,7 +56,7 @@ def save_checkpoint(self, save_path, epoch, train_losses, val_losses, metric_val
if self.scheduler:
state['scheduler_state_dict'] = self.scheduler.state_dict()

torch.save(state, os.path.join(save_path, f'checkpoint_epoch_{epoch}.pth'))
torch.save(state, os.path.join(save_path, f'checkpoint_epoch_{epoch+1}.pth'))

def load_checkpoint(self, load_path):
"""
Expand Down Expand Up @@ -130,7 +130,6 @@ def train(self, train_loader, num_epochs, valid_loader=None, log_path=None, plot
training_epoch_losses = []
validation_epoch_losses = []
metric_values = {metric.name: {'train': [], 'valid': []} for metric in self.metrics}
metric_values['time'] = {'train': [], 'valid': []}

start_time = time.time()

Expand Down

0 comments on commit 6c8f728

Please sign in to comment.