Skip to content

Commit

Permalink
Add async copying to input preparation (#497)
Browse files Browse the repository at this point in the history
This PR introduces async copying into _prepare_prompt and
_prepare_decode, which makes copying faster.
It also moves precompute_indices_and_offsets funtion into forward to
avoid unnecessary H2D copying.
  • Loading branch information
jkaniecki authored Nov 18, 2024
1 parent fb308c9 commit 7c5038c
Showing 1 changed file with 57 additions and 34 deletions.
91 changes: 57 additions & 34 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,17 +278,6 @@ def flatten(in_list):
return list(itertools.chain(*in_list))


def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt):
slot_mapping = slot_mapping.flatten()
indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
if is_prompt:
indices = indices.unflatten(0, (-1, block_size))[:, 0]
offsets = None
else:
offsets = torch.fmod(slot_mapping, block_size)
return indices, offsets


class HpuModelAdapter:

def __init__(self, model, block_size, dtype, enforce_eager):
Expand Down Expand Up @@ -382,6 +371,18 @@ def _set_block_scales(self, metadata, device):
metadata = metadata._replace(block_scales=block_scales)
return metadata

def _set_indices_and_offsets(self, metadata, block_size, is_prompt):
slot_mapping = metadata.slot_mapping.flatten()
indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
if is_prompt:
indices = indices.unflatten(0, (-1, block_size))[:, 0]
offsets = None
else:
offsets = torch.fmod(slot_mapping, block_size)
metadata = metadata._replace(block_offsets=offsets,
block_indices=indices)
return metadata

def _update_metadata(self, attn_metadata, batch_size, seq_len, device,
dtype):
if attn_metadata.is_prompt:
Expand All @@ -391,6 +392,9 @@ def _update_metadata(self, attn_metadata, batch_size, seq_len, device,
attn_metadata = self._set_block_mapping(attn_metadata, batch_size,
device, dtype)
attn_metadata = self._set_block_scales(attn_metadata, device)
attn_metadata = self._set_indices_and_offsets(attn_metadata,
self.block_size,
attn_metadata.is_prompt)
return attn_metadata

def forward(self, *args, **kwargs):
Expand Down Expand Up @@ -956,45 +960,56 @@ def _prepare_prompt(

prefix_block_list_tensor = torch.tensor(prefix_block_list,
dtype=torch.long,
device=self.device)
device='cpu')
else:
prefix_block_list_tensor = None

input_tokens = make_tensor_with_pad(input_tokens,
max_len=max_prompt_len,
pad=0,
dtype=torch.long,
device=self.device)
device='cpu')

input_positions = make_tensor_with_pad(input_positions,
max_len=max_prompt_len,
pad=0,
dtype=torch.long,
device=self.device)
device='cpu')

slot_mapping = make_tensor_with_pad(slot_mapping,
max_len=max_prompt_len,
pad=_PAD_SLOT_ID,
dtype=torch.long,
device=self.device)
device='cpu')

seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.long,
device=self.device)
device='cpu')

context_lens_tensor = torch.tensor(context_lens,
dtype=torch.long,
device=self.device)
device='cpu')

if prefix_block_list_tensor:

This comment has been minimized.

Copy link
@zhentaoyu

zhentaoyu Nov 20, 2024

Should it be if prefix_block_list_tensor is not None:? Otherwise, it may throw Boolean value of Tensor with more than one value is ambiguous error when enabling prefix_caching.

prefix_block_list_tensor = prefix_block_list_tensor.to(
self.device, non_blocking=True)
input_tokens = input_tokens.to( # type: ignore
self.device, non_blocking=True)
input_positions = input_positions.to( # type: ignore
self.device, non_blocking=True)
slot_mapping = slot_mapping.to( # type: ignore
self.device, non_blocking=True)
seq_lens_tensor = seq_lens_tensor.to(self.device, non_blocking=True)
context_lens_tensor = context_lens_tensor.to(self.device,
non_blocking=True)

block_indices, block_offsets = precompute_indices_and_offsets(
self.block_size, slot_mapping, True)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
block_list=prefix_block_list_tensor,
block_mapping=None,
block_usage=None,
block_indices=block_indices,
block_offsets=block_offsets,
block_indices=None,
block_offsets=None,
block_scales=None,
block_groups=None,
attn_bias=None,
Expand Down Expand Up @@ -1093,14 +1108,14 @@ def _prepare_decode(
if output is None:
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
device='cpu')
else:
real_batch_size = len(seq_group_metadata_list)
input_tokens = output[:real_batch_size]

input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
device='cpu')

num_decode_tokens = sum(seq_lens)

Expand Down Expand Up @@ -1142,29 +1157,37 @@ def _prepare_decode(
block_groups = padding_fn(block_groups, -1)
block_usage = padding_fn(block_usage, 1)

block_list = torch.tensor(block_list,
dtype=torch.int,
device=self.device)
block_list = torch.tensor(block_list, dtype=torch.int, device='cpu')
block_groups = torch.tensor(block_groups,
dtype=torch.int,
device=self.device)
device='cpu')
block_usage = torch.tensor(block_usage,
dtype=self.model_config.dtype,
device=self.device)
device='cpu')
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)

block_indices, block_offsets = precompute_indices_and_offsets(
self.block_size, slot_mapping, False)
device='cpu')

input_tokens = input_tokens.to( # type: ignore
self.device, non_blocking=True)
input_positions = input_positions.to( # type: ignore
self.device, non_blocking=True)
block_list = block_list.to( # type: ignore
self.device, non_blocking=True)
block_groups = block_groups.to( # type: ignore
self.device, non_blocking=True)
block_usage = block_usage.to( # type: ignore
self.device, non_blocking=True)
slot_mapping = slot_mapping.to( # type: ignore
self.device, non_blocking=True)

attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
block_list=block_list,
block_mapping=None,
block_usage=block_usage,
block_indices=block_indices,
block_offsets=block_offsets,
block_indices=None,
block_offsets=None,
block_scales=None,
block_groups=block_groups,
attn_bias=None,
Expand Down

0 comments on commit 7c5038c

Please sign in to comment.