-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathmodule.py
192 lines (149 loc) · 7.6 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
from transformers import GPT2LMHeadModel
import torch.nn as nn
import torch
import copy
class UIPrompt:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, nuser, nitem, freezeLM=True, **kwargs):
model = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
# freeze pretrained model parameters
if freezeLM:
for param in model.parameters():
param.requires_grad = False
model.init_prompt(nuser, nitem)
return model
def init_prompt(self, nuser, nitem):
self.src_len = 2
emsize = self.transformer.wte.weight.size(1) # 768
self.user_embeddings = nn.Embedding(nuser, emsize)
self.item_embeddings = nn.Embedding(nitem, emsize)
initrange = 0.1
self.user_embeddings.weight.data.uniform_(-initrange, initrange)
self.item_embeddings.weight.data.uniform_(-initrange, initrange)
def forward(self, user, item, text, mask, ignore_index=-100):
device = user.device
batch_size = user.size(0)
# embeddings
u_src = self.user_embeddings(user) # (batch_size, emsize)
i_src = self.item_embeddings(item) # (batch_size, emsize)
w_src = self.transformer.wte(text) # (batch_size, tgt_len, emsize)
src = torch.cat([u_src.unsqueeze(1), i_src.unsqueeze(1), w_src], 1) # (batch_size, total_len, emsize)
if mask is None:
# auto-regressive generation
return super().forward(inputs_embeds=src)
else:
# training
# input padding
pad_left = torch.ones((batch_size, self.src_len), dtype=torch.int64).to(device)
pad_input = torch.cat([pad_left, mask], 1) # (batch_size, total_len)
# prediction for training
pred_left = torch.full((batch_size, self.src_len), ignore_index, dtype=torch.int64).to(device) # (batch_size, src_len)
pred_right = torch.where(mask == 1, text, torch.tensor(ignore_index).to(device)) # replace <pad> with ignore_index
prediction = torch.cat([pred_left, pred_right], 1) # (batch_size, total_len)
return super().forward(attention_mask=pad_input, inputs_embeds=src, labels=prediction)
class ContinuousPromptLearning(UIPrompt, GPT2LMHeadModel):
def __init__(self, config):
super().__init__(config)
class FeaturePrompt:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
def forward(self, context, explanation, exp_mask, ignore_index=-100):
device = context.device
text = torch.cat([context, explanation], 1) # (batch_size, total_len)
src = self.transformer.wte(text) # (batch_size, total_len, emsize)
if exp_mask is None:
# auto-regressive generation
return super().forward(inputs_embeds=src)
else:
# training
# input padding
pad_left = torch.ones_like(context, dtype=torch.int64).to(device)
pad_input = torch.cat([pad_left, exp_mask], 1) # (batch_size, total_len)
# prediction for training
pred_left = torch.full_like(context, ignore_index, dtype=torch.int64).to(device) # (batch_size, src_len)
pred_right = torch.where(exp_mask == 1, explanation, torch.tensor(ignore_index).to(device)) # replace <pad> with ignore_index
prediction = torch.cat([pred_left, pred_right], 1) # (batch_size, total_len)
return super().forward(attention_mask=pad_input, inputs_embeds=src, labels=prediction)
class DiscretePromptLearning(FeaturePrompt, GPT2LMHeadModel):
def __init__(self, config):
super().__init__(config)
class MF(nn.Module):
def __init__(self):
super(MF, self).__init__()
def forward(self, user, item): # (batch_size, emsize)
rating = torch.sum(user * item, 1) # (batch_size,)
return rating
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
class MLP(nn.Module):
def __init__(self, emsize, hidden_size=400, num_layers=2):
super(MLP, self).__init__()
self.first_layer = nn.Linear(emsize * 2, hidden_size)
self.last_layer = nn.Linear(hidden_size, 1)
layer = nn.Linear(hidden_size, hidden_size)
self.layers = _get_clones(layer, num_layers)
self.sigmoid = nn.Sigmoid()
self.init_weights()
def init_weights(self):
initrange = 0.1
self.first_layer.weight.data.uniform_(-initrange, initrange)
self.first_layer.bias.data.zero_()
self.last_layer.weight.data.uniform_(-initrange, initrange)
self.last_layer.bias.data.zero_()
for layer in self.layers:
layer.weight.data.uniform_(-initrange, initrange)
layer.bias.data.zero_()
def forward(self, user, item): # (batch_size, emsize)
ui_cat = torch.cat([user, item], 1) # (batch_size, emsize * 2)
hidden = self.sigmoid(self.first_layer(ui_cat)) # (batch_size, hidden_size)
for layer in self.layers:
hidden = self.sigmoid(layer(hidden)) # (batch_size, hidden_size)
rating = torch.squeeze(self.last_layer(hidden)) # (batch_size,)
return rating
class UIPromptWithReg:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, nuser, nitem, use_mf=True, **kwargs):
model = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
model.init_prompt(nuser, nitem, use_mf)
return model
def init_prompt(self, nuser, nitem, use_mf):
self.src_len = 2
emsize = self.transformer.wte.weight.size(1) # 768
self.user_embeddings = nn.Embedding(nuser, emsize)
self.item_embeddings = nn.Embedding(nitem, emsize)
if use_mf:
self.rec = MF()
else:
self.rec = MLP(emsize)
initrange = 0.1
self.user_embeddings.weight.data.uniform_(-initrange, initrange)
self.item_embeddings.weight.data.uniform_(-initrange, initrange)
def forward(self, user, item, text, mask, rating_prediction=True, ignore_index=-100):
device = user.device
batch_size = user.size(0)
# embeddings
u_src = self.user_embeddings(user) # (batch_size, emsize)
i_src = self.item_embeddings(item) # (batch_size, emsize)
w_src = self.transformer.wte(text) # (batch_size, tgt_len, emsize)
src = torch.cat([u_src.unsqueeze(1), i_src.unsqueeze(1), w_src], 1) # (batch_size, total_len, emsize)
if rating_prediction:
rating = self.rec(u_src, i_src) # (batch_size,)
else:
rating = None
if mask is None:
# auto-regressive generation
return super().forward(inputs_embeds=src), rating
else:
# training
# input padding
pad_left = torch.ones((batch_size, self.src_len), dtype=torch.int64).to(device)
pad_input = torch.cat([pad_left, mask], 1) # (batch_size, total_len)
# prediction for training
pred_left = torch.full((batch_size, self.src_len), ignore_index, dtype=torch.int64).to(device) # (batch_size, src_len)
pred_right = torch.where(mask == 1, text, torch.tensor(ignore_index).to(device)) # replace <pad> with ignore_index
prediction = torch.cat([pred_left, pred_right], 1) # (batch_size, total_len)
return super().forward(attention_mask=pad_input, inputs_embeds=src, labels=prediction), rating
class RecReg(UIPromptWithReg, GPT2LMHeadModel):
def __init__(self, config):
super().__init__(config)