From 4e26b5740d0cafef633e2e90f54ff246d215312b Mon Sep 17 00:00:00 2001 From: Pablo Olivares <65406121+pab1s@users.noreply.github.com> Date: Mon, 15 Jul 2024 22:22:03 +0200 Subject: [PATCH] Topological trainer and config closes #29 --- config/config_efficientnet_topological.yaml | 80 +++++++++ config/mm_config_densenet_topological.yaml | 80 +++++++++ experiments/train_topological.py | 180 ++++++++++++++++++++ scripts/train_topological.sh | 31 ++++ trainers/topological_trainer.py | 90 ++++++++++ 5 files changed, 461 insertions(+) create mode 100644 config/config_efficientnet_topological.yaml create mode 100644 config/mm_config_densenet_topological.yaml create mode 100644 experiments/train_topological.py create mode 100644 scripts/train_topological.sh create mode 100644 trainers/topological_trainer.py diff --git a/config/config_efficientnet_topological.yaml b/config/config_efficientnet_topological.yaml new file mode 100644 index 0000000..3aa4504 --- /dev/null +++ b/config/config_efficientnet_topological.yaml @@ -0,0 +1,80 @@ +trainer: "TopologicalTrainer" +random_seed: 43 + +model: + type: "efficientnet_b0" + parameters: + num_classes: 152 + pretrained: true + +training: + batch_size: 64 + epochs: + initial: 30 + fine_tuning: 30 + loss_function: + type: "CrossEntropyLoss" + parameters: {} + optimizer: + type: "SGD" + parameters: + lr: 0.05 + learning_rates: + initial: 0.05 + fine_tuning: 0.05 + final_fine_tuning: 0.05 + freeze_until_layer: "classifier.1.0.weight" + +metrics: + - type: "Accuracy" + - type: "Precision" + - type: "Recall" + - type: "F1Score" + +callbacks: + CSVLogging: + parameters: + csv_path: "dinamically/set/by/date.csv" + Checkpoint: + parameters: + save_freq: 5 + EarlyStopping: + parameters: + monitor: "val_loss" + patience: 5 + delta: 0 + verbose: true + +data: + name: "CarDataset" + dataset_path: "./data/processed/DB_Marca_Modelo" + test_size: 0.1 + val_size: 0.1 + transforms: + - type: "Resize" + parameters: + size: [224, 224] + #- type: "TrivialAugmentWide" + # parameters: {} + - type: "ToTensor" + parameters: {} + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + eval_transforms: + - type: "Resize" + parameters: + size: [224, 224] + - type: "ToTensor" + parameters: {} + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +paths: + model_path: "./outputs/models/" + log_path: "./logs/" + plot_path: "./outputs/figures/" + checkpoint_path: "./outputs/checkpoints/" diff --git a/config/mm_config_densenet_topological.yaml b/config/mm_config_densenet_topological.yaml new file mode 100644 index 0000000..50a0df4 --- /dev/null +++ b/config/mm_config_densenet_topological.yaml @@ -0,0 +1,80 @@ +trainer: "TopologicalTrainer" +random_seed: 43 + +model: + type: "densenet121" + parameters: + num_classes: 34 + pretrained: true + +training: + batch_size: 32 + epochs: + initial: 30 + fine_tuning: 30 + loss_function: + type: "CrossEntropyLoss" + parameters: {} + optimizer: + type: "SGD" + parameters: + lr: 0.01 + learning_rates: + initial: 0.05 + fine_tuning: 0.05 + final_fine_tuning: 0.05 + freeze_until_layer: "classifier.0.weight" + +metrics: + - type: "Accuracy" + - type: "Precision" + - type: "Recall" + - type: "F1Score" + +callbacks: + CSVLogging: + parameters: + csv_path: "dinamically/set/by/date.csv" + Checkpoint: + parameters: + save_freq: 5 + EarlyStopping: + parameters: + monitor: "val_loss" + patience: 5 + delta: 0 + verbose: true + +data: + name: "CarDataset" + dataset_path: "./data/processed/DB_Marca" + test_size: 0.1 + val_size: 0.1 + transforms: + - type: "Resize" + parameters: + size: [224, 224] + #- type: "TrivialAugmentWide" + # parameters: {} + - type: "ToTensor" + parameters: {} + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + eval_transforms: + - type: "Resize" + parameters: + size: [224, 224] + - type: "ToTensor" + parameters: {} + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +paths: + model_path: "./outputs/models/" + log_path: "./logs/" + plot_path: "./outputs/figures/" + checkpoint_path: "./outputs/checkpoints/" diff --git a/experiments/train_topological.py b/experiments/train_topological.py new file mode 100644 index 0000000..302420b --- /dev/null +++ b/experiments/train_topological.py @@ -0,0 +1,180 @@ +import torch +import yaml +import argparse +from datetime import datetime +from torch.utils.data import DataLoader, random_split +from datasets.dataset import get_dataset +from datasets.transformations import get_transforms +from utils.metrics import Accuracy, Precision, Recall, F1Score +from factories.model_factory import ModelFactory +from factories.loss_factory import LossFactory +from factories.optimizer_factory import OptimizerFactory +from factories.callback_factory import CallbackFactory +from trainers import get_trainer +from os import path + +def main(config_path, model_path, alpha): + """ + Train a model using the given configuration file. + + Args: + config_path (str): Path to the configuration file. + model_path (str): Path to the trained model file (.pth). + alpha (float): Alpha value for the topological loss. + """ + + with open(config_path, 'r') as file: + config = yaml.safe_load(file) + + # If CUDA not available, finish execution + if not torch.cuda.is_available(): + print("CUDA is not available. Exiting...") + exit() + device = torch.device("cuda") + + # Load and transform data + transforms = get_transforms(config['data']['transforms']) + eval_transforms = get_transforms(config['data']['eval_transforms']) + data = get_dataset(config['data']['name'], config['data']['dataset_path'], train=True, transform=transforms) + + # Split data + total_size = len(data) + test_size = int(total_size * config['data']['test_size']) + val_size = int(total_size * config['data']['val_size']) + train_size = total_size - test_size - val_size + assert train_size > 0 and val_size > 0 and test_size > 0, "One of the splits has zero or negative size." + data_train, data_test = random_split(data, [train_size + val_size, test_size], generator=torch.Generator().manual_seed(config['random_seed'])) + data_train, data_val = random_split(data_train, [train_size, val_size], generator=torch.Generator().manual_seed(config['random_seed'])) + + # Apply evaluation transforms to validation and test datasets + data_test.dataset.transform = eval_transforms + data_val.dataset.transform = eval_transforms + + # Data loaders using the given batch_size + train_loader = DataLoader(data_train, batch_size=config['training']['batch_size'], shuffle=True) + valid_loader = DataLoader(data_val, batch_size=config['training']['batch_size'], shuffle=False) + test_loader = DataLoader(data_test, batch_size=config['training']['batch_size'], shuffle=False) + + # Model setup + model_factory = ModelFactory() + # Initialize with 34 classes, corresponding to the pretrained model + model = model_factory.create(config['model']['type'], num_classes=152, pretrained=True).to(device) + + # Load the pretrained model weights + pretrained_dict = torch.load(model_path) + model_dict = model.state_dict() + + # Remove pretrained classifier weights (since we are modifying the classifier) + pretrained_dict = {k: v for k, v in pretrained_dict.items() if 'classifier' not in k} + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict, strict=False) + + # Reinitialize the classifier for the new number of classes + num_ftrs = model.classifier[0].in_features + + model.classifier = torch.nn.Sequential( + torch.nn.Dropout(p=0.2, inplace=True), + torch.nn.Sequential( + torch.nn.Linear(num_ftrs, 256), + torch.nn.ReLU(), + torch.nn.Dropout(p=0.4, inplace=False), + torch.nn.Linear(256, config['model']['parameters']['num_classes']) + ) + ).to(device) + + # Ensure the model has been updated correctly + print("Updated model structure: ", model) + + # Loss setup + loss_factory = LossFactory() + criterion = loss_factory.create(config['training']['loss_function']['type']) + + # Optimizer setup with given parameters + optimizer_factory = OptimizerFactory() + optimizer = optimizer_factory.create(config['training']['optimizer']['type']) + optimizer_params = config['training']['optimizer']['parameters'] + print("Using optimizer: ", optimizer, " with params: ", optimizer_params) + print("Batch size: ", config['training']['batch_size']) + + # Training stages setup + current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + model_dataset_time = f"TOP_{config['model']['type']}_{config['data']['name']}_{config['training']['optimizer']['type']}_{config['training']['batch_size']}_{current_time}" + log_filename = path.join(config['paths']['log_path'], f"log_finetuning_{model_dataset_time}.csv") + + # Callbacks setup + callbacks_config = config['callbacks'] + if "CSVLogging" in callbacks_config: + callbacks_config["CSVLogging"]["parameters"]["csv_path"] = log_filename + + # Metrics and trainer setup + metrics = [Accuracy(), Precision(), Recall(), F1Score()] + trainer = get_trainer(config['trainer'], model=model, device=device) + + # Initial training stage + print("Starting initial training stage with frozen layers...") + trainer.build( + criterion=criterion, + optimizer_class=optimizer, + optimizer_params=optimizer_params, + metrics=metrics + ) + + callback_factory = CallbackFactory() + callbacks = [] + for name, params in callbacks_config.items(): + if name == "Checkpoint": + params["parameters"]["checkpoint_dir"] = path.join(config['paths']['checkpoint_path'], model_dataset_time) + params["parameters"]["model"] = model + params["parameters"]["optimizer"] = trainer.optimizer + params["parameters"]["scheduler"] = trainer.scheduler + + callback = callback_factory.create(name, **params["parameters"]) + + if name == "EarlyStopping": + callback.set_model_and_optimizer(model, trainer.optimizer) + + callbacks.append(callback) + + kwargs = {'alpha': alpha} + + trainer.train( + train_loader=train_loader, + valid_loader=valid_loader, + num_epochs=config['training']['epochs']['initial'], + callbacks=callbacks, + **kwargs + ) + + # Fine-tuning stage with all layers unfrozen + #print("Unfreezing all layers for fine-tuning...") + #trainer.unfreeze_all_layers() + + #optimizer_instance = trainer.optimizer + #optimizer_factory.update(optimizer_instance, config['training']['learning_rates']['initial']) + + #print("Starting full model fine-tuning...") + #trainer.train( + # train_loader=train_loader, + # valid_loader=valid_loader, + # num_epochs=config['training']['epochs']['fine_tuning'], + # callbacks=callbacks + #) + + # Save model + model_path = path.join(config['paths']['model_path'], f"{model_dataset_time}.pth") + torch.save(model.state_dict(), model_path) + + # Evaluate + trainer.evaluate(data_loader=test_loader) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Train a model using the given configuration file.') + parser.add_argument('config_filename', type=str, help='Filename of the configuration file within the "config" directory') + parser.add_argument('model_path', type=str, help='Path to the trained model file (.pth)') + parser.add_argument('alpha', type=float, help='Alpha value for the topological loss') + + args = parser.parse_args() + + config_path = f"config/{args.config_filename}" + + main(config_path, args.model_path, args.alpha) diff --git a/scripts/train_topological.sh b/scripts/train_topological.sh new file mode 100644 index 0000000..55b65c6 --- /dev/null +++ b/scripts/train_topological.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +#SBATCH --job-name=trainTopological # Process name +#SBATCH --partition=dios # Queue for execution +#SBATCH -w dionisio # Node to execute the job +#SBATCH --gres=gpu:1 # Number of GPUs to use +#SBATCH --mail-type=END,FAIL # Notifications for job done & fail +#SBATCH --mail-user=user@mail.com # Where to send notification + +# Load necessary paths +export PATH="/opt/anaconda/anaconda3/bin:$PATH" +export PATH="/opt/anaconda/bin:$PATH" +export PYTHONPATH=$(dirname $(dirname "$0")) + +# Setup Conda environment +eval "$(conda shell.bash hook)" +conda activate tda-nn-analysis +export TFHUB_CACHE_DIR=. + +# Check if correct number of arguments is passed +if [ "$#" -ne 3 ]; then + echo "Usage: $0 " + exit 1 +fi + +config_file=$1 +model_path=$2 +alpha=$3 + +# Call the Python script with the provided arguments +python train_topological.py $config_file $model_path $alpha diff --git a/trainers/topological_trainer.py b/trainers/topological_trainer.py new file mode 100644 index 0000000..813fd01 --- /dev/null +++ b/trainers/topological_trainer.py @@ -0,0 +1,90 @@ +from trainers.base_trainer import BaseTrainer +from tqdm import tqdm +import torch +from torch_topological.nn import VietorisRipsComplex + +class TopologicalTrainer(BaseTrainer): + """ + A trainer class for training models with a topological regularization term. + + Args: + model (nn.Module): The model to be trained. + device (torch.device): The device to be used for training. + model_type (str, optional): The type of the model. Defaults to None. + """ + + def __init__(self, model, device, model_type=None): + super().__init__(model, device) + self.features = None + self.model_type = model_type + self._register_feature_hook() + + def _register_feature_hook(self): + """ + Registers a forward hook to extract features from the model. + """ + def hook(module, input, output): + self.features = output + + if self.model_type == 'efficientnet': + self.model._avg_pooling.register_forward_hook(hook) + elif self.model_type == 'densenet': + self.model.features.norm5.register_forward_hook(hook) + + def _topological_regularizer(self, features): + """ + Computes the topological regularization loss based on the features. + + Args: + features (torch.Tensor): The extracted features from the model. + + Returns: + torch.Tensor: The topological regularization loss. + """ + diagram_computator = VietorisRipsComplex(dim=1, keep_infinite_features=False, p=2) + pd = diagram_computator(features) + pd = torch.cat((pd[0].diagram, pd[1].diagram), 0) + L = torch.max(pd[:, 1] - pd[:, 0]) + loss = torch.sum(pd[:, 1] - pd[:, 0]) / L + return loss + + def _train_epoch(self, train_loader, epoch, num_epochs, verbose=True, **kwargs) -> float: + """ + Trains the model for one epoch. + + Args: + train_loader (DataLoader): The data loader for training data. + epoch (int): The current epoch number. + num_epochs (int): The total number of epochs. + verbose (bool, optional): Whether to display progress bar. Defaults to True. + **kwargs: Additional keyword arguments. + + Returns: + float: The average loss for the epoch. + """ + self.model.train() + running_loss = 0.0 + alpha = kwargs.get('alpha', 0.0) + + if verbose: + progress_bar = tqdm(enumerate(train_loader, 1), total=len(train_loader), desc=f"Epoch {epoch + 1}/{num_epochs}") + else: + progress_bar = enumerate(train_loader, 1) + + for batch_idx, (images, labels) in progress_bar: + images, labels = images.to(self.device), labels.to(self.device) + self.optimizer.zero_grad() + outputs = self.model(images) + loss = self.criterion(outputs, labels) + topological_loss = self._topological_regularizer(outputs) + total_loss = loss + alpha * topological_loss + total_loss.backward() + self.optimizer.step() + running_loss += total_loss.item() + + if verbose: + progress_bar.set_postfix({'loss': running_loss / batch_idx}) + + epoch_loss = running_loss / len(train_loader) + + return epoch_loss