From 3dbea12ca7fcff0916e203bac0891a4711dc4947 Mon Sep 17 00:00:00 2001 From: Pablo Olivares Date: Tue, 28 May 2024 21:04:28 +0200 Subject: [PATCH] Update fine_tuning_config.yaml --- config/fine_tuning_config.yaml | 53 +++++++++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 8 deletions(-) diff --git a/config/fine_tuning_config.yaml b/config/fine_tuning_config.yaml index 5808821..90b43bd 100644 --- a/config/fine_tuning_config.yaml +++ b/config/fine_tuning_config.yaml @@ -2,20 +2,27 @@ trainer: "BasicTrainer" random_seed: 43 model: - type: "efficientnet_b1" + type: "efficientnet_b0" parameters: - num_classes: 394 + num_classes: 34 pretrained: true training: batch_size: 32 epochs: - initial: 1 - fine_tuning: 1 + initial: 100 + fine_tuning: 10 + loss_function: + type: "CrossEntropyLoss" + parameters: {} + optimizer: + type: "Adam" + parameters: + learning_rate: 0.005 learning_rates: - initial: 0.001 - fine_tuning: 0.0001 - final_fine_tuning: 0.00001 + initial: 0.01 + fine_tuning: 0.001 + final_fine_tuning: 0.0001 freeze_until_layer: "classifier.1.weight" metrics: @@ -24,20 +31,50 @@ metrics: - type: "Recall" - type: "F1Score" +callbacks: + CSVLogging: + parameters: + csv_path: "dinamically/set/by/date.csv" + Checkpoint: + parameters: + save_freq: 1 + EarlyStopping: + parameters: + monitor: "val_loss" + patience: 10 + delta: 0.01 + verbose: true + data: name: "CarDataset" - dataset_path: "./data/processed/BD-461" + dataset_path: "./data/processed/DB_Marca" test_size: 0.2 val_size: 0.1 transforms: - type: "Resize" parameters: size: [240, 240] + - type: "AutoAugment" + parameters: {} - type: "ToTensor" parameters: {} - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + eval_transforms: + - type: "Resize" + parameters: + size: [240, 240] + - type: "ToTensor" parameters: {} + - type: "Normalize" + parameters: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] paths: + model_path: "./outputs/models/" log_path: "./logs/" plot_path: "./outputs/figures/" + checkpoint_path: "./outputs/checkpoints/"