Skip to content

Commit

Permalink
fix flash att in generate
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Apr 30, 2024
1 parent e5e36fc commit 68bc61a
Showing 1 changed file with 2 additions and 9 deletions.
11 changes: 2 additions & 9 deletions wenet/LLM/script/llama3_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ def convert_to_wenet_state_dict(Llama3_state_dict, wenet_state_dict_path,
old_name = name
# embed
name = name.replace('tok_embeddings.weight', 'embed.weight')

# output
name = name.replace('output.weight', 'out.weight')
# layers to decoders
name = name.replace('layers', 'decoder.decoders')

# final norm weight
name = name.replace('norm.weight', 'decoder.final_norm.weight')
if 'attention' in name:
# pre ln (rms norm)
name = name.replace('attention_norm', 'norm1')
Expand All @@ -56,21 +56,14 @@ def convert_to_wenet_state_dict(Llama3_state_dict, wenet_state_dict_path,
'.self_attn.linear_v.weight')
# att out dim
name = name.replace('attention.wo', 'self_attn.linear_out')
elif name == 'norm_weight':
name = name.replace('norm_weight', 'decoder.final_norm.weight')
else:

# mlp
name = name.replace('feed_forward.w1', 'feed_forward.gate')
name = name.replace('feed_forward.w3', 'feed_forward.w_1')
name = name.replace('feed_forward.w2', 'feed_forward.w_2')

# before mlp ln: (rms norm)
name = name.replace('ffn_norm', 'norm2')
# final norm
name = name.replace('model.norm.weight',
'decoder.final_norm.weight')

wenet_state_dict[name] = conformer_state_dict[old_name]
print("Saving {} ckpt to {}...".format(config.dtype,
wenet_state_dict_path))
Expand Down

0 comments on commit 68bc61a

Please sign in to comment.