From 3e1d20b23422861eb679ace54142d9d1d921c379 Mon Sep 17 00:00:00 2001 From: Pablo Olivares <65406121+pab1s@users.noreply.github.com> Date: Mon, 15 Jul 2024 22:18:57 +0200 Subject: [PATCH] Train data augmentation scripts and config files closes #26 --- config/m_color_jitter.yaml | 81 +++++++++++++++ config/m_gaussian_blur.yaml | 79 +++++++++++++++ config/m_random_horizontal_flip.yaml | 78 ++++++++++++++ config/m_random_resized_crop.yaml | 80 +++++++++++++++ config/m_random_rotation.yaml | 78 ++++++++++++++ config/mm_color_jitter.yaml | 81 +++++++++++++++ config/mm_gaussian_blur.yaml | 79 +++++++++++++++ config/mm_random_horizontal_flip.yaml | 78 ++++++++++++++ config/mm_random_resized_crop.yaml | 80 +++++++++++++++ config/mm_random_rotation.yaml | 78 ++++++++++++++ scripts/train_transforms.sh | 29 ++++++ train_transforms.py | 141 -------------------------- 12 files changed, 821 insertions(+), 141 deletions(-) create mode 100644 config/m_color_jitter.yaml create mode 100644 config/m_gaussian_blur.yaml create mode 100644 config/m_random_horizontal_flip.yaml create mode 100644 config/m_random_resized_crop.yaml create mode 100644 config/m_random_rotation.yaml create mode 100644 config/mm_color_jitter.yaml create mode 100644 config/mm_gaussian_blur.yaml create mode 100644 config/mm_random_horizontal_flip.yaml create mode 100644 config/mm_random_resized_crop.yaml create mode 100644 config/mm_random_rotation.yaml create mode 100644 scripts/train_transforms.sh delete mode 100644 train_transforms.py diff --git a/config/m_color_jitter.yaml b/config/m_color_jitter.yaml new file mode 100644 index 0000000..b3ef70d --- /dev/null +++ b/config/m_color_jitter.yaml @@ -0,0 +1,81 @@ +trainer: "BasicTrainer" +random_seed: 43 + +model: + type: "efficientnet_b0" + parameters: + num_classes: 34 + pretrained: true + +training: + batch_size: 64 + epochs: + initial: 200 + fine_tuning: 200 + loss_function: + type: "CrossEntropyLoss" + optimizer: + type: "SGD" + parameters: + lr: 0.05 + learning_rates: + initial: 0.01 + fine_tuning: 0.001 + final_fine_tuning: 0.0001 + freeze_until_layer: "classifier.1.weight" + +metrics: + - type: "Accuracy" + - type: "Precision" + - type: "Recall" + - type: "F1Score" + +callbacks: + CSVLogging: + parameters: + csv_path: "dinamically/set/by/date.csv" + Checkpoint: + parameters: + save_freq: 5 + EarlyStopping: + parameters: + monitor: "val_loss" + patience: 5 + delta: 0 + verbose: true + +data: + name: "CarDataset" + dataset_path: "./data/processed/DB_Marca" + test_size: 0.1 + val_size: 0.1 + transforms: + - type: "Resize" + parameters: + size: [224, 224] + - type: "ColorJitter" + parameters: + brightness: 0.3 + contrast: 0.3 + saturation: 0.3 + hue: 0.1 + - type: "ToTensor" + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + eval_transforms: + - type: "Resize" + parameters: + size: [224, 224] + - type: "ToTensor" + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +paths: + model_path: "./outputs/models/" + log_path: "./logs/" + plot_path: "./outputs/figures/" + checkpoint_path: "./outputs/checkpoints/" diff --git a/config/m_gaussian_blur.yaml b/config/m_gaussian_blur.yaml new file mode 100644 index 0000000..a33fbe7 --- /dev/null +++ b/config/m_gaussian_blur.yaml @@ -0,0 +1,79 @@ +trainer: "BasicTrainer" +random_seed: 43 + +model: + type: "efficientnet_b0" + parameters: + num_classes: 34 + pretrained: true + +training: + batch_size: 64 + epochs: + initial: 200 + fine_tuning: 200 + loss_function: + type: "CrossEntropyLoss" + optimizer: + type: "SGD" + parameters: + lr: 0.05 + learning_rates: + initial: 0.01 + fine_tuning: 0.001 + final_fine_tuning: 0.0001 + freeze_until_layer: "classifier.1.weight" + +metrics: + - type: "Accuracy" + - type: "Precision" + - type: "Recall" + - type: "F1Score" + +callbacks: + CSVLogging: + parameters: + csv_path: "dinamically/set/by/date.csv" + Checkpoint: + parameters: + save_freq: 5 + EarlyStopping: + parameters: + monitor: "val_loss" + patience: 5 + delta: 0 + verbose: true + +data: + name: "CarDataset" + dataset_path: "./data/processed/DB_Marca" + test_size: 0.1 + val_size: 0.1 + transforms: + - type: "Resize" + parameters: + size: [224, 224] + - type: "GaussianBlur" + parameters: + kernel_size: [5, 5] + sigma: [0.1, 2.0] + - type: "ToTensor" + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + eval_transforms: + - type: "Resize" + parameters: + size: [224, 224] + - type: "ToTensor" + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +paths: + model_path: "./outputs/models/" + log_path: "./logs/" + plot_path: "./outputs/figures/" + checkpoint_path: "./outputs/checkpoints/" diff --git a/config/m_random_horizontal_flip.yaml b/config/m_random_horizontal_flip.yaml new file mode 100644 index 0000000..3b25573 --- /dev/null +++ b/config/m_random_horizontal_flip.yaml @@ -0,0 +1,78 @@ +trainer: "BasicTrainer" +random_seed: 43 + +model: + type: "efficientnet_b0" + parameters: + num_classes: 34 + pretrained: true + +training: + batch_size: 64 + epochs: + initial: 200 + fine_tuning: 200 + loss_function: + type: "CrossEntropyLoss" + optimizer: + type: "SGD" + parameters: + lr: 0.05 + learning_rates: + initial: 0.01 + fine_tuning: 0.001 + final_fine_tuning: 0.0001 + freeze_until_layer: "classifier.1.weight" + +metrics: + - type: "Accuracy" + - type: "Precision" + - type: "Recall" + - type: "F1Score" + +callbacks: + CSVLogging: + parameters: + csv_path: "dinamically/set/by/date.csv" + Checkpoint: + parameters: + save_freq: 5 + EarlyStopping: + parameters: + monitor: "val_loss" + patience: 5 + delta: 0 + verbose: true + +data: + name: "CarDataset" + dataset_path: "./data/processed/DB_Marca" + test_size: 0.1 + val_size: 0.1 + transforms: + - type: "Resize" + parameters: + size: [224, 224] + - type: "RandomHorizontalFlip" + parameters: + p: 0.5 + - type: "ToTensor" + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + eval_transforms: + - type: "Resize" + parameters: + size: [224, 224] + - type: "ToTensor" + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +paths: + model_path: "./outputs/models/" + log_path: "./logs/" + plot_path: "./outputs/figures/" + checkpoint_path: "./outputs/checkpoints/" diff --git a/config/m_random_resized_crop.yaml b/config/m_random_resized_crop.yaml new file mode 100644 index 0000000..d55369a --- /dev/null +++ b/config/m_random_resized_crop.yaml @@ -0,0 +1,80 @@ +trainer: "BasicTrainer" +random_seed: 43 + +model: + type: "efficientnet_b0" + parameters: + num_classes: 34 + pretrained: true + +training: + batch_size: 64 + epochs: + initial: 200 + fine_tuning: 200 + loss_function: + type: "CrossEntropyLoss" + optimizer: + type: "SGD" + parameters: + lr: 0.05 + learning_rates: + initial: 0.01 + fine_tuning: 0.001 + final_fine_tuning: 0.0001 + freeze_until_layer: "classifier.1.weight" + +metrics: + - type: "Accuracy" + - type: "Precision" + - type: "Recall" + - type: "F1Score" + +callbacks: + CSVLogging: + parameters: + csv_path: "dinamically/set/by/date.csv" + Checkpoint: + parameters: + save_freq: 5 + EarlyStopping: + parameters: + monitor: "val_loss" + patience: 5 + delta: 0 + verbose: true + +data: + name: "CarDataset" + dataset_path: "./data/processed/DB_Marca" + test_size: 0.1 + val_size: 0.1 + transforms: + - type: "Resize" + parameters: + size: [224, 224] + - type: "RandomResizedCrop" + parameters: + size: [224, 224] + scale: [0.8, 1.0] + ratio: [0.75, 1.33] + - type: "ToTensor" + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + eval_transforms: + - type: "Resize" + parameters: + size: [240, 240] + - type: "ToTensor" + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +paths: + model_path: "./outputs/models/" + log_path: "./logs/" + plot_path: "./outputs/figures/" + checkpoint_path: "./outputs/checkpoints/" diff --git a/config/m_random_rotation.yaml b/config/m_random_rotation.yaml new file mode 100644 index 0000000..bcc2ac8 --- /dev/null +++ b/config/m_random_rotation.yaml @@ -0,0 +1,78 @@ +trainer: "BasicTrainer" +random_seed: 43 + +model: + type: "efficientnet_b0" + parameters: + num_classes: 34 + pretrained: true + +training: + batch_size: 64 + epochs: + initial: 200 + fine_tuning: 200 + loss_function: + type: "CrossEntropyLoss" + optimizer: + type: "SGD" + parameters: + lr: 0.05 + learning_rates: + initial: 0.01 + fine_tuning: 0.001 + final_fine_tuning: 0.0001 + freeze_until_layer: "classifier.1.weight" + +metrics: + - type: "Accuracy" + - type: "Precision" + - type: "Recall" + - type: "F1Score" + +callbacks: + CSVLogging: + parameters: + csv_path: "dinamically/set/by/date.csv" + Checkpoint: + parameters: + save_freq: 5 + EarlyStopping: + parameters: + monitor: "val_loss" + patience: 5 + delta: 0 + verbose: true + +data: + name: "CarDataset" + dataset_path: "./data/processed/DB_Marca" + test_size: 0.1 + val_size: 0.1 + transforms: + - type: "Resize" + parameters: + size: [224, 224] + - type: "RandomRotation" + parameters: + degrees: [-10, 10] + - type: "ToTensor" + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + eval_transforms: + - type: "Resize" + parameters: + size: [224, 224] + - type: "ToTensor" + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +paths: + model_path: "./outputs/models/" + log_path: "./logs/" + plot_path: "./outputs/figures/" + checkpoint_path: "./outputs/checkpoints/" diff --git a/config/mm_color_jitter.yaml b/config/mm_color_jitter.yaml new file mode 100644 index 0000000..156d591 --- /dev/null +++ b/config/mm_color_jitter.yaml @@ -0,0 +1,81 @@ +trainer: "BasicTrainer" +random_seed: 43 + +model: + type: "densenet121" + parameters: + num_classes: 152 + pretrained: true + +training: + batch_size: 32 + epochs: + initial: 200 + fine_tuning: 200 + loss_function: + type: "CrossEntropyLoss" + optimizer: + type: "SGD" + parameters: + lr: 0.025 + learning_rates: + initial: 0.01 + fine_tuning: 0.001 + final_fine_tuning: 0.0001 + freeze_until_layer: "classifier.1.weight" + +metrics: + - type: "Accuracy" + - type: "Precision" + - type: "Recall" + - type: "F1Score" + +callbacks: + CSVLogging: + parameters: + csv_path: "dinamically/set/by/date.csv" + Checkpoint: + parameters: + save_freq: 5 + EarlyStopping: + parameters: + monitor: "val_loss" + patience: 5 + delta: 0 + verbose: true + +data: + name: "CarDataset" + dataset_path: "./data/processed/DB_Marca" + test_size: 0.1 + val_size: 0.1 + transforms: + - type: "Resize" + parameters: + size: [224, 224] + - type: "ColorJitter" + parameters: + brightness: 0.3 + contrast: 0.3 + saturation: 0.3 + hue: 0.1 + - type: "ToTensor" + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + eval_transforms: + - type: "Resize" + parameters: + size: [224, 224] + - type: "ToTensor" + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +paths: + model_path: "./outputs/models/" + log_path: "./logs/" + plot_path: "./outputs/figures/" + checkpoint_path: "./outputs/checkpoints/" diff --git a/config/mm_gaussian_blur.yaml b/config/mm_gaussian_blur.yaml new file mode 100644 index 0000000..382520f --- /dev/null +++ b/config/mm_gaussian_blur.yaml @@ -0,0 +1,79 @@ +trainer: "BasicTrainer" +random_seed: 43 + +model: + type: "densenet121" + parameters: + num_classes: 152 + pretrained: true + +training: + batch_size: 32 + epochs: + initial: 200 + fine_tuning: 200 + loss_function: + type: "CrossEntropyLoss" + optimizer: + type: "SGD" + parameters: + lr: 0.025 + learning_rates: + initial: 0.01 + fine_tuning: 0.001 + final_fine_tuning: 0.0001 + freeze_until_layer: "classifier.1.weight" + +metrics: + - type: "Accuracy" + - type: "Precision" + - type: "Recall" + - type: "F1Score" + +callbacks: + CSVLogging: + parameters: + csv_path: "dinamically/set/by/date.csv" + Checkpoint: + parameters: + save_freq: 5 + EarlyStopping: + parameters: + monitor: "val_loss" + patience: 5 + delta: 0 + verbose: true + +data: + name: "CarDataset" + dataset_path: "./data/processed/DB_Marca" + test_size: 0.1 + val_size: 0.1 + transforms: + - type: "Resize" + parameters: + size: [224, 224] + - type: "GaussianBlur" + parameters: + kernel_size: [5, 5] + sigma: [0.1, 2.0] + - type: "ToTensor" + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + eval_transforms: + - type: "Resize" + parameters: + size: [224, 224] + - type: "ToTensor" + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +paths: + model_path: "./outputs/models/" + log_path: "./logs/" + plot_path: "./outputs/figures/" + checkpoint_path: "./outputs/checkpoints/" diff --git a/config/mm_random_horizontal_flip.yaml b/config/mm_random_horizontal_flip.yaml new file mode 100644 index 0000000..2fa0196 --- /dev/null +++ b/config/mm_random_horizontal_flip.yaml @@ -0,0 +1,78 @@ +trainer: "BasicTrainer" +random_seed: 43 + +model: + type: "densenet121" + parameters: + num_classes: 152 + pretrained: true + +training: + batch_size: 32 + epochs: + initial: 200 + fine_tuning: 200 + loss_function: + type: "CrossEntropyLoss" + optimizer: + type: "SGD" + parameters: + lr: 0.025 + learning_rates: + initial: 0.01 + fine_tuning: 0.001 + final_fine_tuning: 0.0001 + freeze_until_layer: "classifier.1.weight" + +metrics: + - type: "Accuracy" + - type: "Precision" + - type: "Recall" + - type: "F1Score" + +callbacks: + CSVLogging: + parameters: + csv_path: "dinamically/set/by/date.csv" + Checkpoint: + parameters: + save_freq: 5 + EarlyStopping: + parameters: + monitor: "val_loss" + patience: 5 + delta: 0 + verbose: true + +data: + name: "CarDataset" + dataset_path: "./data/processed/DB_Marca" + test_size: 0.1 + val_size: 0.1 + transforms: + - type: "Resize" + parameters: + size: [224, 224] + - type: "RandomHorizontalFlip" + parameters: + p: 0.5 + - type: "ToTensor" + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + eval_transforms: + - type: "Resize" + parameters: + size: [224, 224] + - type: "ToTensor" + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +paths: + model_path: "./outputs/models/" + log_path: "./logs/" + plot_path: "./outputs/figures/" + checkpoint_path: "./outputs/checkpoints/" diff --git a/config/mm_random_resized_crop.yaml b/config/mm_random_resized_crop.yaml new file mode 100644 index 0000000..85a250c --- /dev/null +++ b/config/mm_random_resized_crop.yaml @@ -0,0 +1,80 @@ +trainer: "BasicTrainer" +random_seed: 43 + +model: + type: "densenet121" + parameters: + num_classes: 152 + pretrained: true + +training: + batch_size: 32 + epochs: + initial: 200 + fine_tuning: 200 + loss_function: + type: "CrossEntropyLoss" + optimizer: + type: "SGD" + parameters: + lr: 0.025 + learning_rates: + initial: 0.01 + fine_tuning: 0.001 + final_fine_tuning: 0.0001 + freeze_until_layer: "classifier.1.weight" + +metrics: + - type: "Accuracy" + - type: "Precision" + - type: "Recall" + - type: "F1Score" + +callbacks: + CSVLogging: + parameters: + csv_path: "dinamically/set/by/date.csv" + Checkpoint: + parameters: + save_freq: 5 + EarlyStopping: + parameters: + monitor: "val_loss" + patience: 5 + delta: 0 + verbose: true + +data: + name: "CarDataset" + dataset_path: "./data/processed/DB_Marca" + test_size: 0.1 + val_size: 0.1 + transforms: + - type: "Resize" + parameters: + size: [224, 224] + - type: "RandomResizedCrop" + parameters: + size: [224, 224] + scale: [0.8, 1.0] + ratio: [0.75, 1.33] + - type: "ToTensor" + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + eval_transforms: + - type: "Resize" + parameters: + size: [240, 240] + - type: "ToTensor" + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +paths: + model_path: "./outputs/models/" + log_path: "./logs/" + plot_path: "./outputs/figures/" + checkpoint_path: "./outputs/checkpoints/" diff --git a/config/mm_random_rotation.yaml b/config/mm_random_rotation.yaml new file mode 100644 index 0000000..9acc2ff --- /dev/null +++ b/config/mm_random_rotation.yaml @@ -0,0 +1,78 @@ +trainer: "BasicTrainer" +random_seed: 43 + +model: + type: "densenet121" + parameters: + num_classes: 152 + pretrained: true + +training: + batch_size: 32 + epochs: + initial: 200 + fine_tuning: 200 + loss_function: + type: "CrossEntropyLoss" + optimizer: + type: "SGD" + parameters: + lr: 0.025 + learning_rates: + initial: 0.01 + fine_tuning: 0.001 + final_fine_tuning: 0.0001 + freeze_until_layer: "classifier.1.weight" + +metrics: + - type: "Accuracy" + - type: "Precision" + - type: "Recall" + - type: "F1Score" + +callbacks: + CSVLogging: + parameters: + csv_path: "dinamically/set/by/date.csv" + Checkpoint: + parameters: + save_freq: 5 + EarlyStopping: + parameters: + monitor: "val_loss" + patience: 5 + delta: 0 + verbose: true + +data: + name: "CarDataset" + dataset_path: "./data/processed/DB_Marca" + test_size: 0.1 + val_size: 0.1 + transforms: + - type: "Resize" + parameters: + size: [224, 224] + - type: "RandomRotation" + parameters: + degrees: [-10, 10] + - type: "ToTensor" + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + eval_transforms: + - type: "Resize" + parameters: + size: [224, 224] + - type: "ToTensor" + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +paths: + model_path: "./outputs/models/" + log_path: "./logs/" + plot_path: "./outputs/figures/" + checkpoint_path: "./outputs/checkpoints/" diff --git a/scripts/train_transforms.sh b/scripts/train_transforms.sh new file mode 100644 index 0000000..b115f55 --- /dev/null +++ b/scripts/train_transforms.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +#SBATCH --job-name=trainTransforms # Process name +#SBATCH --partition=dios # Queue for execution +#SBATCH -w dionisio # Node to execute the job +#SBATCH --gres=gpu:1 # Number of GPUs to use +#SBATCH --mail-type=END,FAIL # Notifications for job done & fail +#SBATCH --mail-user=user@mail.com # Where to send notification + +# Load necessary paths +export PATH="/opt/anaconda/anaconda3/bin:$PATH" +export PATH="/opt/anaconda/bin:$PATH" +export PYTHONPATH=$(dirname $(dirname "$0")) + +# Setup Conda environment +eval "$(conda shell.bash hook)" +conda activate tda-nn-analysis +export TFHUB_CACHE_DIR=. + +# Check if correct number of arguments is passed +if [ "$#" -ne 1 ]; then + echo "Usage: $0 " + exit 1 +fi + +config_file=$1 + +# Call the Python script with the provided arguments +python train_transforms.py $config_file diff --git a/train_transforms.py b/train_transforms.py deleted file mode 100644 index e593a0f..0000000 --- a/train_transforms.py +++ /dev/null @@ -1,141 +0,0 @@ -import torch -import yaml -import argparse -from datetime import datetime -from torch.utils.data import DataLoader, random_split, WeightedRandomSampler -from datasets.dataset import get_dataset -from datasets.transformations import get_transforms -from utils.metrics import Accuracy, Precision, Recall, F1Score -from factories.model_factory import ModelFactory -from factories.loss_factory import LossFactory -from factories.optimizer_factory import OptimizerFactory -from factories.callback_factory import CallbackFactory -from trainers import get_trainer -from os import path - -def main(config_path, transform_type): - with open(config_path, 'r') as file: - config = yaml.safe_load(file) - - # If CUDA not available, finish execution - if not torch.cuda.is_available(): - print("CUDA is not available. Exiting...") - exit() - device = torch.device("cuda") - - # Load and transform data - transforms = get_transforms(transform_type) - eval_transforms = get_transforms(config['data']['eval_transforms']) - data = get_dataset(config['data']['name'], config['data']['dataset_path'], train=True, transform=transforms) - - # Split data - total_size = len(data) - test_size = int(total_size * config['data']['test_size']) - val_size = int(total_size * config['data']['val_size']) - train_size = total_size - test_size - val_size - - data_train, data_test = random_split(data, [train_size + val_size, test_size], generator=torch.Generator().manual_seed(config['random_seed'])) - data_train, data_val = random_split(data_train, [train_size, val_size], generator=torch.Generator().manual_seed(config['random_seed'])) - - # Apply evaluation transforms to validation and test datasets - data_test.dataset.transform = eval_transforms - data_val.dataset.transform = eval_transforms - - # Data loaders using the given batch_size - train_loader = DataLoader(data_train, batch_size=config['training']['batch_size'], shuffle=True) - valid_loader = DataLoader(data_val, batch_size=config['training']['batch_size'], shuffle=False) - test_loader = DataLoader(data_test, batch_size=config['training']['batch_size'], shuffle=False) - - # Model setup - model_factory = ModelFactory() - model = model_factory.create(config['model']['type'], num_classes=config['model']['parameters']['num_classes'], pretrained=config['model']['parameters']['pretrained']).to(device) - print(model) - - # Loss setup - loss_factory = LossFactory() - criterion = loss_factory.create(config['training']['loss_function']['type']) - - # Optimizer setup with given parameters - optimizer_factory = OptimizerFactory() - optimizer = optimizer_factory.create(config['training']['optimizer']['type']) - print("Using optimizer: ", optimizer, " with params: ", config['training']['optimizer']['parameters']) - - # Training stages setup - current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - model_dataset_time = f"{config['model']['type']}_{config['data']['name']}_{transform_type}_{current_time}" - log_filename = path.join(config['paths']['log_path'], f"log_finetuning_{model_dataset_time}.csv") - - # Callbacks setup - callbacks_config = config['callbacks'] - if "CSVLogging" in callbacks_config: - callbacks_config["CSVLogging"]["parameters"]["csv_path"] = log_filename - - # Metrics and trainer setup - metrics = [Accuracy(), Precision(), Recall(), F1Score()] - trainer = get_trainer(config['trainer'], model=model, device=device) - - # Initial training stage - print("Starting initial training stage with frozen layers...") - trainer.build( - criterion=criterion, - optimizer_class=optimizer, - optimizer_params=config['training']['optimizer']['parameters'], - # freeze_until_layer=config['training']['freeze_until_layer'], - metrics=metrics - ) - - callback_factory = CallbackFactory() - callbacks = [] - for name, params in callbacks_config.items(): - if name == "Checkpoint": - params["parameters"]["checkpoint_dir"] = path.join(config['paths']['checkpoint_path'], model_dataset_time) - params["parameters"]["model"] = model - params["parameters"]["optimizer"] = trainer.optimizer - params["parameters"]["scheduler"] = trainer.scheduler - - callback = callback_factory.create(name, **params["parameters"]) - - if name == "EarlyStopping": - callback.set_model_and_optimizer(model, trainer.optimizer) - - callbacks.append(callback) - - #trainer.train( - # train_loader=train_loader, - # valid_loader=valid_loader, - # num_epochs=config['training']['epochs']['initial'], - # callbacks=callbacks - #) - - # Fine-tuning stage with all layers unfrozen - print("Unfreezing all layers for fine-tuning...") - trainer.unfreeze_all_layers() - - # optimizer_instance = trainer.optimizer - # optimizer_factory.update(optimizer_instance, config['training']['learning_rates']['initial']) - - print("Starting full model fine-tuning...") - trainer.train( - train_loader=train_loader, - valid_loader=valid_loader, - num_epochs=config['training']['epochs']['fine_tuning'], - callbacks=callbacks - ) - - # Save model - model_path = path.join(config['paths']['model_path'], f"{model_dataset_time}.pth") - torch.save(model.state_dict(), model_path) - - # Evaluate - trainer.evaluate(data_loader=test_loader) - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--config_path", type=str, help="Path to the configuration file") - parser.add_argument("--transform_type", type=str, help="Type of transformation to apply to the dataset") - args = parser.parse_args() - config_path = args.config_path - transform_type = args.transform_type - - config_path = "config/fine_tuning_config.yaml" - main(config_path)