diff --git a/requirements-hpu.txt b/requirements-hpu.txt index f4fb89ef42834..e7c8aaa1cf814 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -8,4 +8,4 @@ pandas tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@4312768 +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@1c23cdf diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 7c3679d40546d..67d6f3fad093f 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -19,7 +19,6 @@ import habana_frameworks.torch as htorch import habana_frameworks.torch.internal.bridge_config as bc import torch -from vllm_hpu_extension.bucketing import HPUBucketingContext from vllm_hpu_extension.ops import LoraMask as LoraMask from vllm_hpu_extension.ops import batch2block, block2batch from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler, @@ -296,11 +295,11 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype): attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( mask, -math.inf)) - if not is_fake_hpu(): + if not is_fake_hpu() and htorch.utils.internal.is_lazy(): block_mapping = torch.nn.functional.one_hot(metadata.block_groups, num_classes=batch_size) else: - # Unfortunately one_hot on CPU + # Unfortunately one_hot on CPU/torch.compile mode/eager mode # doesn't handle out of bounds classes so we need to convert # all negative values to 0 (block_mapping) or bs (block_groups) block_groups = metadata.block_groups.to(torch.long) @@ -622,6 +621,13 @@ def __init__( self.profiler_counter_helper = HabanaProfilerCounterHelper() self.seen_configs: set = set() self._mem_margin: Optional[int] = None + self.use_exponential_bucketing = os.environ.get('VLLM_EXPONENTIAL_BUCKETING', + 'true').lower() == 'true' + if self.use_exponential_bucketing: + from vllm_hpu_extension.bucketing.exponential import HPUExponentialBucketingContext as HPUBucketingContext + else: + from vllm_hpu_extension.bucketing.linear import HPUBucketingContext + self.bucketing_ctx = HPUBucketingContext(self.max_num_seqs, self.max_num_prefill_seqs, self.block_size, @@ -877,6 +883,8 @@ def _prepare_prompt( "sliding window attention") start_idx = max(0, seq_len - self.sliding_window) for i in range(context_len, seq_len): + if self.scheduler_config.task == 'embedding': + break if i < start_idx: slot_mapping[-1].append(_PAD_SLOT_ID) continue @@ -907,7 +915,7 @@ def _prepare_prompt( lora_prompt_mapping.extend( [lora_id] * (max_prompt_len - if seq_group_metadata.sampling_params.prompt_logprobs else 1)) + if seq_group_metadata.sampling_params is not None and seq_group_metadata.sampling_params.prompt_logprobs else 1)) if any(context_lens): assert not self.scheduler_config.chunked_prefill_enabled @@ -2019,19 +2027,6 @@ def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int], return lora_mask, lora_logits_mask - def add_dummy_seq(self, seq_group_metadata_list, is_prompt): - real_batch_size = len(seq_group_metadata_list) - batch_size_padded = self.bucketing_ctx.get_padded_batch_size( - real_batch_size, is_prompt) - batch_size_padding = batch_size_padded - real_batch_size - seq_group_metadata_list = seq_group_metadata_list.copy() - if batch_size_padding > 0: - dummy_seq_group_metadata = self.create_dummy_seq_group_metadata( - 0, 0, is_prompt) - seq_group_metadata_list.extend(dummy_seq_group_metadata - for _ in range(batch_size_padding)) - return seq_group_metadata_list - @torch.inference_mode() def execute_model( self, @@ -2118,8 +2113,8 @@ def execute_model( def try_revert_dummy_output_tokens(): if len(cache_orig_output_tokens_len) > 0: # Reuse the original output token ids length - for i in range(len(cache_orig_output_tokens_len)): - seq_group_metadata = seq_group_metadata_list[i] + for i, seq_group_metadata in enumerate( + seq_group_metadata_list): for j, data in seq_group_metadata.seq_data.items(): orig_output_tokens_len = \ cache_orig_output_tokens_len[i][j] @@ -2141,6 +2136,7 @@ def try_revert_dummy_output_tokens(): self.trim_attn_metadata( broadcast_data["attn_metadata"]) }) + import pdb; pdb.set_trace() with self.profiler.record_event('internal', model_event_name): hidden_states = self.model.forward( **execute_model_kwargs, @@ -2197,7 +2193,7 @@ def try_revert_dummy_output_tokens(): else: raise RuntimeError( "seq_group_metadata_list is uninitialized") - for seq_idx, seq_group_metadata in enumerate( + for i, seq_group_metadata in enumerate( seq_group_metadata_list): # Skip empty steps seq_group_metadata.state.current_step += ( @@ -2205,10 +2201,8 @@ def try_revert_dummy_output_tokens(): # Cache the original output token ids cache_orig_output_tokens_len.append({}) for j, data in seq_group_metadata.seq_data.items(): - cache_orig_output_tokens_len[seq_idx][j] = \ + cache_orig_output_tokens_len[i][j] = \ len(data.output_token_ids) - seq_group_metadata_list = self.add_dummy_seq( - seq_group_metadata_list, is_prompt=False) for seq_group_metadata in seq_group_metadata_list: for data in seq_group_metadata.seq_data.values(): max_output_len = sampling_metadata.seq_groups[