diff --git a/tests/test_fine_tuning_pipeline.py b/tests/test_fine_tuning_pipeline.py index 8f592d9..d4d7e9a 100644 --- a/tests/test_fine_tuning_pipeline.py +++ b/tests/test_fine_tuning_pipeline.py @@ -76,7 +76,5 @@ def test_fine_tuning_loop(): verbose=False ) - assert metrics_results[0] >= 0, "Accuracy should be non-negative" - assert metrics_results[1] >= 0, "Precision should be non-negative" - assert metrics_results[2] >= 0, "Recall should be non-negative" - assert metrics_results[3] >= 0, "F1 Score should be non-negative" + assert len(metrics_results) == len(metrics) + assert all([v >= 0 for v in metrics_results.values()]) diff --git a/tests/test_training_pipeline.py b/tests/test_training_pipeline.py index 48f1ff8..1776632 100644 --- a/tests/test_training_pipeline.py +++ b/tests/test_training_pipeline.py @@ -57,7 +57,5 @@ def test_training_loop(): verbose=False ) - assert metrics_results[0] >= 0, "Accuracy should be non-negative" - assert metrics_results[1] >= 0, "Precision should be non-negative" - assert metrics_results[2] >= 0, "Recall should be non-negative" - assert metrics_results[3] >= 0, "F1 Score should be non-negative" + assert len(metrics_results) == len(metrics) + assert all([v >= 0 for v in metrics_results.values()]) diff --git a/trainers/base_trainer.py b/trainers/base_trainer.py index bc48e6a..e499127 100644 --- a/trainers/base_trainer.py +++ b/trainers/base_trainer.py @@ -94,7 +94,7 @@ def train(self, train_loader, num_epochs, log_path=None, plot_path=None, verbose def evaluate(self, test_loader, metrics=[], verbose=True) -> dict: """ Evaluate the model on the test set using provided metrics. """ - if len(metrics) == 0: + if len(metrics) > 0: self.metrics = metrics self.model.eval()