Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor detection and tracking evaluation #44

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions bdd_data/class_correspondence.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"person": "person",
"rider": "rider",
"car": "car",
"bus": "bus",
"truck": "truck",
"bike": "bike",
"motor": "motor",
"traffic light": "traffic light",
"traffic sign": "traffic sign",
"train": "train",
"other": "other",
"1": "person",
"2": "rider",
"3": "car",
"4": "bus",
"5": "truck",
"6": "bike",
"7": "motor",
"8": "traffic light",
"9": "traffic sign",
"10": "train",
"11": "other"
}
103 changes: 80 additions & 23 deletions bdd_data/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
from PIL import Image
from tqdm import tqdm
import pandas as pd


def parse_args():
Expand All @@ -18,8 +19,10 @@ def parse_args():
parser.add_argument('--gt', '-g', help='path to ground truth')
parser.add_argument('--result', '-r',
help='path to results to be evaluated')
parser.add_argument('--categories', '-c', nargs='+',
help='categories to keep')
parser.add_argument('--correspondence', '-m', default='class_correspondence.json',
help='class correspondence file')
parser.add_argument('--out-path', '-o', default='results.csv',
help='output path')
args = parser.parse_args()

return args
Expand Down Expand Up @@ -180,34 +183,90 @@ def cat_pc(gt, predictions, thresholds):
return recalls, precisions, ap


def evaluate_detection(gt_path, result_path):
def merge_and_delete_classes(data, class_correspondence):

out = {}
for k, v in data.items():
k = str(k)

if not k in class_correspondence.keys():
continue

k_out = class_correspondence[k]

if k_out in out.keys():
out[k_out] += v
else:
out[k_out] = v

return out


def evaluate_detection(gt_path, result_path, class_correspondence_path):

thresholds = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]

print('Loading json files...')
gt = json.load(open(gt_path, 'r'))
pred = json.load(open(result_path, 'r'))

if type(pred[0]['name']) == int:
names = np.unique([g['name'] for g in gt])
for i in range(len(pred)):
pred[i]['name'] = names[pred[i]['name'] - 1]

cat_gt = group_by_key(gt, 'category')
cat_pred = group_by_key(pred, 'category')

# load class correspondence
with open(class_correspondence_path) as f:
class_correspondence = json.load(f)

class_correspondence = {k.encode('ascii', 'ignore'): v.encode('ascii', 'ignore') for k, v in class_correspondence.items()}

cat_gt = merge_and_delete_classes(cat_gt, class_correspondence)
cat_pred = merge_and_delete_classes(cat_pred, class_correspondence)
cat_list = sorted(cat_gt.keys())
thresholds = [0.75]

print('Evaluating...')

aps = np.zeros((len(thresholds), len(cat_list)))
for i, cat in enumerate(cat_list):
for i, cat in enumerate(tqdm(cat_list)):
if cat in cat_pred:
r, p, ap = cat_pc(cat_gt[cat], cat_pred[cat], thresholds)
aps[:, i] = ap
aps *= 100
m_ap = np.mean(aps)
mean, breakdown = m_ap, aps.flatten().tolist()

print('{:.2f}'.format(mean),
', '.join(['{:.2f}'.format(n) for n in breakdown]))
# APs and mAP
ap = aps.mean(axis=0)
results = [{cat_list[i]: ap[i] for i in range(len(cat_list))}]
indices = ['AP']

# AP50, AP75
results.append({cat_list[i]: aps[thresholds.index(0.5), i] for i in range(len(cat_list))})
indices.append('AP50')
results.append({cat_list[i]: aps[thresholds.index(0.75), i] for i in range(len(cat_list))})
indices.append('AP75')

df = pd.DataFrame(results, index=indices)
print('mAP: {}'.format(m_ap))
print(df)
return df

def evaluate_det_tracking(gt_path, result_path, cats=[]):

def evaluate_det_tracking(gt_path, result_path, class_correspondence_path):

import motmetrics as mm

gt = sorted(json.load(open(gt_path)), key=lambda l1: l1['name'])
pred = sorted(json.load(open(result_path)), key=lambda l2: l2['name'])
assert len(gt) == len(pred)

# load class correspondence
with open(class_correspondence_path) as f:
class_correspondence = json.load(f)

acc_dict = {}

print('Collecting IoU...')
Expand All @@ -225,17 +284,16 @@ def evaluate_det_tracking(gt_path, result_path, cats=[]):

cat_gt = group_by_key(im_gt['labels'], 'category')
cat_pred = group_by_key(im_pred['labels'], 'category')
cat_gt = merge_and_delete_classes(cat_gt, class_correspondence)
cat_pred = merge_and_delete_classes(cat_pred, class_correspondence)

cat_list = cat_pred.keys()

for cat in cat_list:

if not (cat in cat_gt.keys()):
continue

if len(cats) > 0:
if not (cat in cats):
continue

# initialize accumulator for each category if needed
if cat not in acc_dict.keys():
acc_dict[cat] = mm.MOTAccumulator(auto_id=True)
Expand Down Expand Up @@ -271,16 +329,13 @@ def evaluate_det_tracking(gt_path, result_path, cats=[]):
mh = mm.metrics.create()

summary = mh.compute_many([i[1] for i in acc_dict.items()],
metrics=mm.metrics.motchallenge_metrics,
metrics=['num_frames', 'mota', 'motp'],
names=[i[0] for i in acc_dict.items()])

strsummary = mm.io.render_summary(
summary,
formatters=mh.formatters,
namemap=mm.io.motchallenge_metric_names
)
summary['motp'] = (1 - summary['motp']) * 100
print(summary)

print(strsummary)
return summary


def main():
Expand All @@ -291,9 +346,11 @@ def main():
elif args.task == 'seg':
evaluate_segmentation(args.gt, args.result, 19, 17)
elif args.task == 'det':
evaluate_detection(args.gt, args.result)
elif args.task == 'det_tracking':
evaluate_det_tracking(args.gt, args.result, cats=args.categories)
results = evaluate_detection(args.gt, args.result, args.correspondence)
elif args.task == 'det-tracking':
results = evaluate_det_tracking(args.gt, args.result, args.correspondence)

results.to_csv(args.out_path)


if __name__ == '__main__':
Expand Down