Skip to content

Commit

Permalink
Added checkpoint support advances #16
Browse files Browse the repository at this point in the history
  • Loading branch information
pab1s committed Apr 22, 2024
1 parent 4b79d4f commit 0a8eb95
Showing 1 changed file with 42 additions and 1 deletion.
43 changes: 42 additions & 1 deletion trainers/base_trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import torch
from utils.plotting import plot_loss
from utils.logging import log_to_csv, log_epoch_results
Expand Down Expand Up @@ -33,6 +34,42 @@ def __init__(self, model, device):
self.optimizer = None
self.scheduler = None
self.metrics = []

def save_checkpoint(self, save_path, epoch, train_losses, val_losses, metric_values):
"""
Saves the current state of the training process.
Args:
save_path (str): Directory to save checkpoint files.
epoch (int): Current epoch number.
train_losses (list): List of training losses up to the current epoch.
val_losses (list): List of validation losses up to the current epoch.
metric_values (dict): Dictionary containing other metric values.
"""
state = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'train_losses': train_losses,
'val_losses': val_losses,
'metric_values': metric_values
}
if self.scheduler:
state['scheduler_state_dict'] = self.scheduler.state_dict()

torch.save(state, os.path.join(save_path, f'checkpoint_epoch_{epoch}.pth'))

def load_checkpoint(self, load_path):
"""
Loads a checkpoint and resumes training or evaluation.
Args:
load_path (str): Path to the checkpoint file.
"""
state = torch.load(load_path)
self.model.load_state_dict(state['model_state_dict'])
self.optimizer.load_state_dict(state['optimizer_state_dict'])
if 'scheduler_state_dict' in state and self.scheduler:
self.scheduler.load_state_dict(state['scheduler_state_dict'])
return state

def build(self, criterion, optimizer_class, optimizer_params={}, scheduler=None, freeze_until_layer=None, metrics=[]) -> None:
""" Build the model, criterion, optimizer and scheduler. """
Expand Down Expand Up @@ -75,7 +112,7 @@ def _train_epoch(self, train_loader, epoch, num_epochs, verbose=True) -> float:
raise NotImplementedError(
"The train_epoch method must be implemented by the subclass.")

def train(self, train_loader, num_epochs, valid_loader=None, log_path=None, plot_path=None, verbose=True) -> None:
def train(self, train_loader, num_epochs, valid_loader=None, log_path=None, plot_path=None, checkpoint_dir=None, verbose=True) -> None:
"""
Train the model for a given number of epochs, calculating metrics at the end of each epoch
for both training and validation sets.
Expand All @@ -86,6 +123,7 @@ def train(self, train_loader, num_epochs, valid_loader=None, log_path=None, plot
valid_loader: The data loader for the validation set (optional).
log_path: The path to save the training log (optional).
plot_path: The path to save the training plot (optional).
checkpoint_dir: The directory to save model checkpoints (optional).
verbose (bool): Whether to print training progress (default: True).
"""
times = []
Expand Down Expand Up @@ -113,6 +151,9 @@ def train(self, train_loader, num_epochs, valid_loader=None, log_path=None, plot
metric_values[metric_name]['train'].append(epoch_metrics_train[metric_name])
metric_values[metric_name]['valid'].append(epoch_metrics_valid.get(metric_name))

if checkpoint_dir and (epoch + 1) % 5 == 0:
self.save_checkpoint(checkpoint_dir, epoch, training_epoch_losses, validation_epoch_losses, metric_values)

epoch_time = time.time() - epoch_start_time
times.append(epoch_time)

Expand Down

0 comments on commit 0a8eb95

Please sign in to comment.