diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 93ff84f64f89c..70fe8670a52d4 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -278,17 +278,6 @@ def flatten(in_list): return list(itertools.chain(*in_list)) -def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt): - slot_mapping = slot_mapping.flatten() - indices = torch.div(slot_mapping, block_size, rounding_mode="floor") - if is_prompt: - indices = indices.unflatten(0, (-1, block_size))[:, 0] - offsets = None - else: - offsets = torch.fmod(slot_mapping, block_size) - return indices, offsets - - class HpuModelAdapter: def __init__(self, model, block_size, dtype, enforce_eager): @@ -382,6 +371,18 @@ def _set_block_scales(self, metadata, device): metadata = metadata._replace(block_scales=block_scales) return metadata + def _set_indices_and_offsets(self, metadata, block_size, is_prompt): + slot_mapping = metadata.slot_mapping.flatten() + indices = torch.div(slot_mapping, block_size, rounding_mode="floor") + if is_prompt: + indices = indices.unflatten(0, (-1, block_size))[:, 0] + offsets = None + else: + offsets = torch.fmod(slot_mapping, block_size) + metadata = metadata._replace(block_offsets=offsets, + block_indices=indices) + return metadata + def _update_metadata(self, attn_metadata, batch_size, seq_len, device, dtype): if attn_metadata.is_prompt: @@ -391,6 +392,9 @@ def _update_metadata(self, attn_metadata, batch_size, seq_len, device, attn_metadata = self._set_block_mapping(attn_metadata, batch_size, device, dtype) attn_metadata = self._set_block_scales(attn_metadata, device) + attn_metadata = self._set_indices_and_offsets(attn_metadata, + self.block_size, + attn_metadata.is_prompt) return attn_metadata def forward(self, *args, **kwargs): @@ -956,7 +960,7 @@ def _prepare_prompt( prefix_block_list_tensor = torch.tensor(prefix_block_list, dtype=torch.long, - device=self.device) + device='cpu') else: prefix_block_list_tensor = None @@ -964,37 +968,48 @@ def _prepare_prompt( max_len=max_prompt_len, pad=0, dtype=torch.long, - device=self.device) + device='cpu') input_positions = make_tensor_with_pad(input_positions, max_len=max_prompt_len, pad=0, dtype=torch.long, - device=self.device) + device='cpu') slot_mapping = make_tensor_with_pad(slot_mapping, max_len=max_prompt_len, pad=_PAD_SLOT_ID, dtype=torch.long, - device=self.device) + device='cpu') seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.long, - device=self.device) + device='cpu') context_lens_tensor = torch.tensor(context_lens, dtype=torch.long, - device=self.device) + device='cpu') + + if prefix_block_list_tensor: + prefix_block_list_tensor = prefix_block_list_tensor.to( + self.device, non_blocking=True) + input_tokens = input_tokens.to( # type: ignore + self.device, non_blocking=True) + input_positions = input_positions.to( # type: ignore + self.device, non_blocking=True) + slot_mapping = slot_mapping.to( # type: ignore + self.device, non_blocking=True) + seq_lens_tensor = seq_lens_tensor.to(self.device, non_blocking=True) + context_lens_tensor = context_lens_tensor.to(self.device, + non_blocking=True) - block_indices, block_offsets = precompute_indices_and_offsets( - self.block_size, slot_mapping, True) attn_metadata = self.attn_backend.make_metadata( is_prompt=True, block_list=prefix_block_list_tensor, block_mapping=None, block_usage=None, - block_indices=block_indices, - block_offsets=block_offsets, + block_indices=None, + block_offsets=None, block_scales=None, block_groups=None, attn_bias=None, @@ -1093,14 +1108,14 @@ def _prepare_decode( if output is None: input_tokens = torch.tensor(input_tokens, dtype=torch.long, - device=self.device) + device='cpu') else: real_batch_size = len(seq_group_metadata_list) input_tokens = output[:real_batch_size] input_positions = torch.tensor(input_positions, dtype=torch.long, - device=self.device) + device='cpu') num_decode_tokens = sum(seq_lens) @@ -1142,29 +1157,37 @@ def _prepare_decode( block_groups = padding_fn(block_groups, -1) block_usage = padding_fn(block_usage, 1) - block_list = torch.tensor(block_list, - dtype=torch.int, - device=self.device) + block_list = torch.tensor(block_list, dtype=torch.int, device='cpu') block_groups = torch.tensor(block_groups, dtype=torch.int, - device=self.device) + device='cpu') block_usage = torch.tensor(block_usage, dtype=self.model_config.dtype, - device=self.device) + device='cpu') slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, - device=self.device) - - block_indices, block_offsets = precompute_indices_and_offsets( - self.block_size, slot_mapping, False) + device='cpu') + + input_tokens = input_tokens.to( # type: ignore + self.device, non_blocking=True) + input_positions = input_positions.to( # type: ignore + self.device, non_blocking=True) + block_list = block_list.to( # type: ignore + self.device, non_blocking=True) + block_groups = block_groups.to( # type: ignore + self.device, non_blocking=True) + block_usage = block_usage.to( # type: ignore + self.device, non_blocking=True) + slot_mapping = slot_mapping.to( # type: ignore + self.device, non_blocking=True) attn_metadata = self.attn_backend.make_metadata( is_prompt=False, block_list=block_list, block_mapping=None, block_usage=block_usage, - block_indices=block_indices, - block_offsets=block_offsets, + block_indices=None, + block_offsets=None, block_scales=None, block_groups=block_groups, attn_bias=None,