Skip to content

Commit

Permalink
fix slow sampling when repetition_penalty is set.
Browse files Browse the repository at this point in the history
  • Loading branch information
ccrhx4 authored and michalkuligowski committed Jan 7, 2025
1 parent 2d24be7 commit 2ab9600
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 16 deletions.
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
output_tokens_tensor, vocab_size, num_seqs)

repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
repetition_penalties[~(prompt_mask | output_mask)] = 1.0
repetition_penalties.masked_fill_(~(prompt_mask | output_mask), 1.0)
logits = torch.where(logits > 0, logits / repetition_penalties,
logits * repetition_penalties)

Expand Down
49 changes: 34 additions & 15 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData,
SequenceGroupMetadata)
from vllm.utils import (PyObjectCache, async_tensor_h2d,
is_pin_memory_available, make_tensor_with_pad)
is_pin_memory_available, make_tensor_with_pad,
make_tensor_with_pad_align)

_SAMPLING_EPS = 1e-5

Expand Down Expand Up @@ -523,20 +524,38 @@ def from_lists(
do_penalties = prompt_tokens or output_tokens

if do_penalties:
prompt_t = make_tensor_with_pad(
prompt_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
output_t = make_tensor_with_pad(
output_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
if current_platform.is_hpu():
prompt_t = make_tensor_with_pad_align(
prompt_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
max_len_align=1024,
)
output_t = make_tensor_with_pad_align(
output_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
max_len_align=1024,
)
else:
prompt_t = make_tensor_with_pad(
prompt_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
output_t = make_tensor_with_pad(
output_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
else:
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
prompt_t = empty_tensor
Expand Down
53 changes: 53 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import importlib.util
import inspect
import ipaddress
import math
import os
import signal
import socket
Expand Down Expand Up @@ -822,6 +823,30 @@ def make_ndarray_with_pad(
return padded_x


def make_ndarray_with_pad_align(
x: List[List[T]],
pad: T,
dtype: npt.DTypeLike,
*,
max_len_align: int = 1024,
) -> npt.NDArray:
"""
Make a padded array from 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
# Unlike for most functions, map is faster than a genexpr over `len`
max_len = max(map(len, x), default=0)
max_len_aligned = math.ceil(max_len / max_len_align) * max_len_align
padded_x = np.full((len(x), max_len_aligned), pad, dtype=dtype)

for ind, blocktb in enumerate(x):
assert len(blocktb) <= max_len_aligned
padded_x[ind, :len(blocktb)] = blocktb

return padded_x


def make_tensor_with_pad(
x: List[List[T]],
pad: T,
Expand All @@ -847,6 +872,34 @@ def make_tensor_with_pad(
return tensor


def make_tensor_with_pad_align(
x: List[List[T]],
pad: T,
dtype: torch.dtype,
*,
max_len_align: int = 1024,
device: Optional[Union[str, torch.device]] = None,
pin_memory: bool = False,
) -> torch.Tensor:
"""
Make a padded tensor from 2D inputs.
The padding is applied to the end of each inner list until it reaches
max_len_aligned, max_len_aligned is max_len rounding to the nearest
`max_len_align`.
"""
np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
padded_x = make_ndarray_with_pad_align(x,
pad,
np_dtype,
max_len_align=max_len_align)

tensor = torch.from_numpy(padded_x).to(device)
if pin_memory:
tensor = tensor.pin_memory()

return tensor


def async_tensor_h2d(
data: list,
dtype: torch.dtype,
Expand Down

0 comments on commit 2ab9600

Please sign in to comment.