-
Notifications
You must be signed in to change notification settings - Fork 5
/
models.py
executable file
·195 lines (178 loc) · 9.27 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from transformers.modeling_bert import BertPreTrainedModel, BertModel
class BertForSequenceClassification(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.deterministic = config.deterministic
self.ib_dim = config.ib_dim
self.ib = config.ib
self.activation = config.activation
self.activations = {'tanh': nn.Tanh(), 'relu': nn.ReLU(), 'sigmoid': nn.Sigmoid()}
if self.ib or self.deterministic:
self.kl_annealing = config.kl_annealing
self.hidden_dim = config.hidden_dim
intermediate_dim = (self.hidden_dim+config.hidden_size)//2
self.mlp = nn.Sequential(
nn.Linear(config.hidden_size, intermediate_dim),
self.activations[self.activation],
nn.Linear(intermediate_dim, self.hidden_dim),
self.activations[self.activation])
self.beta = config.beta
self.sample_size = config.sample_size
self.emb2mu = nn.Linear(self.hidden_dim, self.ib_dim)
self.emb2std = nn.Linear(self.hidden_dim, self.ib_dim)
self.mu_p = nn.Parameter(torch.randn(self.ib_dim))
self.std_p = nn.Parameter(torch.randn(self.ib_dim))
self.classifier = nn.Linear(self.ib_dim, self.config.num_labels)
else:
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
self.init_weights()
def estimate(self, emb, emb2mu, emb2std):
"""Estimates mu and std from the given input embeddings."""
mean = emb2mu(emb)
std = torch.nn.functional.softplus(emb2std(emb))
return mean, std
def kl_div(self, mu_q, std_q, mu_p, std_p):
"""Computes the KL divergence between the two given variational distribution.\
This computes KL(q||p), which is not symmetric. It quantifies how far is\
The estimated distribution q from the true distribution of p."""
k = mu_q.size(1)
mu_diff = mu_p - mu_q
mu_diff_sq = torch.mul(mu_diff, mu_diff)
logdet_std_q = torch.sum(2 * torch.log(torch.clamp(std_q, min=1e-8)), dim=1)
logdet_std_p = torch.sum(2 * torch.log(torch.clamp(std_p, min=1e-8)), dim=1)
fs = torch.sum(torch.div(std_q ** 2, std_p ** 2), dim=1) + torch.sum(torch.div(mu_diff_sq, std_p ** 2), dim=1)
kl_divergence = (fs - k + logdet_std_p - logdet_std_q)*0.5
return kl_divergence.mean()
def reparameterize(self, mu, std):
batch_size = mu.shape[0]
z = torch.randn(self.sample_size, batch_size, mu.shape[1]).cuda()
return mu + std * z
def get_logits(self, z, mu, sampling_type):
if sampling_type == "iid":
logits = self.classifier(z)
mean_logits = logits.mean(dim=0)
logits = logits.permute(1, 2, 0)
else:
mean_logits = self.classifier(mu)
logits = mean_logits
return logits, mean_logits
def sampled_loss(self, logits, mean_logits, labels, sampling_type):
if sampling_type == "iid":
# During the training, computes the loss with the sampled embeddings.
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1, self.sample_size), labels[:, None].float().expand(-1, self.sample_size))
loss = torch.mean(loss, dim=-1)
loss = torch.mean(loss, dim=0)
else:
loss_fct = CrossEntropyLoss(reduce=False)
loss = loss_fct(logits, labels[:, None].expand(-1, self.sample_size))
loss = torch.mean(loss, dim=-1)
loss = torch.mean(loss, dim=0)
else:
# During test time, uses the average value for prediction.
if self.num_labels == 1:
loss_fct = MSELoss()
loss = loss_fct(mean_logits.view(-1), labels.float().view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(mean_logits, labels)
return loss
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
sampling_type="iid",
epoch=1,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the sequence classification/regression loss.
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
from transformers import BertTokenizer, BertForSequenceClassification
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
"""
final_outputs = {}
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
loss = {}
if self.deterministic:
pooled_output = self.mlp(pooled_output)
mu, std = self.estimate(pooled_output, self.emb2mu, self.emb2std)
final_outputs["z"] = mu
sampled_logits, logits = self.get_logits(mu, mu, sampling_type='argmax') # always deterministic
if labels is not None:
loss["loss"] = self.sampled_loss(sampled_logits, logits, labels.view(-1), sampling_type='argmax')
elif self.ib:
pooled_output = self.mlp(pooled_output)
batch_size = pooled_output.shape[0]
mu, std = self.estimate(pooled_output, self.emb2mu, self.emb2std)
mu_p = self.mu_p.view(1, -1).expand(batch_size, -1)
std_p = torch.nn.functional.softplus(self.std_p.view(1, -1).expand(batch_size, -1))
kl_loss = self.kl_div(mu, std, mu_p, std_p)
z = self.reparameterize(mu, std)
final_outputs["z"] = mu
if self.kl_annealing == "linear":
beta = min(1.0, epoch*self.beta)
sampled_logits, logits = self.get_logits(z, mu, sampling_type)
if labels is not None:
ce_loss = self.sampled_loss(sampled_logits, logits, labels.view(-1), sampling_type)
loss["loss"] = ce_loss + (beta if self.kl_annealing == "linear" else self.beta) * kl_loss
else:
final_outputs["z"] = pooled_output
logits = self.classifier(pooled_output)
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss["loss"] = loss_fct(logits.view(-1), labels.float().view(-1))
else:
loss_fct = CrossEntropyLoss()
loss["loss"] = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
final_outputs.update({"logits": logits, "loss": loss, "hidden_attention": outputs[2:]})
return final_outputs