-
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Lookahead.py
58 lines (50 loc) · 1.98 KB
/
Lookahead.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
from torch.optim import Optimizer
from collections import defaultdict
class Lookahead(Optimizer):
r'''Implements Lookahead optimizer.
It's been proposed in paper: Lookahead Optimizer: k steps forward, 1 step back
(https://arxiv.org/pdf/1907.08610.pdf)
Args:
optimizer: The optimizer object used in inner loop for fast weight updates.
alpha: The learning rate for slow weight update.
Default: 0.5
k: Number of iterations of fast weights updates before updating slow
weights.
Default: 5
Example:
> optim = Lookahead(optimizer)
> optim = Lookahead(optimizer, alpha=0.6, k=10)
'''
def __init__(self, optimizer, alpha=0.5, k=5):
assert(0.0 <= alpha <= 1.0)
assert(k >= 1)
self.optimizer = optimizer
self.alpha = alpha
self.k = k
self.k_counter = 0
self.param_groups = self.optimizer.param_groups
self.state = defaultdict(dict)
self.slow_weights = [[param.clone().detach() for param in group['params']] for group in self.param_groups]
def step(self, closure=None):
loss = self.optimizer.step(closure)
self.k_counter += 1
if self.k_counter >= self.k:
for group, slow_weight in zip(self.param_groups, self.slow_weights):
for param, weight in zip(group['params'], slow_weight):
weight.data.add_(self.alpha, (param.data - weight.data))
param.data.copy_(weight.data)
self.k_counter = 0
return loss
def __getstate__(self):
return {
'state': self.state,
'optimizer': self.optimizer,
'alpha': self.alpha,
'k': self.k,
'k_counter': self.k_counter
}
def state_dict(self):
return self.optimizer.state_dict()
def load_state_dict(self, state_dict):
self.optimizer.load_state_dict(state_dict)