Skip to content

Commit

Permalink
Implemented find_lr
Browse files Browse the repository at this point in the history
  • Loading branch information
pab1s committed May 28, 2024
1 parent 7416cdb commit 5534eb0
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 0 deletions.
24 changes: 24 additions & 0 deletions utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,27 @@ def plot_loss(training_epoch_losses, validation_epoch_losses, plot_path):
plt.grid(True)
plt.savefig(plot_path)
plt.close()

def plot_lr_vs_loss(log_lrs, losses, plot_path):
"""
Plots the learning rate vs loss.
Args:
log_lrs (list): List of log learning rates.
losses (list): List of losses.
plot_path (str): Path to save the plot.
Returns:
None
"""
plt.figure(figsize=(10, 5))

plt.plot(log_lrs, losses, label="Learning Rate vs Loss", marker='o')

plt.title("Learning Rate vs Loss")
plt.xlabel("Log Learning Rate")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.savefig(plot_path)
plt.close()
51 changes: 51 additions & 0 deletions utils/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import math
import torch

def find_lr(model, train_loader, criterion, optimizer_class, optimizer_params, init_value=1e-8, final_value=10, beta=0.98, device=None):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

num = len(train_loader) - 1
if num <= 0:
raise ValueError("The training loader must contain more than one batch to compute the learning rate range test.")

lr = init_value
optimizer_params['lr'] = lr
optimizer = optimizer_class(model.parameters(), **optimizer_params)

mult = (final_value / init_value) ** (1 / num)
avg_loss = 0.
best_loss = float('inf')
batch_num = 0
losses = []
log_lrs = []

for data in train_loader:
batch_num += 1
inputs, targets = data
inputs, targets = inputs.to(device), targets.to(device)

optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)

avg_loss = beta * avg_loss + (1 - beta) * loss.item()
smoothed_loss = avg_loss / (1 - beta**batch_num)

if batch_num > 1 and smoothed_loss > 4 * best_loss:
break

if smoothed_loss < best_loss or batch_num == 1:
best_loss = smoothed_loss

losses.append(smoothed_loss)
log_lrs.append(math.log10(lr))

loss.backward()
optimizer.step()

lr *= mult
optimizer.param_groups[0]['lr'] = lr

return log_lrs, losses

0 comments on commit 5534eb0

Please sign in to comment.