Skip to content

Commit

Permalink
Add support for various softmax normalization options (#420)
Browse files Browse the repository at this point in the history
  • Loading branch information
madamczykhabana authored Oct 23, 2024
1 parent 892c090 commit 7f58ad1
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 2 deletions.
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@fd7f2e6
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@c2801bb
1 change: 1 addition & 0 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def forward(
block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.attn_bias,
block_scales=attn_metadata.block_scales,
block_groups=attn_metadata.block_groups,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/ops/hpu_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class HPUPagedAttentionMetadata:
block_indices: Optional[torch.Tensor]
block_offsets: Optional[torch.Tensor]
block_scales: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor]


class HPUPagedAttention:
Expand Down
9 changes: 8 additions & 1 deletion vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,7 @@ def _prepare_prompt(
block_indices=block_indices,
block_offsets=block_offsets,
block_scales=None,
block_groups=None,
attn_bias=None,
seq_lens_tensor=seq_lens_tensor,
num_prefills=real_num_seqs,
Expand Down Expand Up @@ -1028,6 +1029,8 @@ def _prepare_decode(
len(block_list),
self.bucketing_global_state.decode_block_bucket_cfg)
block_list = pad_list(block_list, block_bucket_size, _PAD_BLOCK_ID)
block_groups = pad_list(block_mapping, block_bucket_size,
len(block_tables))
block_mapping = pad_list(block_mapping, block_bucket_size, -1)
block_usage = pad_list(block_usage, block_bucket_size, 1)
block_scales = pad_list(block_scales, block_bucket_size, 0.0)
Expand All @@ -1038,6 +1041,9 @@ def _prepare_decode(
block_mapping = torch.tensor(block_mapping,
dtype=torch.long,
device=self.device)
block_groups = torch.tensor(block_groups,
dtype=torch.long,
device=self.device)
block_usage = torch.tensor(block_usage,
dtype=self.model_config.dtype,
device=self.device)
Expand All @@ -1060,6 +1066,7 @@ def _prepare_decode(
block_indices=block_indices,
block_offsets=block_offsets,
block_scales=block_scales,
block_groups=block_groups,
attn_bias=None,
seq_lens_tensor=None,
num_prefills=0,
Expand Down Expand Up @@ -1271,7 +1278,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [
'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping',
'block_usage', 'slot_mapping', 'is_prompt', 'block_indices',
'block_offsets', 'block_scales'
'block_offsets', 'block_scales', 'block_groups'
])
return attention_metadata

Expand Down

0 comments on commit 7f58ad1

Please sign in to comment.