Skip to content

Commit

Permalink
add auto_dispatch as option
Browse files Browse the repository at this point in the history
  • Loading branch information
zigzagcai committed Aug 15, 2024
1 parent 9e05051 commit aa894cc
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 52 deletions.
4 changes: 3 additions & 1 deletion examples/internlm/internlm2_7b/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from functools import partial

from internlm.core.context import global_context as gpc
from internlm.core.trainer_builder import TrainerBuilder
from internlm.data import (
Expand Down Expand Up @@ -29,7 +31,7 @@ def main(args):
hf_config_initializer.register_module(gpc.config.model_type, InternLM2Config)

# initialize model
model = initialize_model(model_dispatch_func=hf_model_dispatch)
model = initialize_model(model_dispatch_func=partial(hf_model_dispatch, auto_dispatch=True))

# initialize train dataloader
train_dl, dataset_types = build_train_loader_with_data_type()
Expand Down
4 changes: 3 additions & 1 deletion examples/internlm/internlm_7b/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from functools import partial

from internlm.core.context import global_context as gpc
from internlm.core.trainer_builder import TrainerBuilder
from internlm.data import (
Expand All @@ -25,7 +27,7 @@ def main(args):
hf_config_initializer.register_module(gpc.config.model_type, InternLMConfig)

# initialize model
model = initialize_model(model_dispatch_func=hf_model_dispatch)
model = initialize_model(model_dispatch_func=partial(hf_model_dispatch, auto_dispatch=True))

# initialize train dataloader
train_dl, dataset_types = build_train_loader_with_data_type()
Expand Down
112 changes: 84 additions & 28 deletions huggingface_model/dispatch_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# adapted from https://github.com/InternLM/xtuner/blob/main/xtuner/model/modules/dispatch/__init__.py

import importlib
from collections import abc
from typing import Any, Optional, Type, Union

from internlm.core.context.parallel_context import global_context as gpc
from internlm.utils.logger import get_logger

logger = get_logger(__file__)


# adapted from https://github.com/open-mmlab/mmengine/blob/main/mmengine/config/lazy.py#L8
Expand Down Expand Up @@ -190,27 +192,23 @@ def is_seq_of(seq: Any, expected_type: Union[Type, tuple], seq_type: Type = None
return False
return True


EMBED_REPLACE_MAPPING = dict(
Embedding=LazyObject("internlm.model.modules.embedding", "Embedding1D"),
)

NORM_REPLACE_MAPPING = dict(
InternLMRMSNorm=LazyObject("internlm.model.modules.norm", "new_layer_norm"),
InternLM2RMSNorm=LazyObject("internlm.model.modules.norm", "new_layer_norm"),
)

LINEAR_REPLACE_MAPPING = dict(
Linear=LazyObject("internlm.model.modules.linear", "new_linear"),
)

NORM2NEW_NORM_NAME_MAPPING = dict(
input_layernorm="rmsnorm",
post_attention_layernorm="rmsnorm",
norm="rmsnorm",
attention_norm="rmsnorm",
ffn_norm="rmsnorm",

NORM_REPLACE_MAPPING = dict(
InternLMRMSNorm=LazyObject("internlm.model.modules.norm", "new_layer_norm"),
InternLM2RMSNorm=LazyObject("internlm.model.modules.norm", "new_layer_norm"),
)


LINEAR2NEW_LINEAR_NAME_MAPPING = dict(
q_proj="wq",
k_proj="wk",
Expand All @@ -223,8 +221,18 @@ def is_seq_of(seq: Any, expected_type: Union[Type, tuple], seq_type: Type = None
)


NORM2NEW_NORM_NAME_MAPPING = dict(
input_layernorm="rmsnorm",
post_attention_layernorm="rmsnorm",
norm="rmsnorm",
attention_norm="rmsnorm",
ffn_norm="rmsnorm",
)


RESET_PARAM_FUNC_MAPPING = dict(
internlm2_7b=LazyObject("huggingface_model.internlm.internlm2_7b", "reset_parameters"),
internlm_7b=LazyObject("huggingface_model.internlm.internlm_7b", "reset_parameters"),
)


Expand All @@ -247,6 +255,26 @@ def traverse(module):
traverse(model)


def replace_linear(model):
def traverse(module):
for name, child in module.named_children():
cls_name = type(child).__name__
if cls_name in LINEAR_REPLACE_MAPPING:
linear = LINEAR_REPLACE_MAPPING[cls_name]
linear = linear.build()
child_new = linear(
name=LINEAR2NEW_LINEAR_NAME_MAPPING.get(name, name),
in_features=child.in_features,
out_features=child.out_features,
bias=child.bias is not None,
).to(device=child.weight.device, dtype=child.weight.dtype)
setattr(module, name, child_new)
else:
traverse(child)

traverse(model)


def replace_norm(model):
def traverse(module):
for name, child in module.named_children():
Expand All @@ -266,33 +294,61 @@ def traverse(module):
traverse(model)


def replace_linear(model):
def check_embed(model):
def traverse(module):
for name, child in module.named_children():
cls_name = type(child).__name__
if cls_name in EMBED_REPLACE_MAPPING:
embed = EMBED_REPLACE_MAPPING[cls_name]
embed = embed.build()
logger.warning(f"{name} of type {cls_name} is suggested to be replaced with type {embed.__name__}")
else:
traverse(child)

traverse(model)


def check_linear(model):
def traverse(module):
for name, child in module.named_children():
cls_name = type(child).__name__
if cls_name in LINEAR_REPLACE_MAPPING:
linear = LINEAR_REPLACE_MAPPING[cls_name]
linear = linear.build()
child_new = linear(
name=LINEAR2NEW_LINEAR_NAME_MAPPING.get(name, name),
in_features=child.in_features,
out_features=child.out_features,
bias=child.bias is not None,
).to(device=child.weight.device, dtype=child.weight.dtype)
setattr(module, name, child_new)
logger.warning(f"{name} of type {cls_name} is suggested to be replaced with type {linear.__name__}")
else:
traverse(child)

traverse(model)


def check_norm(model):
def traverse(module):
for name, child in module.named_children():
cls_name = type(child).__name__
if cls_name in NORM_REPLACE_MAPPING:
norm = NORM_REPLACE_MAPPING[cls_name]
norm = norm.build()
logger.warning(f"{name} of type {cls_name} is suggested to be replaced with type {norm.__name__}")
else:
traverse(child)

traverse(model)


def hf_model_dispatch(model):
replace_embed(model)
replace_norm(model)
replace_linear(model)
reset_parameters = RESET_PARAM_FUNC_MAPPING.get(gpc.config.HF_MODEL_NAME.split("/")[1].replace("-", "_"), None)
assert reset_parameters is not None, "reset_parameters need to be implemented."
reset_parameters = reset_parameters.build()
reset_parameters(model)
def hf_model_dispatch(model, auto_dispatch=False):
if auto_dispatch:
replace_embed(model)
replace_linear(model)
replace_norm(model)
reset_parameters = RESET_PARAM_FUNC_MAPPING.get(gpc.config.HF_MODEL_NAME.split("/")[1].replace("-", "_"), None)
assert reset_parameters is not None, "In auto_dispatch mode, function reset_parameters need to be implemented."
reset_parameters = reset_parameters.build()
reset_parameters(model)
else:
check_embed(model)
check_linear(model)
check_norm(model)


__all__ = ["hf_model_dispatch"]
__all__ = ["hf_model_dispatch"]
39 changes: 17 additions & 22 deletions huggingface_model/internlm/internlm2_7b/__init__.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,36 @@
import torch
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal

from .configuration_internlm2 import InternLM2Config
from .modeling_internlm2 import InternLM2ForCausalLM

from internlm.initialize.initialize_tensor import (
normal_,
scaled_init_method_normal,
)

import torch

def reset_attn_parameters(layer_idx, layer, use_scaled_init=True):
def reset_attn_parameters(layer_idx, layer, use_scaled_init=True, std=0.02):
for name, param in layer.attention.named_parameters():
if param.ndim == 1:
if param.ndim == 1: # bias
param.data.zero_()
elif "wq" in name or "wk" in name or "wv" in name:
normal_(std=0.02)(param.data)
normal_(std=std)(param.data)
elif use_scaled_init: # wo
scaled_init_method_normal(sigma=0.02, num_layers=layer_idx + 1)(param.data)
else:
normal_(std=0.02)(param.data)
scaled_init_method_normal(sigma=std, num_layers=layer_idx + 1)(param.data)
else: # wo
normal_(std=std)(param.data)

for name, param in layer.feed_forward.named_parameters():
if use_scaled_init:
scaled_init_method_normal(sigma=0.02, num_layers=layer_idx + 1)(param.data)
scaled_init_method_normal(sigma=std, num_layers=layer_idx + 1)(param.data)
else:
normal_(std=0.02)(param.data)
normal_(std=std)(param.data)


def reset_parameters(model):
def reset_parameters(model, std=0.02):
with torch.no_grad():
for _, param in model.model.tok_embeddings.named_parameters():
normal_(std=0.02)(param)
normal_(std=std)(param)
for layer_idx, layer in enumerate(model.model.layers):
reset_attn_parameters(layer_idx, layer)
for _, param in model.output.named_parameters():
normal_(std=0.02)(param)
normal_(std=std)(param)


__all__ = [
"InternLM2Config",
"InternLM2ForCausalLM",
"reset_parameters"
]
__all__ = ["InternLM2Config", "InternLM2ForCausalLM", "reset_parameters"]
33 changes: 33 additions & 0 deletions huggingface_model/internlm/internlm_7b/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,40 @@
import torch
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal

from .configuration_internlm import InternLMConfig
from .modeling_internlm import InternLMForCausalLM


def reset_attn_parameters(layer_idx, layer, use_scaled_init=True, std=0.02):
for name, param in layer.self_attn.named_parameters():
if param.ndim == 1: # bias
param.data.zero_()
elif "q_proj" in name or "k_proj" in name or "v_proj" in name:
normal_(std=std)(param.data)
elif use_scaled_init: # wo
scaled_init_method_normal(sigma=std, num_layers=layer_idx + 1)(param.data)
else: # wo
normal_(std=std)(param.data)

for name, param in layer.mlp.named_parameters():
if use_scaled_init:
scaled_init_method_normal(sigma=std, num_layers=layer_idx + 1)(param.data)
else:
normal_(std=std)(param.data)


def reset_parameters(model, std=0.02):
with torch.no_grad():
for _, param in model.model.embed_tokens.named_parameters():
normal_(std=std)(param)
for layer_idx, layer in enumerate(model.model.layers):
reset_attn_parameters(layer_idx, layer)
for _, param in model.lm_head.named_parameters():
normal_(std=std)(param)


__all__ = [
"InternLMConfig",
"InternLMForCausalLM",
"reset_parameters"
]

0 comments on commit aa894cc

Please sign in to comment.