From 53c8d30f1bbc98ee1631b28f40189061e3d83df5 Mon Sep 17 00:00:00 2001 From: alexchf Date: Thu, 1 Aug 2019 13:56:55 -0700 Subject: [PATCH 1/3] fix detection evaluation --- bdd_data/evaluate.py | 110 ++++++++++++++++++++++++++++++++++--------- bdd_data/results.csv | 4 ++ 2 files changed, 93 insertions(+), 21 deletions(-) create mode 100644 bdd_data/results.csv diff --git a/bdd_data/evaluate.py b/bdd_data/evaluate.py index ae0c3c5..4817a2f 100644 --- a/bdd_data/evaluate.py +++ b/bdd_data/evaluate.py @@ -1,3 +1,4 @@ + import argparse import json import os @@ -8,6 +9,7 @@ import numpy as np from PIL import Image from tqdm import tqdm +import pandas as pd def parse_args(): @@ -20,6 +22,8 @@ def parse_args(): help='path to results to be evaluated') parser.add_argument('--categories', '-c', nargs='+', help='categories to keep') + parser.add_argument('--out-path', '-o', default='results.csv', + help='output path') args = parser.parse_args() return args @@ -170,37 +174,99 @@ def cat_pc(gt, predictions, thresholds): fp = np.cumsum(fp, axis=0) tp = np.cumsum(tp, axis=0) recalls = tp / float(num_gts) + # avoid divide by zero in case the first detection matches a difficult # ground truth precisions = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) - ap = np.zeros(len(thresholds)) + for t in range(len(thresholds)): - ap[t] = get_ap(recalls[:, t], precisions[:, t]) - + for i in range(len(precisions) - 1, -1, -1): + if precisions[i, t] > precisions[i - 1, t]: + precisions[i - 1, t] = precisions[i, t] + + recall_thresholds = np.linspace(.0, 1.00, np.round((1.00 - .0) / .01) + 1, endpoint=True) + + q = np.zeros((len(recall_thresholds), len(thresholds))) + for t in range(len(thresholds)): + inds = np.searchsorted(recalls[:, t], recall_thresholds, side='left') + try: + for ri, pi in enumerate(inds): + q[ri, t] = precisions[pi, t] + if ri > 0 and q[ri-1, t] < q[ri, t]: + q[ri-1, t] = q[ri, t] + except: + pass + + ap = np.mean(q, axis=0) return recalls, precisions, ap def evaluate_detection(gt_path, result_path): + + thresholds = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95] + + category_dict = { + 1: 'person', + 2: 'rider', + 3: 'car', + 4: 'bus', + 5: 'truck', + 6: 'bike', + 7: 'motor', + 8: 'traffic light', + 9: 'traffic sign', + 10: 'train' + } + + print('Loading json files...') gt = json.load(open(gt_path, 'r')) pred = json.load(open(result_path, 'r')) + for i in pred: + i['category'] = category_dict[i['category']] + + if type(pred[0]['name']) == int: + names = np.unique([g['name'] for g in gt]) + for i in range(len(pred)): + pred[i]['name'] = names[pred[i]['name'] - 1] + cat_gt = group_by_key(gt, 'category') cat_pred = group_by_key(pred, 'category') + cat_list = sorted(cat_gt.keys()) - thresholds = [0.75] - aps = np.zeros((len(thresholds), len(cat_list))) - for i, cat in enumerate(cat_list): + + print('Evaluating...') + + aps = -np.ones((len(thresholds), len(cat_list))) + for i, cat in enumerate(tqdm(cat_list)): if cat in cat_pred: r, p, ap = cat_pc(cat_gt[cat], cat_pred[cat], thresholds) aps[:, i] = ap - aps *= 100 - m_ap = np.mean(aps) - mean, breakdown = m_ap, aps.flatten().tolist() + else: + aps[:, i] = 0 + aps[aps > -1] *= 100 + m_ap = np.mean(aps[aps > -1]) - print('{:.2f}'.format(mean), - ', '.join(['{:.2f}'.format(n) for n in breakdown])) + # APs and mAP + ap = aps.mean(axis=0) + results = [{cat_list[i]: ap[i] for i in range(len(cat_list))}] + indices = ['AP'] + # AP50, AP75 + results.append({cat_list[i]: aps[thresholds.index(0.5), i] for i in range(len(cat_list))}) + indices.append('AP50') + results.append({cat_list[i]: aps[thresholds.index(0.75), i] for i in range(len(cat_list))}) + indices.append('AP75') -def evaluate_det_tracking(gt_path, result_path, cats=[]): + for r in results: + r['MEAN'] = np.mean([i for i in r.values() if i != -1]) + + df = pd.DataFrame(results, index=indices) + print('mAP: {}'.format(m_ap)) + print(df) + return df + + +def evaluate_det_tracking(gt_path, result_path): import motmetrics as mm @@ -232,10 +298,6 @@ def evaluate_det_tracking(gt_path, result_path, cats=[]): if not (cat in cat_gt.keys()): continue - if len(cats) > 0: - if not (cat in cats): - continue - # initialize accumulator for each category if needed if cat not in acc_dict.keys(): acc_dict[cat] = mm.MOTAccumulator(auto_id=True) @@ -281,20 +343,26 @@ def evaluate_det_tracking(gt_path, result_path, cats=[]): ) print(strsummary) + return summary + def main(): args = parse_args() - + if args.task == 'drivable': evaluate_drivable(args.gt, args.result) elif args.task == 'seg': evaluate_segmentation(args.gt, args.result, 19, 17) elif args.task == 'det': - evaluate_detection(args.gt, args.result) - elif args.task == 'det_tracking': - evaluate_det_tracking(args.gt, args.result, cats=args.categories) - + results = evaluate_detection(args.gt, args.result) + if args.out_path: + results.to_csv(args.out_path) + elif args.task == 'det-tracking': + results = evaluate_det_tracking(args.gt, args.result) + if args.out_path: + results.to_csv(args.out_path) + if __name__ == '__main__': main() diff --git a/bdd_data/results.csv b/bdd_data/results.csv new file mode 100644 index 0000000..e854a67 --- /dev/null +++ b/bdd_data/results.csv @@ -0,0 +1,4 @@ +,MEAN,car,motor,rider,traffic light,traffic sign +AP,21.28735486860323,48.84044328133856,0.0,0.0,11.896189618961904,45.70014144271569 +AP50,33.66663006491897,80.91297973610743,0.0,0.0,27.872787278727895,59.54738330975952 +AP75,20.6968773800457,52.989337395278,0.0,0.0,0.0,50.495049504950494 From 7d7a844acea985b8fab85623cd1748854ec76a0e Mon Sep 17 00:00:00 2001 From: alexchf Date: Thu, 1 Aug 2019 14:03:19 -0700 Subject: [PATCH 2/3] fix off-by-1 error --- .gitignore | 2 ++ bdd_data/evaluate.py | 12 ++++++------ bdd_data/results.csv | 4 ---- 3 files changed, 8 insertions(+), 10 deletions(-) delete mode 100644 bdd_data/results.csv diff --git a/.gitignore b/.gitignore index 3adf195..c5cf71b 100644 --- a/.gitignore +++ b/.gitignore @@ -103,3 +103,5 @@ ENV/ # mypy .mypy_cache/ + +*.csv diff --git a/bdd_data/evaluate.py b/bdd_data/evaluate.py index 4817a2f..9b559a2 100644 --- a/bdd_data/evaluate.py +++ b/bdd_data/evaluate.py @@ -147,14 +147,14 @@ def cat_pc(gt, predictions, thresholds): iymin = np.maximum(gt_boxes[:, 1], box[1]) ixmax = np.minimum(gt_boxes[:, 2], box[2]) iymax = np.minimum(gt_boxes[:, 3], box[3]) - iw = np.maximum(ixmax - ixmin + 1., 0.) - ih = np.maximum(iymax - iymin + 1., 0.) + iw = np.maximum(ixmax - ixmin, 0.) + ih = np.maximum(iymax - iymin, 0.) inters = iw * ih # union - uni = ((box[2] - box[0] + 1.) * (box[3] - box[1] + 1.) + - (gt_boxes[:, 2] - gt_boxes[:, 0] + 1.) * - (gt_boxes[:, 3] - gt_boxes[:, 1] + 1.) - inters) + uni = ((box[2] - box[0]) * (box[3] - box[1]) + + (gt_boxes[:, 2] - gt_boxes[:, 0]) * + (gt_boxes[:, 3] - gt_boxes[:, 1]) - inters) overlaps = inters / uni ovmax = np.max(overlaps) @@ -349,7 +349,7 @@ def evaluate_det_tracking(gt_path, result_path): def main(): args = parse_args() - + if args.task == 'drivable': evaluate_drivable(args.gt, args.result) elif args.task == 'seg': diff --git a/bdd_data/results.csv b/bdd_data/results.csv deleted file mode 100644 index e854a67..0000000 --- a/bdd_data/results.csv +++ /dev/null @@ -1,4 +0,0 @@ -,MEAN,car,motor,rider,traffic light,traffic sign -AP,21.28735486860323,48.84044328133856,0.0,0.0,11.896189618961904,45.70014144271569 -AP50,33.66663006491897,80.91297973610743,0.0,0.0,27.872787278727895,59.54738330975952 -AP75,20.6968773800457,52.989337395278,0.0,0.0,0.0,50.495049504950494 From 4839dbbcfbc7210f59d08df808957aee4d5ac157 Mon Sep 17 00:00:00 2001 From: Haofeng Chen Date: Thu, 1 Aug 2019 16:32:58 -0700 Subject: [PATCH 3/3] Update evaluate.py --- bdd_data/evaluate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bdd_data/evaluate.py b/bdd_data/evaluate.py index 9b559a2..33a7958 100644 --- a/bdd_data/evaluate.py +++ b/bdd_data/evaluate.py @@ -115,6 +115,7 @@ def group_by_key(detections, key): def cat_pc(gt, predictions, thresholds): """ Implementation refers to https://github.com/rbgirshick/py-faster-rcnn + and https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py """ num_gts = len(gt) image_gts = group_by_key(gt, 'name')