Skip to content

Commit

Permalink
Updated find_lr advances #26
Browse files Browse the repository at this point in the history
  • Loading branch information
pab1s committed Jun 24, 2024
1 parent 129f81e commit d4c8ba2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
12 changes: 9 additions & 3 deletions find_learning_rates.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import yaml
import argparse
from datetime import datetime
from torch.utils.data import DataLoader, random_split
from datasets.dataset import get_dataset
Expand All @@ -24,7 +25,7 @@ def main(config_path, optimizer_type, optimizer_params, batch_size):
# Split data
total_size = len(data)
test_size = int(total_size * config['data']['test_size'])
val_size = int((total_size - test_size) * config['data']['val_size'])
val_size = int(total_size * config['data']['val_size'])
train_size = total_size - test_size - val_size

data_train, _ = random_split(data, [train_size + val_size, test_size], generator=torch.Generator().manual_seed(config['random_seed']))
Expand All @@ -36,7 +37,6 @@ def main(config_path, optimizer_type, optimizer_params, batch_size):
# Model setup
model_factory = ModelFactory()
model = model_factory.create(config['model']['type'], **config['model']['parameters']).to(device)
print(model)

# Loss setup
loss_factory = LossFactory()
Expand All @@ -55,6 +55,11 @@ def main(config_path, optimizer_type, optimizer_params, batch_size):
plot_lr_vs_loss(log_lrs, losses, path.join(config['paths']['plot_path'], plot_filename))

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Process configuration file, optimizer types, and batch sizes.')
parser.add_argument('config_filename', type=str, help='Filename of the configuration file within the "config" directory')

args = parser.parse_args()

batch_sizes = [8, 16, 32, 64]
optimizer_types = ["SGD", "Adam"]
adam_params = {
Expand All @@ -71,7 +76,8 @@ def main(config_path, optimizer_type, optimizer_params, batch_size):
"nesterov": False
}

config_path = "config/fine_tuning_config.yaml"
# Build the path to the configuration file within the 'config' directory
config_path = f"config/{args.config_filename}"

for optimizer_type in optimizer_types:
for batch_size in batch_sizes:
Expand Down
2 changes: 1 addition & 1 deletion utils/training.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import torch

def find_lr(model, train_loader, criterion, optimizer_class, optimizer_params, init_value=1e-8, final_value=10, beta=0.98, device=None):
def find_lr(model, train_loader, criterion, optimizer_class, optimizer_params, init_value=1e-8, final_value=1e-1, beta=0.98, device=None):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
Expand Down

0 comments on commit d4c8ba2

Please sign in to comment.