Skip to content

Commit

Permalink
Enable TP splitting on seq_len
Browse files Browse the repository at this point in the history
Signed-off-by: Kunshang Ji <[email protected]>
Signed-off-by: Chendi Xue <[email protected]>
  • Loading branch information
jikunshang authored and xuechendi committed Dec 16, 2024
1 parent 2576619 commit d6bdc90
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 6 deletions.
12 changes: 12 additions & 0 deletions vllm/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ def split_tensor_along_last_dim(

return tensor_list

def split_tensor_along_x_dim(
tensor: torch.Tensor,
dim: int,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> Sequence[torch.Tensor]:
dim_size = divide(tensor.size()[dim], num_partitions)
tensor_list = torch.split(tensor, dim_size, dim=dim)
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list


def get_pp_indices(num_hidden_layers: int, pp_rank: int,
pp_size: int) -> Tuple[int, int]:
Expand Down
76 changes: 70 additions & 6 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
split_tensor_along_x_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
Expand Down Expand Up @@ -996,13 +997,20 @@ def __init__(self,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
do_split: bool=False, # should enable for donw_proj, disable for o_proj
split_threshold:int = 128,
split_size:int = 2):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)

self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
self.collective_func = tensor_model_parallel_all_reduce
self.do_split = do_split
self.split_threshold = split_threshold
self.split_size = split_size
self.prefix = prefix

# Divide the weight matrix along the last dimension.
self.tp_rank = get_tensor_model_parallel_rank()
Expand Down Expand Up @@ -1099,13 +1107,69 @@ def forward(self, input_):
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,

# print(input_parallel.shape) # [batch_size, seq_lens, hidden_size//tp_size]

# split v2:
# stretage: we split the input tensor on 1 dim(seq length dim), but only when seq_length greater
# than a threshold, otherwise we dont split. which means, decode phase will never split
# why split on 1st dim:
# the 0th dim is batch size, when batch size = 1, we can not split anyway.
# 2nd dim(hidden_size): tp already split on this dim, will change much more if split on this

_, seq_len, _ = input_parallel.shape
shape_total = input_parallel.shape[0] * input_parallel.shape[1] * input_parallel.shape[2]
do_split = self.do_split and seq_len > 1 # split decode
# NOTE: we found split tensor when it is too small is not helping with the performance.
# 1 * 1024 * 4096 * 3 is [batch_size, seq_len, hidden_size * 3]
do_split = do_split and shape_total > 1 * 1024 * 4096 * 3

if do_split:
input_parallels = split_tensor_along_x_dim(input_parallel, 1, self.split_size)
output_parallels = []
for input_parallel in input_parallels:
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
output_parallels.append(output)
output = torch.cat(output_parallels, dim=1)

else:
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel


# split v1:
# why split on 0th dim:
# 1st dim(seq_lens): due to decode phase seq_lens is always 1, so we can not split on this dim
# 2nd dim(hidden_size): tp already split on this dim, will change much more if split on this
# Other limitations & FIXME: need to set VLLM_DECODE_BS_BUCKET_MIN=2, VLLM_PROMPT_BS_BUCKET_MIN=2, otherwise it cannot divide and split.
# Overheads:
# 1. split overhead.
# 2. append may have some overhead, I am not sure whether the output tensor need ready.
# 3. cat tensor overhead. we can do some optimization here. but I am afraid there will always be some copy.
# split = 2
# input_parallels = split_tensor_along_x_dim(input_parallel, 0, split)
# output_parallels = []
# for input_parallel in input_parallels:
# output_parallel = self.quant_method.apply(self,
# input_parallel,
# bias=bias_)
# if self.reduce_results and self.tp_size > 1:
# output = tensor_model_parallel_all_reduce(output_parallel)
# else:
# output = output_parallel
# output_parallels.append(output)
# output = torch.cat(output_parallels, dim=0)

output_bias = self.bias if self.skip_bias_add else None

Expand Down
17 changes: 17 additions & 0 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
import os

import torch
from torch import nn
Expand Down Expand Up @@ -70,6 +71,8 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
do_split: bool = False,
split_size: int = 2
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
Expand All @@ -85,6 +88,8 @@ def __init__(
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
do_split=do_split,
split_size=split_size
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
Expand Down Expand Up @@ -113,6 +118,8 @@ def __init__(
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
do_split: bool = False,
split_size: int = 2
) -> None:
super().__init__()
layer_idx = extract_layer_index(prefix)
Expand Down Expand Up @@ -156,6 +163,8 @@ def __init__(
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
do_split=do_split,
split_size=split_size,
)

is_neox_style = True
Expand Down Expand Up @@ -231,6 +240,10 @@ def __init__(
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)

split_size = int(os.environ.get('VLLM_TP_SPLIT_SIZE_BY_SEQ', '1'))
enable_o_proj_split = int(os.environ.get('VLLM_TP_O_PROJ_SPLIT_ENABLE', '1')) == 1
do_split = split_size > 1
self.self_attn = LlamaAttention(
config=config,
hidden_size=self.hidden_size,
Expand All @@ -244,6 +257,8 @@ def __init__(
bias=attention_bias,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
do_split=do_split and enable_o_proj_split,
split_size=split_size,
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
Expand All @@ -252,6 +267,8 @@ def __init__(
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
do_split=do_split,
split_size=split_size
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
Expand Down

0 comments on commit d6bdc90

Please sign in to comment.