Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
pass-lin authored Mar 17, 2024
1 parent 5bdad0e commit ac3843d
Show file tree
Hide file tree
Showing 16 changed files with 99 additions and 22 deletions.
2 changes: 1 addition & 1 deletion bert4keras3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#! -*- coding: utf-8 -*-

__version__ = '1.0.2'
__version__ = '1.1.2'

from bert4keras3 import backend,layers,models,snippets,tokenizers
from bert4keras3.backend import ops
Binary file modified bert4keras3/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file added bert4keras3/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file modified bert4keras3/__pycache__/backend.cpython-310.pyc
Binary file not shown.
Binary file added bert4keras3/__pycache__/backend.cpython-39.pyc
Binary file not shown.
Binary file modified bert4keras3/__pycache__/layers.cpython-310.pyc
Binary file not shown.
Binary file added bert4keras3/__pycache__/layers.cpython-39.pyc
Binary file not shown.
Binary file modified bert4keras3/__pycache__/models.cpython-310.pyc
Binary file not shown.
Binary file added bert4keras3/__pycache__/models.cpython-39.pyc
Binary file not shown.
Binary file modified bert4keras3/__pycache__/snippets.cpython-310.pyc
Binary file not shown.
Binary file added bert4keras3/__pycache__/snippets.cpython-39.pyc
Binary file not shown.
Binary file modified bert4keras3/__pycache__/tokenizers.cpython-310.pyc
Binary file not shown.
Binary file added bert4keras3/__pycache__/tokenizers.cpython-39.pyc
Binary file not shown.
16 changes: 15 additions & 1 deletion bert4keras3/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,24 @@
import tensorflow as tf
from functools import wraps
is_tf_keras = strtobool(os.environ.get('TF_KERAS', '0'))
lora_model = strtobool(os.environ.get('ENABLE_LORA', '0'))
#jax使用flash参考https://github.com/nshepperd/flash_attn_jax/releases这里安装flash
enable_flashatt = strtobool(os.environ.get('FLASH_ATTN', '0'))
os.environ["KERAS_BACKEND"]=os.environ.get("KERAS_BACKEND", 'tensorflow')
backlib=os.environ["KERAS_BACKEND"]
if backlib=='tfkeras':
is_tf_keras = True
if enable_flashatt:
raise('tensorflow not support flash-attention')
elif backlib=='torch':
import torch
if enable_flashatt:
from flash_attn import flash_attn_func
def flash_mha(q,k,v,softmax_scale=None, is_causal=False, window_size=(-1,-1)):
return flash_attn_func(q, k, v, softmax_scale=softmax_scale, causal=is_causal,window_size=window_size)
elif backlib=='jax':
if enable_flashatt:
from flash_attn_jax import flash_mha
import jax
if is_tf_keras:
sys.modules['keras'] = tf.keras
Expand Down Expand Up @@ -534,7 +545,10 @@ def actual_grad_fn(*doutputs):
else:
sys.modules['keras.ops']=ops


def slices_index(x,index,axis):
shape = list(ops.shape(x))
shape[axis] = index
return ops.slice(x,ops.zeros_like(shape),shape)
custom_objects = {
'gelu_erf': gelu_erf,
'gelu_tanh': ops.gelu,
Expand Down
53 changes: 44 additions & 9 deletions bert4keras3/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

import numpy as np

from bert4keras3.backend import keras, ops, is_tf_keras,K,tf
from bert4keras3.backend import align, sequence_masking
from bert4keras3.backend import keras, ops, is_tf_keras,K,tf,enable_flashatt
if enable_flashatt:
from bert4keras3.backend import flash_mha
from bert4keras3.backend import align, sequence_masking,backlib
from bert4keras3.backend import recompute_grad,int_shape
from bert4keras3.backend import attention_normalize,divide_no_nan
from bert4keras3.backend import sinusoidal_embeddings
from bert4keras3.backend import sinusoidal_embeddings,slices_index
from bert4keras3.backend import apply_rotary_position_embeddings
from keras import initializers, activations
from keras.layers import *
Expand All @@ -27,7 +29,10 @@ def get_config(self):
return dict(list(base_config.items()) + list(config.items()))
def call(self,inputs, **kwargs):
index = kwargs.get('index')
return ops.expand_dims(ops.take(inputs,index,self.axis),self.axis)
out = ops.expand_dims(ops.take(inputs,index,self.axis),self.axis)
if backlib=='torch':
return slices_index(out,index+1,1)
return out
def compute_output_shape(self, input_shape):
input_shape = list(input_shape)
input_shape[self.axis]=1
Expand Down Expand Up @@ -637,6 +642,16 @@ def pay_attention_to(self, inputs, mask=None, **kwargs):
vw = value_cache
else:
cache = None

if enable_flashatt:
is_causal = False
if a_bias is not None:
is_causal = True
softmax_scale = 1.
if self.attention_scale:
softmax_scale = 1 / self.key_size**0.5
o = flash_mha(qw,kw,vw,softmax_scale=softmax_scale, is_causal=is_causal)
return o,[],[]
# Attention
a = ops.einsum('bjhd,bkhd->bhjk', qw, kw)
# 处理位置编码
Expand All @@ -645,6 +660,7 @@ def pay_attention_to(self, inputs, mask=None, **kwargs):
a = a + ops.einsum('bjhd,jkd->bhjk', qw, position_bias)
elif p_bias == 't5_relative':
position_bias = ops.transpose(inputs[n], (2, 0, 1))
#print(a.shape,position_bias.shape)
a = a + ops.expand_dims(position_bias, 0)
# Attention(续)
if self.attention_scale:
Expand Down Expand Up @@ -823,21 +839,40 @@ def call(self, inputs, mask=None, a_bias=None, p_bias=None):
if p_bias == 'rotary':
q, k = apply_rotary_position_embeddings(inputs[n], q, k)
# Attention
if enable_flashatt and ops.shape(k)==ops.shape(v):
z = self.pay_flash_attention_to(q,k,v, a_bias)
else:
z = self.pay_attention_to(q,k,v,mask, a_bias)
# 计算输出
if self.self_attention==False and self.factorization:
z = self.vW_dense(z)
o = self.o_dense(u * z)
return o
def pay_flash_attention_to(self, q,k,v, a_bias):
is_causal = False
if a_bias is not None:
is_causal = True
softmax_scale = 1.
if self.attention_scale:
softmax_scale = 1 / self.key_size**0.5
if ops.ndim(q)==3:
k = ops.expand_dims(k,2)
q = ops.expand_dims(q,2)
v = ops.expand_dims(v,2)
o = flash_mha(q,k,v,softmax_scale=softmax_scale, is_causal=is_causal)
return ops.squeeze(o,2)
def pay_attention_to(self, q,k,v,mask, a_bias):
a = ops.einsum('bmd,bnd->bmn', q, k)
if self.attention_scale:
a = a / self.key_size**0.5
A = attention_normalize(a, mask, -1, self.normalization, a_bias)
if self.attention_dropout:
A = self.dropout(A)
# 计算输出
try:
z=ops.einsum('bmn,bnd->bmd', A, v)
except:
pass
if self.self_attention==False and self.factorization:
z = self.vW_dense(z)
o = self.o_dense(u * z)
return o
return z

def compute_mask(self, inputs, mask=None):
if isinstance(mask, list):
Expand Down
50 changes: 39 additions & 11 deletions bert4keras3/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# 主要模型

import numpy as np
from bert4keras3.backend import tf,keras,backlib
from bert4keras3.backend import tf,keras,backlib,lora_model
from bert4keras3.layers import *
from bert4keras3.snippets import insert_arguments
from bert4keras3.snippets import delete_arguments
Expand Down Expand Up @@ -473,6 +473,16 @@ def cache_call(self,inputs:list,input_lengths:list,end_token,
caches = self.initial_cache(inputs[:1])
key = 0
x = inputs[key]

class start_index(keras.Layer):
def call(self,x):
z = x!=0
if index_bias>0:
t = ops.ones([ops.shape(z)[0],index_bias],dtype=z.dtype)
z = ops.slice_update(z,[0,0],t)
return ops.max(ops.sum(z,-1))-1


length = input_lengths[key]
self.cache_attention_bias=None
self.cache_position_bias=None
Expand All @@ -494,26 +504,23 @@ def cache_call(self,inputs:list,input_lengths:list,end_token,
attention_mask=self.cache_attention_bias,
position_bias=self.cache_position_bias)
z,cache = out[:-1],out[-1]

caches[index*j:index*j+j]=cache

class start_index(keras.Layer):
def call(self,x):
z = x!=0
if index_bias>0:
t = ops.ones([ops.shape(z)[0],index_bias],dtype=z.dtype)
z = ops.slice_update(z,[0,0],t)
return ops.max(ops.sum(z,-1))-1


index = self.apply(
inputs=x,
layer=start_index,
name='start_index'
)

def cond(inputs, caches, index , flags):
cond1 = ops.less(index,length-1)
cond2 = ops.logical_not(ops.all(ops.equal(inputs[key][:,index],end_token),-1))
return ops.logical_and(cond1,cond2)

def body(inputs, caches, index , flags):
def body(inputs, caches, index , flags,cache_shape_torch=None):
if progress_print:

print('\r',index,end='')
Expand All @@ -529,7 +536,10 @@ def body(inputs, caches, index , flags):
position_bias = self.compute_cache_position_bias(self_cache_update_index = index)

for i in range(self.num_hidden_layers):

layer_caches = caches[i*j:i*j+j]
if backlib=='torch':
layer_caches[0]=ops.concatenate([layer_caches[0],ops.zeros(cache_shape_torch,dtype=layer_caches[0].dtype)],axis=2)
out=self.apply_main_cache_layers(z+[layer_caches], i,self_cache_update_index=index,
cross_cache_update_index=None,
attention_mask=attention_mask,
Expand All @@ -546,13 +556,18 @@ def body(inputs, caches, index , flags):
search_in = [o,index,inputs[key],flags]
inputs[key],flags = self.Search(search_in,k=k,mode=search_mode)
return (inputs, caches, index , flags)
num_hidden_layers = self.num_hidden_layers
class WhileLayer(keras.Layer):
def call(self, x):
inputs, caches, index = x[:]
flags = ops.ones([ops.shape(caches[0])[0],1],dtype='bool')
if backlib=='torch':
cache_shape_torch = list(ops.shape(caches[0]))
cache_shape_torch[2] = 1
for i in range(num_hidden_layers):
caches[i*j]=slices_index(caches[i*j],index,2)
while cond(inputs, caches, index , flags):
inputs, caches, index , flags = body(inputs, caches, index , flags)
inputs, caches, index , flags = body(inputs, caches, index , flags,cache_shape_torch)
return (inputs, caches, index)
outs=ops.while_loop(
cond,
Expand Down Expand Up @@ -3639,6 +3654,7 @@ def build_transformer_model(
model='bert',
application='encoder',
return_keras_model=True,
keras_weights_path=None,
**kwargs
):
"""根据配置文件构建模型,可选加载checkpoint权重
Expand Down Expand Up @@ -3718,7 +3734,19 @@ def build_transformer_model(
shape=[1 if t==None else t for t in shape]
inputs.append(np.zeros(shape,modelin.dtype))
transformer.model.predict(inputs,verbose=3)

if keras_weights_path is not None:
transformer.model.load_weights(keras_weights_path, skip_mismatch=True)
if lora_model:

def enable_lora(t):
if isinstance(t,keras.layers.Embedding) or isinstance(t,keras.layers.Dense):
t.enable_lora(True)
for layer in transformer.model.layers:
layer.trainable=False
enable_lora(layer)
for kid in dir (layer):
t = getattr(layer,kid)
enable_lora(t)
if checkpoint_path is not None:
transformer.load_weights_from_checkpoint(checkpoint_path)

Expand Down

0 comments on commit ac3843d

Please sign in to comment.