-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathviz_utils.py
273 lines (191 loc) · 10.3 KB
/
viz_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
import os
import numpy as np
import matplotlib.pyplot as plt
import skimage, skimage.io
import tarfile
def plot_data(data, relevance, query = None, retrieved = None, ax = None):
""" Plots a 2-dimensional dataset.
# Arguments:
- data: n-by-2 data array of n 2-dimensional samples.
- relevance: vector of length n specifying the relevance of the samples
(entries equal to 1 are considered relevant).
- query: optionally, a query vector with 2 elements.
- retrieved: optionally, a list of indices of retrieved samples.
- ax: the Axis instance to draw the plot on. If `None`, the global pyplot
object will be used.
"""
if ax is None:
ax = plt
if retrieved is None:
retrieved = []
not_retrieved = np.setdiff1d(np.arange(len(data)), retrieved)
colors = np.where(np.asarray(relevance) == 1, 'b', 'gray')
colors_ret = np.where(np.asarray(relevance) == 1, 'c', 'orange')
if ax == plt:
plt.figure(figsize = (6.25, 5))
ax.scatter(data[not_retrieved,0], data[not_retrieved,1], c = colors[not_retrieved], s = 15)
if len(retrieved) > 0:
ax.scatter(data[retrieved,0], data[retrieved,1], c = colors_ret[retrieved], s = 15)
if query is not None:
query = np.asarray(query)
if query.ndim == 1:
query = query[None,:]
ax.scatter(query[:,0], query[:,1], c = 'r', s = 15)
if ax == plt:
plt.show()
def plot_distribution(data, prob, query = None, ax = None):
""" Plots the estimated relevance scores of a 2-dimensional dataset.
# Arguments:
- data: n-by-2 data array of n 2-dimensional samples.
- prob: vector of length n containing the estimated relevance scores of the samples.
- query: optionally, a query vector with 2 elements.
- ax: the Axis instance to draw the plot on. If `None`, the global pyplot
object will be used.
"""
if ax is None:
ax = plt
if ax == plt:
plt.figure(figsize = (6.25, 5))
prob_min, prob_max = prob.min(), prob.max()
ax.scatter(data[:,0], data[:,1], c = plt.cm.viridis((prob - prob_min) / (prob_max - prob_min)), s = 15)
if query is not None:
query = np.asarray(query)
if query.ndim == 1:
query = query[None,:]
ax.scatter(query[:,0], query[:,1], c = 'r', s = 15)
if ax == plt:
plt.show()
def plot_dist_and_topk(data, relevance, prob, query = None, k = 25):
""" Plots and shows the estimated relevance scores and the top k retrieved samples of a 2-dimensional dataset.
# Arguments:
- data: n-by-2 data array of n 2-dimensional samples.
- relevance: Vector of length n specifying the relevance of the samples
(entries equal to 1 are considered relevant).
- prob: Vector of length n containing the estimated relevance scores of the samples.
- query: Optionally, a query vector with 2 elements.
- k: the number of top retrieved samples to be shown.
"""
retrieved = np.argsort(prob)[::-1][:k]
fig, ax = plt.subplots(1, 2, figsize = (14, 5))
plot_distribution(data, prob, query, ax = ax[0])
plot_data(data, relevance, query, retrieved, ax = ax[1])
plt.show()
def plot_learning_step(dataset, queries, relevance, learner, ret, fb):
""" Plots and shows a single active learning step.
The output of this function differs depending on the type of data:
- If the dataset provides an `imgs_train` attribute, this function will plot the query image,
the images in the current active learning batch, and the top few images retrieved using the
current classifier.
- If the dataset otherwise contains 2-dimensional data, this will show the 2-d plots of the
current active learning batch, the samples annotated by the user, the estimated relevance
scores of the entire dataset after updating the learner, and the top few samples retrieved
using that updated classifier.
# Arguments:
- dataset: a datasets.Dataset instance.
- queries: the index of the query image in dataset.X_train (may also be a list of query indices).
- relevance: the ground-truth relevance labels of all samples in dataset.X_train.
- learner: an ital.retrieval_base.ActiveRetrievalBase instance.
- ret: the indices of the samples selected for the current active learning batch.
- fb: a list of feedback provided for each sample in the current batch. Possible feedback values
are -1 (irrelevant), 1 (relevant) or 0 (no feedback).
"""
if isinstance(queries, int):
queries = [queries]
if dataset.imgs_train is not None:
cols = max([10, len(queries), len(ret)])
fig, axes = plt.subplots(6, cols, figsize = (cols, 6))
for query, ax in zip(queries, axes[0]):
ax.imshow(canonicalize_image(dataset.imgs_train[query]), interpolation = 'bicubic', cmap = plt.cm.gray)
ax.set_title(canonicalize_img_name(dataset.imgs_train[query]), fontsize = 6)
for r, f, ax in zip(ret, fb, axes[1]):
ax.imshow(canonicalize_image(dataset.imgs_train[r]), interpolation = 'bicubic', cmap = plt.cm.gray)
ax.set_title(canonicalize_img_name(dataset.imgs_train[r]), fontsize = 6)
for side in ('bottom','top','left','right'):
ax.spines[side].set_color('k' if f == 0 else ('g' if f > 0 else 'r'))
top_ret = np.argsort(learner.gp.predict(dataset.X_test_norm))[::-1][:cols*(len(axes)-2)]
for r, ax in zip(top_ret, axes[2:].ravel()):
ax.imshow(canonicalize_image(dataset.imgs_test[r]), interpolation = 'bicubic', cmap = plt.cm.gray)
ax.set_title(canonicalize_img_name(dataset.imgs_test[r]), fontsize = 6)
for ax in axes.ravel():
ax.tick_params(left='off', bottom='off', labelleft='off', labelbottom='off')
for ax in axes[0,1:]:
ax.axis('off')
for ax in axes[1,len(ret):]:
ax.axis('off')
fig.tight_layout()
plt.show()
elif dataset.X_train.shape[1] == 2:
fig, axes = plt.subplots(2, 2, figsize = (10, 7))
axes[0,0].set_title('Active Learning Batch')
axes[0,1].set_title('Labelled Examples')
axes[1,0].set_title('Relevance Distribution')
axes[1,1].set_title('Retrieval')
plot_data(dataset.X_train, relevance, dataset.X_train[queries], ret, axes[0,0])
plot_data(dataset.X_train, relevance, dataset.X_train[queries], [r for i, r in enumerate(ret) if fb[i] != 0], axes[0,1])
plot_distribution(dataset.X_train, learner.rel_mean, dataset.X_train[queries], axes[1,0])
plot_data(dataset.X_train, relevance, dataset.X_train[queries], np.argsort(learner.rel_mean)[::-1][:np.sum(relevance > 0)], axes[1,1])
fig.tight_layout()
plt.show()
else:
raise RuntimeError("Don't know how to plot this dataset.")
def plot_regression_step(dataset, init, learner, ret, fb):
""" Plots and shows a single active regression step for 2-dimensional data.
# Arguments:
- dataset: a datasets.RegressionDataset instance.
- init: list of indices of the initial training samples in dataset.X_train.
- learner: an ital.regression_base.ActiveRegressionBase instance.
- ret: the indices of the samples selected for the current active learning batch.
- fb: a list of feedback provided for each sample in the current batch.
"""
if isinstance(init, int):
init = [init]
if dataset.X_train.shape[1] == 2:
fig, axes = plt.subplots(1, 3, figsize = (12, 4))
axes[0].set_title('Active Learning Batch')
axes[1].set_title('Labelled Examples')
axes[2].set_title('Relevance Distribution')
plot_data(dataset.X_train, [0] * len(dataset.X_train), dataset.X_train[init], ret, axes[0])
plot_distribution(dataset.X_train, dataset.y_train, dataset.X_train[[r for i, r in enumerate(ret) if fb[i] is not None]], axes[1])
plot_distribution(dataset.X_train, learner.mean, np.zeros((0,2)), axes[2])
fig.tight_layout()
plt.show()
else:
raise RuntimeError("Don't know how to plot this dataset.")
def canonicalize_image(img, color = True, channels_first = False):
""" Converts an image to the canonical format, i.e., a `numpy.ndarray` with shape HxWx3, where the last axis represents RGB tuples with values in [0,1].
If `color` is set to `False`, the last axis will be of size 1.
If `channels_first` is set to True, the channel axis will be the first instead of the last axis.
`img` can be one of the following:
- a `HxW` `numpy.ndarray` giving a grayscale image,
- a `HxWx3` `numpy.ndarray` giving a color image,
- a string giving the filename of the image,
- a tuple consisting of the path to a tarfile and either the name of the member or a corresponding `tarfile.TarInfo` object.
"""
if isinstance(img, str):
img = skimage.io.imread(img, as_grey = not color, img_num = 0)
elif (isinstance(img, tuple) or isinstance(img, list)) and (len(img) == 2):
with tarfile.open(img[0]) as tf:
img = skimage.io.imread(tf.extractfile(img[1]), as_grey = not color, img_num = 0)
img = skimage.img_as_float(img).astype(np.float32, copy = False)
if img.ndim == 2:
img = img[:, :, np.newaxis]
if color:
img = np.tile(img, (1, 1, 3))
elif img.shape[2] == 4:
img = img[:, :, :3]
if (not color) and (img.shape[2] == 3):
img = np.mean(img, axis = 2, keepdims = True)
if channels_first:
img = np.transpose(img, (2, 0, 1))
return img
def canonicalize_img_name(img):
""" Returns the filename of an image without directory path and file extension.
# Arguments:
- img: Either a path to a file or a tuple of length 2 whose second element contains that path.
# Returns:
the filename without directory path and file extension
"""
if isinstance(img, str):
return os.path.splitext(os.path.basename(img))[0]
elif (isinstance(img, tuple) or isinstance(img, list)) and (len(img) == 2) and isinstance(img[1], str):
return os.path.splitext(img[1])[0]