-
Notifications
You must be signed in to change notification settings - Fork 3
/
auto_scale_topk.py
37 lines (35 loc) · 1.25 KB
/
auto_scale_topk.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import powersgd_grad_original as psgd_original
from timer import Timer
auto_scale_high = 0.99
auto_scale_low = 0.1
alpha = 1.1
def metric(*args, **kwargs):
if True == 0:
log_metric(*args, **kwargs)
#instead of per layer version use the one we normally
timer = Timer(verbosity_level=2, log_fn=metric)
def run_auto_scale_gng(current_grad_norms, old_grad_norms, current_epoch):
"""
current grad norms and old grad norms
"""
if old_grad_norms is None:
old_grad_norms = [None]*len(current_grad_norms)
auto_scale_candidate = []
ratio_list = []
for new_norm, prev_norm in zip(current_grad_norms, old_grad_norms):
if current_epoch !=0:
if prev_norm != 0.0:
ratio = (abs(prev_norm - new_norm))/(prev_norm)
ratio_list.append(ratio)
else:
ratio = 9
ratio_list.append(ratio)
# if prev norm is zero handle it and give it a high rank
if ratio < 0.5:
auto_scale_candidate.append(auto_scale_low)
else:
auto_scale_candidate.append(auto_scale_high)
else:
auto_scale_candidate.append(auto_scale_high)
return (auto_scale_candidate, ratio_list)