-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
154 lines (135 loc) · 5.42 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, AutoModelForQuestionAnswering
from dataclasses import dataclass
from typing import Tuple
@dataclass
class ModelOutput:
start_logits: torch.Tensor
end_logits: torch.Tensor
loss: torch.Tensor
class AbhishekModel(nn.Module):
def __init__(self, model_name: str) -> None:
super(AbhishekModel, self).__init__()
hidden_dropout_prob: float = 0.0
layer_norm_eps: float = 1e-7
config = AutoConfig.from_pretrained(model_name)
config.update(
{
"output_hidden_states": True,
"hidden_dropout_prob": hidden_dropout_prob,
"layer_norm_eps": layer_norm_eps,
"add_pooling_layer": False,
}
)
self.transformer = AutoModel.from_pretrained(model_name, config=config)
self.output = nn.Linear(config.hidden_size, config.num_labels)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
start_positions: torch.Tensor = None,
end_positions: torch.Tensor = None
) -> ModelOutput:
transformer_out = self.transformer(input_ids, attention_mask)
sequence_output = transformer_out[0]
logits = self.output(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
loss = None
if start_positions is not None and end_positions is not None:
loss = self._loss(start_logits, end_logits, start_positions, end_positions)
return ModelOutput(start_logits=start_logits, end_logits=end_logits, loss=loss)
def _loss(
self,
start_logits: torch.Tensor,
end_logits: torch.Tensor,
start_positions: torch.Tensor,
end_positions: torch.Tensor
) -> torch.Tensor:
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
return total_loss
class TorchModel(nn.Module):
def __init__(
self,
model_name: str,
init_weights: bool = True,
torchscript: bool = False
) -> None:
super(TorchModel, self).__init__()
self.torchscript = torchscript
self.config = AutoConfig.from_pretrained(model_name, torchscript=torchscript)
self.xlm_roberta = AutoModel.from_pretrained(model_name, config=self.config)
self.qa_outputs = nn.Linear(self.config.hidden_size, 2)
if init_weights:
self._init_weights(self.qa_outputs)
def _init_weights(self, module: nn.Linear) -> None:
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
start_positions: torch.Tensor = None,
end_positions: torch.Tensor = None
) -> ModelOutput:
outputs = self.xlm_roberta(input_ids, attention_mask)
sequence_output = outputs[0]
qa_logits = self.qa_outputs(sequence_output)
start_logits, end_logits = qa_logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
loss = None
if start_positions is not None and end_positions is not None:
loss = self._loss_fn(start_logits, end_logits, start_positions, end_positions)
if self.torchscript:
return start_logits, end_logits
else:
return ModelOutput(start_logits=start_logits, end_logits=end_logits, loss=loss)
def _loss_fn(
self,
start_preds: torch.Tensor,
end_preds: torch.Tensor,
start_labels: torch.Tensor,
end_labels: torch.Tensor
) -> torch.Tensor:
start_loss = nn.CrossEntropyLoss(ignore_index=-1)(start_preds, start_labels)
end_loss = nn.CrossEntropyLoss(ignore_index=-1)(end_preds, end_labels)
total_loss = (start_loss + end_loss) / 2
return total_loss
def make_model(
model_name: str,
model_type: str = "hf",
model_weights: str = None,
device: str = "cuda",
torchscript: bool = False
) -> nn.Module:
if model_type == "hf":
model = AutoModelForQuestionAnswering.from_pretrained(
model_name,
torchscript=torchscript
)
elif model_type == "abhishek":
model = AbhishekModel(model_name)
elif model_type == "torch":
model = TorchModel(model_name, torchscript=torchscript)
else:
raise ValueError(f"{model_type} is not a recognised model type.")
if model_weights:
print(f"Loading weights from {model_weights}")
model.load_state_dict(torch.load(model_weights, map_location=torch.device(device)))
model.to(device)
return model