Skip to content

Commit

Permalink
Created method find_lr advances #26
Browse files Browse the repository at this point in the history
  • Loading branch information
pab1s committed May 28, 2024
1 parent ecec316 commit f0600f0
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 1 deletion.
3 changes: 2 additions & 1 deletion factories/callback_factory.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from factories.factory import Factory
from callbacks import CSVLogging, EarlyStopping
from callbacks import CSVLogging, EarlyStopping, Checkpoint

class CallbackFactory(Factory):
def __init__(self):
super().__init__()
self.register("CSVLogging", CSVLogging)
self.register("EarlyStopping", EarlyStopping)
self.register("Checkpoint", Checkpoint)

def create(self, name, **kwargs):
creator = self._creators.get(name)
Expand Down
4 changes: 4 additions & 0 deletions factories/optimizer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@ def create(self, name, **kwargs):
if not creator:
raise ValueError(f"Unknown configuration: {name}")
return creator

def update(self, optimizer, new_lr):
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
79 changes: 79 additions & 0 deletions find_learning_rates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch
import yaml
from datetime import datetime
from torch.utils.data import DataLoader, random_split
from datasets.dataset import get_dataset
from datasets.transformations import get_transforms
from utils.training import find_lr
from utils.plotting import plot_lr_vs_loss
from factories.model_factory import ModelFactory
from factories.loss_factory import LossFactory
from factories.optimizer_factory import OptimizerFactory
from os import path

def main(config_path, optimizer_type, optimizer_params, batch_size):
with open(config_path, 'r') as file:
config = yaml.safe_load(file)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load and transform data
transforms = get_transforms(config['data']['transforms'])
data = get_dataset(config['data']['name'], config['data']['dataset_path'], train=True, transform=transforms)

# Split data
total_size = len(data)
test_size = int(total_size * config['data']['test_size'])
val_size = int((total_size - test_size) * config['data']['val_size'])
train_size = total_size - test_size - val_size

data_train, _ = random_split(data, [train_size + val_size, test_size], generator=torch.Generator().manual_seed(config['random_seed']))
data_train, _ = random_split(data_train, [train_size, val_size], generator=torch.Generator().manual_seed(config['random_seed']))

# Data loaders using the given batch_size
train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True)

# Model setup
model_factory = ModelFactory()
model = model_factory.create(config['model']['type'], **config['model']['parameters']).to(device)
print(model)

# Loss setup
loss_factory = LossFactory()
criterion = loss_factory.create(config['training']['loss_function']['type'])

# Optimizer setup with given parameters
optimizer_factory = OptimizerFactory()
optimizer = optimizer_factory.create(optimizer_type, params=model.parameters(), **optimizer_params)

# Find learning rate
print("Finding learning rate...")
log_lrs, losses = find_lr(model, train_loader, criterion, optimizer, optimizer_params, device=device)

current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
plot_filename = f"lr_vs_loss_{config['model']['type']}_{current_time}_batch{batch_size}_{optimizer_type}.png"
plot_lr_vs_loss(log_lrs, losses, path.join(config['paths']['plot_path'], plot_filename))

if __name__ == "__main__":
batch_sizes = [8, 16, 32, 64]
optimizer_types = ["SGD", "Adam"]
adam_params = {
"lr": 0.01,
"betas": (0.9, 0.999),
"eps": 1e-8,
"weight_decay": 0,
"amsgrad": False
}
sgd_params = {
"lr": 0.001,
"momentum": 0.9,
"weight_decay": 0,
"nesterov": False
}

config_path = "config/fine_tuning_config.yaml"

for optimizer_type in optimizer_types:
for batch_size in batch_sizes:
optimizer_params = adam_params if optimizer_type == "Adam" else sgd_params
main(config_path, optimizer_type, optimizer_params, batch_size)

0 comments on commit f0600f0

Please sign in to comment.