-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathexport.py
85 lines (63 loc) · 3.25 KB
/
export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from modulefinder import Module
from model import UIE
import torch
from torch import nn
import pnnx
model = UIE.from_pretrained('uie-nano-pytorch')
class EmbeddingRepack(nn.Module):
def __init__(self, model):
super().__init__()
self.word_embeddings = model.encoder.embeddings.word_embeddings
self.position_embeddings = model.encoder.embeddings.position_embeddings
self.token_type_embeddings = model.encoder.embeddings.token_type_embeddings
self.task_type_embeddings = model.encoder.embeddings.task_type_embeddings
self.layer_norm = model.encoder.embeddings.LayerNorm
def forward(self, input_ids, position_ids, token_type_ids, task_type_ids):
inputs_embeddings = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
position_embeddings = self.position_embeddings(position_ids)
task_type_embeddings = self.task_type_embeddings(task_type_ids)
embeddings = inputs_embeddings + token_type_embeddings + position_embeddings + task_type_embeddings
embeddings = self.layer_norm(embeddings)
return embeddings
class EncoderRepack(nn.Module):
def __init__(self, model):
super().__init__()
self.encoder = model.encoder.encoder
self.pooler = model.encoder.pooler
def forward(self, embeddings):
encoder_output = self.encoder(embeddings)
return encoder_output.last_hidden_state
class UIERepack(nn.Module):
def __init__(self, model):
super().__init__()
self.embedding_repack = EmbeddingRepack(model)
self.encoder_repack = EncoderRepack(model)
self.linear_start = model.linear_start
self.linear_end = model.linear_end
self.sigmoid = nn.Sigmoid()
def forward(self, input_ids, position_ids, token_type_ids, task_type_ids):
embeds = self.embedding_repack(input_ids, position_ids, token_type_ids,
task_type_ids)
sequence_output = self.encoder_repack(embeds)
start_logits = self.linear_start(sequence_output)
start_logits = torch.squeeze(start_logits, -1)
start_prob = self.sigmoid(start_logits)
end_logits = self.linear_end(sequence_output)
end_logits = torch.squeeze(end_logits, -1)
end_prob = self.sigmoid(end_logits)
return start_prob, end_prob
model_repack = UIERepack(model)
model_repack.eval()
x_0 = torch.tensor([[1, 36, 143, 2, 249, 136, 585, 139, 28, 1598,
252, 560, 1296, 1038, 32, 67, 190, 220, 1311, 1097,
291, 85, 19, 1498, 357, 448, 628, 12, 12, 20,
352, 247, 1146, 329, 1853, 22, 4734, 42, 1266, 59,
349, 116, 192, 664, 12044, 2]], dtype=torch.long)
x_1 = torch.tensor([[i for i in range(x_0.size(1))]], dtype=torch.long)
x_2 = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
x_3 = torch.tensor([[0] * x_0.size(1)])
pnnx.export(model_repack, 'uie_nano_pnnx.pt', [x_0, x_1, x_2, x_3],
inputs2=[torch.tensor([[1] * 16]), torch.tensor([[1] * 16]), torch.tensor([[1] * 16]),
torch.tensor([[1] * 16])])