From 6b144eb0e2003adebb3146680961bb7868cb4e87 Mon Sep 17 00:00:00 2001 From: Pablo Olivares Date: Sun, 31 Mar 2024 11:19:21 +0200 Subject: [PATCH] Fixed evaluation bug closes #12 --- tests/test_fine_tuning_pipeline.py | 6 ++---- tests/test_training_pipeline.py | 6 ++---- trainers/base_trainer.py | 2 +- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/tests/test_fine_tuning_pipeline.py b/tests/test_fine_tuning_pipeline.py index 8f592d9..d4d7e9a 100644 --- a/tests/test_fine_tuning_pipeline.py +++ b/tests/test_fine_tuning_pipeline.py @@ -76,7 +76,5 @@ def test_fine_tuning_loop(): verbose=False ) - assert metrics_results[0] >= 0, "Accuracy should be non-negative" - assert metrics_results[1] >= 0, "Precision should be non-negative" - assert metrics_results[2] >= 0, "Recall should be non-negative" - assert metrics_results[3] >= 0, "F1 Score should be non-negative" + assert len(metrics_results) == len(metrics) + assert all([v >= 0 for v in metrics_results.values()]) diff --git a/tests/test_training_pipeline.py b/tests/test_training_pipeline.py index 48f1ff8..1776632 100644 --- a/tests/test_training_pipeline.py +++ b/tests/test_training_pipeline.py @@ -57,7 +57,5 @@ def test_training_loop(): verbose=False ) - assert metrics_results[0] >= 0, "Accuracy should be non-negative" - assert metrics_results[1] >= 0, "Precision should be non-negative" - assert metrics_results[2] >= 0, "Recall should be non-negative" - assert metrics_results[3] >= 0, "F1 Score should be non-negative" + assert len(metrics_results) == len(metrics) + assert all([v >= 0 for v in metrics_results.values()]) diff --git a/trainers/base_trainer.py b/trainers/base_trainer.py index bc48e6a..e499127 100644 --- a/trainers/base_trainer.py +++ b/trainers/base_trainer.py @@ -94,7 +94,7 @@ def train(self, train_loader, num_epochs, log_path=None, plot_path=None, verbose def evaluate(self, test_loader, metrics=[], verbose=True) -> dict: """ Evaluate the model on the test set using provided metrics. """ - if len(metrics) == 0: + if len(metrics) > 0: self.metrics = metrics self.model.eval()