diff --git a/trainers/base_trainer.py b/trainers/base_trainer.py index a7f3b0d..c55a2a9 100644 --- a/trainers/base_trainer.py +++ b/trainers/base_trainer.py @@ -1,7 +1,6 @@ import os import torch from utils.plotting import plot_loss -from utils.logging import log_to_csv, log_epoch_results from abc import ABC, abstractmethod import time from typing import Tuple @@ -35,7 +34,7 @@ def __init__(self, model, device): self.scheduler = None self.metrics = [] - def save_checkpoint(self, save_path, epoch, train_losses, val_losses, metric_values) -> None: + def save_checkpoint(self, save_path, epoch, train_losses, val_losses, logs) -> None: """ Saves the current state of the training process. Args: @@ -43,7 +42,7 @@ def save_checkpoint(self, save_path, epoch, train_losses, val_losses, metric_val epoch (int): Current epoch number. train_losses (list): List of training losses up to the current epoch. val_losses (list): List of validation losses up to the current epoch. - metric_values (dict): Dictionary containing other metric values. + logs (dict): Dictionary containing other metric values. """ state = { 'epoch': epoch, @@ -51,7 +50,7 @@ def save_checkpoint(self, save_path, epoch, train_losses, val_losses, metric_val 'optimizer_state_dict': self.optimizer.state_dict(), 'train_losses': train_losses, 'val_losses': val_losses, - 'metric_values': metric_values + 'logs': logs } if self.scheduler: state['scheduler_state_dict'] = self.scheduler.state_dict() @@ -96,7 +95,7 @@ def unfreeze_all_layers(self) -> None: param.requires_grad = True @abstractmethod - def _train_epoch(self, train_loader, epoch, num_epochs, verbose=True) -> float: + def _train_epoch(self, train_loader, epoch, num_epochs) -> float: """ Trains the model for one epoch using the provided train_loader. @@ -104,7 +103,6 @@ def _train_epoch(self, train_loader, epoch, num_epochs, verbose=True) -> float: train_loader (DataLoader): The data loader for training data. epoch (int): The current epoch number. num_epochs (int): The total number of epochs. - verbose (bool, optional): Whether to print training progress. Defaults to True. Returns: float: The loss value for the epoch. @@ -112,7 +110,7 @@ def _train_epoch(self, train_loader, epoch, num_epochs, verbose=True) -> float: raise NotImplementedError( "The train_epoch method must be implemented by the subclass.") - def train(self, train_loader, num_epochs, valid_loader=None, log_path=None, plot_path=None, checkpoint_dir=None, verbose=True) -> None: + def train(self, train_loader, num_epochs, valid_loader=None, plot_path=None, checkpoint_dir=None, callbacks=None): """ Train the model for a given number of epochs, calculating metrics at the end of each epoch for both training and validation sets. @@ -121,51 +119,64 @@ def train(self, train_loader, num_epochs, valid_loader=None, log_path=None, plot train_loader: The data loader for the training set. num_epochs (int): The number of epochs to train the model. valid_loader: The data loader for the validation set (optional). - log_path: The path to save the training log (optional). plot_path: The path to save the training plot (optional). checkpoint_dir: The directory to save model checkpoints (optional). - verbose (bool): Whether to print training progress (default: True). + callbacks: List of callback objects to use during training (optional). """ + logs = {} times = [] training_epoch_losses = [] validation_epoch_losses = [] - metric_values = {metric.name: {'train': [], 'valid': []} for metric in self.metrics} + if callbacks is None: + callbacks = [] + start_time = time.time() + for callback in callbacks: + callback.on_train_begin(logs=logs) + for epoch in range(num_epochs): epoch_start_time = time.time() - epoch_loss_train = self._train_epoch(train_loader, epoch, num_epochs, verbose) + + logs['epoch'] = epoch + epoch_loss_train = self._train_epoch(train_loader, epoch, num_epochs) training_epoch_losses.append(epoch_loss_train) _, epoch_metrics_train = self.evaluate(train_loader, self.metrics, verbose=False) + logs['train_loss'] = epoch_loss_train + logs['train_metrics'] = epoch_metrics_train if valid_loader is not None: epoch_loss_valid, epoch_metrics_valid = self.evaluate(valid_loader, self.metrics, verbose=False) validation_epoch_losses.append(epoch_loss_valid) + logs['val_loss'] = epoch_loss_valid + logs['val_metrics'] = epoch_metrics_valid else: - epoch_metrics_valid = {metric.name: None for metric in self.metrics} + logs['val_loss'] = None + logs['val_metrics'] = {} + + for callback in callbacks: + callback.on_epoch_begin(epoch, logs=logs) + callback.on_epoch_end(epoch, logs=logs) - for metric_name in metric_values.keys(): - metric_values[metric_name]['train'].append(epoch_metrics_train[metric_name]) - metric_values[metric_name]['valid'].append(epoch_metrics_valid.get(metric_name)) - - if checkpoint_dir and (epoch + 1) % 5 == 0: - self.save_checkpoint(checkpoint_dir, epoch, training_epoch_losses, validation_epoch_losses, metric_values) - epoch_time = time.time() - epoch_start_time times.append(epoch_time) - if verbose: - log_epoch_results(epoch, num_epochs, epoch_loss_train, epoch_metrics_train, epoch_metrics_valid) + if checkpoint_dir and (epoch + 1) % 5 == 0: + self.save_checkpoint(checkpoint_dir, epoch, training_epoch_losses, validation_epoch_losses, logs) - if log_path is not None: - log_to_csv(training_epoch_losses, validation_epoch_losses, metric_values, times, log_path) + if not all(callback.should_continue(logs=logs) for callback in callbacks): + print(f"Training stopped early at epoch {epoch + 1}.") + break + + logs['times'] = times + + for callback in callbacks: + callback.on_train_end(logs=logs) elapsed_time = time.time() - start_time - - if verbose: - print(f"Training completed in: {elapsed_time:.2f} seconds") + print(f"Training completed in: {elapsed_time:.2f} seconds") if plot_path is not None: plot_loss(training_epoch_losses, validation_epoch_losses, plot_path)