-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathutils.py
29 lines (22 loc) · 770 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import os
import json
def get_lr(optimizer):
for param_group in optimizer.param_groups:
curr_lr = param_group['lr']
return curr_lr
def get_grad_norm(model):
total_norm = 0
for p in model.parameters():
if p.requires_grad:
if p.grad is None:
continue
else:
param_norm = p.grad.detach().data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
return total_norm
def save_params(model_dir, params, name='params'):
"""Save params to a .json file. Params is a dictionary of parameters."""
path = os.path.join(model_dir, f'{name}.json')
with open(path, 'w') as f:
json.dump(params, f, indent=2, sort_keys=True)