-
Notifications
You must be signed in to change notification settings - Fork 4
/
score.py
56 lines (51 loc) · 2.15 KB
/
score.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from __future__ import division
import caffe
import numpy as np
import os
import sys
from datetime import datetime
from PIL import Image
def fast_hist(a, b, n):
k = (a >= 0) & (a < n)
return np.bincount(n * a[k].astype(int) + b[k], minlength=n**2).reshape(n, n)
def compute_hist(net, save_dir, dataset, layer='score', gt='label'):
n_cl = net.blobs[layer].channels
if save_dir:
os.mkdir(save_dir)
hist = np.zeros((n_cl, n_cl))
loss = 0
for idx in dataset:
net.forward()
hist += fast_hist(net.blobs[gt].data[0, 0].flatten(),
net.blobs[layer].data[0].argmax(0).flatten(),
n_cl)
if save_dir:
im = Image.fromarray(net.blobs[layer].data[0].argmax(0).astype(np.uint8), mode='P')
im.save(os.path.join(save_dir, idx + '.png'))
# compute the loss as well
loss += net.blobs['loss'].data.flat[0]
return hist, loss / len(dataset)
def seg_tests(solver, save_format, dataset, layer='score', gt='label'):
print '>>>', datetime.now(), 'Begin seg tests'
solver.test_nets[0].share_with(solver.net)
do_seg_tests(solver.test_nets[0], solver.iter, save_format, dataset, layer, gt)
def do_seg_tests(net, iter, save_format, dataset, layer='score', gt='label'):
n_cl = net.blobs[layer].channels
if save_format:
save_format = save_format.format(iter)
hist, loss = compute_hist(net, save_format, dataset, layer, gt)
# mean loss
print '>>>', datetime.now(), 'Iteration', iter, 'loss', loss
# overall accuracy
acc = np.diag(hist).sum() / hist.sum()
print '>>>', datetime.now(), 'Iteration', iter, 'overall accuracy', acc
# per-class accuracy
acc = np.diag(hist) / hist.sum(1)
print '>>>', datetime.now(), 'Iteration', iter, 'mean accuracy', np.nanmean(acc)
# per-class IU
iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
print '>>>', datetime.now(), 'Iteration', iter, 'mean IU', np.nanmean(iu)
freq = hist.sum(1) / hist.sum()
print '>>>', datetime.now(), 'Iteration', iter, 'fwavacc', \
(freq[freq > 0] * iu[freq > 0]).sum()
return hist