-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_general.py
215 lines (153 loc) · 6.65 KB
/
main_general.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import numpy as np
import matplotlib.pyplot as plt
from data.conversion_tools import img2array, batch2img
from datasets.default_trainingsets import get_13botleftshuang, get_19SE_shuang
from methods.basic import Threshholding, local_thresholding
from plotting import concurrent
from methods.examples import neuralNet0
from neuralNetwork.optimization import find_learning_rate
from performance.testing import optimal_test_thresh, filter_non_zero, get_y_pred_thresh
from preprocessing.image import get_class_weights, get_class_imbalance, get_flow
def get_training_data(train_data):
x = train_data.get_x_train()
y_tr = train_data.get_y_train()
x_te = train_data.get_x_test()
y_te = train_data.get_y_test()
return x, y_tr, x_te, y_te
if __name__ == '__main__':
### Settings
dataset_name = '19_hand_SE' # 19_hand_SE
mod = 5 # 'clean' #'all'; 5 (everything except UVF)
b_imbalance = True
verbose = 0
epochs = 40 # 40 seems to be right
b_opt_lr = False # TODO watch out for flag setting
## amount of filters
k_lst = [1, 2, 4, 8, 16, 32]
# k_lst = [1, 2, 4, 8, 16, 32, 64, 128]
# k_lst = [16, 32, 64] # 3x3 conv
# k_lst = np.arange(16, 52, 2)
k_lst = np.arange(1, 21, 1)
# k_lst = [5] # Test new class imbalance
### Data
if 0:
a = get_19hand()
b = False
if b:
a.plot()
### Training/Validation data
img_y = a.get('annot')
y = annotations2y(img_y)
y_annot = y2bool_annot(y)
b = False
if b:
y_annot_tr, y_annot_te = panel19withoutRightBot(y_annot)
concurrent([a.get('clean'), y_annot, y_annot_tr, y_annot_te],
['clean', 'annotation', 'a annot', 'test annot'])
if 0:
train_data = get_train19_topleft(mod=mod)
elif dataset_name == '19_hand_SE':
train_data = get_19SE_shuang(mod=mod)
else:
train_data = get_13botleftshuang(mod=mod)
# TODO normalise inputs This seems to be super important...
# train_data.x = (1/255. * train_data.x).astype(np.float16)
# train_data.x = (255. * train_data.x).astype(np.float16)
x, y_tr, x_te, y_te = get_training_data(train_data)
# To get w_ext
w_ext = neuralNet0(mod=mod, k=1, verbose=1).w_ext
flow_tr = get_flow(x[0], y_tr[0],
w_patch=10, # Comes from 10
w_ext_in=w_ext
)
flow_te = get_flow(x_te[0], y_te[0],
w_patch=10, # Comes from 10
w_ext_in=w_ext
)
b = 1
class_weight = (1, 1)
if b:
# Balance the data
class_weight_tr = get_class_weights(flow_tr)
class_weight = tuple(c_i * c_j for c_i, c_j in zip(class_weight, class_weight_tr))
if b_imbalance:
# Introduce class imbalance to let the network train there is class imbalance.
class_imbalance_te = get_class_imbalance(flow_te)
b_geometric_mean = False
if b_geometric_mean:
# Act as if class imbalance (n1/n0) is only (n1/n0)**.5
# or (f1'/f0') is only (f1/f0)**.5
def f_i_geometric_mean(f_i):
return 1/((1/f_i - 1)**.5 + 1)
geometric_class_imbalance = tuple(map(f_i_geometric_mean, class_imbalance_te))
class_weight_geometric = tuple(2. * f_i for f_i in geometric_class_imbalance)
else:
"""
Introduce class imbalance through the weights
"""
class_weight_geometric = tuple(2. * f_i for f_i in class_imbalance_te)
class_weight = tuple(c_i * c_j for c_i, c_j in zip(class_weight, class_weight_geometric))
print(f'final class_weight: {class_weight}')
for k in k_lst:
print(f'\n\tk = {k}')
n = neuralNet0(mod=mod, k=k, verbose=verbose, class_weights=class_weight)
if b_opt_lr:
### Finding optimal lr
lr_opt = find_learning_rate(n.get_model(), flow_tr, class_weight=class_weight, verbose=verbose)
else:
lr_opt = 1e-0 # Fully connected
lr_opt = 1e-1 # CNN batchnorm
lr_opt = 1e0 # ti unet
lr_opt = 1e-3 # ti unet + NADAM
lr_opt = 5e-3 # ti unet + NADAM (class imbalanced!)
print(f'Optimal expected learning rate: {lr_opt}')
n = neuralNet0(mod=mod, lr=lr_opt, k=k, verbose=verbose, class_weights=class_weight)
info = f'{dataset_name}/ti_unet_k{k}'
if b_imbalance:
info += '_imbalanced'
n.train(flow_tr, flow_te, epochs=epochs, verbose=verbose, info=info)
b = False
if b:
# Model
t = Threshholding()
t.method = local_thresholding
o = t.predict(a.get('clean'))
else:
x_img = img2array(x)
y_pred = n.predict(x_img)
o = y_pred[..., 1]
b = False
if b:
# plotting results
concurrent([a.get('clean'), o], ['clean', 'prediction'])
### Evaluation
if 1:
test_thresh = optimal_test_thresh(y_pred, y_tr, y_te, verbose=verbose, d_thresh=.01)
else:
test_thresh = .96
o2 = np.greater_equal(o, test_thresh)
if 0:
concurrent([a.get('clean'), o, o2], ['clean', 'prediction', f'thresh {test_thresh}'], verbose=verbose)
### test class imbalance
y_pred_img, y_te_img = map(batch2img, (y_pred, y_te))
y_te_filter, y_pred_filter = filter_non_zero(y_te_img, y_pred_img)
def class_distribution(y):
assert len(y.shape) == 2, y.shape
n_01 = np.sum(y, axis=0)
f_01 = n_01 / sum(n_01)
return f_01
f_01_te = class_distribution(y_te_filter)
f_01_pred_50 = class_distribution(get_y_pred_thresh(y_pred_filter, thresh=.5))
f_01_pred_85 = class_distribution(get_y_pred_thresh(y_pred_filter, thresh=.85))
d_thresh = 0.01
thresh_lst = np.arange(d_thresh, 1, d_thresh)
f_01_lst = [class_distribution(get_y_pred_thresh(y_pred_filter, thresh=thresh_i)) for thresh_i in thresh_lst]
f_1_lst = list(zip(*f_01_lst))[1]
if 0:
plt.figure()
plt.plot(thresh_lst, f_1_lst)
plt.xlim(0, 1)
plt.ylim(0, 1)
plt.title('Predicted class distribution versus prediction threshold')
plt.show()
print('Done')