Skip to content

Commit

Permalink
Add exponential bucketing integration
Browse files Browse the repository at this point in the history
  • Loading branch information
kzawora-intel committed Dec 17, 2024
1 parent da61ecf commit 7cbe922
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 24 deletions.
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@4312768
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@1c23cdf
40 changes: 17 additions & 23 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import habana_frameworks.torch as htorch
import habana_frameworks.torch.internal.bridge_config as bc
import torch
from vllm_hpu_extension.bucketing import HPUBucketingContext
from vllm_hpu_extension.ops import LoraMask as LoraMask
from vllm_hpu_extension.ops import batch2block, block2batch
from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
Expand Down Expand Up @@ -296,11 +295,11 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype):
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
mask, -math.inf))

if not is_fake_hpu():
if not is_fake_hpu() and htorch.utils.internal.is_lazy():
block_mapping = torch.nn.functional.one_hot(metadata.block_groups,
num_classes=batch_size)
else:
# Unfortunately one_hot on CPU
# Unfortunately one_hot on CPU/torch.compile mode/eager mode
# doesn't handle out of bounds classes so we need to convert
# all negative values to 0 (block_mapping) or bs (block_groups)
block_groups = metadata.block_groups.to(torch.long)
Expand Down Expand Up @@ -622,6 +621,13 @@ def __init__(
self.profiler_counter_helper = HabanaProfilerCounterHelper()
self.seen_configs: set = set()
self._mem_margin: Optional[int] = None
self.use_exponential_bucketing = os.environ.get('VLLM_EXPONENTIAL_BUCKETING',

Check failure on line 624 in vllm/worker/hpu_model_runner.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/worker/hpu_model_runner.py:624:81: E501 Line too long (85 > 80)
'true').lower() == 'true'
if self.use_exponential_bucketing:
from vllm_hpu_extension.bucketing.exponential import HPUExponentialBucketingContext as HPUBucketingContext

Check failure on line 627 in vllm/worker/hpu_model_runner.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/worker/hpu_model_runner.py:627:81: E501 Line too long (118 > 80)
else:
from vllm_hpu_extension.bucketing.linear import HPUBucketingContext

self.bucketing_ctx = HPUBucketingContext(self.max_num_seqs,
self.max_num_prefill_seqs,
self.block_size,
Expand Down Expand Up @@ -877,6 +883,8 @@ def _prepare_prompt(
"sliding window attention")
start_idx = max(0, seq_len - self.sliding_window)
for i in range(context_len, seq_len):
if self.scheduler_config.task == 'embedding':
break
if i < start_idx:
slot_mapping[-1].append(_PAD_SLOT_ID)
continue
Expand Down Expand Up @@ -907,7 +915,7 @@ def _prepare_prompt(
lora_prompt_mapping.extend(
[lora_id] *
(max_prompt_len
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
if seq_group_metadata.sampling_params is not None and seq_group_metadata.sampling_params.prompt_logprobs else 1))

Check failure on line 918 in vllm/worker/hpu_model_runner.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/worker/hpu_model_runner.py:918:81: E501 Line too long (130 > 80)

if any(context_lens):
assert not self.scheduler_config.chunked_prefill_enabled
Expand Down Expand Up @@ -2019,19 +2027,6 @@ def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],

return lora_mask, lora_logits_mask

def add_dummy_seq(self, seq_group_metadata_list, is_prompt):
real_batch_size = len(seq_group_metadata_list)
batch_size_padded = self.bucketing_ctx.get_padded_batch_size(
real_batch_size, is_prompt)
batch_size_padding = batch_size_padded - real_batch_size
seq_group_metadata_list = seq_group_metadata_list.copy()
if batch_size_padding > 0:
dummy_seq_group_metadata = self.create_dummy_seq_group_metadata(
0, 0, is_prompt)
seq_group_metadata_list.extend(dummy_seq_group_metadata
for _ in range(batch_size_padding))
return seq_group_metadata_list

@torch.inference_mode()
def execute_model(
self,
Expand Down Expand Up @@ -2118,8 +2113,8 @@ def execute_model(
def try_revert_dummy_output_tokens():
if len(cache_orig_output_tokens_len) > 0:
# Reuse the original output token ids length
for i in range(len(cache_orig_output_tokens_len)):
seq_group_metadata = seq_group_metadata_list[i]
for i, seq_group_metadata in enumerate(
seq_group_metadata_list):
for j, data in seq_group_metadata.seq_data.items():
orig_output_tokens_len = \
cache_orig_output_tokens_len[i][j]
Expand All @@ -2141,6 +2136,7 @@ def try_revert_dummy_output_tokens():
self.trim_attn_metadata(
broadcast_data["attn_metadata"])
})
import pdb; pdb.set_trace()

Check failure on line 2139 in vllm/worker/hpu_model_runner.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E702)

vllm/worker/hpu_model_runner.py:2139:27: E702 Multiple statements on one line (semicolon)
with self.profiler.record_event('internal', model_event_name):
hidden_states = self.model.forward(
**execute_model_kwargs,
Expand Down Expand Up @@ -2197,18 +2193,16 @@ def try_revert_dummy_output_tokens():
else:
raise RuntimeError(
"seq_group_metadata_list is uninitialized")
for seq_idx, seq_group_metadata in enumerate(
for i, seq_group_metadata in enumerate(
seq_group_metadata_list):
# Skip empty steps
seq_group_metadata.state.current_step += (
num_steps - 2)
# Cache the original output token ids
cache_orig_output_tokens_len.append({})
for j, data in seq_group_metadata.seq_data.items():
cache_orig_output_tokens_len[seq_idx][j] = \
cache_orig_output_tokens_len[i][j] = \
len(data.output_token_ids)
seq_group_metadata_list = self.add_dummy_seq(
seq_group_metadata_list, is_prompt=False)
for seq_group_metadata in seq_group_metadata_list:
for data in seq_group_metadata.seq_data.values():
max_output_len = sampling_metadata.seq_groups[
Expand Down

0 comments on commit 7cbe922

Please sign in to comment.