diff --git a/recbole/evaluator/metrics.py b/recbole/evaluator/metrics.py index 4f60a31a0..eb0678c04 100644 --- a/recbole/evaluator/metrics.py +++ b/recbole/evaluator/metrics.py @@ -27,7 +27,7 @@ import numpy as np from collections import Counter from sklearn.metrics import auc as sk_auc -from sklearn.metrics import mean_absolute_error, mean_squared_error +from sklearn.metrics import mean_absolute_error, mean_squared_error, average_precision_score, accuracy_score, f1_score, precision_score, recall_score from recbole.evaluator.utils import _binary_clf_curve from recbole.evaluator.base_metric import AbstractMetric, TopkMetric, LossMetric @@ -380,6 +380,61 @@ def metric_info(self, preds, trues): return result +class AP(LossMetric): + def __init__(self, config): + super().__init__(config) + + def calculate_metric(self, dataobject): + return self.output_metric("ap", dataobject) + + def metric_info(self, preds, trues): + return average_precision_score(trues, preds) + + +class ACC(LossMetric): + def __init__(self, config): + super().__init__(config) + + def calculate_metric(self, dataobject): + return self.output_metric("acc", dataobject) + + def metric_info(self, preds, trues): + return accuracy_score(trues, preds > 0.5) + + +class Preci(LossMetric): + def __init__(self, config): + super().__init__(config) + + def calculate_metric(self, dataobject): + return self.output_metric("precision", dataobject) + + def metric_info(self, preds, trues): + return precision_score(trues, preds > 0.5) + + +class Recal(LossMetric): + def __init__(self, config): + super().__init__(config) + + def calculate_metric(self, dataobject): + return self.output_metric("recall", dataobject) + + def metric_info(self, preds, trues): + return recall_score(trues, preds > 0.5, zero_division=0) + + +class F1(LossMetric): + def __init__(self, config): + super().__init__(config) + + def calculate_metric(self, dataobject): + return self.output_metric("f1_score", dataobject) + + def metric_info(self, preds, trues): + return f1_score(trues, preds > 0.5, zero_division=0) + + # Loss-based Metrics