From cd228dc0ad0f96ecbbb3abc24cae186ca7f0d764 Mon Sep 17 00:00:00 2001 From: sallyjunjun Date: Wed, 14 Aug 2024 19:07:20 +0800 Subject: [PATCH] add split_weights for internlm2 --- .../internlm2_7b/modeling_internlm2.py | 59 +++++++++++++++++-- .../internlm/internlm_7b/modeling_internlm.py | 6 +- 2 files changed, 58 insertions(+), 7 deletions(-) diff --git a/huggingface_model/internlm/internlm2_7b/modeling_internlm2.py b/huggingface_model/internlm/internlm2_7b/modeling_internlm2.py index 23b0f32..a1b0bd6 100644 --- a/huggingface_model/internlm/internlm2_7b/modeling_internlm2.py +++ b/huggingface_model/internlm/internlm2_7b/modeling_internlm2.py @@ -46,7 +46,7 @@ ) from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.ops.attention import isp_flash_attn_varlen_func, isp_flash_attn_func +from internlm.model.ops.attention import hf_q_k_v_with_cu_seqlens, hf_q_k_v_without_cu_seqlens try: from transformers.generation.streamers import BaseStreamer @@ -485,7 +485,7 @@ def forward( # ) if use_packed_dataset: - attn_output = isp_flash_attn_varlen_func( + attn_output = hf_q_k_v_with_cu_seqlens( query_states, key_states, value_states, @@ -495,7 +495,7 @@ def forward( attention_dropout = dropout_rate, ) else: - attn_output = isp_flash_attn_func( + attn_output = hf_q_k_v_without_cu_seqlens( query_states, key_states, value_states, causal=True, attention_dropout=dropout_rate, ) @@ -1178,6 +1178,57 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + def split_weights(self, first_layer, model_state_dict, state_dict, split_size, local_rank, row_dim): + for i in range(0, gpc.config.model.num_layers): + model_state_dict[f"model.layers.{i}.attention.wqkv.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{i+first_layer}.attention.wqkv.weight"), + split_size, + dim=0, + )[local_rank] + model_state_dict[f"model.layers.{i}.attention.wo.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{i+first_layer}.attention.wo.weight"), + split_size, + dim=row_dim, + )[local_rank] + model_state_dict[f"model.layers.{i}.feed_forward.w1.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{i+first_layer}.feed_forward.w1.weight"), + split_size, + dim=0, + )[local_rank] + model_state_dict[f"model.layers.{i}.feed_forward.w3.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{i+first_layer}.feed_forward.w3.weight"), + split_size, + dim=0, + )[local_rank] + model_state_dict[f"model.layers.{i}.feed_forward.w2.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{i+first_layer}.feed_forward.w2.weight"), + split_size, + dim=row_dim, + )[local_rank] + model_state_dict[f"model.layers.{i}.attention_norm.weight"] = state_dict.pop( + f"model.layers.{i+first_layer}.attention_norm.weight" + ) + model_state_dict[f"model.layers.{i}.ffn_norm.weight"] = state_dict.pop( + f"model.layers.{i+first_layer}.ffn_norm.weight" + ) + + if (gpc.get_local_rank(ParallelMode.PIPELINE) - 1 == 0) or (not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)): + model_state_dict[f"model.tok_embeddings.weight"] = torch.chunk( + state_dict.pop(f"model.tok_embeddings.weight"), + split_size, + dim=1, + )[local_rank] + + if gpc.is_last_rank(ParallelMode.PIPELINE): + model_state_dict[f"output.weight"] = torch.chunk( + state_dict.pop(f"output.weight"), + split_size, + dim=0, + )[local_rank] + model_state_dict[f"model.norm.weight"] = state_dict[f"model.norm.weight"] + + return model_state_dict + @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1823,4 +1874,4 @@ def forward( logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - ) \ No newline at end of file + ) diff --git a/huggingface_model/internlm/internlm_7b/modeling_internlm.py b/huggingface_model/internlm/internlm_7b/modeling_internlm.py index ada307b..3450a58 100644 --- a/huggingface_model/internlm/internlm_7b/modeling_internlm.py +++ b/huggingface_model/internlm/internlm_7b/modeling_internlm.py @@ -40,7 +40,7 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.model.ops.rotary_emb import apply_rotary_emb -from internlm.model.ops.attention import isp_flash_attn_varlen_func, isp_flash_attn_func +from internlm.model.ops.attention import hf_q_k_v_with_cu_seqlens, hf_q_k_v_without_cu_seqlens try: @@ -483,7 +483,7 @@ def forward( # ) if use_packed_dataset: - attn_output = isp_flash_attn_varlen_func( + attn_output = hf_q_k_v_with_cu_seqlens( query_states, key_states, value_states, @@ -492,7 +492,7 @@ def forward( causal=True, ) else: - attn_output = isp_flash_attn_func( + attn_output = hf_q_k_v_without_cu_seqlens( query_states, key_states, value_states, causal=True, )