Skip to content

Commit

Permalink
Checkpoint test closes #16
Browse files Browse the repository at this point in the history
  • Loading branch information
pab1s committed Apr 22, 2024
1 parent 6c8f728 commit ec81830
Showing 1 changed file with 83 additions and 0 deletions.
83 changes: 83 additions & 0 deletions tests/test_checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import pytest
import os
from trainers import get_trainer
from utils.metrics import Accuracy, Precision, Recall, F1Score
from datasets.transformations import get_transforms
from datasets.dataset import get_dataset
from models import get_model
import torch
import yaml

CONFIG_TEST = {}

with open("./config/config_test.yaml", 'r') as file:
CONFIG_TEST = yaml.safe_load(file)

def test_checkpoint_functionality():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transforms = get_transforms(CONFIG_TEST)
data = get_dataset(
name=CONFIG_TEST['data']['name'],
root_dir=CONFIG_TEST['data']['dataset_path'],
train=True,
transform=transforms
)

# Split the data
train_size = int(0.64 * len(data)) # 80% for training, of which 80% is train and 20% is val
test_size = len(data) - train_size
data_train, data_test = torch.utils.data.random_split(data, [train_size, test_size], generator=torch.Generator().manual_seed(42))

train_loader = torch.utils.data.DataLoader(data_train, batch_size=CONFIG_TEST['training']['batch_size'], shuffle=True)
test_loader = torch.utils.data.DataLoader(data_test, batch_size=CONFIG_TEST['training']['batch_size'], shuffle=False)

model = get_model(CONFIG_TEST['model']['name'], CONFIG_TEST['model']['num_classes'], CONFIG_TEST['model']['pretrained']).to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam
optimizer_params = {'lr': CONFIG_TEST['training']['learning_rate']}
metrics = [Accuracy(), Precision(), Recall(), F1Score()]

trainer = get_trainer(CONFIG_TEST['trainer'], model=model, device=device)

checkpoint_dir = "./outputs/checkpoints/"
os.makedirs(checkpoint_dir, exist_ok=True)

# Build and train partially
trainer.build(
criterion=criterion,
optimizer_class=optimizer,
optimizer_params=optimizer_params,
metrics=metrics
)
trainer.train(
train_loader=train_loader,
num_epochs=6, # Train less epochs initially
checkpoint_dir=checkpoint_dir,
verbose=False
)

checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint_epoch_5.pth')
assert os.path.exists(checkpoint_dir), "Checkpoint file was not created."

# Modify the model (to simulate that loading is necessary)
for param in model.parameters():
param.data.zero_()

# Load the checkpoint
trainer.load_checkpoint(checkpoint_path)

# Continue training
trainer.train(
train_loader=train_loader,
num_epochs=2, # Complete the training
checkpoint_dir=checkpoint_dir,
verbose=False
)

# Evaluation to ensure no degradation or improvement
_, metrics_results = trainer.evaluate(test_loader, verbose=False)
assert all([v >= 0 for v in metrics_results.values()]), "Metrics after resuming are not valid."

test_checkpoint_functionality()

0 comments on commit ec81830

Please sign in to comment.