Skip to content

Commit

Permalink
add split_weights for internlm2
Browse files Browse the repository at this point in the history
  • Loading branch information
sallyjunjun committed Aug 14, 2024
1 parent 6bfd957 commit cd228dc
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 7 deletions.
59 changes: 55 additions & 4 deletions huggingface_model/internlm/internlm2_7b/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
)
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.ops.attention import isp_flash_attn_varlen_func, isp_flash_attn_func
from internlm.model.ops.attention import hf_q_k_v_with_cu_seqlens, hf_q_k_v_without_cu_seqlens

try:
from transformers.generation.streamers import BaseStreamer
Expand Down Expand Up @@ -485,7 +485,7 @@ def forward(
# )

if use_packed_dataset:
attn_output = isp_flash_attn_varlen_func(
attn_output = hf_q_k_v_with_cu_seqlens(
query_states,
key_states,
value_states,
Expand All @@ -495,7 +495,7 @@ def forward(
attention_dropout = dropout_rate,
)
else:
attn_output = isp_flash_attn_func(
attn_output = hf_q_k_v_without_cu_seqlens(
query_states, key_states, value_states, causal=True, attention_dropout=dropout_rate,
)

Expand Down Expand Up @@ -1178,6 +1178,57 @@ def set_decoder(self, decoder):
def get_decoder(self):
return self.model

def split_weights(self, first_layer, model_state_dict, state_dict, split_size, local_rank, row_dim):
for i in range(0, gpc.config.model.num_layers):
model_state_dict[f"model.layers.{i}.attention.wqkv.weight"] = torch.chunk(
state_dict.pop(f"model.layers.{i+first_layer}.attention.wqkv.weight"),
split_size,
dim=0,
)[local_rank]
model_state_dict[f"model.layers.{i}.attention.wo.weight"] = torch.chunk(
state_dict.pop(f"model.layers.{i+first_layer}.attention.wo.weight"),
split_size,
dim=row_dim,
)[local_rank]
model_state_dict[f"model.layers.{i}.feed_forward.w1.weight"] = torch.chunk(
state_dict.pop(f"model.layers.{i+first_layer}.feed_forward.w1.weight"),
split_size,
dim=0,
)[local_rank]
model_state_dict[f"model.layers.{i}.feed_forward.w3.weight"] = torch.chunk(
state_dict.pop(f"model.layers.{i+first_layer}.feed_forward.w3.weight"),
split_size,
dim=0,
)[local_rank]
model_state_dict[f"model.layers.{i}.feed_forward.w2.weight"] = torch.chunk(
state_dict.pop(f"model.layers.{i+first_layer}.feed_forward.w2.weight"),
split_size,
dim=row_dim,
)[local_rank]
model_state_dict[f"model.layers.{i}.attention_norm.weight"] = state_dict.pop(
f"model.layers.{i+first_layer}.attention_norm.weight"
)
model_state_dict[f"model.layers.{i}.ffn_norm.weight"] = state_dict.pop(
f"model.layers.{i+first_layer}.ffn_norm.weight"
)

if (gpc.get_local_rank(ParallelMode.PIPELINE) - 1 == 0) or (not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)):
model_state_dict[f"model.tok_embeddings.weight"] = torch.chunk(
state_dict.pop(f"model.tok_embeddings.weight"),
split_size,
dim=1,
)[local_rank]

if gpc.is_last_rank(ParallelMode.PIPELINE):
model_state_dict[f"output.weight"] = torch.chunk(
state_dict.pop(f"output.weight"),
split_size,
dim=0,
)[local_rank]
model_state_dict[f"model.norm.weight"] = state_dict[f"model.norm.weight"]

return model_state_dict

@add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
Expand Down Expand Up @@ -1823,4 +1874,4 @@ def forward(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
)
6 changes: 3 additions & 3 deletions huggingface_model/internlm/internlm_7b/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.ops.rotary_emb import apply_rotary_emb
from internlm.model.ops.attention import isp_flash_attn_varlen_func, isp_flash_attn_func
from internlm.model.ops.attention import hf_q_k_v_with_cu_seqlens, hf_q_k_v_without_cu_seqlens


try:
Expand Down Expand Up @@ -483,7 +483,7 @@ def forward(
# )

if use_packed_dataset:
attn_output = isp_flash_attn_varlen_func(
attn_output = hf_q_k_v_with_cu_seqlens(
query_states,
key_states,
value_states,
Expand All @@ -492,7 +492,7 @@ def forward(
causal=True,
)
else:
attn_output = isp_flash_attn_func(
attn_output = hf_q_k_v_without_cu_seqlens(
query_states, key_states, value_states, causal=True,
)

Expand Down

0 comments on commit cd228dc

Please sign in to comment.