-
Notifications
You must be signed in to change notification settings - Fork 1
/
rwkv_generate_test.py
49 lines (45 loc) · 1.6 KB
/
rwkv_generate_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
os.environ['KERAS_BACKEND'] = 'jax'
os.environ['OPS_KERNEL'] = '1'#1是否使用纯py的wkv算子,0就使用cuda算子
import keras
keras.config.set_floatx('bfloat16')
from bert4keras3.models import build_transformer_model
from bert4keras3.tokenizers import RWKV_TOKENIZER
from bert4keras3.Models.RWKV import *
import numpy as np
base_path = 'RWKV6-1.6B/'
config_path = base_path+'config.json'
weights_path = base_path+'model.weights.h5'
dict_path = base_path+'rwkv_vocab_v20230424.txt'
maxlen = 2048
tokenizer = RWKV_TOKENIZER(dict_path)
RWKV = build_transformer_model(
config_path=config_path,
model='rwkv6',
keras_weights_path=weights_path,
return_keras_model=False,
sequence_length = maxlen,
with_lm='softmax',
)
rwkv = RWKV.model
generate_model = RWKV.build_cache_model(input_lengths=[maxlen],end_token=-1,#rwkv的结束符号是\n\n,对应词表里就算261
search_mode='topp',k=0.5,progress_print=True,index_bias=0)
generate_model.compile( jit_compile="auto",)
text = '''\n下面是一个关于python实现'''
print('test generate')
def generate(text):
x = tokenizer.encode(text)[0]
x+= [0]*(maxlen-len(x))
x = np.array([x],dtype='int32')
import time
start = time.time()
o2 = generate_model.predict([x])[0]
o2 = o2[o2!=0]
print(len(o2)-len(tokenizer.encode(text)[0]))
print(tokenizer.decode([o2])[0])
times = time.time()-start
print('\n')
print('总耗时为'+str(times)+' 秒')
print('推理速度为'+str(sum(o2!=0)//times)+' token/s')
generate(text)
generate(text)