-
Notifications
You must be signed in to change notification settings - Fork 305
/
Copy pathpool.py
452 lines (381 loc) · 14.3 KB
/
pool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
"""Implements downsampling and upsampling on sequences."""
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange, repeat, reduce
from src.models.sequence import SequenceModule
from src.models.nn import LinearActivation
"""The following pooling modules all subscribe to the same interface.
stride: Subsample on the layer dimension.
expand: Expansion factor on the feature dimension.
"""
class DownSample(SequenceModule):
def __init__(self, d_input, stride=1, expand=1, transposed=True):
super().__init__()
self.d_input = d_input
self.stride = stride
self.expand = expand
self.transposed = transposed
def forward(self, x):
if x is None: return None
if self.stride > 1:
assert x.ndim == 3, "Downsampling with higher-dimensional inputs is currently not supported. It is recommended to use average or spectral pooling instead."
if self.transposed:
x = x[..., 0::self.stride]
else:
x = x[..., 0::self.stride, :]
if self.expand > 1:
if self.transposed:
x = repeat(x, 'b d ... -> b (d e) ...', e=self.expand)
else:
x = repeat(x, 'b ... d -> b ... (d e)', e=self.expand)
return x, None
def step(self, x, state, **kwargs):
if self.stride > 1 or self.expand > 1:
raise NotImplementedError
return x, state
@property
def d_output(self):
return self.d_input * self.expand
class DownAvgPool(SequenceModule):
def __init__(self, d_input, stride=1, expand=None, transposed=True):
super().__init__()
self.d_input = d_input
self.stride = stride
self.expand = expand
self.transposed = transposed
if self.expand is not None:
self.linear = LinearActivation(
d_input,
d_input * expand,
transposed=transposed,
)
def forward(self, x):
if not self.transposed:
x = rearrange(x, 'b ... d -> b d ...')
if self.stride > 1:
# einops appears slower than F
if x.ndim == 3:
x = F.avg_pool1d(x, self.stride, self.stride)
elif x.ndim == 4:
x = F.avg_pool2d(x, self.stride, self.stride)
else:
# Reduction string e.g. "b d (l1 2) (l2 2) -> b d l1 l2"
reduce_str = "b d " + " ".join([f"(l{i} {self.stride})" for i in range(x.ndim-2)]) \
+ " -> b d " + " ".join([f"l{i}" for i in range(x.ndim-2)])
x = reduce(x, reduce_str, 'mean')
# if self.expand > 1:
# x = repeat(x, 'b d ... -> b (d e) ...', e=self.expand)
if not self.transposed:
x = rearrange(x, 'b d ... -> b ... d')
if self.expand is not None:
x = self.linear(x)
return x, None
def step(self, x, state, **kwargs):
if self.stride > 1 or self.expand > 1:
raise NotImplementedError
return x, state
@property
def d_output(self):
if self.expand is None:
return self.d_input
else:
return self.d_input * self.expand
class DownSpectralPool(SequenceModule):
def __init__(self, d_input, stride=1, expand=1, transposed=True):
super().__init__()
self.d_input = d_input
self.stride = stride
self.expand = expand
self.transposed = transposed
def forward(self, x):
"""
x: (B, L..., D)
"""
if not self.transposed:
x = rearrange(x, 'b ... d -> b d ...')
shape = x.shape[2:]
x_f = torch.fft.ifftn(x, s=shape)
for axis, l in enumerate(shape):
assert l % self.stride == 0, 'input length must be divisible by stride'
new_l = l // self.stride
idx = torch.cat([torch.arange(0, new_l-new_l//2), l+torch.arange(-new_l//2, 0)]).to(x_f.device)
x_f = torch.index_select(x_f, 2+axis, idx)
x = torch.fft.ifftn(x_f, s=[l//self.stride for l in shape])
x = x.real
if self.expand > 1:
x = repeat(x, 'b d ... -> b (d e) ...', e=self.expand)
if not self.transposed:
x = rearrange(x, 'b d ... -> b ... d')
return x, None
def step(self, x, state, **kwargs):
if self.stride > 1 or self.expand > 1:
raise NotImplementedError
return x, state
@property
def d_output(self):
return self.d_input * self.expand
class UpSample(SequenceModule):
def __init__(self, d_input, stride=1, expand=1, transposed=True):
super().__init__()
self.d_input = d_input
self.stride = stride
self.expand = expand
self.transposed = transposed
def forward(self, x):
if x is None: return None
if self.expand > 1:
if self.transposed:
x = reduce(x, '... (d e) l -> ... d l', 'mean', e=self.expand)
else:
x = reduce(x, '... (d e) -> ... d', 'mean', e=self.expand)
if self.stride > 1:
if self.transposed:
x = repeat(x, '... l -> ... (l e)', e=self.stride)
else:
x = repeat(x, '... l d -> ... (l e) d', e=self.stride)
return x, None
@property
def d_output(self):
return self.d_input // self.expand
def step(self, x, state, **kwargs):
if self.stride > 1 or self.expand > 1:
raise NotImplementedError
return x, state
class UpAvgPool(SequenceModule):
def __init__(self, d_input, stride=1, expand=1, causal=False, transposed=True):
super().__init__()
assert d_input % expand == 0
self.d_input = d_input
self.stride = stride
self.expand = expand
self.causal = causal
self.transposed = transposed
self.linear = LinearActivation(
d_input,
d_input // expand,
transposed=transposed,
)
def forward(self, x):
# TODO only works for 1D right now
if x is None: return None
x = self.linear(x)
if self.stride > 1:
if self.transposed:
if self.causal:
x = F.pad(x[..., :-1], (1, 0)) # Shift to ensure causality
x = repeat(x, '... l -> ... (l e)', e=self.stride)
else:
if self.causal:
x = F.pad(x[..., :-1, :], (0, 0, 1, 0)) # Shift to ensure causality
x = repeat(x, '... l d -> ... (l e) d', e=self.stride)
return x, None
@property
def d_output(self):
return self.d_input // self.expand
def step(self, x, state, **kwargs):
if self.stride > 1 or self.expand > 1:
raise NotImplementedError
return x, state
class DownLinearPool(SequenceModule):
def __init__(self, d_model, stride=1, expand=1, causal=False, transposed=True):
super().__init__()
self.d_model = d_model
self.stride = stride
self.expand = expand
self.transposed = transposed
self.linear = LinearActivation(
d_model * stride,
d_model * expand,
transposed=transposed,
)
def forward(self, x):
if self.transposed:
x = rearrange(x, '... h (l s) -> ... (h s) l', s=self.stride)
else:
x = rearrange(x, '... (l s) h -> ... l (h s)', s=self.stride)
x = self.linear(x)
return x, None
def step(self, x, state, **kwargs):
if x is None: return None, state
state.append(x)
if len(state) == self.stride:
x = rearrange(torch.stack(state, dim=-1), '... h s -> ... (h s)')
if self.transposed: x = x.unsqueeze(-1)
x = self.linear(x)
if self.transposed: x = x.squeeze(-1)
return x, []
else:
return None, state
def default_state(self, *batch_shape, device=None):
return []
@property
def d_output(self):
return self.d_input * self.expand
class UpLinearPool(SequenceModule):
def __init__(self, d, stride=1, expand=1, causal=False, transposed=True):
super().__init__()
assert d % expand == 0
self.d_model = d
self.d_output = d // expand
# self._d_output = d_output
self.stride = stride
self.causal = causal
self.transposed = transposed
self.linear = LinearActivation(
self.d_model,
self.d_output * stride,
transposed=transposed,
)
def forward(self, x, skip=None):
x = self.linear(x)
if self.transposed:
if self.causal:
x = F.pad(x[..., :-1], (1, 0)) # Shift to ensure causality
x = rearrange(x, '... (h s) l -> ... h (l s)', s=self.stride)
else:
if self.causal:
x = F.pad(x[..., :-1, :], (0, 0, 1, 0)) # Shift to ensure causality
x = rearrange(x, '... l (h s) -> ... (l s) h', s=self.stride)
if skip is not None:
x = x + skip
return x, None
def step(self, x, state, **kwargs):
"""
x: (..., H)
"""
assert len(state) > 0
y, state = state[0], state[1:]
if len(state) == 0:
assert x is not None
if self.transposed: x = x.unsqueeze(-1)
x = self.linear(x)
if self.transposed: x = x.squeeze(-1)
x = rearrange(x, '... (h s) -> ... h s', s=self.stride)
state = list(torch.unbind(x, dim=-1))
else: assert x is None
return y, state
def default_state(self, *batch_shape, device=None):
state = torch.zeros(batch_shape + (self.d_output, self.stride), device=device) # (batch, h, s)
state = list(torch.unbind(state, dim=-1)) # List of (..., H)
return state
"""Pooling functions with trainable parameters."""
class DownPool2d(SequenceModule):
def __init__(self, d_input, d_output, stride=1, transposed=True, weight_norm=True):
# TODO make d_output expand instead
super().__init__()
self.linear = LinearActivation(
d_input,
d_output,
transposed=transposed,
weight_norm=weight_norm,
)
self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride),
def forward(self, x):
if self.transposed:
x = self.pool(x)
# TODO DownPool/UpPool are currently used by unet/sashimi backbones
# DownLinearPool is used by the registry (for isotropic backbone)
# DownPool is essentially the same as DownLinearPool. These should be consolidated
class DownPool(SequenceModule):
def __init__(self, d_input, d_output=None, expand=None, stride=1, transposed=True, weight_norm=True, initializer=None, activation=None):
super().__init__()
assert (d_output is None) + (expand is None) == 1
if d_output is None: d_output = d_input * expand
self.d_output = d_output
self.stride = stride
self.transposed = transposed
self.linear = LinearActivation(
d_input * stride,
d_output,
transposed=transposed,
initializer=initializer,
weight_norm = weight_norm,
activation=activation,
activate=True if activation is not None else False,
)
def forward(self, x):
if self.transposed:
x = rearrange(x, '... h (l s) -> ... (h s) l', s=self.stride)
else:
x = rearrange(x, '... (l s) h -> ... l (h s)', s=self.stride)
x = self.linear(x)
return x, None
def step(self, x, state, **kwargs):
"""
x: (..., H)
"""
if x is None: return None, state
state.append(x)
if len(state) == self.stride:
x = rearrange(torch.stack(state, dim=-1), '... h s -> ... (h s)')
if self.transposed: x = x.unsqueeze(-1)
x = self.linear(x)
if self.transposed: x = x.squeeze(-1)
return x, []
else:
return None, state
def default_state(self, *batch_shape, device=None):
return []
class UpPool(SequenceModule):
def __init__(self, d_input, d_output, stride, transposed=True, weight_norm=True, initializer=None, activation=None):
super().__init__()
self.d_input = d_input
self._d_output = d_output
self.stride = stride
self.transposed = transposed
self.linear = LinearActivation(
d_input,
d_output * stride,
transposed=transposed,
initializer=initializer,
weight_norm = weight_norm,
activation=activation,
activate=True if activation is not None else False,
)
def forward(self, x, skip=None):
x = self.linear(x)
if self.transposed:
x = F.pad(x[..., :-1], (1, 0)) # Shift to ensure causality
x = rearrange(x, '... (h s) l -> ... h (l s)', s=self.stride)
else:
x = F.pad(x[..., :-1, :], (0, 0, 1, 0)) # Shift to ensure causality
x = rearrange(x, '... l (h s) -> ... (l s) h', s=self.stride)
if skip is not None:
x = x + skip
return x, None
def step(self, x, state, **kwargs):
"""
x: (..., H)
"""
assert len(state) > 0
y, state = state[0], state[1:]
if len(state) == 0:
assert x is not None
if self.transposed: x = x.unsqueeze(-1)
x = self.linear(x)
if self.transposed: x = x.squeeze(-1)
x = rearrange(x, '... (h s) -> ... h s', s=self.stride)
state = list(torch.unbind(x, dim=-1))
else: assert x is None
return y, state
def default_state(self, *batch_shape, device=None):
state = torch.zeros(batch_shape + (self.d_output, self.stride), device=device) # (batch, h, s)
state = list(torch.unbind(state, dim=-1)) # List of (..., H)
return state
@property
def d_output(self): return self._d_output
registry = {
'sample': DownSample,
'pool': DownAvgPool,
'avg': DownAvgPool,
'linear': DownLinearPool,
'spectral': DownSpectralPool,
}
up_registry = {
# 'sample': UpSample,
'pool': UpAvgPool,
'avg': UpAvgPool,
'linear': UpLinearPool,
# 'spectral': UpSpectralPool, # Not implemented and no way to make this causal
}