From d6bdc90c4742fc5b3f365583b44e11d6aad6fdea Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Thu, 5 Dec 2024 14:53:32 +0800 Subject: [PATCH] Enable TP splitting on seq_len Signed-off-by: Kunshang Ji Signed-off-by: Chendi Xue --- vllm/distributed/utils.py | 12 +++++ vllm/model_executor/layers/linear.py | 76 +++++++++++++++++++++++++--- vllm/model_executor/models/llama.py | 17 +++++++ 3 files changed, 99 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index dcfcb848cbe06..c20e05f2dd851 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -57,6 +57,18 @@ def split_tensor_along_last_dim( return tensor_list +def split_tensor_along_x_dim( + tensor: torch.Tensor, + dim: int, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> Sequence[torch.Tensor]: + dim_size = divide(tensor.size()[dim], num_partitions) + tensor_list = torch.split(tensor, dim_size, dim=dim) + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + return tensor_list + def get_pp_indices(num_hidden_layers: int, pp_rank: int, pp_size: int) -> Tuple[int, int]: diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 38d33809255e4..7ad660044c29b 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -9,6 +9,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, + split_tensor_along_x_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) from vllm.logger import init_logger @@ -996,13 +997,20 @@ def __init__(self, params_dtype: Optional[torch.dtype] = None, reduce_results: bool = True, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + do_split: bool=False, # should enable for donw_proj, disable for o_proj + split_threshold:int = 128, + split_size:int = 2): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix) self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results self.collective_func = tensor_model_parallel_all_reduce + self.do_split = do_split + self.split_threshold = split_threshold + self.split_size = split_size + self.prefix = prefix # Divide the weight matrix along the last dimension. self.tp_rank = get_tensor_model_parallel_rank() @@ -1099,13 +1107,69 @@ def forward(self, input_): # Only fuse bias add into GEMM for rank 0 (this ensures that # bias will not get added more than once in TP>1 case) bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias - output_parallel = self.quant_method.apply(self, + + # print(input_parallel.shape) # [batch_size, seq_lens, hidden_size//tp_size] + + # split v2: + # stretage: we split the input tensor on 1 dim(seq length dim), but only when seq_length greater + # than a threshold, otherwise we dont split. which means, decode phase will never split + # why split on 1st dim: + # the 0th dim is batch size, when batch size = 1, we can not split anyway. + # 2nd dim(hidden_size): tp already split on this dim, will change much more if split on this + + _, seq_len, _ = input_parallel.shape + shape_total = input_parallel.shape[0] * input_parallel.shape[1] * input_parallel.shape[2] + do_split = self.do_split and seq_len > 1 # split decode + # NOTE: we found split tensor when it is too small is not helping with the performance. + # 1 * 1024 * 4096 * 3 is [batch_size, seq_len, hidden_size * 3] + do_split = do_split and shape_total > 1 * 1024 * 4096 * 3 + + if do_split: + input_parallels = split_tensor_along_x_dim(input_parallel, 1, self.split_size) + output_parallels = [] + for input_parallel in input_parallels: + output_parallel = self.quant_method.apply(self, + input_parallel, + bias=bias_) + if self.reduce_results and self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + output_parallels.append(output) + output = torch.cat(output_parallels, dim=1) + + else: + output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) - if self.reduce_results and self.tp_size > 1: - output = tensor_model_parallel_all_reduce(output_parallel) - else: - output = output_parallel + if self.reduce_results and self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + + + # split v1: + # why split on 0th dim: + # 1st dim(seq_lens): due to decode phase seq_lens is always 1, so we can not split on this dim + # 2nd dim(hidden_size): tp already split on this dim, will change much more if split on this + # Other limitations & FIXME: need to set VLLM_DECODE_BS_BUCKET_MIN=2, VLLM_PROMPT_BS_BUCKET_MIN=2, otherwise it cannot divide and split. + # Overheads: + # 1. split overhead. + # 2. append may have some overhead, I am not sure whether the output tensor need ready. + # 3. cat tensor overhead. we can do some optimization here. but I am afraid there will always be some copy. + # split = 2 + # input_parallels = split_tensor_along_x_dim(input_parallel, 0, split) + # output_parallels = [] + # for input_parallel in input_parallels: + # output_parallel = self.quant_method.apply(self, + # input_parallel, + # bias=bias_) + # if self.reduce_results and self.tp_size > 1: + # output = tensor_model_parallel_all_reduce(output_parallel) + # else: + # output = output_parallel + # output_parallels.append(output) + # output = torch.cat(output_parallels, dim=0) output_bias = self.bias if self.skip_bias_add else None diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 6461a80cef331..3a263e59f31e6 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -22,6 +22,7 @@ # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union +import os import torch from torch import nn @@ -70,6 +71,8 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, bias: bool = False, prefix: str = "", + do_split: bool = False, + split_size: int = 2 ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -85,6 +88,8 @@ def __init__( bias=bias, quant_config=quant_config, prefix=f"{prefix}.down_proj", + do_split=do_split, + split_size=split_size ) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -113,6 +118,8 @@ def __init__( bias: bool = False, cache_config: Optional[CacheConfig] = None, prefix: str = "", + do_split: bool = False, + split_size: int = 2 ) -> None: super().__init__() layer_idx = extract_layer_index(prefix) @@ -156,6 +163,8 @@ def __init__( bias=bias, quant_config=quant_config, prefix=f"{prefix}.o_proj", + do_split=do_split, + split_size=split_size, ) is_neox_style = True @@ -231,6 +240,10 @@ def __init__( # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( config, "bias", False) + + split_size = int(os.environ.get('VLLM_TP_SPLIT_SIZE_BY_SEQ', '1')) + enable_o_proj_split = int(os.environ.get('VLLM_TP_O_PROJ_SPLIT_ENABLE', '1')) == 1 + do_split = split_size > 1 self.self_attn = LlamaAttention( config=config, hidden_size=self.hidden_size, @@ -244,6 +257,8 @@ def __init__( bias=attention_bias, cache_config=cache_config, prefix=f"{prefix}.self_attn", + do_split=do_split and enable_o_proj_split, + split_size=split_size, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, @@ -252,6 +267,8 @@ def __init__( quant_config=quant_config, bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", + do_split=do_split, + split_size=split_size ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)