Skip to content

Commit

Permalink
Added callbacks support advances #17
Browse files Browse the repository at this point in the history
  • Loading branch information
pab1s committed Apr 22, 2024
1 parent 3c9d80a commit 3058a04
Showing 1 changed file with 37 additions and 26 deletions.
63 changes: 37 additions & 26 deletions trainers/base_trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import torch
from utils.plotting import plot_loss
from utils.logging import log_to_csv, log_epoch_results
from abc import ABC, abstractmethod
import time
from typing import Tuple
Expand Down Expand Up @@ -35,23 +34,23 @@ def __init__(self, model, device):
self.scheduler = None
self.metrics = []

def save_checkpoint(self, save_path, epoch, train_losses, val_losses, metric_values) -> None:
def save_checkpoint(self, save_path, epoch, train_losses, val_losses, logs) -> None:
"""
Saves the current state of the training process.
Args:
save_path (str): Directory to save checkpoint files.
epoch (int): Current epoch number.
train_losses (list): List of training losses up to the current epoch.
val_losses (list): List of validation losses up to the current epoch.
metric_values (dict): Dictionary containing other metric values.
logs (dict): Dictionary containing other metric values.
"""
state = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'train_losses': train_losses,
'val_losses': val_losses,
'metric_values': metric_values
'logs': logs
}
if self.scheduler:
state['scheduler_state_dict'] = self.scheduler.state_dict()
Expand Down Expand Up @@ -96,23 +95,22 @@ def unfreeze_all_layers(self) -> None:
param.requires_grad = True

@abstractmethod
def _train_epoch(self, train_loader, epoch, num_epochs, verbose=True) -> float:
def _train_epoch(self, train_loader, epoch, num_epochs) -> float:
"""
Trains the model for one epoch using the provided train_loader.
Args:
train_loader (DataLoader): The data loader for training data.
epoch (int): The current epoch number.
num_epochs (int): The total number of epochs.
verbose (bool, optional): Whether to print training progress. Defaults to True.
Returns:
float: The loss value for the epoch.
"""
raise NotImplementedError(
"The train_epoch method must be implemented by the subclass.")

def train(self, train_loader, num_epochs, valid_loader=None, log_path=None, plot_path=None, checkpoint_dir=None, verbose=True) -> None:
def train(self, train_loader, num_epochs, valid_loader=None, plot_path=None, checkpoint_dir=None, callbacks=None):
"""
Train the model for a given number of epochs, calculating metrics at the end of each epoch
for both training and validation sets.
Expand All @@ -121,51 +119,64 @@ def train(self, train_loader, num_epochs, valid_loader=None, log_path=None, plot
train_loader: The data loader for the training set.
num_epochs (int): The number of epochs to train the model.
valid_loader: The data loader for the validation set (optional).
log_path: The path to save the training log (optional).
plot_path: The path to save the training plot (optional).
checkpoint_dir: The directory to save model checkpoints (optional).
verbose (bool): Whether to print training progress (default: True).
callbacks: List of callback objects to use during training (optional).
"""
logs = {}
times = []
training_epoch_losses = []
validation_epoch_losses = []
metric_values = {metric.name: {'train': [], 'valid': []} for metric in self.metrics}

if callbacks is None:
callbacks = []

start_time = time.time()

for callback in callbacks:
callback.on_train_begin(logs=logs)

for epoch in range(num_epochs):
epoch_start_time = time.time()
epoch_loss_train = self._train_epoch(train_loader, epoch, num_epochs, verbose)

logs['epoch'] = epoch
epoch_loss_train = self._train_epoch(train_loader, epoch, num_epochs)
training_epoch_losses.append(epoch_loss_train)

_, epoch_metrics_train = self.evaluate(train_loader, self.metrics, verbose=False)
logs['train_loss'] = epoch_loss_train
logs['train_metrics'] = epoch_metrics_train

if valid_loader is not None:
epoch_loss_valid, epoch_metrics_valid = self.evaluate(valid_loader, self.metrics, verbose=False)
validation_epoch_losses.append(epoch_loss_valid)
logs['val_loss'] = epoch_loss_valid
logs['val_metrics'] = epoch_metrics_valid
else:
epoch_metrics_valid = {metric.name: None for metric in self.metrics}
logs['val_loss'] = None
logs['val_metrics'] = {}

for callback in callbacks:
callback.on_epoch_begin(epoch, logs=logs)
callback.on_epoch_end(epoch, logs=logs)

for metric_name in metric_values.keys():
metric_values[metric_name]['train'].append(epoch_metrics_train[metric_name])
metric_values[metric_name]['valid'].append(epoch_metrics_valid.get(metric_name))

if checkpoint_dir and (epoch + 1) % 5 == 0:
self.save_checkpoint(checkpoint_dir, epoch, training_epoch_losses, validation_epoch_losses, metric_values)

epoch_time = time.time() - epoch_start_time
times.append(epoch_time)

if verbose:
log_epoch_results(epoch, num_epochs, epoch_loss_train, epoch_metrics_train, epoch_metrics_valid)
if checkpoint_dir and (epoch + 1) % 5 == 0:
self.save_checkpoint(checkpoint_dir, epoch, training_epoch_losses, validation_epoch_losses, logs)

if log_path is not None:
log_to_csv(training_epoch_losses, validation_epoch_losses, metric_values, times, log_path)
if not all(callback.should_continue(logs=logs) for callback in callbacks):
print(f"Training stopped early at epoch {epoch + 1}.")
break

logs['times'] = times

for callback in callbacks:
callback.on_train_end(logs=logs)

elapsed_time = time.time() - start_time

if verbose:
print(f"Training completed in: {elapsed_time:.2f} seconds")
print(f"Training completed in: {elapsed_time:.2f} seconds")

if plot_path is not None:
plot_loss(training_epoch_losses, validation_epoch_losses, plot_path)
Expand Down

0 comments on commit 3058a04

Please sign in to comment.