Skip to content

Commit

Permalink
fix llama3
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Apr 30, 2024
1 parent e5e36fc commit 0e81840
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions wenet/LLM/script/llama3_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,13 @@ def convert_to_wenet_state_dict(Llama3_state_dict, wenet_state_dict_path,
old_name = name
# embed
name = name.replace('tok_embeddings.weight', 'embed.weight')

# output
name = name.replace('output.weight', 'out.weight')
# layers to decoders
name = name.replace('layers', 'decoder.decoders')

if 'attention' in name:
# pre ln (rms norm)
name = name.replace('attention_norm', 'norm1')
name = name.replace('attention_norm.weight', 'norm1.weight')
# att weight
name = name.replace('.attention.wq.weight',
'.self_attn.linear_q.weight')
Expand All @@ -56,22 +54,18 @@ def convert_to_wenet_state_dict(Llama3_state_dict, wenet_state_dict_path,
'.self_attn.linear_v.weight')
# att out dim
name = name.replace('attention.wo', 'self_attn.linear_out')
elif name == 'norm_weight':
name = name.replace('norm_weight', 'decoder.final_norm.weight')
else:

# mlp
name = name.replace('feed_forward.w1', 'feed_forward.gate')
name = name.replace('feed_forward.w3', 'feed_forward.w_1')
name = name.replace('feed_forward.w2', 'feed_forward.w_2')

# before mlp ln: (rms norm)
name = name.replace('ffn_norm', 'norm2')
# final norm
name = name.replace('model.norm.weight',
'decoder.final_norm.weight')

wenet_state_dict[name] = conformer_state_dict[old_name]
# final norm weight
wenet_state_dict['decoder.final_norm.weight'] = conformer_state_dict[
'norm.weight']
print("Saving {} ckpt to {}...".format(config.dtype,
wenet_state_dict_path))
torch.save(wenet_state_dict, wenet_state_dict_path)
Expand Down

0 comments on commit 0e81840

Please sign in to comment.