Skip to content

Commit

Permalink
[BUG fix] Rebase caused spec decode fix (#613)
Browse files Browse the repository at this point in the history
Error reported in https://jira.habana-labs.com/browse/SW-212516

Found two recent merged PR breaks down Spec Decode functionality:

1. #491 overrides existing
workerwrapperBase design for speculative decoding.
```
if model_runner_cls is not None:
    ModelRunnerClass = model_runner_cls
```
is not needed since we now use codes as below for init model_runner_cls
to follow upstream design.
```
if model_runner_cls is not None:
            self.model_runner = model_runner_cls(self.model_runner)
```

2. #566 is not working in Spec
Decode Eagle mode
Due to input tensors is now different to the pre-assumption that
decode_fwd only provide one token per seq. Spec Decode provides multiple
candidates tokens as q.
To fix that, added a new ENV - "**VLLM_COS_SIN_RECOMPUTE**=true", need
to use it to trigger recompute to cos and sin for spec decode.

---------

Signed-off-by: Chendi.Xue <[email protected]>
  • Loading branch information
xuechendi authored Jan 7, 2025
1 parent 5b5bf26 commit 2d24be7
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .jenkins/test_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,4 @@ stages:
command: TORCH_COMPILE_DISABLE=true VLLM_CONTIGUOUS_PA=false VLLM_SKIP_WARMUP=True pytest -v tests/spec_decode/e2e/test_medusa_correctness.py::test_medusa_e2e_greedy_correctness
- name: gsm8k_small_g2_tp1_eagle_spec_decode
flavor: g2
command: TORCH_COMPILE_DISABLE=true VLLM_CONTIGUOUS_PA=false VLLM_SKIP_WARMUP=True pytest -v tests/spec_decode/e2e/test_eagle_correctness.py::test_eagle_e2e_greedy_correctness
command: VLLM_COS_SIN_RECOMPUTE=true TORCH_COMPILE_DISABLE=true VLLM_CONTIGUOUS_PA=false VLLM_SKIP_WARMUP=True pytest -v tests/spec_decode/e2e/test_eagle_correctness.py::test_eagle_e2e_greedy_correctness
6 changes: 5 additions & 1 deletion vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ def __init__(

def prepare_cos_sin(self,
positions: torch.Tensor,
offsets: Optional[torch.Tensor] = None):
offsets: Optional[torch.Tensor] = None,
recompute_cos_sin: bool = False):
self.recompute_cos_sin = recompute_cos_sin
if offsets is not None:
offsets = offsets.view(positions.shape[0], -1)
positions = positions + offsets
Expand Down Expand Up @@ -237,6 +239,8 @@ def forward_hpu(
# forward, since the offset information wasn't available previously
if hasattr(self, "scaling_factors") or self.sin is None:
self.prepare_cos_sin(positions, offsets)
if self.recompute_cos_sin:
self.prepare_cos_sin(positions, offsets, recompute_cos_sin=True)
num_tokens = positions.shape[0] * positions.shape[1]
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
# to query hidden dimension, so the original tensors need to be
Expand Down
3 changes: 2 additions & 1 deletion vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,7 +1182,8 @@ def update(self,
second_last_token_hidden_states: Optional[torch.Tensor] = None):
"""Update hidden states from target model invocation. Only used for
decode steps"""
assert len(seq_group_metadata_list) == len(hidden_states)
if len(seq_group_metadata_list) < len(hidden_states):
hidden_states = hidden_states[:len(seq_group_metadata_list)]
self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
self.hidden_states = torch.cat([self.hidden_states, hidden_states])

Expand Down
5 changes: 4 additions & 1 deletion vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ def __init__(self, model, block_size, dtype, enforce_eager, layer_names):
self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'1').lower() in ['1', 'true'] \
and not is_fake_hpu()
self.recompute_cos_sin = os.getenv('VLLM_COS_SIN_RECOMPUTE',
'false').lower() in ['1', 'true']
self.block_size = block_size
self.dtype = dtype
self.layer_names = layer_names
Expand Down Expand Up @@ -370,7 +372,8 @@ def _prepare_cos_sin(self, positions):

# At the end, we should be at the RotaryEmbedding layer.
if hasattr(current_module, 'prepare_cos_sin'):
current_module.prepare_cos_sin(positions)
current_module.prepare_cos_sin(
positions, recompute_cos_sin=self.recompute_cos_sin)
else:
raise AttributeError(
"The module at the end of the path does not have \
Expand Down
4 changes: 1 addition & 3 deletions vllm/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ def __init__(

is_encoder_decoder_model = self._is_encoder_decoder_model()
ModelRunnerClass: Type[HPUModelRunnerBase] = HPUModelRunner
if model_runner_cls is not None:
ModelRunnerClass = model_runner_cls
elif is_encoder_decoder_model:
if is_encoder_decoder_model:
ModelRunnerClass = HPUEncoderDecoderModelRunner
self.model_runner: HPUModelRunnerBase = ModelRunnerClass(
vllm_config=vllm_config,
Expand Down

0 comments on commit 2d24be7

Please sign in to comment.