From d4a0fd365c242191d0eb273cc9bfc8fb61cba8f5 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Wed, 14 Aug 2024 17:52:39 +0800 Subject: [PATCH] add tmp reset logic --- huggingface_model/dispatch_utils/__init__.py | 63 ++++++++++++++++---- 1 file changed, 53 insertions(+), 10 deletions(-) diff --git a/huggingface_model/dispatch_utils/__init__.py b/huggingface_model/dispatch_utils/__init__.py index 5f1b3f7..3454502 100644 --- a/huggingface_model/dispatch_utils/__init__.py +++ b/huggingface_model/dispatch_utils/__init__.py @@ -4,6 +4,16 @@ from collections import abc from typing import Any, Optional, Type, Union +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc + +from internlm.initialize.initialize_tensor import ( + normal_, + scaled_init_method_normal, +) + +import torch + # adapted from https://github.com/open-mmlab/mmengine/blob/main/mmengine/config/lazy.py#L8 class LazyObject: """LazyObject is used to lazily initialize the imported module during @@ -221,8 +231,8 @@ def is_seq_of(seq: Any, expected_type: Union[Type, tuple], seq_type: Type = None def replace_embed(model): - def traverse(module): - for name, child in module.named_children(): + def traverse(model): + for name, child in model.named_children(): cls_name = type(child).__name__ if cls_name in EMBED_REPLACE_MAPPING: embed = EMBED_REPLACE_MAPPING[cls_name] @@ -232,7 +242,7 @@ def traverse(module): embedding_dim=child.embedding_dim, padding_idx=child.padding_idx, ).to(device=child.weight.device, dtype=child.weight.dtype) - setattr(module, name, child_new) + setattr(model, name, child_new) else: traverse(child) @@ -240,8 +250,8 @@ def traverse(module): def replace_norm(model): - def traverse(module): - for name, child in module.named_children(): + def traverse(model): + for name, child in model.named_children(): cls_name = type(child).__name__ if cls_name in NORM_REPLACE_MAPPING: norm = NORM_REPLACE_MAPPING[cls_name] @@ -251,7 +261,7 @@ def traverse(module): normalized_shape=child.weight.shape, eps=child.variance_epsilon, ).to(device=child.weight.device, dtype=child.weight.dtype) - setattr(module, name, child_new) + setattr(model, name, child_new) else: traverse(child) @@ -259,8 +269,8 @@ def traverse(module): def replace_linear(model): - def traverse(module): - for name, child in module.named_children(): + def traverse(model): + for name, child in model.named_children(): cls_name = type(child).__name__ if cls_name in LINEAR_REPLACE_MAPPING: linear = LINEAR_REPLACE_MAPPING[cls_name] @@ -271,17 +281,50 @@ def traverse(module): out_features=child.out_features, bias=child.bias is not None, ).to(device=child.weight.device, dtype=child.weight.dtype) - setattr(module, name, child_new) + setattr(model, name, child_new) else: traverse(child) traverse(model) +def reset_attn_parameters(layer_idx, layer, use_swiglu=False, use_scaled_init=True): + for name, param in layer.attention.named_parameters(): + if param.ndim == 1: + param.data.zero_() + elif "wq" in name or "wk" in name or "wv" in name: + normal_(std=0.02)(param.data) + elif use_scaled_init: # wo + scaled_init_method_normal(sigma=0.02, num_layers=layer_idx + 1)(param.data) + else: + normal_(std=0.02)(param.data) + + for name, param in layer.feed_forward.named_parameters(): + if use_swiglu: + if use_scaled_init and "w2" in name: + scaled_init_method_normal(sigma=0.02, num_layers=layer_idx + 1)(param.data) + else: + # candidate: w1, w3, fused_w1_w3 + normal_(std=0.02)(param.data) + else: + if use_scaled_init and "fc1" not in name: + scaled_init_method_normal(sigma=0.02, num_layers=layer_idx + 1)(param.data) + else: + normal_(std=0.02)(param.data) + +def reset_parameters(model): + with torch.no_grad(): + for _, param in model.model.tok_embeddings.named_parameters(): + normal_(std=0.02)(param) + for layer_idx, layer in enumerate(model.model.layers): + reset_attn_parameters(layer_idx, layer) + for _, param in model.output.named_parameters(): + normal_(std=0.02)(param) + def hf_model_dispatch(model): replace_embed(model) replace_norm(model) replace_linear(model) - + reset_parameters(model) __all__ = ["hf_model_dispatch"] \ No newline at end of file