-
Notifications
You must be signed in to change notification settings - Fork 477
/
Copy pathmerge_prune.py
103 lines (81 loc) · 3.97 KB
/
merge_prune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from __future__ import print_function
import os
import numpy as np
# from six.moves.urllib.request import urlretrieve
# from six.moves import cPickle as pickle
import cPickle as pickle
from not_mnist.img_pickle import maybe_pickle, save_obj
image_size = 28 # Pixel width and height.
def make_arrays(nb_rows, img_size):
if nb_rows:
dataset = np.ndarray((nb_rows, img_size, img_size), dtype=np.float32)
labels = np.ndarray(nb_rows, dtype=np.int32)
else:
dataset, labels = None, None
return dataset, labels
def merge_datasets(pickle_files, train_size, valid_size=0):
num_classes = len(pickle_files)
valid_dataset, valid_labels = make_arrays(valid_size, image_size)
train_dataset, train_labels = make_arrays(train_size, image_size)
vsize_per_class = valid_size // num_classes
tsize_per_class = train_size // num_classes
start_v, start_t = 0, 0
end_v, end_t = vsize_per_class, tsize_per_class
end_l = vsize_per_class + tsize_per_class
for label, pickle_file in enumerate(pickle_files):
try:
with open(pickle_file, 'rb') as f:
letter_set = pickle.load(f)
# let's shuffle the letters to have random validation and training set
np.random.shuffle(letter_set)
if valid_dataset is not None: # None for test dataSet
valid_letter = letter_set[:vsize_per_class, :, :]
valid_dataset[start_v:end_v, :, :] = valid_letter
valid_labels[start_v:end_v] = label
start_v += vsize_per_class
end_v += vsize_per_class
train_letter = letter_set[vsize_per_class:end_l, :, :]
train_dataset[start_t:end_t, :, :] = train_letter
train_labels[start_t:end_t] = label
start_t += tsize_per_class
end_t += tsize_per_class
except Exception as e:
print('Unable to process data from', pickle_file, ':', e)
raise
return valid_dataset, valid_labels, train_dataset, train_labels
def randomize(dataset, labels):
permutation = np.random.permutation(labels.shape[0])
shuffled_dataset = dataset[permutation, :, :]
shuffled_labels = labels[permutation]
return shuffled_dataset, shuffled_labels
def merge_prune(train_folders, test_folders):
train_datasets = maybe_pickle(train_folders, 45000)
test_datasets = maybe_pickle(test_folders, 1800)
train_size = 200000
valid_size = 10000
test_size = 10000
valid_dataset, valid_labels, train_dataset, train_labels = merge_datasets(
train_datasets, train_size, valid_size)
_, _, test_dataset, test_labels = merge_datasets(test_datasets, test_size)
print('Training:', train_dataset.shape, train_labels.shape)
print('Validation:', valid_dataset.shape, valid_labels.shape)
print('Testing:', test_dataset.shape, test_labels.shape)
train_dataset, train_labels = randomize(train_dataset, train_labels)
test_dataset, test_labels = randomize(test_dataset, test_labels)
valid_dataset, valid_labels = randomize(valid_dataset, valid_labels)
pickle_file = 'notMNIST.pickle'
save = {
'train_dataset': train_dataset,
'train_labels': train_labels,
'valid_dataset': valid_dataset,
'valid_labels': valid_labels,
'test_dataset': test_dataset,
'test_labels': test_labels,
}
save_obj(pickle_file, save)
if __name__ == "__main__":
train_folders = ['notMNIST_large/A', 'notMNIST_large/B', 'notMNIST_large/C', 'notMNIST_large/D', 'notMNIST_large/E',
'notMNIST_large/F', 'notMNIST_large/G', 'notMNIST_large/H', 'notMNIST_large/I', 'notMNIST_large/J']
test_folders = ['notMNIST_small/A', 'notMNIST_small/B', 'notMNIST_small/C', 'notMNIST_small/D', 'notMNIST_small/E',
'notMNIST_small/F', 'notMNIST_small/G', 'notMNIST_small/H', 'notMNIST_small/I', 'notMNIST_small/J']
merge_prune(train_folders, test_folders)