Skip to content

Commit

Permalink
use index_put with full block_indices
Browse files Browse the repository at this point in the history
Signed-off-by: Chendi.Xue <[email protected]>
  • Loading branch information
xuechendi committed Jan 8, 2025
1 parent 612abed commit b7d0931
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 50 deletions.
23 changes: 5 additions & 18 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
context_lens_tensor: Optional[torch.Tensor]
enable_merged_prefill: bool = False
seq_lens: Optional[List[int]] = None
slot_mapping_merged: Optional[torch.Tensor] = None
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
cross_block_indices: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -268,7 +269,6 @@ def forward(
block_offsets = kwargs.get('block_offsets', None)
seq_lens_tensor = kwargs.get('seq_lens_tensor', None)
attn_bias = kwargs.get('attn_bias', None)
seq_lens_tensor_list = kwargs.get('seq_lens_tensor_list', None)
enable_merged_prefill = attn_metadata.enable_merged_prefill
if block_indices is None:
block_indices = attn_metadata.block_indices
Expand All @@ -278,25 +278,12 @@ def forward(
seq_lens_tensor = attn_metadata.seq_lens_tensor
if attn_bias is None: # This is the case for prompt run
attn_bias = attn_metadata.attn_bias
if enable_merged_prefill and attn_metadata.is_prompt and kv_cache is not None:
max_len = attn_metadata.slot_mapping.size(1)
# we need to copy the key and value tensors to the padded tensors
# shape is [bacth_size, entire_seq_len, num_kv_heads, head_size]
padded_key_tensor = split_and_pad_to_length(
key, max_len, seq_lens_tensor_list)
padded_value_tensor = split_and_pad_to_length(
value, max_len, seq_lens_tensor_list)
padded_key_tensor = padded_key_tensor.flatten(0, 1).unflatten(
0, (block_indices.size(0), -1))
padded_value_tensor = padded_value_tensor.flatten(0, 1).unflatten(
0, (block_indices.size(0), -1))

if enable_merged_prefill and attn_metadata.is_prompt and kv_cache is not None:
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

key_cache = self.k_cache(padded_key_tensor, key_cache,
kv_cache, self.num_kv_heads, self.head_size)
key_cache = self.k_cache(key, key_cache,
block_indices, block_offsets)
value_cache = self.v_cache(padded_value_tensor, value_cache,
value_cache = self.v_cache(value, value_cache,
block_indices, block_offsets)
else:
if attn_metadata.is_prompt:
Expand Down
11 changes: 3 additions & 8 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,6 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
seq_lens_tensor_list: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
if isinstance(hidden_states, torch.Tensor):
skip_split = hidden_states.size()[0] == 1
Expand All @@ -314,8 +313,7 @@ def forward(
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
seq_lens_tensor_list=seq_lens_tensor_list)
attn_metadata=attn_metadata)

# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
Expand Down Expand Up @@ -481,15 +479,11 @@ def forward(
import habana_frameworks.torch as htorch
htorch.core.mark_step()

if attn_metadata.enable_merged_prefill and attn_metadata.is_prompt:
seq_lens_tensor_list = attn_metadata.seq_lens_tensor.tolist()
else:
seq_lens_tensor_list = None
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual, seq_lens_tensor_list)
attn_metadata, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
Expand All @@ -499,6 +493,7 @@ def forward(
# we need to split result before do RMSNorm
if attn_metadata.enable_merged_prefill and attn_metadata.is_prompt:
max_len=attn_metadata.slot_mapping.size(1)
seq_lens_tensor_list = attn_metadata.seq_lens_tensor.tolist()[:attn_metadata.slot_mapping.size(0)]
hidden_states = split_and_pad_to_length(hidden_states.view(-1, hidden_states.size(2)), max_len, seq_lens_tensor_list)
residual = split_and_pad_to_length(residual.view(-1, hidden_states.size(2)), max_len, seq_lens_tensor_list)
hidden_states, _ = self.norm(hidden_states, residual)
Expand Down
48 changes: 24 additions & 24 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,22 +227,7 @@ def generate_prompt_buckets(self):
self.max_num_batched_tokens)

print("prompt_buckets: ", prompt_buckets)
# expand
self.global_state.prompt_buckets = []
VLLM_PROMPT_BS_BUCKET_MAX = int(
os.environ.get('VLLM_PROMPT_BS_BUCKET_MAX', 16))
for bucket in prompt_buckets:
bs = 1
while bs <= VLLM_PROMPT_BS_BUCKET_MAX:
seq_len = bucket[1] // bs
if seq_len <= 32:
bs = bs * 2
continue
self.global_state.prompt_buckets.append(
(bs * bucket[0], seq_len))
bs = bs * 2

self.global_state.prompt_buckets = list(filter(lambda bucket: bucket[1] <= origin_max_prompt_len, self.global_state.prompt_buckets))
self.global_state.prompt_buckets = list(filter(lambda bucket: bucket[1] <= origin_max_prompt_len or bucket[0] > 1, prompt_buckets))

msg = (f"Generated {len(self.global_state.prompt_buckets)} "
f"prompt buckets [bs, seq]: "
Expand Down Expand Up @@ -425,13 +410,13 @@ def _set_block_scales(self, metadata, device):
return metadata

def _set_indices_and_offsets(self, metadata, block_size, is_prompt):
slot_mapping = metadata.slot_mapping.flatten()
indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
if is_prompt:
indices = indices.unflatten(0, (-1, block_size))[:, 0]
offsets = None
if metadata.enable_merged_prefill and is_prompt:
# remove 0 in tensor
slot_mapping = metadata.slot_mapping_merged
else:
offsets = torch.fmod(slot_mapping, block_size)
slot_mapping = metadata.slot_mapping.flatten()
indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
offsets = torch.fmod(slot_mapping, block_size)
metadata = metadata._replace(block_offsets=offsets,
block_indices=indices)
return metadata
Expand Down Expand Up @@ -481,8 +466,7 @@ def forward(self, *args, **kwargs):
self._prepare_cos_sin(kwargs['positions'])
if kwargs['attn_metadata'].is_prompt:
print("Warming up HPU Graph - input_ids: ", input_ids.shape,
"seq_lens_tensor: ", kwargs['attn_metadata'].seq_lens_tensor,
"selected_token_indices: ", selected_token_indices)
"seq_lens_tensor: ", kwargs['attn_metadata'].seq_lens_tensor.shape, "block_indices: ", kwargs['attn_metadata'].block_indices.shape)
hidden_states = self.model(*args, **kwargs)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = hidden_states.index_select(0, selected_token_indices)
Expand Down Expand Up @@ -1244,6 +1228,8 @@ def _prepare_prompt_merged(
#context_lens
#prefix_block_list

slot_mapping_merged = list(itertools.chain.from_iterable(slot_mapping))
slot_mapping_merged = [i for i in slot_mapping_merged if i != _PAD_SLOT_ID]
input_tokens_merged = list(itertools.chain.from_iterable(input_tokens))
input_tokens_merged = [input_tokens_merged]
input_positions_merged = list(
Expand Down Expand Up @@ -1289,6 +1275,16 @@ def _prepare_prompt_merged(
dtype=torch.long,
device='cpu')

slot_mapping_merged = make_tensor_with_pad(slot_mapping_merged,
max_len=merged_prompt_len,
pad=_PAD_SLOT_ID,
dtype=torch.long,
device='cpu')

max_prefill_bs = int(os.environ.get('VLLM_PROMPT_BS_BUCKET_MAX', '16'))
max_prefill_bs = max(max_prefill_bs, len(seq_lens))
seq_lens = seq_lens + [0] * (max_prefill_bs - len(seq_lens))
context_lens = context_lens + [0] * (max_prefill_bs - len(context_lens))
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.long,
device='cpu')
Expand All @@ -1306,6 +1302,8 @@ def _prepare_prompt_merged(
self.device, non_blocking=True)
slot_mapping = slot_mapping.to( # type: ignore
self.device, non_blocking=True)
slot_mapping_merged = slot_mapping_merged.to( # type: ignore
self.device, non_blocking=True)
seq_lens_tensor = seq_lens_tensor.to(self.device, non_blocking=True)
context_lens_tensor = context_lens_tensor.to(self.device,
non_blocking=True)
Expand All @@ -1327,6 +1325,7 @@ def _prepare_prompt_merged(
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=slot_mapping,
slot_mapping_merged=slot_mapping_merged,
multi_modal_placeholder_index_maps=
None # FIXME(kzawora): mutli-modality will not work here
)
Expand Down Expand Up @@ -1723,6 +1722,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
'block_mapping',
'block_usage',
'slot_mapping',
'slot_mapping_merged',
'is_prompt',
'block_indices',
'block_offsets',
Expand Down

0 comments on commit b7d0931

Please sign in to comment.