Skip to content

Commit

Permalink
FIX - Save model snapshot advances #17
Browse files Browse the repository at this point in the history
  • Loading branch information
pab1s committed May 28, 2024
1 parent 0216d59 commit 798ba24
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 12 deletions.
33 changes: 33 additions & 0 deletions callbacks/early_stopping.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import torch
from copy import deepcopy
from callbacks.callback import Callback

class EarlyStopping(Callback):
Expand All @@ -10,15 +12,27 @@ def __init__(self, monitor='val_loss', patience=5, verbose=False, delta=0):
self.epochs_no_improve = 0
self.stopped_epoch = 0
self.early_stop = False
self.best_model_state = None
self.best_optimizer_state = None
self.model = None
self.optimizer = None

def set_model_and_optimizer(self, model, optimizer):
self.model = model
self.optimizer = optimizer

def on_epoch_end(self, epoch, logs=None):
if self.model is None or self.optimizer is None:
raise ValueError("Model and optimizer must be set before calling on_epoch_end.")

current = logs.get(self.monitor)
if current is None:
return

score = -current if 'loss' in self.monitor else current
if self.best_score is None:
self.best_score = score
self.save_checkpoint()
elif score < self.best_score + self.delta:
self.epochs_no_improve += 1
if self.epochs_no_improve >= self.patience:
Expand All @@ -29,6 +43,25 @@ def on_epoch_end(self, epoch, logs=None):
else:
self.best_score = score
self.epochs_no_improve = 0
self.save_checkpoint()

def save_checkpoint(self):
self.best_model_state = deepcopy(self.model.state_dict())
self.best_optimizer_state = deepcopy(self.optimizer.state_dict())

def should_continue(self, logs=None):
return not self.early_stop

def load_best_checkpoint(self):
if self.best_model_state is None or self.best_optimizer_state is None:
raise ValueError("No best checkpoint available to load.")

self.model.load_state_dict(self.best_model_state)
self.optimizer.load_state_dict(self.best_optimizer_state)
if self.verbose:
print(f"Loaded best checkpoint")

def on_train_end(self, logs=None):
self.load_best_checkpoint()
if self.verbose:
print(f"Training ended. Best model checkpoint has been loaded.")
32 changes: 21 additions & 11 deletions tests/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@
import torch
import yaml

CONFIG_TEST = {}

# Load test configuration
with open("./config/config_test.yaml", 'r') as file:
CONFIG_TEST = yaml.safe_load(file)

def test_early_stopping():
@pytest.mark.parametrize("patience,delta,num_epochs", [
(2, 0.1, 5), # Patience of 2, delta of 0.1, and training for 5 epochs
])
def test_early_stopping(patience, delta, num_epochs):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data transformations and loading
transforms = get_transforms(CONFIG_TEST['data']['transforms'])
data = get_dataset(
name=CONFIG_TEST['data']['name'],
Expand All @@ -24,40 +27,47 @@ def test_early_stopping():
transform=transforms
)

# Use a NaiveTrainer to test the early stopping
CONFIG_TEST['trainer'] = 'NaiveTrainer'

# Split data into training and testing sets
train_size = int(0.64 * len(data))
test_size = len(data) - train_size
data_train, data_test = torch.utils.data.random_split(data, [train_size, test_size], generator=torch.Generator().manual_seed(42))

train_loader = torch.utils.data.DataLoader(data_train, batch_size=CONFIG_TEST['training']['batch_size'], shuffle=True)
test_loader = torch.utils.data.DataLoader(data_test, batch_size=CONFIG_TEST['training']['batch_size'], shuffle=False)

# Initialize model
model = get_model(CONFIG_TEST['model']['name'], CONFIG_TEST['model']['num_classes'], CONFIG_TEST['model']['pretrained']).to(device)

# Initialize criterion and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam
optimizer_class = torch.optim.Adam
optimizer_params = {'lr': CONFIG_TEST['training']['learning_rate']}
metrics = [Accuracy()]

# Get trainer and build it
trainer = get_trainer(CONFIG_TEST['trainer'], model=model, device=device)

trainer.build(
criterion=criterion,
optimizer_class=optimizer,
optimizer_class=optimizer_class,
optimizer_params=optimizer_params,
metrics=metrics
)

early_stopping_callback = EarlyStopping(patience=2, verbose=True, monitor='val_loss', delta=0.1)
# Initialize EarlyStopping callback
early_stopping_callback = EarlyStopping(patience=patience, verbose=True, monitor='val_loss', delta=delta)

# Train the model
trainer.train(
train_loader=train_loader,
num_epochs=3, # Intentionally, one more epoch than patience as early stopping should trigger
num_epochs=num_epochs,
valid_loader=test_loader,
callbacks=[early_stopping_callback],
)

# Assert that early stopping was triggered
assert early_stopping_callback.early_stop, "Early stopping did not trigger as expected."

test_early_stopping()
# Run the test
if __name__ == "__main__":
pytest.main([__file__])
5 changes: 4 additions & 1 deletion trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
import time
from typing import Tuple

from callbacks.early_stopping import EarlyStopping

class BaseTrainer(ABC):
"""
Expand Down Expand Up @@ -133,6 +133,9 @@ def train(self, train_loader, num_epochs, valid_loader=None, callbacks=None) ->
logs['val_metrics'] = {}

for callback in callbacks:
if isinstance(callback, EarlyStopping):
callback.set_model_and_optimizer(self.model, self.optimizer)

callback.on_epoch_end(epoch, logs=logs)

epoch_time = time.time() - epoch_start_time
Expand Down

0 comments on commit 798ba24

Please sign in to comment.