Skip to content

Commit

Permalink
fix slow sampling when repetition_penalty is set. (#584)
Browse files Browse the repository at this point in the history
This PR is to fix the slow sampling in HPU when repetition_penalty is
set in the sampling parameters.

It replaces the slow pytorch API on HPU and mitigate the dynamic shapes
in the code.

Without this PR:
SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0,
repetition_penalty=1.06, temperature=1.0, top_p=1.0, top_k=-1,
min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[],
include_stop_str_in_output=False, ignore_eos=True, max_tokens=1024,
min_tokens=0, logprobs=None, prompt_logprobs=None,
skip_special_tokens=True, spaces_between_special_tokens=True,
truncate_prompt_tokens=None, guided_decoding=None)
Warming up...
Profiling iterations: 100%|5/5 [03:32<00:00, 42.49s/it]
Avg latency: 42.49439047839987 seconds
10% percentile latency: 11.322476224999628 seconds
25% percentile latency: 11.32563829100036 seconds
50% percentile latency: 11.331052645000455 seconds
75% percentile latency: 11.333669468998778 seconds
90% percentile latency: 104.8302020711999 seconds
99% percentile latency: 160.92812163252054 seconds

With PR:
Avg latency: 11.038154767800005 seconds
10% percentile latency: 10.964674918200398 seconds
25% percentile latency: 10.964709408001 seconds
50% percentile latency: 10.966433088000485 seconds
75% percentile latency: 10.967024742998547 seconds
90% percentile latency: 11.18358270219942 seconds
99% percentile latency: 11.313517477719943 seconds

Testing code:

https://github.com/ccrhx4/huanxing.vllm-fork/blob/slow_repetition_penalty/benchmarks/reproduce.sh

The only difference about this PR and
#442 is that I do not enable
pin_memory as this feature readiness is poor on HPU.
  • Loading branch information
ccrhx4 authored Jan 7, 2025
1 parent 2d24be7 commit 27a22ab
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 16 deletions.
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
output_tokens_tensor, vocab_size, num_seqs)

repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
repetition_penalties[~(prompt_mask | output_mask)] = 1.0
repetition_penalties.masked_fill_(~(prompt_mask | output_mask), 1.0)
logits = torch.where(logits > 0, logits / repetition_penalties,
logits * repetition_penalties)

Expand Down
49 changes: 34 additions & 15 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData,
SequenceGroupMetadata)
from vllm.utils import (PyObjectCache, async_tensor_h2d,
is_pin_memory_available, make_tensor_with_pad)
is_pin_memory_available, make_tensor_with_pad,
make_tensor_with_pad_align)

_SAMPLING_EPS = 1e-5

Expand Down Expand Up @@ -523,20 +524,38 @@ def from_lists(
do_penalties = prompt_tokens or output_tokens

if do_penalties:
prompt_t = make_tensor_with_pad(
prompt_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
output_t = make_tensor_with_pad(
output_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
if current_platform.is_hpu():
prompt_t = make_tensor_with_pad_align(
prompt_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
max_len_align=1024,
)
output_t = make_tensor_with_pad_align(
output_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
max_len_align=1024,
)
else:
prompt_t = make_tensor_with_pad(
prompt_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
output_t = make_tensor_with_pad(
output_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
else:
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
prompt_t = empty_tensor
Expand Down
53 changes: 53 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import importlib.util
import inspect
import ipaddress
import math
import os
import signal
import socket
Expand Down Expand Up @@ -822,6 +823,30 @@ def make_ndarray_with_pad(
return padded_x


def make_ndarray_with_pad_align(
x: List[List[T]],
pad: T,
dtype: npt.DTypeLike,
*,
max_len_align: int = 1024,
) -> npt.NDArray:
"""
Make a padded array from 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
# Unlike for most functions, map is faster than a genexpr over `len`
max_len = max(map(len, x), default=0)
max_len_aligned = math.ceil(max_len / max_len_align) * max_len_align
padded_x = np.full((len(x), max_len_aligned), pad, dtype=dtype)

for ind, blocktb in enumerate(x):
assert len(blocktb) <= max_len_aligned
padded_x[ind, :len(blocktb)] = blocktb

return padded_x


def make_tensor_with_pad(
x: List[List[T]],
pad: T,
Expand All @@ -847,6 +872,34 @@ def make_tensor_with_pad(
return tensor


def make_tensor_with_pad_align(
x: List[List[T]],
pad: T,
dtype: torch.dtype,
*,
max_len_align: int = 1024,
device: Optional[Union[str, torch.device]] = None,
pin_memory: bool = False,
) -> torch.Tensor:
"""
Make a padded tensor from 2D inputs.
The padding is applied to the end of each inner list until it reaches
max_len_aligned, max_len_aligned is max_len rounding to the nearest
`max_len_align`.
"""
np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
padded_x = make_ndarray_with_pad_align(x,
pad,
np_dtype,
max_len_align=max_len_align)

tensor = torch.from_numpy(padded_x).to(device)
if pin_memory:
tensor = tensor.pin_memory()

return tensor


def async_tensor_h2d(
data: list,
dtype: torch.dtype,
Expand Down

0 comments on commit 27a22ab

Please sign in to comment.