-
Notifications
You must be signed in to change notification settings - Fork 477
/
extract.py
39 lines (32 loc) · 1.21 KB
/
extract.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from __future__ import print_function
import os
import sys
import tarfile
import numpy as np
# from six.moves.urllib.request import urlretrieve
# from six.moves import cPickle as pickle
num_classes = 10
np.random.seed(133)
def maybe_extract(filename, force=False):
root = os.path.splitext(os.path.splitext(filename)[0])[0] # remove .tar.gz
if os.path.isdir(root) and not force:
# You may override by setting force=True.
print('%s already present - Skipping extraction of %s.' % (root, filename))
else:
print('Extracting data for %s. This may take a while. Please wait.' % root)
tar = tarfile.open(filename)
sys.stdout.flush()
tar.extractall()
tar.close()
data_folders = [
os.path.join(root, d) for d in sorted(os.listdir(root))
if os.path.isdir(os.path.join(root, d))]
if len(data_folders) != num_classes:
raise Exception(
'Expected %d folders, one per class. Found %d instead.' % (
num_classes, len(data_folders)))
print(data_folders)
return data_folders
if __name__ == '__main__':
train_folders = maybe_extract('notMNIST_large.tar.gz')
test_folders = maybe_extract('notMNIST_small.tar.gz')