diff --git a/callbacks/early_stopping.py b/callbacks/early_stopping.py index d8659e0..37d0bac 100644 --- a/callbacks/early_stopping.py +++ b/callbacks/early_stopping.py @@ -1,3 +1,5 @@ +import torch +from copy import deepcopy from callbacks.callback import Callback class EarlyStopping(Callback): @@ -10,8 +12,19 @@ def __init__(self, monitor='val_loss', patience=5, verbose=False, delta=0): self.epochs_no_improve = 0 self.stopped_epoch = 0 self.early_stop = False + self.best_model_state = None + self.best_optimizer_state = None + self.model = None + self.optimizer = None + + def set_model_and_optimizer(self, model, optimizer): + self.model = model + self.optimizer = optimizer def on_epoch_end(self, epoch, logs=None): + if self.model is None or self.optimizer is None: + raise ValueError("Model and optimizer must be set before calling on_epoch_end.") + current = logs.get(self.monitor) if current is None: return @@ -19,6 +32,7 @@ def on_epoch_end(self, epoch, logs=None): score = -current if 'loss' in self.monitor else current if self.best_score is None: self.best_score = score + self.save_checkpoint() elif score < self.best_score + self.delta: self.epochs_no_improve += 1 if self.epochs_no_improve >= self.patience: @@ -29,6 +43,25 @@ def on_epoch_end(self, epoch, logs=None): else: self.best_score = score self.epochs_no_improve = 0 + self.save_checkpoint() + + def save_checkpoint(self): + self.best_model_state = deepcopy(self.model.state_dict()) + self.best_optimizer_state = deepcopy(self.optimizer.state_dict()) def should_continue(self, logs=None): return not self.early_stop + + def load_best_checkpoint(self): + if self.best_model_state is None or self.best_optimizer_state is None: + raise ValueError("No best checkpoint available to load.") + + self.model.load_state_dict(self.best_model_state) + self.optimizer.load_state_dict(self.best_optimizer_state) + if self.verbose: + print(f"Loaded best checkpoint") + + def on_train_end(self, logs=None): + self.load_best_checkpoint() + if self.verbose: + print(f"Training ended. Best model checkpoint has been loaded.") diff --git a/tests/test_early_stopping.py b/tests/test_early_stopping.py index bf0a771..3e07f34 100644 --- a/tests/test_early_stopping.py +++ b/tests/test_early_stopping.py @@ -8,14 +8,17 @@ import torch import yaml -CONFIG_TEST = {} - +# Load test configuration with open("./config/config_test.yaml", 'r') as file: CONFIG_TEST = yaml.safe_load(file) -def test_early_stopping(): +@pytest.mark.parametrize("patience,delta,num_epochs", [ + (2, 0.1, 5), # Patience of 2, delta of 0.1, and training for 5 epochs +]) +def test_early_stopping(patience, delta, num_epochs): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # Data transformations and loading transforms = get_transforms(CONFIG_TEST['data']['transforms']) data = get_dataset( name=CONFIG_TEST['data']['name'], @@ -24,9 +27,7 @@ def test_early_stopping(): transform=transforms ) - # Use a NaiveTrainer to test the early stopping - CONFIG_TEST['trainer'] = 'NaiveTrainer' - + # Split data into training and testing sets train_size = int(0.64 * len(data)) test_size = len(data) - train_size data_train, data_test = torch.utils.data.random_split(data, [train_size, test_size], generator=torch.Generator().manual_seed(42)) @@ -34,30 +35,39 @@ def test_early_stopping(): train_loader = torch.utils.data.DataLoader(data_train, batch_size=CONFIG_TEST['training']['batch_size'], shuffle=True) test_loader = torch.utils.data.DataLoader(data_test, batch_size=CONFIG_TEST['training']['batch_size'], shuffle=False) + # Initialize model model = get_model(CONFIG_TEST['model']['name'], CONFIG_TEST['model']['num_classes'], CONFIG_TEST['model']['pretrained']).to(device) + # Initialize criterion and optimizer criterion = torch.nn.CrossEntropyLoss() - optimizer = torch.optim.Adam + optimizer_class = torch.optim.Adam optimizer_params = {'lr': CONFIG_TEST['training']['learning_rate']} metrics = [Accuracy()] + # Get trainer and build it trainer = get_trainer(CONFIG_TEST['trainer'], model=model, device=device) trainer.build( criterion=criterion, - optimizer_class=optimizer, + optimizer_class=optimizer_class, optimizer_params=optimizer_params, metrics=metrics ) - early_stopping_callback = EarlyStopping(patience=2, verbose=True, monitor='val_loss', delta=0.1) + # Initialize EarlyStopping callback + early_stopping_callback = EarlyStopping(patience=patience, verbose=True, monitor='val_loss', delta=delta) + + # Train the model trainer.train( train_loader=train_loader, - num_epochs=3, # Intentionally, one more epoch than patience as early stopping should trigger + num_epochs=num_epochs, valid_loader=test_loader, callbacks=[early_stopping_callback], ) + # Assert that early stopping was triggered assert early_stopping_callback.early_stop, "Early stopping did not trigger as expected." -test_early_stopping() +# Run the test +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/trainers/base_trainer.py b/trainers/base_trainer.py index 51b3e07..7724fdd 100644 --- a/trainers/base_trainer.py +++ b/trainers/base_trainer.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod import time from typing import Tuple - +from callbacks.early_stopping import EarlyStopping class BaseTrainer(ABC): """ @@ -133,6 +133,9 @@ def train(self, train_loader, num_epochs, valid_loader=None, callbacks=None) -> logs['val_metrics'] = {} for callback in callbacks: + if isinstance(callback, EarlyStopping): + callback.set_model_and_optimizer(self.model, self.optimizer) + callback.on_epoch_end(epoch, logs=logs) epoch_time = time.time() - epoch_start_time