diff --git a/tests/test_checkpoints.py b/tests/test_checkpoints.py index 2e87dcd..c8513b3 100644 --- a/tests/test_checkpoints.py +++ b/tests/test_checkpoints.py @@ -24,8 +24,7 @@ def test_checkpoint_functionality(): transform=transforms ) - # Split the data - train_size = int(0.64 * len(data)) # 80% for training, of which 80% is train and 20% is val + train_size = int(0.64 * len(data)) test_size = len(data) - train_size data_train, data_test = torch.utils.data.random_split(data, [train_size, test_size], generator=torch.Generator().manual_seed(42)) @@ -44,7 +43,6 @@ def test_checkpoint_functionality(): checkpoint_dir = "./outputs/checkpoints/" os.makedirs(checkpoint_dir, exist_ok=True) - # Build and train partially trainer.build( criterion=criterion, optimizer_class=optimizer, @@ -53,7 +51,7 @@ def test_checkpoint_functionality(): ) trainer.train( train_loader=train_loader, - num_epochs=6, # Train less epochs initially + num_epochs=6, checkpoint_dir=checkpoint_dir, verbose=False ) @@ -61,22 +59,18 @@ def test_checkpoint_functionality(): checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint_epoch_5.pth') assert os.path.exists(checkpoint_dir), "Checkpoint file was not created." - # Modify the model (to simulate that loading is necessary) for param in model.parameters(): param.data.zero_() - # Load the checkpoint trainer.load_checkpoint(checkpoint_path) - # Continue training trainer.train( train_loader=train_loader, - num_epochs=2, # Complete the training + num_epochs=2, checkpoint_dir=checkpoint_dir, verbose=False ) - # Evaluation to ensure no degradation or improvement _, metrics_results = trainer.evaluate(test_loader, verbose=False) assert all([v >= 0 for v in metrics_results.values()]), "Metrics after resuming are not valid."