Skip to content

Commit

Permalink
Revert "to make repetition penalty faster (#442)"
Browse files Browse the repository at this point in the history
This reverts commit cef2df0.
  • Loading branch information
michalkuligowski authored Dec 2, 2024
1 parent 49c9efa commit 775388f
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 121 deletions.
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
output_tokens_tensor, vocab_size, num_seqs)

repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
repetition_penalties.masked_fill_(~(prompt_mask | output_mask), 1.0)
repetition_penalties[~(prompt_mask | output_mask)] = 1.0
logits = torch.where(logits > 0, logits / repetition_penalties,
logits * repetition_penalties)

Expand Down
74 changes: 22 additions & 52 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData,
SequenceGroupMetadata)
from vllm.utils import (PyObjectCache, async_tensor_h2d,
is_pin_memory_available, make_tensor_with_pad,
make_tensor_with_pad_align)
is_pin_memory_available, make_tensor_with_pad)

_SAMPLING_EPS = 1e-5

Expand Down Expand Up @@ -523,38 +522,20 @@ def from_lists(
do_penalties = prompt_tokens or output_tokens

if do_penalties:
if current_platform.is_hpu():
prompt_t = make_tensor_with_pad_align(
prompt_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
max_len_align=1024,
)
output_t = make_tensor_with_pad_align(
output_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
max_len_align=1024,
)
else:
prompt_t = make_tensor_with_pad(
prompt_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
output_t = make_tensor_with_pad(
output_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
prompt_t = make_tensor_with_pad(
prompt_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
output_t = make_tensor_with_pad(
output_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
else:
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
prompt_t = empty_tensor
Expand All @@ -564,58 +545,47 @@ def from_lists(
temperatures,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
top_ps_t = torch.tensor(
top_ps,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
min_ps_t = torch.tensor(
min_ps,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
presence_penalties_t = torch.tensor(
presence_penalties,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
frequency_penalties_t = torch.tensor(
frequency_penalties,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
repetition_penalties_t = torch.tensor(
repetition_penalties,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
top_ks_t = torch.tensor(
top_ks,
device="cpu",
dtype=torch.int,
pin_memory=pin_memory,
)
# Because the memory is pinned, we can do non-blocking
# transfer to device.

if pin_memory:
if not current_platform.is_hpu():
temperatures_t.pin_memory()
top_ps_t.pin_memory()
min_ps_t.pin_memory()
frequency_penalties_t.pin_memory()
presence_penalties_t.pin_memory()
repetition_penalties_t.pin_memory()
top_ks_t.pin_memory()
else:
temperatures_t.pin_memory(device="hpu")
top_ps_t.pin_memory(device="hpu")
min_ps_t.pin_memory(device="hpu")
frequency_penalties_t.pin_memory(device="hpu")
presence_penalties_t.pin_memory(device="hpu")
repetition_penalties_t.pin_memory(device="hpu")
top_ks_t.pin_memory(device="hpu")

return cls(
temperatures=temperatures_t.to(device=device, non_blocking=True),
top_ps=top_ps_t.to(device=device, non_blocking=True),
Expand Down
74 changes: 6 additions & 68 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import importlib.util
import inspect
import ipaddress
import math
import os
import socket
import subprocess
Expand Down Expand Up @@ -751,8 +750,10 @@ def is_pin_memory_available() -> bool:
elif current_platform.is_neuron():
print_warning_once("Pin memory is not supported on Neuron.")
return False
elif (current_platform.is_cpu() or current_platform.is_openvino()
or is_fake_hpu()):
elif current_platform.is_hpu():
print_warning_once("Pin memory is not supported on HPU.")
return False
elif current_platform.is_cpu() or current_platform.is_openvino():
return False
return True

Expand Down Expand Up @@ -810,31 +811,6 @@ def make_ndarray_with_pad(
return padded_x


def make_ndarray_with_pad_align(
x: List[List[T]],
pad: T,
dtype: npt.DTypeLike,
*,
max_len_align: int = 1024,
) -> npt.NDArray:
"""
Make a padded array from 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
# Unlike for most functions, map is faster than a genexpr over `len`
max_len = max(map(len, x), default=0)
max_len_aligned = math.ceil(max_len / max_len_align) * max_len_align
padded_x = np.full((len(x), max_len_aligned), pad, dtype=dtype)

for ind, blocktb in enumerate(x):
assert len(blocktb) <= max_len_aligned
padded_x[ind, :len(blocktb)] = blocktb

return padded_x


def make_tensor_with_pad(
x: List[List[T]],
pad: T,
Expand All @@ -855,39 +831,7 @@ def make_tensor_with_pad(

tensor = torch.from_numpy(padded_x).to(device)
if pin_memory:
if not current_platform.is_hpu():
tensor = tensor.pin_memory()
else:
tensor = tensor.pin_memory("hpu")

return tensor


def make_tensor_with_pad_align(
x: List[List[T]],
pad: T,
dtype: torch.dtype,
*,
max_len_align: int = 1024,
device: Optional[Union[str, torch.device]] = None,
pin_memory: bool = False,
) -> torch.Tensor:
"""
Make a padded tensor from 2D inputs.
The padding is applied to the end of each inner list until it reaches
max_len_aligned, max_len_aligned is max_len rounding to the nearest
`max_len_align`.
"""
np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
padded_x = make_ndarray_with_pad_align(x,
pad,
np_dtype,
max_len_align=max_len_align)

tensor = torch.from_numpy(padded_x).to(device)
if pin_memory:
tensor = tensor.pin_memory("hpu")
tensor = tensor.pin_memory()

return tensor

Expand All @@ -899,13 +843,7 @@ def async_tensor_h2d(
pin_memory: bool,
) -> torch.Tensor:
"""Asynchronously create a tensor and copy it from host to device."""
t = torch.tensor(data, dtype=dtype, device="cpu")
if pin_memory:
if not current_platform.is_hpu():
t.pin_memory()
else:
t.pin_memory(device="hpu")

t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
return t.to(device=target_device, non_blocking=True)


Expand Down

0 comments on commit 775388f

Please sign in to comment.