Skip to content

Commit

Permalink
Fix accuracy issue
Browse files Browse the repository at this point in the history
Signed-off-by: Chendi.Xue <[email protected]>
  • Loading branch information
xuechendi committed Dec 20, 2024
1 parent b133542 commit b6f6961
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 27 deletions.
28 changes: 2 additions & 26 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
HPUPagedAttentionMetadata)
from vllm.logger import init_logger
from vllm.utils import is_fake_hpu
from vllm.model_executor.models.utils import split_and_pad_to_length

logger = init_logger(__name__)

Expand All @@ -30,28 +31,6 @@
logger.warning("Could not import HPU FusedSDPA kernel. "
"vLLM will use native implementation.")

def split_and_pad_to_length(input, target_length, seq_lens_tensor_list):
# we need to copy the key and value tensors to the padded tensors
# shape is [bacth_size, entire_seq_len, num_kv_heads, head_size]
padded_list = torch.split_with_sizes(input[:sum(seq_lens_tensor_list)], seq_lens_tensor_list, dim=0)

padded_tensor = torch.nn.utils.rnn.pad_sequence(padded_list, batch_first=True)
p3d = (0, 0, 0, 0, 0, target_length - padded_tensor.size(1))
padded_tensor = torch.nn.functional.pad(padded_tensor, p3d, value=0)
return padded_tensor

def split_and_pad_to_length_2(input, target_length, seq_lens_tensor_list):
# we need to copy the key and value tensors to the padded tensors
# shape is [bacth_size, entire_seq_len, num_kv_heads, head_size]
padded_tensor = torch.zeros((len(seq_lens_tensor_list), target_length, input.size(1), input.size(2)), device=input.device, dtype=input.dtype)

start = 0
for i in range(len(seq_lens_tensor_list)):
padded_tensor[i, :seq_lens_tensor_list[i], :, :] = input[start: start + seq_lens_tensor_list[i], :, :]
start = start + seq_lens_tensor_list[i]

return padded_tensor

def prompt_attention(
query: torch.Tensor,
key: torch.Tensor,
Expand Down Expand Up @@ -306,8 +285,6 @@ def forward(
padded_key_tensor = padded_key_tensor.flatten(0, 1).unflatten(0, (block_indices.size(0), -1))
padded_value_tensor = padded_value_tensor.flatten(0, 1).unflatten(0, (block_indices.size(0), -1))

#seq_lens_tensor_merged = torch.tensor(sum(seq_lens_tensor_list), device=seq_lens_tensor.device, dtype=seq_lens_tensor.dtype).unsqueeze(0)
seq_lens_tensor_merged = seq_lens_tensor
if kv_cache is not None:
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
Expand All @@ -320,7 +297,6 @@ def forward(
if attn_metadata.is_prompt:
key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1))
seq_lens_tensor_merged = seq_lens_tensor
if kv_cache is not None:
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
Expand Down Expand Up @@ -369,7 +345,7 @@ def forward(
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
valid_seq_lengths=seq_lens_tensor_merged,
valid_seq_lengths=seq_lens_tensor,
fsdpa_op=self.fused_scaled_dot_product_attention,
)
else:
Expand Down
8 changes: 7 additions & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
maybe_prefix, split_and_pad_to_length)

is_hpu = current_platform.is_hpu()

Expand Down Expand Up @@ -490,6 +490,12 @@ def forward(
"residual": residual
})

# we need to split result before do RMSNorm
if attn_metadata.enable_merged_prefill and attn_metadata.is_prompt:
max_len=attn_metadata.slot_mapping.size(1)
seq_lens_tensor_list = attn_metadata.seq_lens_tensor.tolist()
hidden_states = split_and_pad_to_length(hidden_states.view(-1, hidden_states.size(2)), max_len, seq_lens_tensor_list)
residual = split_and_pad_to_length(residual.view(-1, hidden_states.size(2)), max_len, seq_lens_tensor_list)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states

Expand Down
23 changes: 23 additions & 0 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,3 +666,26 @@ def extract_layer_index(layer_name: str) -> int:
assert len(int_vals) == 1, (f"layer name {layer_name} should"
" only contain one integer")
return int_vals[0]

def split_and_pad_to_length(input, target_length, seq_lens_tensor_list):
# we need to copy the key and value tensors to the padded tensors
# shape is [bacth_size, entire_seq_len, num_kv_heads, head_size]
padded_list = torch.split_with_sizes(input[:sum(seq_lens_tensor_list)], seq_lens_tensor_list, dim=0)

padded_tensor = torch.nn.utils.rnn.pad_sequence(padded_list, batch_first=True)
pad_shape = [0] * (input.dim() - 1) * 2
pad_shape += [0, target_length - padded_tensor.size(1)]
padded_tensor = torch.nn.functional.pad(padded_tensor, pad_shape, value=0)
return padded_tensor

def split_and_pad_to_length_2(input, target_length, seq_lens_tensor_list):
# we need to copy the key and value tensors to the padded tensors
# shape is [bacth_size, entire_seq_len, num_kv_heads, head_size]
padded_tensor = torch.zeros((len(seq_lens_tensor_list), target_length, input.size(1), input.size(2)), device=input.device, dtype=input.dtype)

start = 0
for i in range(len(seq_lens_tensor_list)):
padded_tensor[i, :seq_lens_tensor_list[i], :, :] = input[start: start + seq_lens_tensor_list[i], :, :]
start = start + seq_lens_tensor_list[i]

return padded_tensor

0 comments on commit b6f6961

Please sign in to comment.