-
Notifications
You must be signed in to change notification settings - Fork 10
/
DALIDataLoader.py
273 lines (231 loc) · 12.5 KB
/
DALIDataLoader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
'''
A new dataloader using NVIDIA DALI in order to speed up the dataloader in pytorch
Ref: https://github.com/d-li14/mobilenetv2.pytorch/blob/master/utils/dataloaders.py
https://github.com/NVIDIA/DALI/blob/master/docs/examples/pytorch/resnet50/main.py
'''
import os
import torch
import numpy as np
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from math import ceil
try:
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
except ImportError:
print("Please install DALI from https://www.github.com/NVIDIA/DALI to run DataLoader.")
class TinyImageNetHybridTrainPipe(Pipeline):
def __init__(self, batch_size, num_threads, device_id, data_dir, crop, seed, dali_cpu=False):
super(TinyImageNetHybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed)
if torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
local_rank = 0
world_size = 1
self.input = ops.FileReader(
file_root=data_dir,
shard_id=local_rank,
num_shards=world_size,
pad_last_batch=True,
random_shuffle=False,
shuffle_after_epoch=True)
# decide to work on cpu or gpu
dali_device = 'cpu' if dali_cpu else 'gpu'
decoder_device = 'cpu' if dali_cpu else 'mixed'
self.decode = ops.ImageDecoder(device=decoder_device, output_type=types.RGB)
self.res = ops.RandomResizedCrop(device=dali_device, size=crop, random_aspect_ratio=[0.75, 4./3],
random_area=[0.08, 1.0], num_attempts=100, interp_type=types.INTERP_TRIANGULAR)
self.cmnp = ops.CropMirrorNormalize(device='gpu',
output_dtype=types.FLOAT,
output_layout=types.NCHW,
crop=(crop, crop),
image_type=types.RGB,
mean=[0.485*255, 0.456*255, 0.406*255],
std=[0.229*255, 0.224*255, 0.225*255])
self.coin = ops.CoinFlip(probability=0.5)
def define_graph(self):
rng = self.coin()
self.jpegs, self.labels = self.input(name='Reader')
images = self.decode(self.jpegs)
images = self.res(images)
output = self.cmnp(images.gpu(), mirror = rng)
return [output, self.labels]
class TinyImageNetHybridValPipe(Pipeline):
def __init__(self, batch_size, num_threads, device_id, data_dir, crop, seed):
super(TinyImageNetHybridValPipe, self).__init__(batch_size, num_threads, device_id, seed)
if torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
local_rank = 0
world_size = 1
self.input = ops.FileReader(
file_root=data_dir,
shard_id=local_rank,
num_shards=world_size,
pad_last_batch=True,
random_shuffle=False)
self.decode = ops.ImageDecoder(device='mixed', output_type=types.RGB)
self.cmnp = ops.CropMirrorNormalize(device='gpu',
output_dtype=types.FLOAT,
output_layout=types.NCHW,
crop=(crop, crop),
image_type=types.RGB,
mean=[0.485*255, 0.456*255, 0.406*255],
std=[0.229*255, 0.224*255, 0.225*255])
def define_graph(self):
self.jpegs, self.labels = self.input(name='Reader')
images = self.decode(self.jpegs)
output = self.cmnp(images)
return [output, self.labels]
class ImageNetHybridTrainPipe(Pipeline):
def __init__(self, batch_size, num_threads, device_id, data_dir, crop, seed, dali_cpu=False):
super(ImageNetHybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed = seed)
if torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
local_rank = 0
world_size = 1
self.input = ops.FileReader(
file_root=data_dir,
shard_id=local_rank,
num_shards=world_size,
pad_last_batch=True,
random_shuffle=False,
shuffle_after_epoch=True)
# decide to work on cpu or gpu
dali_device = 'cpu' if dali_cpu else 'gpu'
decoder_device = 'cpu' if dali_cpu else 'mixed'
# This padding sets the size of the internal nvJPEG buffers to be able to handle all images from full-sized ImageNet
# without additional reallocations
device_memory_padding = 211025920 if decoder_device == 'mixed' else 0
host_memory_padding = 140544512 if decoder_device == 'mixed' else 0
'''
self.decode = ops.ImageDecoderRandomCrop(device=decoder_device, output_type=types.RGB,
device_memory_padding=device_memory_padding,
host_memory_padding=host_memory_padding,
random_aspect_ratio=[0.75, 1.25],
random_area=[0.08, 1.0],
num_attempts=100)
self.res = ops.Resize(device=dali_device, resize_x=crop, resize_y=crop, interp_type=types.INTERP_TRIANGULAR)
'''
self.decode = ops.ImageDecoder(device=decoder_device, output_type=types.RGB,
device_memory_padding=device_memory_padding,
host_memory_padding=host_memory_padding,)
self.res = ops.RandomResizedCrop(device=dali_device, size=crop, random_aspect_ratio=[0.75, 4./3],
random_area=[0.08, 1.0], num_attempts=100, interp_type=types.INTERP_TRIANGULAR)
self.cmnp = ops.CropMirrorNormalize(device='gpu',
output_dtype=types.FLOAT,
output_layout=types.NCHW,
crop=(crop, crop),
image_type=types.RGB,
mean=[0.485 * 255,0.456 * 255,0.406 * 255],
std=[0.229 * 255,0.224 * 255,0.225 * 255])
self.coin = ops.CoinFlip(probability=0.5)
def define_graph(self):
rng = self.coin()
self.jpegs, self.labels = self.input(name='Reader')
images = self.decode(self.jpegs)
images = self.res(images)
output = self.cmnp(images.gpu(), mirror = rng)
return [output, self.labels]
class ImageNetHybridValPipe(Pipeline):
def __init__(self, batch_size, num_threads, device_id, data_dir, crop, size, seed):
super(ImageNetHybridValPipe, self).__init__(batch_size, num_threads, device_id, seed = seed)
if torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
local_rank = 0
world_size = 1
self.input = ops.FileReader(
file_root=data_dir,
shard_id=local_rank,
num_shards=world_size,
pad_last_batch=True,
random_shuffle=False)
self.decode = ops.ImageDecoder(device='mixed', output_type=types.RGB)
self.res = ops.Resize(device='gpu', resize_shorter=size, interp_type=types.INTERP_TRIANGULAR)
self.cmnp = ops.CropMirrorNormalize(device='gpu',
output_dtype=types.FLOAT,
output_layout=types.NCHW,
crop=(crop, crop),
image_type=types.RGB,
mean=[0.485*255, 0.456*255, 0.406*255],
std=[0.229*255, 0.224*255, 0.225*255])
def define_graph(self):
self.jpegs, self.labels = self.input(name='Reader')
images = self.decode(self.jpegs)
images = self.res(images)
output = self.cmnp(images)
return [output, self.labels]
class DALIWrapper(object):
def gen_wrapper(dali_pipeline):
for data in dali_pipeline:
input = data[0]['data']
target = data[0]['label'].squeeze().cuda().long()
yield input, target
def __init__(self, dali_pipeline):
self.dali_pipeline = dali_pipeline
def __iter__(self):
return DALIWrapper.gen_wrapper(self.dali_pipeline)
def get_dali_tinyImageNet_train_loader(data_path, batch_size, seed, num_threads=4, dali_cpu=False):
if torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
local_rank = 0
world_size = 1
train_dir = os.path.join(data_path, 'train')
pipe = TinyImageNetHybridTrainPipe(batch_size=batch_size, num_threads=num_threads,
device_id=local_rank, data_dir=train_dir,
crop=56, seed=seed, dali_cpu=dali_cpu)
pipe.build()
train_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size('Reader') / world_size), fill_last_batch=False, last_batch_padded=True, auto_reset=True)
return DALIWrapper(train_loader), ceil(pipe.epoch_size('Reader') / (world_size*batch_size))
def get_dali_tinyImageNet_val_loader(data_path, batch_size, seed, num_threads=4):
if torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
local_rank = 0
world_size = 1
val_dir = os.path.join(data_path, 'val')
pipe = TinyImageNetHybridValPipe(batch_size=batch_size, num_threads=num_threads,
device_id=local_rank, data_dir=val_dir,
crop=56, seed=seed)
pipe.build()
val_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size('Reader')/world_size), fill_last_batch=False, last_batch_padded=True, auto_reset=True)
return DALIWrapper(val_loader), ceil(pipe.epoch_size('Reader') / (world_size * batch_size))
def get_dali_imageNet_train_loader(data_path, batch_size, seed, num_threads=4, dali_cpu=False):
if torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
local_rank = 0
world_size = 1
train_dir = os.path.join(data_path, 'ILSVRC2012_img_train')
pipe = ImageNetHybridTrainPipe(batch_size=batch_size, num_threads=num_threads,
device_id=local_rank, data_dir=train_dir,
crop=224, seed=seed, dali_cpu=dali_cpu)
pipe.build()
train_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size('Reader') / world_size), fill_last_batch=False, last_batch_padded=True, auto_reset=True)
return DALIWrapper(train_loader), ceil(pipe.epoch_size('Reader') / (world_size*batch_size))
def get_dali_imageNet_val_loader(data_path, batch_size, seed, num_threads=4):
if torch.distributed.is_initialized():
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
local_rank = 0
world_size = 1
val_dir = os.path.join(data_path, 'ILSVRC2012_img_val')
pipe = ImageNetHybridValPipe(batch_size=batch_size, num_threads=num_threads,
device_id=local_rank, data_dir=val_dir,
crop=224, size=256, seed=seed)
pipe.build()
val_loader = DALIClassificationIterator(pipe, size=int(pipe.epoch_size('Reader')/world_size), fill_last_batch=False, last_batch_padded=True, auto_reset=True)
return DALIWrapper(val_loader), ceil(pipe.epoch_size('Reader') / (world_size * batch_size))