diff --git a/find_learning_rates.py b/find_learning_rates.py index eaa7df6..ba00f20 100644 --- a/find_learning_rates.py +++ b/find_learning_rates.py @@ -1,5 +1,6 @@ import torch import yaml +import argparse from datetime import datetime from torch.utils.data import DataLoader, random_split from datasets.dataset import get_dataset @@ -24,7 +25,7 @@ def main(config_path, optimizer_type, optimizer_params, batch_size): # Split data total_size = len(data) test_size = int(total_size * config['data']['test_size']) - val_size = int((total_size - test_size) * config['data']['val_size']) + val_size = int(total_size * config['data']['val_size']) train_size = total_size - test_size - val_size data_train, _ = random_split(data, [train_size + val_size, test_size], generator=torch.Generator().manual_seed(config['random_seed'])) @@ -36,7 +37,6 @@ def main(config_path, optimizer_type, optimizer_params, batch_size): # Model setup model_factory = ModelFactory() model = model_factory.create(config['model']['type'], **config['model']['parameters']).to(device) - print(model) # Loss setup loss_factory = LossFactory() @@ -55,6 +55,11 @@ def main(config_path, optimizer_type, optimizer_params, batch_size): plot_lr_vs_loss(log_lrs, losses, path.join(config['paths']['plot_path'], plot_filename)) if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Process configuration file, optimizer types, and batch sizes.') + parser.add_argument('config_filename', type=str, help='Filename of the configuration file within the "config" directory') + + args = parser.parse_args() + batch_sizes = [8, 16, 32, 64] optimizer_types = ["SGD", "Adam"] adam_params = { @@ -71,7 +76,8 @@ def main(config_path, optimizer_type, optimizer_params, batch_size): "nesterov": False } - config_path = "config/fine_tuning_config.yaml" + # Build the path to the configuration file within the 'config' directory + config_path = f"config/{args.config_filename}" for optimizer_type in optimizer_types: for batch_size in batch_sizes: diff --git a/utils/training.py b/utils/training.py index 3126f99..5c9d5c6 100644 --- a/utils/training.py +++ b/utils/training.py @@ -1,7 +1,7 @@ import math import torch -def find_lr(model, train_loader, criterion, optimizer_class, optimizer_params, init_value=1e-8, final_value=10, beta=0.98, device=None): +def find_lr(model, train_loader, criterion, optimizer_class, optimizer_params, init_value=1e-8, final_value=1e-1, beta=0.98, device=None): if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device)