-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
137 lines (113 loc) · 4.68 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
import torch.nn as nn
import os, shutil
import numpy as np
from torch.autograd import Variable
def repackage_hidden(h):
if type(h) == Variable:
return Variable(h.data)
else:
return tuple(repackage_hidden(v) for v in h)
def batchify(data, bsz, args):
nbatch = data.size(0) // bsz
data = data.narrow(0, 0, nbatch * bsz)
data = data.view(bsz, -1).t().contiguous()
print(data.size())
if args.cuda:
data = data.cuda()
return data
def get_batch(source, i, args, seq_len=None, evaluation=False):
seq_len = min(seq_len if seq_len else args.bptt, len(source) - 1 - i)
data = Variable(source[i:i+seq_len], volatile=evaluation)
target = Variable(source[i+1:i+1+seq_len])
return data, target
def create_exp_dir(path, scripts_to_save=None):
if not os.path.exists(path):
os.mkdir(path)
print('Experiment dir : {}'.format(path))
if scripts_to_save is not None:
os.mkdir(os.path.join(path, 'scripts'))
for script in scripts_to_save:
dst_file = os.path.join(path, 'scripts', os.path.basename(script))
shutil.copyfile(script, dst_file)
def save_checkpoint(model, optimizer, epoch, path, state,finetune=False):
if finetune:
torch.save(model, os.path.join(path, 'finetune_model.pt'))
torch.save(optimizer.state_dict(), os.path.join(path, 'finetune_optimizer.pt'))
else:
torch.save(model.state_dict(), os.path.join(path, 'model.pt'))
torch.save(model, os.path.join(path, 'model_xs.pt'))
torch.save(state, os.path.join(path, 'state.pt'))
torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer.pt'))
torch.save({'epoch': epoch+1}, os.path.join(path, 'misc.pt'))
def save_checkpoint_t(model, optimizer, epoch, path, finetune=False):
if finetune:
torch.save(model, os.path.join(path, 'finetune_model.pt'))
torch.save(optimizer.state_dict(), os.path.join(path, 'finetune_optimizer.pt'))
else:
torch.save(model, os.path.join(path, 'model.pt'))
torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer.pt'))
torch.save({'epoch': epoch+1}, os.path.join(path, 'misc.pt'))
def embedded_dropout(embed, words, dropout=0.1, scale=None):
if dropout:
mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as(embed.weight) / (1 - dropout)
mask = Variable(mask)
masked_embed_weight = mask * embed.weight
else:
masked_embed_weight = embed.weight
if scale:
masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight
padding_idx = embed.padding_idx
if padding_idx is None:
padding_idx = -1
# X = embed._backend.Embedding.apply(words, masked_embed_weight,
# padding_idx, embed.max_norm, embed.norm_type,
# embed.scale_grad_by_freq, embed.sparse
# )
X = nn.functional.embedding(words, masked_embed_weight,
padding_idx, embed.max_norm, embed.norm_type,
embed.scale_grad_by_freq, embed.sparse
)
return X
class LockedDropout(nn.Module):
def __init__(self):
super(LockedDropout, self).__init__()
def forward(self, x, dropout=0.5):
if not self.training or not dropout:
return x
m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout)
mask = Variable(m.div_(1 - dropout), requires_grad=False)
mask = mask.expand_as(x)
return mask * x
def mask2d(B, D, keep_prob, cuda=True):
m = torch.floor(torch.rand(B, D) + keep_prob) / keep_prob
m = Variable(m, requires_grad=False)
if cuda:
m = m.cuda()
return m
def load(model, model_path):
# model.load_state_dict(torch.load(model_path))
# original saved file with DataParallel
state_dict = torch.load(model_path) # 模型可以保存为pth文件,也可以为pt文件。
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
# print(k)
if k == 'rnns.0.bn.running_mean' or k =='rnns.0.bn.running_var':
pass
else:
name = k
new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。
# model.load_state_dict(new_state_dict) # 从新加载这个模型。
model.load_state_dict(new_state_dict,False)
def load_cpu(model, model_path):
model.load_state_dict(torch.load(model_path, map_location='cpu'),False)
def word2list(ii,word_2_id):
output = []
for i in ii:
output.append(i)
max_i = sorted(output)[-1]
index = output.index(max_i)
index_word = word_2_id[index]
output.append(index_word)