-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathfinetuning.py
182 lines (139 loc) · 4.94 KB
/
finetuning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# -*- coding: utf-8 -*-
"""
Created on Tue May 16 16:52:47 2023
@author: Administrator
"""
import os
os.environ["RECOMPUTE"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
epochs = 20 # 总的epoch
batch_size = 20 # 根据显存设置
config_path = "misaka-v3/misaka_v3.json" # config路径
dirs = "data/" # 数据集的路径,会读取路径下的全部csv
model_load_weight = "misaka_v3.h5" # 待读取权重的路径
max_input_len = 128 # 最大输入长度
max_output_len = 512 # 最大输出长度
model_save_weight = "expand.h5" # 保存路径
from sklearn.utils import shuffle
import numpy as np
import pandas as pd
from sznlp.my_bert4keras.tokenizers import Tokenizer
tokenizer = Tokenizer("misaka-v3/vocab.txt", do_lower_case=True)
from sznlp.my_bert4keras.backend import tf, keras, K
from sznlp.my_bert4keras.layers import Loss
import time
while tf.test.is_gpu_available() == False:
print("fingding gpu")
time.sleep(1)
from sznlp.my_bert4keras.snippets import sequence_padding
print(tf.__version__)
from tqdm import tqdm
from sznlp.my_bert4keras.optimizers import Adam, AdaFactor
from sznlp.my_bert4keras.optimizers import extend_with_weight_decay, Tiger
from sznlp.misaka_models import *
from sznlp.my_bert4keras.optimizers import extend_with_layer_adaptation
from sznlp.my_bert4keras.optimizers import extend_with_piecewise_linear_lr
from sznlp.my_bert4keras.optimizers import extend_with_gradient_accumulation
from sznlp.my_bert4keras.optimizers import extend_with_piecewise_linear_lr
from sznlp.my_bert4keras.models import build_transformer_model
class CrossEntropy(Loss):
"""交叉熵作为loss,并mask掉输入部分@"""
def compute_loss(self, inputs, mask=None):
y_true, y_pred = inputs
y_pred = keras.layers.Activation("linear", dtype="float32")(y_pred)
y_mask = K.cast(K.not_equal(y_true, 0), y_pred.dtype)
y_true = y_true[:, 1:] # 目标token_ids
y_mask = y_mask[:, 1:] # segment_ids,刚好指示了要预测的部分
y_pred = y_pred[:, :-1] # 预测序列,错开一位
acc = keras.metrics.sparse_categorical_accuracy(y_true, y_pred)
acc = K.cast(acc, y_pred.dtype)
acc = K.sum(acc * y_mask) / K.sum(y_mask)
self.add_metric(acc, name="accuracy") # , aggregation='mean')
loss = K.sparse_categorical_crossentropy(
y_true,
y_pred, # from_logits=True
)
loss = K.sum(loss * y_mask) / K.sum(y_mask)
return loss * 1000
# with strategy.scope():
if True:
misaka = build_transformer_model(
config_path=config_path,
model=Misaka_V3,
# with_lm='linear',
return_keras_model=False,
)
# model.summary()
model = misaka.model
output = CrossEntropy(1)(model.inputs[1:] + model.outputs)
train_model = keras.models.Model(model.inputs, output)
encoder = misaka.encoder
decoder = misaka.decoder
optimizer = AdaFactor(
learning_rate=2e-5,
)
train_model.compile(optimizer=optimizer)
model.summary()
try:
train_model.load_weights(model_load_weight, by_name=True)
print("成功加载权重")
except:
try:
misaka.encoder.load_weights(model_load_weight, by_name=True)
print("成功加载权重编码器")
except:
print("模型加载失败")
from tqdm import tqdm
def load_data(filename):
f = pd.read_csv(filename).values[:, 1:]
f = shuffle(f)
encoders, decoders = [], []
for t in tqdm(f):
for a in t[:3]:
if type(a) == float:
continue
try:
inputs, outputs = t[:2]
encoder = tokenizer.encode(inputs.replace("氼。", "氼"))[0][
-1 * max_input_len :
]
decoder = tokenizer.encode(
outputs.replace("氼。", "氼"), maxlen=max_output_len
)[0]
if len(decoder) < 128:
continue
encoder[0] = tokenizer._token_start_id
encoders.append(encoder)
decoders.append(decoder)
except:
continue
return [encoders, decoders]
files = os.listdir(dirs)
x, y = [], []
print("开启数据加载")
for i, filename in enumerate(files):
print(i, "/", len(files))
if ".csv" not in filename.lower():
continue
x_t, y_t = load_data(dirs + filename)
x.extend(x_t)
y.extend(y_t)
print("开启数据填充")
x = sequence_padding(x)
y = sequence_padding(y)
print(x.shape, y.shape)
a = []
num = 0
class Save(keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
train_model.save_weights(model_save_weight)
evaluator = Save()
train_model.fit(
[x, y],
epochs=epochs,
verbose=1,
batch_size=batch_size,
shuffle=True,
callbacks=[evaluator],
)
train_model.save_weights(model_save_weight)