Skip to content

Commit

Permalink
add tmp reset logic
Browse files Browse the repository at this point in the history
  • Loading branch information
zigzagcai committed Aug 14, 2024
1 parent 6bfd957 commit d4a0fd3
Showing 1 changed file with 53 additions and 10 deletions.
63 changes: 53 additions & 10 deletions huggingface_model/dispatch_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@
from collections import abc
from typing import Any, Optional, Type, Union

from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc

from internlm.initialize.initialize_tensor import (
normal_,
scaled_init_method_normal,
)

import torch

# adapted from https://github.com/open-mmlab/mmengine/blob/main/mmengine/config/lazy.py#L8
class LazyObject:
"""LazyObject is used to lazily initialize the imported module during
Expand Down Expand Up @@ -221,8 +231,8 @@ def is_seq_of(seq: Any, expected_type: Union[Type, tuple], seq_type: Type = None


def replace_embed(model):
def traverse(module):
for name, child in module.named_children():
def traverse(model):
for name, child in model.named_children():
cls_name = type(child).__name__
if cls_name in EMBED_REPLACE_MAPPING:
embed = EMBED_REPLACE_MAPPING[cls_name]
Expand All @@ -232,16 +242,16 @@ def traverse(module):
embedding_dim=child.embedding_dim,
padding_idx=child.padding_idx,
).to(device=child.weight.device, dtype=child.weight.dtype)
setattr(module, name, child_new)
setattr(model, name, child_new)
else:
traverse(child)

traverse(model)


def replace_norm(model):
def traverse(module):
for name, child in module.named_children():
def traverse(model):
for name, child in model.named_children():
cls_name = type(child).__name__
if cls_name in NORM_REPLACE_MAPPING:
norm = NORM_REPLACE_MAPPING[cls_name]
Expand All @@ -251,16 +261,16 @@ def traverse(module):
normalized_shape=child.weight.shape,
eps=child.variance_epsilon,
).to(device=child.weight.device, dtype=child.weight.dtype)
setattr(module, name, child_new)
setattr(model, name, child_new)
else:
traverse(child)

traverse(model)


def replace_linear(model):
def traverse(module):
for name, child in module.named_children():
def traverse(model):
for name, child in model.named_children():
cls_name = type(child).__name__
if cls_name in LINEAR_REPLACE_MAPPING:
linear = LINEAR_REPLACE_MAPPING[cls_name]
Expand All @@ -271,17 +281,50 @@ def traverse(module):
out_features=child.out_features,
bias=child.bias is not None,
).to(device=child.weight.device, dtype=child.weight.dtype)
setattr(module, name, child_new)
setattr(model, name, child_new)
else:
traverse(child)

traverse(model)


def reset_attn_parameters(layer_idx, layer, use_swiglu=False, use_scaled_init=True):
for name, param in layer.attention.named_parameters():
if param.ndim == 1:
param.data.zero_()
elif "wq" in name or "wk" in name or "wv" in name:
normal_(std=0.02)(param.data)
elif use_scaled_init: # wo
scaled_init_method_normal(sigma=0.02, num_layers=layer_idx + 1)(param.data)
else:
normal_(std=0.02)(param.data)

for name, param in layer.feed_forward.named_parameters():
if use_swiglu:
if use_scaled_init and "w2" in name:
scaled_init_method_normal(sigma=0.02, num_layers=layer_idx + 1)(param.data)
else:
# candidate: w1, w3, fused_w1_w3
normal_(std=0.02)(param.data)
else:
if use_scaled_init and "fc1" not in name:
scaled_init_method_normal(sigma=0.02, num_layers=layer_idx + 1)(param.data)
else:
normal_(std=0.02)(param.data)

def reset_parameters(model):
with torch.no_grad():
for _, param in model.model.tok_embeddings.named_parameters():
normal_(std=0.02)(param)
for layer_idx, layer in enumerate(model.model.layers):
reset_attn_parameters(layer_idx, layer)
for _, param in model.output.named_parameters():
normal_(std=0.02)(param)

def hf_model_dispatch(model):
replace_embed(model)
replace_norm(model)
replace_linear(model)

reset_parameters(model)

__all__ = ["hf_model_dispatch"]

0 comments on commit d4a0fd3

Please sign in to comment.