forked from kevinzakka/tsne-viz
-
Notifications
You must be signed in to change notification settings - Fork 0
/
vis_utils.py
89 lines (73 loc) · 2.58 KB
/
vis_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import numpy as np
import matplotlib.pyplot as plt
label_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandals', 'Shirt', 'Sneaker', 'Bag', 'Ankle boots']
def plot_9by9(images, cls_true):
assert len(images) == len(cls_true) == 9
fig, axes = plt.subplots(3, 3)
for i, ax in enumerate(axes.flat):
# plot the image
ax.imshow(images[i], cmap="Greys_r")
# get its equivalent class name
cls_true_name = label_names[cls_true[i]]
xlabel = "{0} ({1})".format(cls_true_name, cls_true[i])
ax.set_xlabel(xlabel)
ax.set_xticks([])
ax.set_yticks([])
plt.savefig('/Users/kevin/Desktop/' + 'shit.png', format='png', dpi=1000)
plt.show()
def visualize_grid(Xs, ubound=255.0, padding=1):
"""
Reshape a 4D tensor of image data to a grid for easy visualization.
Inputs
------
- Xs: Data of shape (N, H, W, C)
- ubound: Output grid will have values scaled to the range [0, ubound]
- padding: The number of blank pixels between elements of the grid
Returns
-------
- grid
References
----------
- Adapted from CS231n - http://cs231n.github.io/
"""
(N, H, W, C) = Xs.shape
grid_size = int(np.ceil(np.sqrt(N)))
grid_height = H * grid_size + padding * (grid_size - 1)
grid_width = W * grid_size + padding * (grid_size - 1)
grid = np.zeros((grid_height, grid_width, C))
next_idx = 0
y0, y1 = 0, H
for y in range(grid_size):
x0, x1 = 0, W
for x in range(grid_size):
if next_idx < N:
img = Xs[next_idx]
low, high = np.min(img), np.max(img)
grid[y0:y1, x0:x1] = ubound * (img - low) / (high - low)
next_idx += 1
x0 += W + padding
x1 += W + padding
y0 += H + padding
y1 += H + padding
return grid
def view_images(X, ubound=1.0, save=False, name=''):
""" Quick helper function to view rgb or gray images."""
if X.ndim == 3:
H, W, C = X.shape
X = X.reshape(H, W, C, 1)
grid = visualize_grid(X, ubound)
H, W, C = grid.shape
grid = grid.reshape((H, W))
plt.imshow(grid, cmap="Greys_r")
if save:
plt.savefig('/Users/kevin/Desktop/' + name, format='png', dpi=1000)
plt.show()
elif X.ndim == 4:
grid = visualize_grid(X, ubound)
plt.imshow(grid)
if save:
plt.savefig('/Users/kevin/Desktop/' + name, format='png', dpi=1000)
plt.show()
else:
raise ValueError