-
Notifications
You must be signed in to change notification settings - Fork 3
/
encode_dataset.py
49 lines (45 loc) · 1.36 KB
/
encode_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#!/usr/bin/env python
import os
from config import Config
from utils import py_utils
from argparse import ArgumentParser
from ops.data_to_tfrecords import data_to_tfrecords
def encode_dataset(dataset):
config = Config()
data_class = py_utils.import_module(dataset)
data_proc = data_class.data_processing()
files, labels = data_proc.get_data()
targets = data_proc.targets
im_size = data_proc.im_size
preproc_list = data_proc.preprocess
if hasattr(data_proc, 'label_size'):
label_size = data_proc.label_size
else:
label_size = None
if hasattr(data_proc, 'label_size'):
store_z = data_proc.store_z
else:
store_z = False
if hasattr(data_proc, 'normalize_im'):
normalize_im = data_proc.normalize_im
else:
normalize_im = False
ds_name = os.path.join(config.tf_records, data_proc.output_name)
data_to_tfrecords(
files=files,
labels=labels,
targets=targets,
ds_name=ds_name,
im_size=im_size,
label_size=label_size,
preprocess=preproc_list,
store_z=store_z,
normalize_im=normalize_im)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument(
'--dataset',
dest='dataset',
help='Name of the dataset.')
args = parser.parse_args()
encode_dataset(**vars(args))