Skip to content

Commit

Permalink
Fixed evaluation bug closes #12
Browse files Browse the repository at this point in the history
  • Loading branch information
pab1s committed Mar 31, 2024
1 parent c4f7e5e commit 6b144eb
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 9 deletions.
6 changes: 2 additions & 4 deletions tests/test_fine_tuning_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,5 @@ def test_fine_tuning_loop():
verbose=False
)

assert metrics_results[0] >= 0, "Accuracy should be non-negative"
assert metrics_results[1] >= 0, "Precision should be non-negative"
assert metrics_results[2] >= 0, "Recall should be non-negative"
assert metrics_results[3] >= 0, "F1 Score should be non-negative"
assert len(metrics_results) == len(metrics)
assert all([v >= 0 for v in metrics_results.values()])
6 changes: 2 additions & 4 deletions tests/test_training_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,5 @@ def test_training_loop():
verbose=False
)

assert metrics_results[0] >= 0, "Accuracy should be non-negative"
assert metrics_results[1] >= 0, "Precision should be non-negative"
assert metrics_results[2] >= 0, "Recall should be non-negative"
assert metrics_results[3] >= 0, "F1 Score should be non-negative"
assert len(metrics_results) == len(metrics)
assert all([v >= 0 for v in metrics_results.values()])
2 changes: 1 addition & 1 deletion trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def train(self, train_loader, num_epochs, log_path=None, plot_path=None, verbose

def evaluate(self, test_loader, metrics=[], verbose=True) -> dict:
""" Evaluate the model on the test set using provided metrics. """
if len(metrics) == 0:
if len(metrics) > 0:
self.metrics = metrics

self.model.eval()
Expand Down

0 comments on commit 6b144eb

Please sign in to comment.