Skip to content

Commit

Permalink
Adapted tests to work with callbacks advances #17
Browse files Browse the repository at this point in the history
  • Loading branch information
pab1s committed Apr 22, 2024
1 parent 3058a04 commit 7c336f4
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 8 deletions.
5 changes: 2 additions & 3 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from trainers import get_trainer
from callbacks.early_stopping import EarlyStopping
from callbacks import EarlyStopping
from utils.metrics import Accuracy
from datasets.transformations import get_transforms
from datasets.dataset import get_dataset
Expand Down Expand Up @@ -53,10 +53,9 @@ def test_early_stopping():
early_stopping_callback = EarlyStopping(patience=2, verbose=True, monitor='val_loss', delta=0.1)
trainer.train(
train_loader=train_loader,
num_epochs=4, # Intentionally, one more epoch than patience as early stopping should trigger
num_epochs=3, # Intentionally, one more epoch than patience as early stopping should trigger
valid_loader=test_loader,
callbacks=[early_stopping_callback],
verbose=False
)

assert early_stopping_callback.early_stop, "Early stopping did not trigger as expected."
Expand Down
2 changes: 0 additions & 2 deletions tests/test_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def test_checkpoint_functionality():
train_loader=train_loader,
num_epochs=6,
checkpoint_dir=checkpoint_dir,
verbose=False
)

checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint_epoch_5.pth')
Expand All @@ -68,7 +67,6 @@ def test_checkpoint_functionality():
train_loader=train_loader,
num_epochs=2,
checkpoint_dir=checkpoint_dir,
verbose=False
)

_, metrics_results = trainer.evaluate(test_loader, verbose=False)
Expand Down
3 changes: 1 addition & 2 deletions tests/test_fine_tuning_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ def test_fine_tuning_loop():
freeze_until_layer=CONFIG_TEST['training'].get('freeze_until_layer'),
metrics=metrics
)

trainer.train(
train_loader=train_loader,
valid_loader=valid_loader,
num_epochs=CONFIG_TEST['training']['num_epochs'],
verbose=False
)

trainer.unfreeze_all_layers()
Expand All @@ -99,7 +99,6 @@ def test_fine_tuning_loop():
train_loader=train_loader,
valid_loader=valid_loader,
num_epochs=CONFIG_TEST['training']['num_epochs'],
verbose=False
)

_, metrics_results = trainer.evaluate(
Expand Down
1 change: 0 additions & 1 deletion tests/test_training_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def test_training_loop():
train_loader=train_loader,
valid_loader=valid_loader,
num_epochs=CONFIG_TEST['training']['num_epochs'],
verbose=False
)
_, metrics_results = trainer.evaluate(
data_loader=test_loader,
Expand Down

0 comments on commit 7c336f4

Please sign in to comment.