diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index b61fb0e47d07a..8aa6646c5dcea 100755 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -528,7 +528,7 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, output_tokens_tensor, vocab_size, num_seqs) repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) - repetition_penalties.masked_fill_(~(prompt_mask | output_mask), 1.0) + repetition_penalties[~(prompt_mask | output_mask)] = 1.0 logits = torch.where(logits > 0, logits / repetition_penalties, logits * repetition_penalties) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 234d1cb304d59..9fda807d29236 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -9,8 +9,7 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData, SequenceGroupMetadata) from vllm.utils import (PyObjectCache, async_tensor_h2d, - is_pin_memory_available, make_tensor_with_pad, - make_tensor_with_pad_align) + is_pin_memory_available, make_tensor_with_pad) _SAMPLING_EPS = 1e-5 @@ -523,38 +522,20 @@ def from_lists( do_penalties = prompt_tokens or output_tokens if do_penalties: - if current_platform.is_hpu(): - prompt_t = make_tensor_with_pad_align( - prompt_tokens, - vocab_size, - device="cpu", - dtype=torch.int64, - pin_memory=pin_memory, - max_len_align=1024, - ) - output_t = make_tensor_with_pad_align( - output_tokens, - vocab_size, - device="cpu", - dtype=torch.int64, - pin_memory=pin_memory, - max_len_align=1024, - ) - else: - prompt_t = make_tensor_with_pad( - prompt_tokens, - vocab_size, - device="cpu", - dtype=torch.int64, - pin_memory=pin_memory, - ) - output_t = make_tensor_with_pad( - output_tokens, - vocab_size, - device="cpu", - dtype=torch.int64, - pin_memory=pin_memory, - ) + prompt_t = make_tensor_with_pad( + prompt_tokens, + vocab_size, + device="cpu", + dtype=torch.int64, + pin_memory=pin_memory, + ) + output_t = make_tensor_with_pad( + output_tokens, + vocab_size, + device="cpu", + dtype=torch.int64, + pin_memory=pin_memory, + ) else: empty_tensor = torch.empty(0, device=device, dtype=torch.long) prompt_t = empty_tensor @@ -564,58 +545,47 @@ def from_lists( temperatures, device="cpu", dtype=dtype, + pin_memory=pin_memory, ) top_ps_t = torch.tensor( top_ps, device="cpu", dtype=dtype, + pin_memory=pin_memory, ) min_ps_t = torch.tensor( min_ps, device="cpu", dtype=dtype, + pin_memory=pin_memory, ) presence_penalties_t = torch.tensor( presence_penalties, device="cpu", dtype=dtype, + pin_memory=pin_memory, ) frequency_penalties_t = torch.tensor( frequency_penalties, device="cpu", dtype=dtype, + pin_memory=pin_memory, ) repetition_penalties_t = torch.tensor( repetition_penalties, device="cpu", dtype=dtype, + pin_memory=pin_memory, ) top_ks_t = torch.tensor( top_ks, device="cpu", dtype=torch.int, + pin_memory=pin_memory, ) # Because the memory is pinned, we can do non-blocking # transfer to device. - if pin_memory: - if not current_platform.is_hpu(): - temperatures_t.pin_memory() - top_ps_t.pin_memory() - min_ps_t.pin_memory() - frequency_penalties_t.pin_memory() - presence_penalties_t.pin_memory() - repetition_penalties_t.pin_memory() - top_ks_t.pin_memory() - else: - temperatures_t.pin_memory(device="hpu") - top_ps_t.pin_memory(device="hpu") - min_ps_t.pin_memory(device="hpu") - frequency_penalties_t.pin_memory(device="hpu") - presence_penalties_t.pin_memory(device="hpu") - repetition_penalties_t.pin_memory(device="hpu") - top_ks_t.pin_memory(device="hpu") - return cls( temperatures=temperatures_t.to(device=device, non_blocking=True), top_ps=top_ps_t.to(device=device, non_blocking=True), diff --git a/vllm/utils.py b/vllm/utils.py index 3d08701b27e33..734c4efb02fda 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -8,7 +8,6 @@ import importlib.util import inspect import ipaddress -import math import os import socket import subprocess @@ -751,8 +750,10 @@ def is_pin_memory_available() -> bool: elif current_platform.is_neuron(): print_warning_once("Pin memory is not supported on Neuron.") return False - elif (current_platform.is_cpu() or current_platform.is_openvino() - or is_fake_hpu()): + elif current_platform.is_hpu(): + print_warning_once("Pin memory is not supported on HPU.") + return False + elif current_platform.is_cpu() or current_platform.is_openvino(): return False return True @@ -810,31 +811,6 @@ def make_ndarray_with_pad( return padded_x -def make_ndarray_with_pad_align( - x: List[List[T]], - pad: T, - dtype: npt.DTypeLike, - *, - max_len_align: int = 1024, -) -> npt.NDArray: - """ - Make a padded array from 2D inputs. - - The padding is applied to the end of each inner list until it reaches - `max_len`. - """ - # Unlike for most functions, map is faster than a genexpr over `len` - max_len = max(map(len, x), default=0) - max_len_aligned = math.ceil(max_len / max_len_align) * max_len_align - padded_x = np.full((len(x), max_len_aligned), pad, dtype=dtype) - - for ind, blocktb in enumerate(x): - assert len(blocktb) <= max_len_aligned - padded_x[ind, :len(blocktb)] = blocktb - - return padded_x - - def make_tensor_with_pad( x: List[List[T]], pad: T, @@ -855,39 +831,7 @@ def make_tensor_with_pad( tensor = torch.from_numpy(padded_x).to(device) if pin_memory: - if not current_platform.is_hpu(): - tensor = tensor.pin_memory() - else: - tensor = tensor.pin_memory("hpu") - - return tensor - - -def make_tensor_with_pad_align( - x: List[List[T]], - pad: T, - dtype: torch.dtype, - *, - max_len_align: int = 1024, - device: Optional[Union[str, torch.device]] = None, - pin_memory: bool = False, -) -> torch.Tensor: - """ - Make a padded tensor from 2D inputs. - - The padding is applied to the end of each inner list until it reaches - max_len_aligned, max_len_aligned is max_len rounding to the nearest - `max_len_align`. - """ - np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype] - padded_x = make_ndarray_with_pad_align(x, - pad, - np_dtype, - max_len_align=max_len_align) - - tensor = torch.from_numpy(padded_x).to(device) - if pin_memory: - tensor = tensor.pin_memory("hpu") + tensor = tensor.pin_memory() return tensor @@ -899,13 +843,7 @@ def async_tensor_h2d( pin_memory: bool, ) -> torch.Tensor: """Asynchronously create a tensor and copy it from host to device.""" - t = torch.tensor(data, dtype=dtype, device="cpu") - if pin_memory: - if not current_platform.is_hpu(): - t.pin_memory() - else: - t.pin_memory(device="hpu") - + t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu") return t.to(device=target_device, non_blocking=True)