-
Notifications
You must be signed in to change notification settings - Fork 314
/
models.py
executable file
·386 lines (316 loc) · 23 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
#!/usr/bin/env python
from collections import OrderedDict
import numpy as np
from scipy import ndimage
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision
import matplotlib.pyplot as plt
import time
class reactive_net(nn.Module):
def __init__(self, use_cuda): # , snapshot=None
super(reactive_net, self).__init__()
self.use_cuda = use_cuda
# Initialize network trunks with DenseNet pre-trained on ImageNet
self.push_color_trunk = torchvision.models.densenet.densenet121(pretrained=True)
self.push_depth_trunk = torchvision.models.densenet.densenet121(pretrained=True)
self.grasp_color_trunk = torchvision.models.densenet.densenet121(pretrained=True)
self.grasp_depth_trunk = torchvision.models.densenet.densenet121(pretrained=True)
self.num_rotations = 16
# Construct network branches for pushing and grasping
self.pushnet = nn.Sequential(OrderedDict([
('push-norm0', nn.BatchNorm2d(2048)),
('push-relu0', nn.ReLU(inplace=True)),
('push-conv0', nn.Conv2d(2048, 64, kernel_size=1, stride=1, bias=False)),
('push-norm1', nn.BatchNorm2d(64)),
('push-relu1', nn.ReLU(inplace=True)),
('push-conv1', nn.Conv2d(64, 3, kernel_size=1, stride=1, bias=False))
# ('push-upsample2', nn.Upsample(scale_factor=4, mode='bilinear'))
]))
self.graspnet = nn.Sequential(OrderedDict([
('grasp-norm0', nn.BatchNorm2d(2048)),
('grasp-relu0', nn.ReLU(inplace=True)),
('grasp-conv0', nn.Conv2d(2048, 64, kernel_size=1, stride=1, bias=False)),
('grasp-norm1', nn.BatchNorm2d(64)),
('grasp-relu1', nn.ReLU(inplace=True)),
('grasp-conv1', nn.Conv2d(64, 3, kernel_size=1, stride=1, bias=False))
# ('grasp-upsample2', nn.Upsample(scale_factor=4, mode='bilinear'))
]))
# Initialize network weights
for m in self.named_modules():
if 'push-' in m[0] or 'grasp-' in m[0]:
if isinstance(m[1], nn.Conv2d):
nn.init.kaiming_normal(m[1].weight.data)
elif isinstance(m[1], nn.BatchNorm2d):
m[1].weight.data.fill_(1)
m[1].bias.data.zero_()
# Initialize output variable (for backprop)
self.interm_feat = []
self.output_prob = []
def forward(self, input_color_data, input_depth_data, is_volatile=False, specific_rotation=-1):
if is_volatile:
output_prob = []
interm_feat = []
# Apply rotations to images
for rotate_idx in range(self.num_rotations):
rotate_theta = np.radians(rotate_idx*(360/self.num_rotations))
# Compute sample grid for rotation BEFORE neural network
affine_mat_before = np.asarray([[np.cos(-rotate_theta), np.sin(-rotate_theta), 0],[-np.sin(-rotate_theta), np.cos(-rotate_theta), 0]])
affine_mat_before.shape = (2,3,1)
affine_mat_before = torch.from_numpy(affine_mat_before).permute(2,0,1).float()
if self.use_cuda:
flow_grid_before = F.affine_grid(Variable(affine_mat_before, requires_grad=False).cuda(), input_color_data.size())
else:
flow_grid_before = F.affine_grid(Variable(affine_mat_before, requires_grad=False), input_color_data.size())
# Rotate images clockwise
if self.use_cuda:
rotate_color = F.grid_sample(Variable(input_color_data, volatile=True).cuda(), flow_grid_before, mode='nearest')
rotate_depth = F.grid_sample(Variable(input_depth_data, volatile=True).cuda(), flow_grid_before, mode='nearest')
else:
rotate_color = F.grid_sample(Variable(input_color_data, volatile=True), flow_grid_before, mode='nearest')
rotate_depth = F.grid_sample(Variable(input_depth_data, volatile=True), flow_grid_before, mode='nearest')
# Compute intermediate features
interm_push_color_feat = self.push_color_trunk.features(rotate_color)
interm_push_depth_feat = self.push_depth_trunk.features(rotate_depth)
interm_push_feat = torch.cat((interm_push_color_feat, interm_push_depth_feat), dim=1)
interm_grasp_color_feat = self.grasp_color_trunk.features(rotate_color)
interm_grasp_depth_feat = self.grasp_depth_trunk.features(rotate_depth)
interm_grasp_feat = torch.cat((interm_grasp_color_feat, interm_grasp_depth_feat), dim=1)
interm_feat.append([interm_push_feat, interm_grasp_feat])
# Compute sample grid for rotation AFTER branches
affine_mat_after = np.asarray([[np.cos(rotate_theta), np.sin(rotate_theta), 0],[-np.sin(rotate_theta), np.cos(rotate_theta), 0]])
affine_mat_after.shape = (2,3,1)
affine_mat_after = torch.from_numpy(affine_mat_after).permute(2,0,1).float()
if self.use_cuda:
flow_grid_after = F.affine_grid(Variable(affine_mat_after, requires_grad=False).cuda(), interm_push_feat.data.size())
else:
flow_grid_after = F.affine_grid(Variable(affine_mat_after, requires_grad=False), interm_push_feat.data.size())
# Forward pass through branches, undo rotation on output predictions, upsample results
output_prob.append([nn.Upsample(scale_factor=16, mode='bilinear').forward(F.grid_sample(self.pushnet(interm_push_feat), flow_grid_after, mode='nearest')),
nn.Upsample(scale_factor=16, mode='bilinear').forward(F.grid_sample(self.graspnet(interm_grasp_feat), flow_grid_after, mode='nearest'))])
return output_prob, interm_feat
else:
self.output_prob = []
self.interm_feat = []
# Apply rotations to intermediate features
# for rotate_idx in range(self.num_rotations):
rotate_idx = specific_rotation
rotate_theta = np.radians(rotate_idx*(360/self.num_rotations))
# Compute sample grid for rotation BEFORE branches
affine_mat_before = np.asarray([[np.cos(-rotate_theta), np.sin(-rotate_theta), 0],[-np.sin(-rotate_theta), np.cos(-rotate_theta), 0]])
affine_mat_before.shape = (2,3,1)
affine_mat_before = torch.from_numpy(affine_mat_before).permute(2,0,1).float()
if self.use_cuda:
flow_grid_before = F.affine_grid(Variable(affine_mat_before, requires_grad=False).cuda(), input_color_data.size())
else:
flow_grid_before = F.affine_grid(Variable(affine_mat_before, requires_grad=False), input_color_data.size())
# Rotate images clockwise
if self.use_cuda:
rotate_color = F.grid_sample(Variable(input_color_data, requires_grad=False).cuda(), flow_grid_before, mode='nearest')
rotate_depth = F.grid_sample(Variable(input_depth_data, requires_grad=False).cuda(), flow_grid_before, mode='nearest')
else:
rotate_color = F.grid_sample(Variable(input_color_data, requires_grad=False), flow_grid_before, mode='nearest')
rotate_depth = F.grid_sample(Variable(input_depth_data, requires_grad=False), flow_grid_before, mode='nearest')
# Compute intermediate features
interm_push_color_feat = self.push_color_trunk.features(rotate_color)
interm_push_depth_feat = self.push_depth_trunk.features(rotate_depth)
interm_push_feat = torch.cat((interm_push_color_feat, interm_push_depth_feat), dim=1)
interm_grasp_color_feat = self.grasp_color_trunk.features(rotate_color)
interm_grasp_depth_feat = self.grasp_depth_trunk.features(rotate_depth)
interm_grasp_feat = torch.cat((interm_grasp_color_feat, interm_grasp_depth_feat), dim=1)
self.interm_feat.append([interm_push_feat, interm_grasp_feat])
# Compute sample grid for rotation AFTER branches
affine_mat_after = np.asarray([[np.cos(rotate_theta), np.sin(rotate_theta), 0],[-np.sin(rotate_theta), np.cos(rotate_theta), 0]])
affine_mat_after.shape = (2,3,1)
affine_mat_after = torch.from_numpy(affine_mat_after).permute(2,0,1).float()
if self.use_cuda:
flow_grid_after = F.affine_grid(Variable(affine_mat_after, requires_grad=False).cuda(), interm_push_feat.data.size())
else:
flow_grid_after = F.affine_grid(Variable(affine_mat_after, requires_grad=False), interm_push_feat.data.size())
# Forward pass through branches, undo rotation on output predictions, upsample results
self.output_prob.append([nn.Upsample(scale_factor=16, mode='bilinear').forward(F.grid_sample(self.pushnet(interm_push_feat), flow_grid_after, mode='nearest')),
nn.Upsample(scale_factor=16, mode='bilinear').forward(F.grid_sample(self.graspnet(interm_grasp_feat), flow_grid_after, mode='nearest'))])
return self.output_prob, self.interm_feat
class reinforcement_net(nn.Module):
def __init__(self, use_cuda): # , snapshot=None
super(reinforcement_net, self).__init__()
self.use_cuda = use_cuda
# Initialize network trunks with DenseNet pre-trained on ImageNet
self.push_color_trunk = torchvision.models.densenet.densenet121(pretrained=True)
self.push_depth_trunk = torchvision.models.densenet.densenet121(pretrained=True)
self.grasp_color_trunk = torchvision.models.densenet.densenet121(pretrained=True)
self.grasp_depth_trunk = torchvision.models.densenet.densenet121(pretrained=True)
self.num_rotations = 16
# Construct network branches for pushing and grasping
self.pushnet = nn.Sequential(OrderedDict([
('push-norm0', nn.BatchNorm2d(2048)),
('push-relu0', nn.ReLU(inplace=True)),
('push-conv0', nn.Conv2d(2048, 64, kernel_size=1, stride=1, bias=False)),
('push-norm1', nn.BatchNorm2d(64)),
('push-relu1', nn.ReLU(inplace=True)),
('push-conv1', nn.Conv2d(64, 1, kernel_size=1, stride=1, bias=False))
# ('push-upsample2', nn.Upsample(scale_factor=4, mode='bilinear'))
]))
self.graspnet = nn.Sequential(OrderedDict([
('grasp-norm0', nn.BatchNorm2d(2048)),
('grasp-relu0', nn.ReLU(inplace=True)),
('grasp-conv0', nn.Conv2d(2048, 64, kernel_size=1, stride=1, bias=False)),
('grasp-norm1', nn.BatchNorm2d(64)),
('grasp-relu1', nn.ReLU(inplace=True)),
('grasp-conv1', nn.Conv2d(64, 1, kernel_size=1, stride=1, bias=False))
# ('grasp-upsample2', nn.Upsample(scale_factor=4, mode='bilinear'))
]))
# Initialize network weights
for m in self.named_modules():
if 'push-' in m[0] or 'grasp-' in m[0]:
if isinstance(m[1], nn.Conv2d):
nn.init.kaiming_normal(m[1].weight.data)
elif isinstance(m[1], nn.BatchNorm2d):
m[1].weight.data.fill_(1)
m[1].bias.data.zero_()
# Initialize output variable (for backprop)
self.interm_feat = []
self.output_prob = []
def forward(self, input_color_data, input_depth_data, is_volatile=False, specific_rotation=-1):
if is_volatile:
with torch.no_grad():
output_prob = []
interm_feat = []
# Apply rotations to images
for rotate_idx in range(self.num_rotations):
rotate_theta = np.radians(rotate_idx*(360/self.num_rotations))
# Compute sample grid for rotation BEFORE neural network
affine_mat_before = np.asarray([[np.cos(-rotate_theta), np.sin(-rotate_theta), 0],[-np.sin(-rotate_theta), np.cos(-rotate_theta), 0]])
affine_mat_before.shape = (2,3,1)
affine_mat_before = torch.from_numpy(affine_mat_before).permute(2,0,1).float()
if self.use_cuda:
flow_grid_before = F.affine_grid(Variable(affine_mat_before, requires_grad=False).cuda(), input_color_data.size())
else:
flow_grid_before = F.affine_grid(Variable(affine_mat_before, requires_grad=False), input_color_data.size())
# Rotate images clockwise
if self.use_cuda:
rotate_color = F.grid_sample(Variable(input_color_data, volatile=True).cuda(), flow_grid_before, mode='nearest')
rotate_depth = F.grid_sample(Variable(input_depth_data, volatile=True).cuda(), flow_grid_before, mode='nearest')
else:
rotate_color = F.grid_sample(Variable(input_color_data, volatile=True), flow_grid_before, mode='nearest')
rotate_depth = F.grid_sample(Variable(input_depth_data, volatile=True), flow_grid_before, mode='nearest')
# Compute intermediate features
interm_push_color_feat = self.push_color_trunk.features(rotate_color)
interm_push_depth_feat = self.push_depth_trunk.features(rotate_depth)
interm_push_feat = torch.cat((interm_push_color_feat, interm_push_depth_feat), dim=1)
interm_grasp_color_feat = self.grasp_color_trunk.features(rotate_color)
interm_grasp_depth_feat = self.grasp_depth_trunk.features(rotate_depth)
interm_grasp_feat = torch.cat((interm_grasp_color_feat, interm_grasp_depth_feat), dim=1)
interm_feat.append([interm_push_feat, interm_grasp_feat])
# Compute sample grid for rotation AFTER branches
affine_mat_after = np.asarray([[np.cos(rotate_theta), np.sin(rotate_theta), 0],[-np.sin(rotate_theta), np.cos(rotate_theta), 0]])
affine_mat_after.shape = (2,3,1)
affine_mat_after = torch.from_numpy(affine_mat_after).permute(2,0,1).float()
if self.use_cuda:
flow_grid_after = F.affine_grid(Variable(affine_mat_after, requires_grad=False).cuda(), interm_push_feat.data.size())
else:
flow_grid_after = F.affine_grid(Variable(affine_mat_after, requires_grad=False), interm_push_feat.data.size())
# Forward pass through branches, undo rotation on output predictions, upsample results
output_prob.append([nn.Upsample(scale_factor=16, mode='bilinear').forward(F.grid_sample(self.pushnet(interm_push_feat), flow_grid_after, mode='nearest')),
nn.Upsample(scale_factor=16, mode='bilinear').forward(F.grid_sample(self.graspnet(interm_grasp_feat), flow_grid_after, mode='nearest'))])
return output_prob, interm_feat
else:
self.output_prob = []
self.interm_feat = []
# Apply rotations to intermediate features
# for rotate_idx in range(self.num_rotations):
rotate_idx = specific_rotation
rotate_theta = np.radians(rotate_idx*(360/self.num_rotations))
# Compute sample grid for rotation BEFORE branches
affine_mat_before = np.asarray([[np.cos(-rotate_theta), np.sin(-rotate_theta), 0],[-np.sin(-rotate_theta), np.cos(-rotate_theta), 0]])
affine_mat_before.shape = (2,3,1)
affine_mat_before = torch.from_numpy(affine_mat_before).permute(2,0,1).float()
if self.use_cuda:
flow_grid_before = F.affine_grid(Variable(affine_mat_before, requires_grad=False).cuda(), input_color_data.size())
else:
flow_grid_before = F.affine_grid(Variable(affine_mat_before, requires_grad=False), input_color_data.size())
# Rotate images clockwise
if self.use_cuda:
rotate_color = F.grid_sample(Variable(input_color_data, requires_grad=False).cuda(), flow_grid_before, mode='nearest')
rotate_depth = F.grid_sample(Variable(input_depth_data, requires_grad=False).cuda(), flow_grid_before, mode='nearest')
else:
rotate_color = F.grid_sample(Variable(input_color_data, requires_grad=False), flow_grid_before, mode='nearest')
rotate_depth = F.grid_sample(Variable(input_depth_data, requires_grad=False), flow_grid_before, mode='nearest')
# Compute intermediate features
interm_push_color_feat = self.push_color_trunk.features(rotate_color)
interm_push_depth_feat = self.push_depth_trunk.features(rotate_depth)
interm_push_feat = torch.cat((interm_push_color_feat, interm_push_depth_feat), dim=1)
interm_grasp_color_feat = self.grasp_color_trunk.features(rotate_color)
interm_grasp_depth_feat = self.grasp_depth_trunk.features(rotate_depth)
interm_grasp_feat = torch.cat((interm_grasp_color_feat, interm_grasp_depth_feat), dim=1)
self.interm_feat.append([interm_push_feat, interm_grasp_feat])
# Compute sample grid for rotation AFTER branches
affine_mat_after = np.asarray([[np.cos(rotate_theta), np.sin(rotate_theta), 0],[-np.sin(rotate_theta), np.cos(rotate_theta), 0]])
affine_mat_after.shape = (2,3,1)
affine_mat_after = torch.from_numpy(affine_mat_after).permute(2,0,1).float()
if self.use_cuda:
flow_grid_after = F.affine_grid(Variable(affine_mat_after, requires_grad=False).cuda(), interm_push_feat.data.size())
else:
flow_grid_after = F.affine_grid(Variable(affine_mat_after, requires_grad=False), interm_push_feat.data.size())
# Forward pass through branches, undo rotation on output predictions, upsample results
self.output_prob.append([nn.Upsample(scale_factor=16, mode='bilinear').forward(F.grid_sample(self.pushnet(interm_push_feat), flow_grid_after, mode='nearest')),
nn.Upsample(scale_factor=16, mode='bilinear').forward(F.grid_sample(self.graspnet(interm_grasp_feat), flow_grid_after, mode='nearest'))])
return self.output_prob, self.interm_feat
# # OLD VERSION: IMPLICIT ROTATION INSIDE
# def forward(self, input_color_data, input_depth_data, is_volatile=False):
# # Run forward pass through trunk to get intermediate features
# if is_volatile:
# interm_color_feat = self.color_trunk.features(Variable(input_color_data, volatile=True).cuda())
# interm_depth_feat = self.depth_trunk.features(Variable(input_depth_data, volatile=True).cuda())
# interm_feat = torch.cat((interm_color_feat, interm_depth_feat), dim=1)
# output_prob = []
# # Apply rotations to intermediate features
# for rotate_idx in range(self.num_rotations):
# rotate_theta = np.radians(rotate_idx*(360/self.num_rotations))
# # Compute sample grid for rotation BEFORE branches
# affine_mat_before = np.asarray([[np.cos(-rotate_theta), np.sin(-rotate_theta), 0],[-np.sin(-rotate_theta), np.cos(-rotate_theta), 0]])
# affine_mat_before.shape = (2,3,1)
# affine_mat_before = torch.from_numpy(affine_mat_before).permute(2,0,1).float()
# flow_grid_before = F.affine_grid(Variable(affine_mat_before, requires_grad=False).cuda(), interm_feat.data.size())
# # Rotate intermediate features clockwise
# rotate_feat = F.grid_sample(interm_feat, flow_grid_before, mode='nearest')
# # test = rotate_feat.cpu().data.numpy()
# # test = np.sum(test[0,:,:,:], axis=0)
# # plt.imshow(test)
# # plt.show()
# # Compute sample grid for rotation AFTER branches
# affine_mat_after = np.asarray([[np.cos(rotate_theta), np.sin(rotate_theta), 0],[-np.sin(rotate_theta), np.cos(rotate_theta), 0]])
# affine_mat_after.shape = (2,3,1)
# affine_mat_after = torch.from_numpy(affine_mat_after).permute(2,0,1).float()
# flow_grid_after = F.affine_grid(Variable(affine_mat_after, requires_grad=False).cuda(), rotate_feat.data.size())
# # Forward pass through branches, undo rotation on output predictions, upsample results
# output_prob.append([nn.Upsample(scale_factor=16, mode='bilinear').forward(F.grid_sample(self.pushnet(rotate_feat), flow_grid_after, mode='nearest')),
# nn.Upsample(scale_factor=16, mode='bilinear').forward(F.grid_sample(self.graspnet(rotate_feat), flow_grid_after, mode='nearest'))])
# return output_prob, interm_feat
# else:
# interm_color_feat = self.color_trunk.features(Variable(input_color_data, requires_grad=False).cuda())
# interm_depth_feat = self.depth_trunk.features(Variable(input_depth_data, requires_grad=False).cuda())
# self.interm_feat = torch.cat((interm_color_feat, interm_depth_feat), dim=1)
# self.output_prob = []
# # Apply rotations to intermediate features
# # for rotate_idx in range(self.num_rotations):
# rotate_idx = specific_rotation
# rotate_theta = np.radians(rotate_idx*(360/self.num_rotations))
# # Compute sample grid for rotation BEFORE branches
# affine_mat_before = np.asarray([[np.cos(-rotate_theta), np.sin(-rotate_theta), 0],[-np.sin(-rotate_theta), np.cos(-rotate_theta), 0]])
# affine_mat_before.shape = (2,3,1)
# affine_mat_before = torch.from_numpy(affine_mat_before).permute(2,0,1).float()
# flow_grid_before = F.affine_grid(Variable(affine_mat_before, requires_grad=False).cuda(), self.interm_feat.data.size())
# # Rotate intermediate features clockwise
# rotate_feat = F.grid_sample(self.interm_feat, flow_grid_before, mode='nearest')
# # Compute sample grid for rotation AFTER branches
# affine_mat_after = np.asarray([[np.cos(rotate_theta), np.sin(rotate_theta), 0],[-np.sin(rotate_theta), np.cos(rotate_theta), 0]])
# affine_mat_after.shape = (2,3,1)
# affine_mat_after = torch.from_numpy(affine_mat_after).permute(2,0,1).float()
# flow_grid_after = F.affine_grid(Variable(affine_mat_after, requires_grad=False).cuda(), rotate_feat.data.size())
# # Forward pass through branches, undo rotation on output predictions, upsample results
# self.output_prob.append([nn.Upsample(scale_factor=16, mode='bilinear').forward(F.grid_sample(self.pushnet(rotate_feat), flow_grid_after, mode='nearest')),
# nn.Upsample(scale_factor=16, mode='bilinear').forward(F.grid_sample(self.graspnet(rotate_feat), flow_grid_after, mode='nearest'))])
# return self.output_prob, self.interm_feat