-
Notifications
You must be signed in to change notification settings - Fork 0
/
DenseNet_Efficient.py
484 lines (386 loc) · 19.4 KB
/
DenseNet_Efficient.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
# This implementation is a new efficient implementation of Densenet-BC,
# as described in "Memory-Efficient Implementation of DenseNets"
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import reduce
from operator import mul
from collections import OrderedDict
from torch.autograd import Variable, Function
from torch._thnn import type2backend
from torch.backends import cudnn
# I'm throwing all the gross code at the end of the file :)
# Let's start with the nice (and interesting) stuff
class _SharedAllocation(object):
"""
A helper class which maintains a shared memory allocation.
Used for concatenation and batch normalization.
"""
def __init__(self, storage):
self.storage = storage
def type(self, t):
self.storage = self.storage.type(t)
def type_as(self, obj):
if isinstance(obj, Variable):
self.storage = self.storage.type(obj.data.storage().type())
elif isinstance(obj, torch._TensorBase):
self.storage = self.storage.type(obj.storage().type())
else:
self.storage = self.storage.type(obj.type())
def resize_(self, size):
if self.storage.size() < size:
self.storage.resize_(size)
return self
class _EfficientDensenetBottleneck(nn.Module):
"""
A optimized layer which encapsulates the batch normalization, ReLU, and
convolution operations within the bottleneck of a DenseNet layer.
This layer usage shared memory allocations to store the outputs of the
concatenation and batch normalization features. Because the shared memory
is not perminant, these features are recomputed during the backward pass.
"""
def __init__(self, shared_allocation_1, shared_allocation_2, num_input_channels, num_output_channels):
super(_EfficientDensenetBottleneck, self).__init__()
self.shared_allocation_1 = shared_allocation_1
self.shared_allocation_2 = shared_allocation_2
self.num_input_channels = num_input_channels
self.norm_weight = nn.Parameter(torch.Tensor(num_input_channels))
self.norm_bias = nn.Parameter(torch.Tensor(num_input_channels))
self.register_buffer('norm_running_mean', torch.zeros(num_input_channels))
self.register_buffer('norm_running_var', torch.ones(num_input_channels))
self.conv_weight = nn.Parameter(torch.Tensor(num_output_channels, num_input_channels, 1, 1))
self._reset_parameters()
def _reset_parameters(self):
self.norm_weight.data.uniform_()
self.norm_bias.data.zero_()
self.norm_running_mean.zero_()
self.norm_running_var.fill_(1)
stdv = 1. / math.sqrt(self.num_input_channels)
self.conv_weight.data.uniform_(-stdv, stdv)
def forward(self, inputs):
if isinstance(inputs, Variable):
inputs = [inputs]
fn = _EfficientDensenetBottleneckFn(self.shared_allocation_1, self.shared_allocation_2,
self.norm_running_mean, self.norm_running_var,
stride=1, padding=0, dilation=1, groups=1,
training=self.training, momentum=0.1, eps=1e-5)
return fn(self.norm_weight, self.norm_bias, self.conv_weight, *inputs)
class _DenseLayer(nn.Sequential):
def __init__(self, shared_allocation_1, shared_allocation_2, num_input_features, growth_rate, bn_size, drop_rate):
super(_DenseLayer, self).__init__()
self.shared_allocation_1 = shared_allocation_1
self.shared_allocation_2 = shared_allocation_2
self.drop_rate = drop_rate
self.add_module('bn', _EfficientDensenetBottleneck(shared_allocation_1, shared_allocation_2,
num_input_features, bn_size * growth_rate))
self.add_module('norm.2', nn.BatchNorm2d(bn_size * growth_rate)),
self.add_module('relu.2', nn.ReLU(inplace=True)),
self.add_module('conv.2', nn.Conv2d(bn_size * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1, bias=False)),
def forward(self, x):
if isinstance(x, Variable):
prev_features = [x]
else:
prev_features = x
new_features = super(_DenseLayer, self).forward(prev_features)
if self.drop_rate > 0:
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
return new_features
class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features):
super(_Transition, self).__init__()
self.add_module('norm', nn.BatchNorm2d(num_input_features))
self.add_module('relu', nn.ReLU(inplace=True))
self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
kernel_size=1, stride=1, bias=False))
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
class _DenseBlock(nn.Container):
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, storage_size=1024):
input_storage_1 = torch.Storage(storage_size)
input_storage_2 = torch.Storage(storage_size)
self.final_num_features = num_input_features + (growth_rate * num_layers)
self.shared_allocation_1 = _SharedAllocation(input_storage_1)
self.shared_allocation_2 = _SharedAllocation(input_storage_2)
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(self.shared_allocation_1, self.shared_allocation_2, num_input_features + i * growth_rate,
growth_rate, bn_size, drop_rate)
self.add_module('denselayer%d' % (i + 1), layer)
def forward(self, x):
# Update storage type
self.shared_allocation_1.type_as(x)
self.shared_allocation_2.type_as(x)
# Resize storage
final_size = list(x.size())
final_size[1] = self.final_num_features
final_storage_size = reduce(mul, final_size, 1)
self.shared_allocation_1.resize_(final_storage_size)
self.shared_allocation_2.resize_(final_storage_size)
outputs = [x]
for module in self.children():
outputs.append(module.forward(outputs))
return torch.cat(outputs, dim=1)
class DenseNetEfficient(nn.Module):
r"""Densenet-BC model class, based on
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
This model uses shared memory allocations for the outputs of batch norm and
concat operations, as described in `"Memory-Efficient Implementation of DenseNets"`.
Args:
growth_rate (int) - how many filters to add each layer (`k` in paper)
block_config (list of 4 ints) - how many layers in each pooling block
num_init_features (int) - the number of filters to learn in the first convolution layer
bn_size (int) - multiplicative factor for number of bottle neck layers
(i.e. bn_size * k features in the bottleneck layer)
drop_rate (float) - dropout rate after each dense layer
num_classes (int) - number of classification classes
"""
def __init__(self, growth_rate=12, block_config=(16, 16, 16), compression=0.5,
num_init_features=24, bn_size=4, drop_rate=0,
num_classes=10, cifar=True):
super(DenseNetEfficient, self).__init__()
assert 0 < compression <= 1, 'compression of densenet should be between 0 and 1'
self.avgpool_size = 8 if cifar else 7
# First convolution
if cifar:
self.features = nn.Sequential(OrderedDict([
('conv0', nn.Conv2d(3, num_init_features, kernel_size=3, stride=1, padding=1, bias=False)),
]))
else:
self.features = nn.Sequential(OrderedDict([
('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
]))
self.features.add_module('norm0', nn.BatchNorm2d(num_init_features))
self.features.add_module('relu0', nn.ReLU(inplace=True))
self.features.add_module('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1,
ceil_mode=False))
# Each denseblock
num_features = num_init_features
for i, num_layers in enumerate(block_config):
block = _DenseBlock(num_layers=num_layers,
num_input_features=num_features,
bn_size=bn_size, growth_rate=growth_rate,
drop_rate=drop_rate)
self.features.add_module('denseblock%d' % (i + 1), block)
num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1:
trans = _Transition(num_input_features=num_features,
num_output_features=int(num_features
* compression))
self.features.add_module('transition%d' % (i + 1), trans)
num_features = int(num_features * compression)
# Final batch norm
self.features.add_module('norm_final', nn.BatchNorm2d(num_features))
# Linear layer
self.classifier = nn.Linear(num_features, num_classes)
def forward(self, x):
features = self.features(x)
out = F.relu(features, inplace=True)
out = F.avg_pool2d(out, kernel_size=self.avgpool_size).view(
features.size(0), -1)
out = self.classifier(out)
return out
# Begin gross code :/
# Here's where we define the internals of the efficient bottleneck layer
class _EfficientDensenetBottleneckFn(Function):
"""
The autograd function which performs the efficient bottlenck operations.
Each of the sub-operations -- concatenation, batch normalization, ReLU,
and convolution -- are abstracted into their own classes
"""
def __init__(self, shared_allocation_1, shared_allocation_2,
running_mean, running_var,
stride=1, padding=0, dilation=1, groups=1,
training=False, momentum=0.1, eps=1e-5):
self.efficient_cat = _EfficientCat(shared_allocation_1.storage)
self.efficient_batch_norm = _EfficientBatchNorm(shared_allocation_2.storage, running_mean, running_var,
training, momentum, eps)
self.efficient_relu = _EfficientReLU()
self.efficient_conv = _EfficientConv2d(stride, padding, dilation, groups)
# Buffers to store old versions of bn statistics
self.prev_running_mean = self.efficient_batch_norm.running_mean.new()
self.prev_running_mean.resize_as_(self.efficient_batch_norm.running_mean)
self.prev_running_var = self.efficient_batch_norm.running_var.new()
self.prev_running_var.resize_as_(self.efficient_batch_norm.running_var)
self.curr_running_mean = self.efficient_batch_norm.running_mean.new()
self.curr_running_mean.resize_as_(self.efficient_batch_norm.running_mean)
self.curr_running_var = self.efficient_batch_norm.running_var.new()
self.curr_running_var.resize_as_(self.efficient_batch_norm.running_var)
def forward(self, bn_weight, bn_bias, conv_weight, *inputs):
self.prev_running_mean.copy_(self.efficient_batch_norm.running_mean)
self.prev_running_var.copy_(self.efficient_batch_norm.running_var)
bn_input = self.efficient_cat.forward(*inputs)
bn_output = self.efficient_batch_norm.forward(bn_weight, bn_bias, bn_input)
relu_output = self.efficient_relu.forward(bn_output)
conv_output = self.efficient_conv.forward(conv_weight, None, relu_output)
self.bn_weight = bn_weight
self.bn_bias = bn_bias
self.conv_weight = conv_weight
self.inputs = inputs
return conv_output
def backward(self, grad_output):
# Turn off bn training status, and temporarily reset statistics
training = self.efficient_batch_norm.training
self.curr_running_mean.copy_(self.efficient_batch_norm.running_mean)
self.curr_running_var.copy_(self.efficient_batch_norm.running_var)
# self.efficient_batch_norm.training = False
self.efficient_batch_norm.running_mean.copy_(self.prev_running_mean)
self.efficient_batch_norm.running_var.copy_(self.prev_running_var)
# Recompute concat and BN
cat_output = self.efficient_cat.forward(*self.inputs)
bn_output = self.efficient_batch_norm.forward(self.bn_weight, self.bn_bias, cat_output)
relu_output = self.efficient_relu.forward(bn_output)
# Conv backward
conv_weight_grad, _, conv_grad_output = self.efficient_conv.backward(
self.conv_weight, None, relu_output, grad_output)
# ReLU backward
relu_grad_output = self.efficient_relu.backward(bn_output, conv_grad_output)
# BN backward
self.efficient_batch_norm.running_mean.copy_(self.curr_running_mean)
self.efficient_batch_norm.running_var.copy_(self.curr_running_var)
bn_weight_grad, bn_bias_grad, bn_grad_output = self.efficient_batch_norm.backward(
self.bn_weight, self.bn_bias, cat_output, relu_grad_output)
# Input backward
grad_inputs = self.efficient_cat.backward(bn_grad_output)
# Reset bn training status and statistics
self.efficient_batch_norm.training = training
self.efficient_batch_norm.running_mean.copy_(self.curr_running_mean)
self.efficient_batch_norm.running_var.copy_(self.curr_running_var)
return tuple([bn_weight_grad, bn_bias_grad, conv_weight_grad] + list(grad_inputs))
# The following helper classes are written similarly to pytorch autogrd functions.
# However, they are designed to work on tensors, not variables, and therefore
# are not functions.
class _EfficientBatchNorm(object):
def __init__(self, storage, running_mean, running_var,
training=False, momentum=0.1, eps=1e-5):
self.storage = storage
self.running_mean = running_mean
self.running_var = running_var
self.training = training
self.momentum = momentum
self.eps = eps
def forward(self, weight, bias, input):
# Assert we're using cudnn
for i in ([weight, bias, input]):
if i is not None and not(cudnn.is_acceptable(i)):
raise Exception('You must be using CUDNN to use _EfficientBatchNorm')
# Create save variables
self.save_mean = self.running_mean.new()
self.save_mean.resize_as_(self.running_mean)
self.save_var = self.running_var.new()
self.save_var.resize_as_(self.running_var)
# Do forward pass - store in input variable
res = type(input)(self.storage)
res.resize_as_(input)
torch._C._cudnn_batch_norm_forward(
input, res, weight, bias, self.running_mean, self.running_var,
self.save_mean, self.save_var, self.training, self.momentum, self.eps
)
return res
def recompute_forward(self, weight, bias, input):
# Do forward pass - store in input variable
res = type(input)(self.storage)
res.resize_as_(input)
torch._C._cudnn_batch_norm_forward(
input, res, weight, bias, self.running_mean, self.running_var,
self.save_mean, self.save_var, self.training, self.momentum, self.eps
)
return res
def backward(self, weight, bias, input, grad_output):
# Create grad variables
grad_weight = weight.new()
grad_weight.resize_as_(weight)
grad_bias = bias.new()
grad_bias.resize_as_(bias)
# Run backwards pass - result stored in grad_output
grad_input = grad_output
torch._C._cudnn_batch_norm_backward(
input, grad_output, grad_input, grad_weight, grad_bias,
weight, self.running_mean, self.running_var, self.save_mean,
self.save_var, self.training, self.eps
)
# Unpack grad_output
res = tuple([grad_weight, grad_bias, grad_input])
return res
class _EfficientCat(object):
def __init__(self, storage):
self.storage = storage
def forward(self, *inputs):
# Get size of new varible
self.all_num_channels = [input.size(1) for input in inputs]
size = list(inputs[0].size())
for num_channels in self.all_num_channels[1:]:
size[1] += num_channels
# Create variable, using existing storage
res = type(inputs[0])(self.storage).resize_(size)
torch.cat(inputs, dim=1, out=res)
return res
def backward(self, grad_output):
# Return a table of tensors pointing to same storage
res = []
index = 0
for num_channels in self.all_num_channels:
new_index = num_channels + index
res.append(grad_output[:, index:new_index])
index = new_index
return tuple(res)
class _EfficientReLU(object):
def __init__(self):
pass
def forward(self, input):
backend = type2backend[type(input)]
output = input
backend.Threshold_updateOutput(backend.library_state, input, output, 0, 0, True)
return output
def backward(self, input, grad_output):
grad_input = grad_output
grad_input.masked_fill_(input <= 0, 0)
return grad_input
class _EfficientConv2d(object):
def __init__(self, stride=1, padding=0, dilation=1, groups=1):
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
def _output_size(self, input, weight):
channels = weight.size(0)
output_size = (input.size(0), channels)
for d in range(input.dim() - 2):
in_size = input.size(d + 2)
pad = self.padding
kernel = self.dilation * (weight.size(d + 2) - 1) + 1
stride = self.stride
output_size += ((in_size + (2 * pad) - kernel) // stride + 1,)
if not all(map(lambda s: s > 0, output_size)):
raise ValueError("convolution input is too small (output would be {})".format(
'x'.join(map(str, output_size))))
return output_size
def forward(self, weight, bias, input):
# Assert we're using cudnn
for i in ([weight, bias, input]):
if i is not None and not(cudnn.is_acceptable(i)):
raise Exception('You must be using CUDNN to use _EfficientBatchNorm')
res = input.new(*self._output_size(input, weight))
self._cudnn_info = torch._C._cudnn_convolution_full_forward(
input, weight, bias, res,
(self.padding, self.padding),
(self.stride, self.stride),
(self.dilation, self.dilation),
self.groups, cudnn.benchmark
)
return res
def backward(self, weight, bias, input, grad_output):
grad_input = input.new()
grad_input.resize_as_(input)
torch._C._cudnn_convolution_backward_data(
grad_output, grad_input, weight, self._cudnn_info,
cudnn.benchmark)
grad_weight = weight.new().resize_as_(weight)
torch._C._cudnn_convolution_backward_filter(grad_output, input, grad_weight, self._cudnn_info,
cudnn.benchmark)
if bias is not None:
grad_bias = bias.new().resize_as_(bias)
torch._C._cudnn_convolution_backward_bias(grad_output, grad_bias, self._cudnn_info)
else:
grad_bias = None
return grad_weight, grad_bias, grad_input