diff --git a/callbacks/__init__.py b/callbacks/__init__.py new file mode 100644 index 0000000..81213aa --- /dev/null +++ b/callbacks/__init__.py @@ -0,0 +1,4 @@ +from callbacks.csv_logging import CSVLogging +from callbacks.epoch_results_logging import EpochResultsLogging +from callbacks.early_stopping import EarlyStopping +from callbacks.checkpoint import Checkpoint \ No newline at end of file diff --git a/callbacks/callback.py b/callbacks/callback.py new file mode 100644 index 0000000..f92b9e1 --- /dev/null +++ b/callbacks/callback.py @@ -0,0 +1,73 @@ +class Callback: + """ + A base class for defining callbacks in a training process. + + Callbacks are functions that can be executed at various stages during training. + They can be used to perform additional actions or modify the behavior of the training process. + + Methods: + should_continue(logs=None) -> bool: + Determines whether the training process should continue or stop. + + on_epoch_begin(epoch, logs=None) -> None: + Executed at the beginning of each epoch. + + on_epoch_end(epoch, logs=None) -> None: + Executed at the end of each epoch. + + on_train_begin(logs=None) -> None: + Executed at the beginning of the training process. + + on_train_end(logs=None) -> None: + Executed at the end of the training process. + """ + + def should_continue(self, logs=None) -> bool: + """ + Determines whether the training process should continue or stop. + + Args: + logs (dict): Optional dictionary containing training logs. + + Returns: + bool: True if the training process should continue, False otherwise. + """ + return True + + def on_epoch_begin(self, epoch, logs=None) -> None: + """ + Executed at the beginning of each epoch. + + Args: + epoch (int): The current epoch number. + logs (dict): Optional dictionary containing training logs. + """ + pass + + def on_epoch_end(self, epoch, logs=None) -> None: + """ + Executed at the end of each epoch. + + Args: + epoch (int): The current epoch number. + logs (dict): Optional dictionary containing training logs. + """ + pass + + def on_train_begin(self, logs=None) -> None: + """ + Executed at the beginning of the training process. + + Args: + logs (dict): Optional dictionary containing training logs. + """ + pass + + def on_train_end(self, logs=None) -> None: + """ + Executed at the end of the training process. + + Args: + logs (dict): Optional dictionary containing training logs. + """ + pass diff --git a/callbacks/checkpoint.py b/callbacks/checkpoint.py new file mode 100644 index 0000000..0ca0ff7 --- /dev/null +++ b/callbacks/checkpoint.py @@ -0,0 +1,60 @@ +from callbacks.callback import Callback +import os +import torch + +class Checkpoint(Callback): + """ + Callback class for saving model checkpoints during training. + + Args: + checkpoint_dir (str): Directory to save the checkpoints. + model (torch.nn.Module): The model to be saved. + optimizer (torch.optim.Optimizer): The optimizer to be saved. + scheduler (torch.optim.lr_scheduler._LRScheduler, optional): The scheduler to be saved. Default is None. + save_freq (int, optional): Frequency of saving checkpoints. Default is 1. + verbose (bool, optional): Whether to print the checkpoint save path. Default is False. + """ + + def __init__(self, checkpoint_dir, model, optimizer, scheduler=None, save_freq=1, verbose=False): + self.checkpoint_dir = checkpoint_dir + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + self.save_freq = save_freq + self.verbose = verbose + if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir, exist_ok=True) + + def on_epoch_end(self, epoch, logs=None): + """ + Callback function called at the end of each epoch. + + Args: + epoch (int): The current epoch number. + logs (dict, optional): Dictionary containing training and validation losses. Default is None. + """ + if (epoch + 1) % self.save_freq == 0: + self.save_checkpoint(epoch, logs) + + def save_checkpoint(self, epoch, logs=None): + """ + Save the model checkpoint. + + Args: + epoch (int): The current epoch number. + logs (dict, optional): Dictionary containing training and validation losses. Default is None. + """ + state = { + 'epoch': epoch, + 'model_state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'train_losses': logs.get('training_losses', []), + 'val_losses': logs.get('validation_losses', []), + } + if self.scheduler: + state['scheduler_state_dict'] = self.scheduler.state_dict() + + save_path = os.path.join(self.checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth') + torch.save(state, save_path) + if self.verbose: + print(f"Checkpoint saved at {save_path}") diff --git a/callbacks/csv_logging.py b/callbacks/csv_logging.py new file mode 100644 index 0000000..805f13c --- /dev/null +++ b/callbacks/csv_logging.py @@ -0,0 +1,54 @@ +from callbacks.callback import Callback +import csv + +class CSVLogging(Callback): + """ + Callback for logging training and validation metrics to a CSV file. + + Args: + csv_path (str): The path to the CSV file. + + Attributes: + csv_path (str): The path to the CSV file. + headers_written (bool): Flag indicating whether the headers have been written to the CSV file. + """ + + def __init__(self, csv_path): + self.csv_path = csv_path + self.headers_written = False + + def on_epoch_end(self, epoch, logs=None): + """ + Method called at the end of each epoch during training. + + Args: + epoch (int): The current epoch number. + logs (dict): Dictionary containing the training and validation metrics. + + Returns: + None + """ + if logs is None: + return + + epoch_data = logs.get('epoch') + train_loss = logs.get('train_loss') + val_loss = logs.get('val_loss') + train_metrics = logs.get('train_metrics', {}) + val_metrics = logs.get('val_metrics', {}) + + metrics = {'train_loss': train_loss, 'val_loss': val_loss} + metrics.update({f'train_{key}': value for key, value in train_metrics.items()}) + metrics.update({f'val_{key}': value for key, value in val_metrics.items()}) + + if not self.headers_written: + headers = ['epoch'] + list(metrics.keys()) + with open(self.csv_path, 'w', newline='') as file: + writer = csv.writer(file) + writer.writerow(headers) + self.headers_written = True + + values = [epoch_data] + [metrics[key] for key in headers[1:]] # Ensure the order matches headers + with open(self.csv_path, 'a', newline='') as file: + writer = csv.writer(file) + writer.writerow(values) diff --git a/callbacks/early_stopping.py b/callbacks/early_stopping.py new file mode 100644 index 0000000..d8659e0 --- /dev/null +++ b/callbacks/early_stopping.py @@ -0,0 +1,34 @@ +from callbacks.callback import Callback + +class EarlyStopping(Callback): + def __init__(self, monitor='val_loss', patience=5, verbose=False, delta=0): + self.monitor = monitor + self.patience = patience + self.verbose = verbose + self.delta = delta + self.best_score = None + self.epochs_no_improve = 0 + self.stopped_epoch = 0 + self.early_stop = False + + def on_epoch_end(self, epoch, logs=None): + current = logs.get(self.monitor) + if current is None: + return + + score = -current if 'loss' in self.monitor else current + if self.best_score is None: + self.best_score = score + elif score < self.best_score + self.delta: + self.epochs_no_improve += 1 + if self.epochs_no_improve >= self.patience: + self.early_stop = True + self.stopped_epoch = epoch + 1 + if self.verbose: + print(f"Early stopping triggered at epoch {self.stopped_epoch}") + else: + self.best_score = score + self.epochs_no_improve = 0 + + def should_continue(self, logs=None): + return not self.early_stop diff --git a/callbacks/epoch_results_logging.py b/callbacks/epoch_results_logging.py new file mode 100644 index 0000000..0feb950 --- /dev/null +++ b/callbacks/epoch_results_logging.py @@ -0,0 +1,39 @@ +from callbacks.callback import Callback + +class EpochResultsLogging(Callback): + """ + Callback for logging epoch results during training. + + This callback prints the training loss, training metrics, and validation metrics at the end of each epoch. + + Args: + Callback: The base class for Keras callbacks. + + Methods: + on_epoch_end: Called at the end of each epoch to print the epoch results. + """ + + def on_epoch_end(self, epoch, logs=None): + """ + Prints the epoch results at the end of each epoch. + + Args: + epoch (int): The current epoch number. + logs (dict): A dictionary containing the training and validation metrics. + + Returns: + None + """ + epoch_loss_train = logs.get('train_loss') + epoch_metrics_train = logs.get('train_metrics', {}) + epoch_metrics_valid = logs.get('val_metrics', {}) + num_epochs = logs.get('num_epochs', 0) + + print(f"Epoch {epoch+1}/{num_epochs}, Training loss: {epoch_loss_train:.4f}") + + for metric_name, value in epoch_metrics_train.items(): + print(f"Training {metric_name}: {value:.4f}") + + if epoch_metrics_valid: + for metric_name, value in epoch_metrics_valid.items(): + print(f"Validation {metric_name}: {value if value is not None else 'N/A'}") diff --git a/fine_tuning.py b/fine_tuning.py index 960f435..9205056 100644 --- a/fine_tuning.py +++ b/fine_tuning.py @@ -4,7 +4,7 @@ from torch.utils.data import DataLoader, random_split from datasets.dataset import get_dataset from datasets.transformations import get_transforms -from utils.metrics import Accuracy, Precision, Recall, F1Score +from utils.metrics import Accuracy, Precision from models import get_model from trainers import get_trainer from os import path @@ -62,7 +62,6 @@ def main(config_path): train_loader=train_loader, valid_loader=valid_loader, num_epochs=config['training']['initial_epochs'], - log_path=log_filename, plot_path=plot_filename ) @@ -83,7 +82,6 @@ def main(config_path): train_loader=train_loader, valid_loader=valid_loader, num_epochs=config['training']['fine_tuning_epochs'], - log_path=log_filename, plot_path=plot_filename ) diff --git a/main.py b/main.py index 83ab540..60c9135 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,7 @@ from datetime import datetime from torch.utils.data import DataLoader, random_split from datasets.dataset import get_dataset +from callbacks import EarlyStopping, CSVLogging, EpochResultsLogging from datasets.transformations import get_transforms from utils.metrics import Accuracy, Precision, Recall, F1Score from models import get_model @@ -58,12 +59,22 @@ def main(config_path): metrics=metrics ) + callbacks = [ + CSVLogging(log_filename), + EpochResultsLogging(), + EarlyStopping( + monitor='val_loss', + patience=2, + delta=0.1, + verbose=True + ) + ] + trainer.train( train_loader=train_loader, valid_loader=valid_loader, num_epochs=config['training']['num_epochs'], - log_path=log_filename, - plot_path=plot_filename + callbacks=callbacks, ) trainer.evaluate(data_loader=test_loader) diff --git a/tests/test_checkpoints.py b/tests/test_checkpoints.py index c8513b3..c8d8ea4 100644 --- a/tests/test_checkpoints.py +++ b/tests/test_checkpoints.py @@ -1,19 +1,20 @@ import pytest import os +import yaml +import torch from trainers import get_trainer from utils.metrics import Accuracy, Precision, Recall, F1Score from datasets.transformations import get_transforms from datasets.dataset import get_dataset from models import get_model -import torch -import yaml +from callbacks import Checkpoint CONFIG_TEST = {} with open("./config/config_test.yaml", 'r') as file: CONFIG_TEST = yaml.safe_load(file) -def test_checkpoint_functionality(): +def test_checkpoint(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") transforms = get_transforms(CONFIG_TEST) @@ -34,44 +35,48 @@ def test_checkpoint_functionality(): model = get_model(CONFIG_TEST['model']['name'], CONFIG_TEST['model']['num_classes'], CONFIG_TEST['model']['pretrained']).to(device) criterion = torch.nn.CrossEntropyLoss() - optimizer = torch.optim.Adam - optimizer_params = {'lr': CONFIG_TEST['training']['learning_rate']} + optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG_TEST['training']['learning_rate']) metrics = [Accuracy(), Precision(), Recall(), F1Score()] trainer = get_trainer(CONFIG_TEST['trainer'], model=model, device=device) checkpoint_dir = "./outputs/checkpoints/" os.makedirs(checkpoint_dir, exist_ok=True) - + checkpoint_callback = Checkpoint( + checkpoint_dir=checkpoint_dir, + model=model, + optimizer=optimizer, + save_freq=5, + verbose=False + ) + trainer.build( criterion=criterion, - optimizer_class=optimizer, - optimizer_params=optimizer_params, + optimizer_class=torch.optim.Adam, + optimizer_params={'lr': CONFIG_TEST['training']['learning_rate']}, metrics=metrics ) + + # Train the model and automatically save the checkpoint at the specified interval trainer.train( train_loader=train_loader, num_epochs=6, - checkpoint_dir=checkpoint_dir, - verbose=False + valid_loader=None, + callbacks=[checkpoint_callback] ) checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint_epoch_5.pth') - assert os.path.exists(checkpoint_dir), "Checkpoint file was not created." + assert os.path.exists(checkpoint_path), "Checkpoint file was not created." + # Zero out the model parameters to simulate a restart for param in model.parameters(): param.data.zero_() + # Load the checkpoint trainer.load_checkpoint(checkpoint_path) - trainer.train( - train_loader=train_loader, - num_epochs=2, - checkpoint_dir=checkpoint_dir, - verbose=False - ) - + # Continue training or perform evaluation _, metrics_results = trainer.evaluate(test_loader, verbose=False) assert all([v >= 0 for v in metrics_results.values()]), "Metrics after resuming are not valid." -test_checkpoint_functionality() +test_checkpoint() diff --git a/tests/test_early_stopping.py b/tests/test_early_stopping.py new file mode 100644 index 0000000..0069291 --- /dev/null +++ b/tests/test_early_stopping.py @@ -0,0 +1,63 @@ +import pytest +from trainers import get_trainer +from callbacks import EarlyStopping +from utils.metrics import Accuracy +from datasets.transformations import get_transforms +from datasets.dataset import get_dataset +from models import get_model +import torch +import yaml + +CONFIG_TEST = {} + +with open("./config/config_test.yaml", 'r') as file: + CONFIG_TEST = yaml.safe_load(file) + +def test_early_stopping(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + transforms = get_transforms(CONFIG_TEST) + data = get_dataset( + name=CONFIG_TEST['data']['name'], + root_dir=CONFIG_TEST['data']['dataset_path'], + train=True, + transform=transforms + ) + + # Use a NaiveTrainer to test the early stopping + CONFIG_TEST['trainer'] = 'NaiveTrainer' + + train_size = int(0.64 * len(data)) + test_size = len(data) - train_size + data_train, data_test = torch.utils.data.random_split(data, [train_size, test_size], generator=torch.Generator().manual_seed(42)) + + train_loader = torch.utils.data.DataLoader(data_train, batch_size=CONFIG_TEST['training']['batch_size'], shuffle=True) + test_loader = torch.utils.data.DataLoader(data_test, batch_size=CONFIG_TEST['training']['batch_size'], shuffle=False) + + model = get_model(CONFIG_TEST['model']['name'], CONFIG_TEST['model']['num_classes'], CONFIG_TEST['model']['pretrained']).to(device) + + criterion = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.Adam + optimizer_params = {'lr': CONFIG_TEST['training']['learning_rate']} + metrics = [Accuracy()] + + trainer = get_trainer(CONFIG_TEST['trainer'], model=model, device=device) + + trainer.build( + criterion=criterion, + optimizer_class=optimizer, + optimizer_params=optimizer_params, + metrics=metrics + ) + + early_stopping_callback = EarlyStopping(patience=2, verbose=True, monitor='val_loss', delta=0.1) + trainer.train( + train_loader=train_loader, + num_epochs=3, # Intentionally, one more epoch than patience as early stopping should trigger + valid_loader=test_loader, + callbacks=[early_stopping_callback], + ) + + assert early_stopping_callback.early_stop, "Early stopping did not trigger as expected." + +test_early_stopping() diff --git a/tests/test_fine_tuning_pipeline.py b/tests/test_fine_tuning_pipeline.py index 3e41aae..a23e51f 100644 --- a/tests/test_fine_tuning_pipeline.py +++ b/tests/test_fine_tuning_pipeline.py @@ -79,11 +79,11 @@ def test_fine_tuning_loop(): freeze_until_layer=CONFIG_TEST['training'].get('freeze_until_layer'), metrics=metrics ) + trainer.train( train_loader=train_loader, valid_loader=valid_loader, num_epochs=CONFIG_TEST['training']['num_epochs'], - verbose=False ) trainer.unfreeze_all_layers() @@ -99,7 +99,6 @@ def test_fine_tuning_loop(): train_loader=train_loader, valid_loader=valid_loader, num_epochs=CONFIG_TEST['training']['num_epochs'], - verbose=False ) _, metrics_results = trainer.evaluate( diff --git a/tests/test_training_pipeline.py b/tests/test_training_pipeline.py index 0976acb..ea25796 100644 --- a/tests/test_training_pipeline.py +++ b/tests/test_training_pipeline.py @@ -80,7 +80,6 @@ def test_training_loop(): train_loader=train_loader, valid_loader=valid_loader, num_epochs=CONFIG_TEST['training']['num_epochs'], - verbose=False ) _, metrics_results = trainer.evaluate( data_loader=test_loader, diff --git a/trainers/__init__.py b/trainers/__init__.py index e5d9ec2..87299e4 100644 --- a/trainers/__init__.py +++ b/trainers/__init__.py @@ -1,4 +1,5 @@ from trainers.basic_trainer import BasicTrainer +from trainers.naive_trainer import NaiveTrainer def get_trainer(trainer_name, **kwargs): """ @@ -16,5 +17,7 @@ def get_trainer(trainer_name, **kwargs): """ if trainer_name == "BasicTrainer": return BasicTrainer(**kwargs) + if trainer_name == "NaiveTrainer": + return NaiveTrainer(**kwargs) else: raise ValueError(f"Trainer {trainer_name} not recognized.") diff --git a/trainers/base_trainer.py b/trainers/base_trainer.py index a7f3b0d..51b3e07 100644 --- a/trainers/base_trainer.py +++ b/trainers/base_trainer.py @@ -1,7 +1,4 @@ -import os import torch -from utils.plotting import plot_loss -from utils.logging import log_to_csv, log_epoch_results from abc import ABC, abstractmethod import time from typing import Tuple @@ -34,31 +31,8 @@ def __init__(self, model, device): self.optimizer = None self.scheduler = None self.metrics = [] - - def save_checkpoint(self, save_path, epoch, train_losses, val_losses, metric_values) -> None: - """ - Saves the current state of the training process. - Args: - save_path (str): Directory to save checkpoint files. - epoch (int): Current epoch number. - train_losses (list): List of training losses up to the current epoch. - val_losses (list): List of validation losses up to the current epoch. - metric_values (dict): Dictionary containing other metric values. - """ - state = { - 'epoch': epoch, - 'model_state_dict': self.model.state_dict(), - 'optimizer_state_dict': self.optimizer.state_dict(), - 'train_losses': train_losses, - 'val_losses': val_losses, - 'metric_values': metric_values - } - if self.scheduler: - state['scheduler_state_dict'] = self.scheduler.state_dict() - - torch.save(state, os.path.join(save_path, f'checkpoint_epoch_{epoch+1}.pth')) - def load_checkpoint(self, load_path): + def load_checkpoint(self, load_path) -> dict: """ Loads a checkpoint and resumes training or evaluation. Args: @@ -96,7 +70,7 @@ def unfreeze_all_layers(self) -> None: param.requires_grad = True @abstractmethod - def _train_epoch(self, train_loader, epoch, num_epochs, verbose=True) -> float: + def _train_epoch(self, train_loader, epoch, num_epochs) -> float: """ Trains the model for one epoch using the provided train_loader. @@ -104,7 +78,6 @@ def _train_epoch(self, train_loader, epoch, num_epochs, verbose=True) -> float: train_loader (DataLoader): The data loader for training data. epoch (int): The current epoch number. num_epochs (int): The total number of epochs. - verbose (bool, optional): Whether to print training progress. Defaults to True. Returns: float: The loss value for the epoch. @@ -112,7 +85,7 @@ def _train_epoch(self, train_loader, epoch, num_epochs, verbose=True) -> float: raise NotImplementedError( "The train_epoch method must be implemented by the subclass.") - def train(self, train_loader, num_epochs, valid_loader=None, log_path=None, plot_path=None, checkpoint_dir=None, verbose=True) -> None: + def train(self, train_loader, num_epochs, valid_loader=None, callbacks=None) -> None: """ Train the model for a given number of epochs, calculating metrics at the end of each epoch for both training and validation sets. @@ -121,54 +94,61 @@ def train(self, train_loader, num_epochs, valid_loader=None, log_path=None, plot train_loader: The data loader for the training set. num_epochs (int): The number of epochs to train the model. valid_loader: The data loader for the validation set (optional). - log_path: The path to save the training log (optional). - plot_path: The path to save the training plot (optional). - checkpoint_dir: The directory to save model checkpoints (optional). - verbose (bool): Whether to print training progress (default: True). + callbacks: List of callback objects to use during training (optional). """ + logs = {} times = [] training_epoch_losses = [] validation_epoch_losses = [] - metric_values = {metric.name: {'train': [], 'valid': []} for metric in self.metrics} + if callbacks is None: + callbacks = [] + start_time = time.time() + for callback in callbacks: + callback.on_train_begin(logs=logs) + for epoch in range(num_epochs): epoch_start_time = time.time() - epoch_loss_train = self._train_epoch(train_loader, epoch, num_epochs, verbose) + + for callback in callbacks: + callback.on_epoch_begin(epoch, logs=logs) + + logs['epoch'] = epoch + epoch_loss_train = self._train_epoch(train_loader, epoch, num_epochs) training_epoch_losses.append(epoch_loss_train) _, epoch_metrics_train = self.evaluate(train_loader, self.metrics, verbose=False) + logs['train_loss'] = epoch_loss_train + logs['train_metrics'] = epoch_metrics_train if valid_loader is not None: epoch_loss_valid, epoch_metrics_valid = self.evaluate(valid_loader, self.metrics, verbose=False) validation_epoch_losses.append(epoch_loss_valid) + logs['val_loss'] = epoch_loss_valid + logs['val_metrics'] = epoch_metrics_valid else: - epoch_metrics_valid = {metric.name: None for metric in self.metrics} - - for metric_name in metric_values.keys(): - metric_values[metric_name]['train'].append(epoch_metrics_train[metric_name]) - metric_values[metric_name]['valid'].append(epoch_metrics_valid.get(metric_name)) - - if checkpoint_dir and (epoch + 1) % 5 == 0: - self.save_checkpoint(checkpoint_dir, epoch, training_epoch_losses, validation_epoch_losses, metric_values) - + logs['val_loss'] = None + logs['val_metrics'] = {} + + for callback in callbacks: + callback.on_epoch_end(epoch, logs=logs) + epoch_time = time.time() - epoch_start_time times.append(epoch_time) - if verbose: - log_epoch_results(epoch, num_epochs, epoch_loss_train, epoch_metrics_train, epoch_metrics_valid) + if not all(callback.should_continue(logs=logs) for callback in callbacks): + print(f"Training stopped early at epoch {epoch + 1}.") + break - if log_path is not None: - log_to_csv(training_epoch_losses, validation_epoch_losses, metric_values, times, log_path) + logs['times'] = times - elapsed_time = time.time() - start_time - - if verbose: - print(f"Training completed in: {elapsed_time:.2f} seconds") + for callback in callbacks: + callback.on_train_end(logs=logs) - if plot_path is not None: - plot_loss(training_epoch_losses, validation_epoch_losses, plot_path) + elapsed_time = time.time() - start_time + print(f"Training completed in: {elapsed_time:.2f} seconds") def predict(self, instance) -> torch.Tensor: """ diff --git a/trainers/naive_trainer.py b/trainers/naive_trainer.py new file mode 100644 index 0000000..0f9e574 --- /dev/null +++ b/trainers/naive_trainer.py @@ -0,0 +1,34 @@ +from trainers.base_trainer import BaseTrainer + + +class NaiveTrainer(BaseTrainer): + """ + A complete naive trainer class for training a model. It main purpose is to + be used as a boilerplate to test some functionalities of the trainer. + + Args: + model (nn.Module): The model to be trained. + device (torch.device): The device to be used for training. + + Attributes: + model (nn.Module): The model to be trained. + device (torch.device): The device to be used for training. + """ + + def __init__(self, model, device): + super().__init__(model, device) + + def _train_epoch(self, train_loader, epoch, num_epochs, verbose=True) -> float: + """ + Simulates training the model for one epoch. It actually does nothing. + + Args: + train_loader (DataLoader): The data loader for training data. + epoch (int): The current epoch number. + num_epochs (int): The total number of epochs. + verbose (bool, optional): Whether to display progress bar. Defaults to True. + + Returns: + float: A mock loss value for the epoch of 0.0. + """ + return 0.0 \ No newline at end of file diff --git a/utils/logging.py b/utils/logging.py deleted file mode 100644 index 84adb9b..0000000 --- a/utils/logging.py +++ /dev/null @@ -1,67 +0,0 @@ -import csv - - -import csv - -def log_to_csv(training_losses, validation_losses, metric_values, times, csv_path) -> None: - """ - Logs the training and validation losses, along with metric values, to a CSV file. - - Args: - training_losses (list): A list of training losses for each epoch. - validation_losses (list): A list of validation losses for each epoch. - metric_values (dict): A dictionary containing metric values for each epoch. - The keys are the names of the metrics, and the values are dictionaries - with 'train' and 'valid' keys, representing the metric values for training - and validation sets, respectively. - times (list): A list of times taken for each epoch. - csv_path (str): The path to the CSV file where the log will be saved. - """ - headers = ['epoch', 'train_loss', 'val_loss'] - metric_names = list(metric_values.keys()) - for name in metric_names: - headers.append(f'train_{name}') - headers.append(f'val_{name}') - headers.append('time') - - with open(csv_path, 'w', newline='') as csv_file: - writer = csv.writer(csv_file) - writer.writerow(headers) - - num_epochs = max(len(training_losses), len(validation_losses)) - - for epoch in range(num_epochs): - row = [epoch + 1] - row.append(training_losses[epoch] if epoch < len(training_losses) else 'N/A') - row.append(validation_losses[epoch] if epoch < len(validation_losses) else 'N/A') - - for name in metric_names: - train_metrics = metric_values[name]['train'] - valid_metrics = metric_values[name]['valid'] - - row.append(train_metrics[epoch] if epoch < len(train_metrics) else 'N/A') - row.append(valid_metrics[epoch] if epoch < len(valid_metrics) else 'N/A') - - row.append(times[epoch]) - - writer.writerow(row) - -def log_epoch_results(epoch, num_epochs, epoch_loss_train, epoch_metrics_train, epoch_metrics_valid=None) -> None: - """ - Logs the results of an epoch during training. - - Args: - epoch (int): The current epoch number. - num_epochs (int): The total number of epochs. - epoch_loss_train (float): The training loss for the epoch. - epoch_metrics_train (dict): A dictionary containing the training metrics for the epoch. - epoch_metrics_valid (dict, optional): A dictionary containing the validation metrics for the epoch. Defaults to None. - """ - print(f"Epoch {epoch+1}/{num_epochs}, Training loss: {epoch_loss_train:.4f}") - - for metric_name, value in epoch_metrics_train.items(): - print(f"Training {metric_name}: {value:.4f}") - - if epoch_metrics_valid: - for metric_name, value in epoch_metrics_valid.items(): - print(f"Validation {metric_name}: {value if value is not None else 'N/A'}")