Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add exponential bucketing integration #642

Draft
wants to merge 2 commits into
base: habana_main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@4312768
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@1c23cdf
27 changes: 9 additions & 18 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,25 +130,16 @@
return batch_size * max_seq_len

def _hpu_padding_fn(self, batch_size, max_seq_len):
from vllm_hpu_extension.bucketing import (HPUBucketingGlobalState,
find_bucket)
padded_bs = batch_size
padded_seq = max_seq_len

hpu_bucketing_global_state = HPUBucketingGlobalState()

bs_cfg = hpu_bucketing_global_state.prompt_bs_bucket_cfg
if bs_cfg is not None:
padded_bs = find_bucket(batch_size, bs_cfg)
use_exponential_bucketing = os.environ.get('VLLM_EXPONENTIAL_BUCKETING',
'true').lower() == 'true'
if use_exponential_bucketing:
from vllm_hpu_extension.bucketing.exponential import HPUExponentialBucketingContext as HPUBucketingContext

Check failure on line 136 in vllm/core/scheduler.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/core/scheduler.py:136:81: E501 Line too long (118 > 80)
else:
logger.warning(
"prompt_bs_bucket_cfg was not set! Using unpadded batch size.")
seq_cfg = hpu_bucketing_global_state.prompt_seq_bucket_cfg
if seq_cfg is not None:
padded_seq = find_bucket(max_seq_len, seq_cfg)
else:
logger.warning("prompt_seq_bucket_cfg was not set! "
"Using unpadded sequence length.")
from vllm_hpu_extension.bucketing.linear import HPUBucketingContext

hpu_bucketing_context = HPUBucketingContext()
padded_bs = hpu_bucketing_context.get_padded_prompt_batch_size(batch_size)

Check failure on line 141 in vllm/core/scheduler.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/core/scheduler.py:141:81: E501 Line too long (82 > 80)
padded_seq = hpu_bucketing_context.get_padded_prompt_seq_len(max_seq_len)

Check failure on line 142 in vllm/core/scheduler.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/core/scheduler.py:142:81: E501 Line too long (81 > 80)
return padded_bs * padded_seq

def _padding_fn_selector(self):
Expand Down
46 changes: 20 additions & 26 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import habana_frameworks.torch as htorch
import habana_frameworks.torch.internal.bridge_config as bc
import torch
from vllm_hpu_extension.bucketing import HPUBucketingContext
from vllm_hpu_extension.ops import LoraMask as LoraMask
from vllm_hpu_extension.ops import batch2block, block2batch
from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
Expand Down Expand Up @@ -296,11 +295,11 @@
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
mask, -math.inf))

if not is_fake_hpu():
if not is_fake_hpu() and htorch.utils.internal.is_lazy():
block_mapping = torch.nn.functional.one_hot(metadata.block_groups,
num_classes=batch_size)
else:
# Unfortunately one_hot on CPU
# Unfortunately one_hot on CPU/torch.compile mode/eager mode
# doesn't handle out of bounds classes so we need to convert
# all negative values to 0 (block_mapping) or bs (block_groups)
block_groups = metadata.block_groups.to(torch.long)
Expand Down Expand Up @@ -622,10 +621,18 @@
self.profiler_counter_helper = HabanaProfilerCounterHelper()
self.seen_configs: set = set()
self._mem_margin: Optional[int] = None
self.use_exponential_bucketing = os.environ.get('VLLM_EXPONENTIAL_BUCKETING',

Check failure on line 624 in vllm/worker/hpu_model_runner.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/worker/hpu_model_runner.py:624:81: E501 Line too long (85 > 80)
'true').lower() == 'true'
if self.use_exponential_bucketing:
from vllm_hpu_extension.bucketing.exponential import HPUExponentialBucketingContext as HPUBucketingContext

Check failure on line 627 in vllm/worker/hpu_model_runner.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/worker/hpu_model_runner.py:627:81: E501 Line too long (118 > 80)
else:
from vllm_hpu_extension.bucketing.linear import HPUBucketingContext

self.bucketing_ctx = HPUBucketingContext(self.max_num_seqs,
self.max_num_prefill_seqs,
self.block_size,
self.max_num_batched_tokens)
self.max_num_batched_tokens,
self.max_model_len)
self.graphed_buckets: Set[Any] = set()

self._set_gc_threshold()
Expand Down Expand Up @@ -877,6 +884,8 @@
"sliding window attention")
start_idx = max(0, seq_len - self.sliding_window)
for i in range(context_len, seq_len):
if self.scheduler_config.task == 'embedding':
break
if i < start_idx:
slot_mapping[-1].append(_PAD_SLOT_ID)
continue
Expand Down Expand Up @@ -907,7 +916,7 @@
lora_prompt_mapping.extend(
[lora_id] *
(max_prompt_len
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
if seq_group_metadata.sampling_params is not None and seq_group_metadata.sampling_params.prompt_logprobs else 1))

Check failure on line 919 in vllm/worker/hpu_model_runner.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/worker/hpu_model_runner.py:919:81: E501 Line too long (130 > 80)

if any(context_lens):
assert not self.scheduler_config.chunked_prefill_enabled
Expand Down Expand Up @@ -1418,13 +1427,13 @@

def profile_run(self) -> None:
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers

Check failure on line 1430 in vllm/worker/hpu_model_runner.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F841)

vllm/worker/hpu_model_runner.py:1430:9: F841 Local variable `kv_caches` is assigned to but never used
_, max_seq_len = self.bucketing_ctx.get_max_prompt_shape()
max_batch_size = min(self.max_num_seqs,

Check failure on line 1432 in vllm/worker/hpu_model_runner.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F841)

vllm/worker/hpu_model_runner.py:1432:9: F841 Local variable `max_batch_size` is assigned to but never used
self.max_num_batched_tokens // max_seq_len)

self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches,
False, True)
# self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches,
# False, True)
return

def warmup_scenario(self,
Expand Down Expand Up @@ -2019,19 +2028,6 @@

return lora_mask, lora_logits_mask

def add_dummy_seq(self, seq_group_metadata_list, is_prompt):
real_batch_size = len(seq_group_metadata_list)
batch_size_padded = self.bucketing_ctx.get_padded_batch_size(
real_batch_size, is_prompt)
batch_size_padding = batch_size_padded - real_batch_size
seq_group_metadata_list = seq_group_metadata_list.copy()
if batch_size_padding > 0:
dummy_seq_group_metadata = self.create_dummy_seq_group_metadata(
0, 0, is_prompt)
seq_group_metadata_list.extend(dummy_seq_group_metadata
for _ in range(batch_size_padding))
return seq_group_metadata_list

@torch.inference_mode()
def execute_model(
self,
Expand Down Expand Up @@ -2118,8 +2114,8 @@
def try_revert_dummy_output_tokens():
if len(cache_orig_output_tokens_len) > 0:
# Reuse the original output token ids length
for i in range(len(cache_orig_output_tokens_len)):
seq_group_metadata = seq_group_metadata_list[i]
for i, seq_group_metadata in enumerate(
seq_group_metadata_list):
for j, data in seq_group_metadata.seq_data.items():
orig_output_tokens_len = \
cache_orig_output_tokens_len[i][j]
Expand Down Expand Up @@ -2197,18 +2193,16 @@
else:
raise RuntimeError(
"seq_group_metadata_list is uninitialized")
for seq_idx, seq_group_metadata in enumerate(
for i, seq_group_metadata in enumerate(
seq_group_metadata_list):
# Skip empty steps
seq_group_metadata.state.current_step += (
num_steps - 2)
# Cache the original output token ids
cache_orig_output_tokens_len.append({})
for j, data in seq_group_metadata.seq_data.items():
cache_orig_output_tokens_len[seq_idx][j] = \
cache_orig_output_tokens_len[i][j] = \
len(data.output_token_ids)
seq_group_metadata_list = self.add_dummy_seq(
seq_group_metadata_list, is_prompt=False)
for seq_group_metadata in seq_group_metadata_list:
for data in seq_group_metadata.seq_data.values():
max_output_len = sampling_metadata.seq_groups[
Expand Down
Loading