Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove assert for alibi in case of FusedSDPA. #587

Open
wants to merge 1 commit into
base: habana_main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,6 @@ def __init__(
self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'1').lower() in ['1', 'true'] \
and not is_fake_hpu()
if self.prefill_use_fusedsdpa:
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'

suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
Expand Down Expand Up @@ -196,7 +193,8 @@ def forward(
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
self.head_size)
if attn_metadata is None or attn_metadata.block_list is None:
if not self.prefill_use_fusedsdpa:
if (not self.prefill_use_fusedsdpa
or self.alibi_slopes is not None):
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward'
Expand Down
6 changes: 1 addition & 5 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,6 @@ class HpuModelAdapter:

def __init__(self, model, block_size, dtype, enforce_eager):
self.model = model
self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'1').lower() in ['1', 'true'] \
and not is_fake_hpu()
self.block_size = block_size
self.dtype = dtype
if not is_fake_hpu() and not htorch.utils.internal.is_lazy(
Expand Down Expand Up @@ -212,8 +209,7 @@ def _compile_region(self, model, name, module):

def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device,
dtype):
if (attn_metadata is None or self.prefill_use_fusedsdpa
or not attn_metadata.is_prompt):
if (attn_metadata is None or not attn_metadata.is_prompt):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove prefill_use_fusedsdpa here?

return attn_metadata

prefill_metadata = attn_metadata
Expand Down
Loading