diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..19cd7c8 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,53 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: +- repo: https://github.com/psf/black + rev: '22.8.0' + hooks: + - id: black + args: + - --line-length=120 +- repo: https://github.com/pycqa/isort + rev: '5.12.0' + hooks: + - id: isort + name: isort + files: "\\.(py)$" + args: + - --profile=black +- repo: https://github.com/PyCQA/flake8 + rev: '3.8.4' + hooks: + - id: flake8 + args: + - --ignore=F403,F405,W504,W503,E203 + - --max-line-length=120 +- repo: https://github.com/pre-commit/pygrep-hooks + rev: v1.9.0 + hooks: + - id: python-check-blanket-noqa +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.3.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-added-large-files + args: ['--maxkb=100',--enforce-all] + - id: check-json + - id: check-docstring-first + - id: check-yaml + - id: debug-statements + - id: mixed-line-ending +- repo: https://github.com/PyCQA/pylint/ + rev: v2.17.2 + hooks: + - id: pylint + name: pylint + entry: pylint + language: system + types: [python] + args: + [ + '--rcfile=.pylintrc', + '--disable=C0114,C0415,W0212,W0235,W0238,W0621,C0103,R1735,C2801,E0402,C0412,W0719,R1728,W1514,W0718,W0105,W0707,C0209,W0703,W1203' + ] \ No newline at end of file diff --git a/examples/internlm/internlm2_7b/train.py b/examples/internlm/internlm2_7b/train.py index 8774b9f..fefe3a6 100644 --- a/examples/internlm/internlm2_7b/train.py +++ b/examples/internlm/internlm2_7b/train.py @@ -13,6 +13,7 @@ from internlm.train import initialize_model from internlm.utils.common import parse_args +from huggingface_model.dispatch_utils import hf_model_dispatch from huggingface_model.internlm.internlm2_7b.configuration_internlm2 import ( InternLM2Config, ) @@ -28,7 +29,7 @@ def main(args): hf_config_initializer.register_module(gpc.config.model_type, InternLM2Config) # initialize model - model = initialize_model() + model = initialize_model(model_dispatch_func=hf_model_dispatch) # initialize train dataloader train_dl, dataset_types = build_train_loader_with_data_type() diff --git a/examples/internlm/internlm_7b/train.py b/examples/internlm/internlm_7b/train.py index 1cde856..afefc8d 100644 --- a/examples/internlm/internlm_7b/train.py +++ b/examples/internlm/internlm_7b/train.py @@ -13,6 +13,7 @@ from internlm.train import initialize_model from internlm.utils.common import parse_args +from huggingface_model.dispatch_utils import hf_model_dispatch from huggingface_model.internlm.internlm_7b.configuration_internlm import InternLMConfig from huggingface_model.internlm.internlm_7b.modeling_internlm import InternLMForCausalLM @@ -24,7 +25,7 @@ def main(args): hf_config_initializer.register_module(gpc.config.model_type, InternLMConfig) # initialize model - model = initialize_model() + model = initialize_model(model_dispatch_func=hf_model_dispatch) # initialize train dataloader train_dl, dataset_types = build_train_loader_with_data_type() diff --git a/huggingface_model/README.md b/huggingface_model/README.md new file mode 100644 index 0000000..8a9982a --- /dev/null +++ b/huggingface_model/README.md @@ -0,0 +1,115 @@ +# Adapting HuggingFace Models for InternEvo Packed and ISP Training + +## Background + +When HuggingFace models are being integrated with the InternEvo framework, we want packed training and ISP be supproted to: +1. Improve GPU computation utilization (reduce wasting computation on meaningless padded tokens) +2. Support training with long sequences (use the latest parallel techniques from InternEvo framework) + +This requires adapting the models to support: +1. Packed training +2. ISP (Intern Sequence Parallelism) training + +## Supporting Packed Training + +### Example for modeling_internlm.py + +Step 1. Obtain `cu_seqlens` and `max_seqlen` from `gpc` for the current batch. + +```python +use_packed_dataset = gpc.config.data.get("use_packed_dataset", False) + +if use_packed_dataset: + assert bsz == 1, "hidden_states should be packed into bsz=1 when use_packed_dataset=True" + cu_seqlens = gpc.config.data[f"cu_seqlens_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] + max_seqlen = gpc.config.data[f"max_seqlen_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] +``` + +Optional Step 2. If the rotary embedding logic cannot meet the requirement of packed training, please use InternEvo `apply_rotary_emb`. +Otherwise, just use the original logic and skip this step. + +```python +if use_packed_dataset: + cos, sin = self.rotary_emb(value_states, max_seqlen) + cos = cos[position_ids].squeeze(0) + sin = sin[position_ids].squeeze(0) + assert sin.shape == cos.shape, "cos and sin must have the same shape" + _, rotary_dim = cos.shape + rotary_dim_half = rotary_dim // 2 + cos_half = cos[:q_len, :rotary_dim_half] + sin_half = sin[:q_len, :rotary_dim_half] + query_states = apply_rotary_emb(query_states, cos_half, sin_half) + key_states = apply_rotary_emb(key_states, cos_half, sin_half) +``` + +Step 3. Pass `cu_seqlens` and `max_seqlen` to flash attention varlen kernel for variable-length attention calculation. + +```python +if use_packed_dataset: + attn_output = isp_flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + causal=True, + ) +``` + + +### Example for modeling_internlm2.py + +Step 1. Obtain `cu_seqlens` and `max_seqlen` from gpc for the current batch. + +```python +use_packed_dataset = gpc.config.data.get("use_packed_dataset", False) + +if use_packed_dataset: + assert bsz == 1, "hidden_states should be packed into bsz=1 when use_packed_dataset=True" + cu_seqlens = gpc.config.data[f"cu_seqlens_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] + max_seqlen = gpc.config.data[f"max_seqlen_data_rank{gpc.get_local_rank(ParallelMode.DATA)}"] +``` + +Step 2. Pass `cu_seqlens` and `max_seqlen` to flash attention varlen kernel for variable-length attention calculation. + +```python +if use_packed_dataset: + attn_output = isp_flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + causal=True, + ) +``` + + +## Supporting ISP Training + +### Automatic dispatch + +For simplicity, you can just create model with `hf_model_dispatch` like that: + +``` +model = initialize_model(model_dispatch_func=hf_model_dispatch) +``` + +And you can also modify `huggingface_model/dispatch_utils/__init__.py` to add custom patterns for automatic dispatch. + +For the config, you need to set ISP size like that: + +```python +parallel = dict( + zero1=dict(size=-1), + tensor=dict(size=2, mode="isp"), + pipeline=dict(size=1, interleaved_overlap=True), + weight=dict(size=1, overlap=False, memory_pool=True), +) +``` + +- Set `tensor` size and mode for ISP. + +### Manual code adaption dispatch + +T.B.A. \ No newline at end of file diff --git a/huggingface_model/dispatch_utils/__init__.py b/huggingface_model/dispatch_utils/__init__.py new file mode 100644 index 0000000..5f1b3f7 --- /dev/null +++ b/huggingface_model/dispatch_utils/__init__.py @@ -0,0 +1,287 @@ +# adapted from https://github.com/InternLM/xtuner/blob/main/xtuner/model/modules/dispatch/__init__.py + +import importlib +from collections import abc +from typing import Any, Optional, Type, Union + +# adapted from https://github.com/open-mmlab/mmengine/blob/main/mmengine/config/lazy.py#L8 +class LazyObject: + """LazyObject is used to lazily initialize the imported module during + parsing the configuration file. + During parsing process, the syntax like: + Examples: + >>> import torch.nn as nn + >>> from mmdet.models import RetinaNet + >>> import mmcls.models + >>> import mmcls.datasets + >>> import mmcls + Will be parsed as: + Examples: + >>> # import torch.nn as nn + >>> nn = lazyObject('torch.nn') + >>> # from mmdet.models import RetinaNet + >>> RetinaNet = lazyObject('mmdet.models', 'RetinaNet') + >>> # import mmcls.models; import mmcls.datasets; import mmcls + >>> mmcls = lazyObject(['mmcls', 'mmcls.datasets', 'mmcls.models']) + ``LazyObject`` records all module information and will be further + referenced by the configuration file. + Args: + module (str or list or tuple): The module name to be imported. + imported (str, optional): The imported module name. Defaults to None. + location (str, optional): The filename and line number of the imported + module statement happened. + """ + + def __init__(self, module: Union[str, list, tuple], imported: Optional[str] = None, location: Optional[str] = None): + if not isinstance(module, str) and not is_seq_of(module, str): + raise TypeError( + "module should be `str`, `list`, or `tuple`" + f"but got {type(module)}, this might be " + "a bug of MMEngine, please report it to " + "https://github.com/open-mmlab/mmengine/issues" + ) + self._module: Union[str, list, tuple] = module + + if not isinstance(imported, str) and imported is not None: + raise TypeError( + "imported should be `str` or None, but got " + f"{type(imported)}, this might be " + "a bug of MMEngine, please report it to " + "https://github.com/open-mmlab/mmengine/issues" + ) + self._imported = imported + self.location = location + + def build(self) -> Any: + if isinstance(self._module, str): + try: + module = importlib.import_module(self._module) + except Exception as e: + raise type(e)(f"Failed to import {self._module} " f"in {self.location} for {e}") + + if self._imported is not None: + if hasattr(module, self._imported): + module = getattr(module, self._imported) + else: + raise ImportError(f"Failed to import {self._imported} " f"from {self._module} in {self.location}") + + return module + else: + try: + for module in self._module: + importlib.import_module(module) # type: ignore + module_name = self._module[0].split(".")[0] + return importlib.import_module(module_name) + except Exception as e: + raise type(e)(f"Failed to import {self.module} " f"in {self.location} for {e}") + + @property + def module(self): + if isinstance(self._module, str): + return self._module + return self._module[0].split(".")[0] + + def __call__(self, *args, **kwargs): + raise RuntimeError() + + def __deepcopy__(self, memo): + return LazyObject(self._module, self._imported, self.location) + + def __getattr__(self, name): + if self.location is not None: + location = self.location.split(", line")[0] + else: + location = self.location + return LazyAttr(name, self, location) + + def __str__(self) -> str: + if self._imported is not None: + return self._imported + return self.module + + __repr__ = __str__ + + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, state): + self.__dict__ = state + + +# adapted from https://github.com/open-mmlab/mmengine/blob/main/mmengine/config/lazy.py#L135 +class LazyAttr: + """The attribute of the LazyObject. + When parsing the configuration file, the imported syntax will be + parsed as the assignment ``LazyObject``. During the subsequent parsing + process, users may reference the attributes of the LazyObject. + To ensure that these attributes also contain information needed to + reconstruct the attribute itself, LazyAttr was introduced. + Examples: + >>> models = LazyObject(['mmdet.models']) + >>> model = dict(type=models.RetinaNet) + >>> print(type(model['type'])) # + >>> print(model['type'].build()) # + """ # noqa: E501 + + def __init__(self, name: str, source: Union["LazyObject", "LazyAttr"], location=None): + self.name = name + self.source: Union[LazyAttr, LazyObject] = source + + if isinstance(self.source, LazyObject): + if isinstance(self.source._module, str): + if self.source._imported is None: + self._module = self.source._module + else: + self._module = f"{self.source._module}.{self.source}" + else: + self._module = str(self.source) + elif isinstance(self.source, LazyAttr): + self._module = f"{self.source._module}.{self.source.name}" + self.location = location + + @property + def module(self): + return self._module + + def __call__(self, *args, **kwargs: Any) -> Any: + raise RuntimeError() + + def __getattr__(self, name: str) -> "LazyAttr": + return LazyAttr(name, self) + + def __deepcopy__(self, memo): + return LazyAttr(self.name, self.source) + + def build(self) -> Any: + obj = self.source.build() + try: + return getattr(obj, self.name) + except AttributeError: + raise ImportError(f"Failed to import {self.module}.{self.name} in " f"{self.location}") + except ImportError as e: + raise e + + def __str__(self) -> str: + return self.name + + __repr__ = __str__ + + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, state): + self.__dict__ = state + + +# adapt from https://github.com/open-mmlab/mmengine/blob/main/mmengine/utils/misc.py#L132 +def is_seq_of(seq: Any, expected_type: Union[Type, tuple], seq_type: Type = None) -> bool: + if seq_type is None: + exp_seq_type = abc.Sequence + else: + assert isinstance(seq_type, type) + exp_seq_type = seq_type + if not isinstance(seq, exp_seq_type): + return False + for item in seq: + if not isinstance(item, expected_type): + return False + return True + +EMBED_REPLACE_MAPPING = dict( + Embedding=LazyObject("internlm.model.modules.embedding", "Embedding1D"), +) + +NORM_REPLACE_MAPPING = dict( + InternLMRMSNorm=LazyObject("internlm.model.modules.norm", "new_layer_norm"), + InternLM2RMSNorm=LazyObject("internlm.model.modules.norm", "new_layer_norm"), +) + +LINEAR_REPLACE_MAPPING = dict( + Linear=LazyObject("internlm.model.modules.linear", "new_linear"), +) + +NORM2NEW_NORM_NAME_MAPPING = dict( + input_layernorm="rmsnorm", + post_attention_layernorm="rmsnorm", + norm="rmsnorm", + attention_norm="rmsnorm", + ffn_norm="rmsnorm", +) + +LINEAR2NEW_LINEAR_NAME_MAPPING = dict( + q_proj="wq", + k_proj="wk", + v_proj="wv", + o_proj="wo", + gate_proj="w1", + down_proj="w2", + up_proj="w3", + lm_head="head", +) + + +def replace_embed(model): + def traverse(module): + for name, child in module.named_children(): + cls_name = type(child).__name__ + if cls_name in EMBED_REPLACE_MAPPING: + embed = EMBED_REPLACE_MAPPING[cls_name] + embed = embed.build() + child_new = embed( + num_embeddings=child.num_embeddings, + embedding_dim=child.embedding_dim, + padding_idx=child.padding_idx, + ).to(device=child.weight.device, dtype=child.weight.dtype) + setattr(module, name, child_new) + else: + traverse(child) + + traverse(model) + + +def replace_norm(model): + def traverse(module): + for name, child in module.named_children(): + cls_name = type(child).__name__ + if cls_name in NORM_REPLACE_MAPPING: + norm = NORM_REPLACE_MAPPING[cls_name] + norm = norm.build() + child_new = norm( + norm_type=NORM2NEW_NORM_NAME_MAPPING[name], + normalized_shape=child.weight.shape, + eps=child.variance_epsilon, + ).to(device=child.weight.device, dtype=child.weight.dtype) + setattr(module, name, child_new) + else: + traverse(child) + + traverse(model) + + +def replace_linear(model): + def traverse(module): + for name, child in module.named_children(): + cls_name = type(child).__name__ + if cls_name in LINEAR_REPLACE_MAPPING: + linear = LINEAR_REPLACE_MAPPING[cls_name] + linear = linear.build() + child_new = linear( + name=LINEAR2NEW_LINEAR_NAME_MAPPING.get(name, name), + in_features=child.in_features, + out_features=child.out_features, + bias=child.bias is not None, + ).to(device=child.weight.device, dtype=child.weight.dtype) + setattr(module, name, child_new) + else: + traverse(child) + + traverse(model) + + +def hf_model_dispatch(model): + replace_embed(model) + replace_norm(model) + replace_linear(model) + + +__all__ = ["hf_model_dispatch"] \ No newline at end of file diff --git a/huggingface_model/internlm/internlm2_7b/modeling_internlm2.py b/huggingface_model/internlm/internlm2_7b/modeling_internlm2.py index 774d47f..23b0f32 100644 --- a/huggingface_model/internlm/internlm2_7b/modeling_internlm2.py +++ b/huggingface_model/internlm/internlm2_7b/modeling_internlm2.py @@ -46,6 +46,7 @@ ) from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.model.ops.attention import isp_flash_attn_varlen_func, isp_flash_attn_func try: from transformers.generation.streamers import BaseStreamer @@ -78,7 +79,6 @@ def _get_unpad_data(attention_mask): max_seqlen_in_batch, ) - class InternLM2RMSNorm(nn.Module): """InternLM2RMSNorm is equivalent to T5LayerNorm.""" @@ -485,22 +485,18 @@ def forward( # ) if use_packed_dataset: - attn_output = flash_attn_varlen_func( - query_states.flatten(0, 1), - key_states.flatten(0, 1), - value_states.flatten(0, 1), - cu_seqlens, - cu_seqlens, - max_seqlen, - max_seqlen, - dropout_rate, - softmax_scale=None, + attn_output = isp_flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, causal=True, - return_attn_probs=False, - ).unsqueeze(0) + attention_dropout = dropout_rate, + ) else: - attn_output = flash_attn_func( - query_states, key_states, value_states, dropout_rate, softmax_scale=None, causal=True, return_attn_probs=False, + attn_output = isp_flash_attn_func( + query_states, key_states, value_states, causal=True, attention_dropout=dropout_rate, ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() diff --git a/huggingface_model/internlm/internlm_7b/modeling_internlm.py b/huggingface_model/internlm/internlm_7b/modeling_internlm.py index 7b41d93..ada307b 100644 --- a/huggingface_model/internlm/internlm_7b/modeling_internlm.py +++ b/huggingface_model/internlm/internlm_7b/modeling_internlm.py @@ -19,7 +19,6 @@ import threading from typing import List, Optional, Tuple, Union -from internlm.core.context.parallel_context import IS_REPLICA_ZERO_PARALLEL import torch import torch.utils.checkpoint from torch import nn @@ -41,9 +40,8 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.model.ops.rotary_emb import apply_rotary_emb -from internlm.model.modules.embedding import Embedding1D -from internlm.model.modules.linear import new_linear -from internlm.model.ops.attention import hf_q_k_v_with_cu_seqlens, hf_q_k_v_without_cu_seqlens +from internlm.model.ops.attention import isp_flash_attn_varlen_func, isp_flash_attn_func + try: from transformers.generation.streamers import BaseStreamer @@ -126,8 +124,6 @@ def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps - for param in self.parameters(): - setattr(param, IS_REPLICA_ZERO_PARALLEL, True) def forward(self, hidden_states): variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) @@ -281,9 +277,9 @@ def __init__( hidden_act: str, ): super().__init__() - self.gate_proj = new_linear("w1", hidden_size, intermediate_size, bias=False) - self.down_proj = new_linear("w2", intermediate_size, hidden_size, bias=False) - self.up_proj = new_linear("w3", hidden_size, intermediate_size, bias=False) + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.act_fn = ACT2FN[hidden_act] def forward(self, x): @@ -307,10 +303,10 @@ def __init__(self, config: InternLMConfig): f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) - self.q_proj = new_linear("wq", self.hidden_size, self.num_heads * self.head_dim, bias=config.bias) - self.k_proj = new_linear("wk", self.hidden_size, self.num_heads * self.head_dim, bias=config.bias) - self.v_proj = new_linear("wv", self.hidden_size, self.num_heads * self.head_dim, bias=config.bias) - self.o_proj = new_linear("wo", self.num_heads * self.head_dim, self.hidden_size, bias=config.bias) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias) self.rotary_emb = self._init_rope() self.is_causal = True @@ -487,17 +483,17 @@ def forward( # ) if use_packed_dataset: - attn_output = hf_q_k_v_with_cu_seqlens( + attn_output = isp_flash_attn_varlen_func( query_states, key_states, value_states, - cumulative_len=cu_seqlens, + cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, - dropout_p=0.0, + causal=True, ) else: - attn_output = hf_q_k_v_without_cu_seqlens( - query_states, key_states, value_states, dropout_p=0.0, softmax_scale=None, causal=True, + attn_output = isp_flash_attn_func( + query_states, key_states, value_states, causal=True, ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() @@ -710,7 +706,7 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() - elif isinstance(module, Embedding1D): + elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() @@ -794,8 +790,8 @@ def __init__(self, config: InternLMConfig): self.vocab_size = config.vocab_size self.config = config - self.embed_tokens = Embedding1D(num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, padding_idx=self.padding_idx) - + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([InternLMDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.norm = InternLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -975,7 +971,7 @@ def __init__(self, config): super().__init__(config) self.model = InternLMModel(config) - self.lm_head = new_linear("head", config.hidden_size, config.vocab_size, bias=False) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() @@ -1371,4 +1367,4 @@ def forward( past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, - ) + ) \ No newline at end of file