Skip to content

Commit

Permalink
[BUG FIX] [SPEC DECODE] 0.6.4 rebase cause incorrectness in spec deco…
Browse files Browse the repository at this point in the history
…de, fix in this PR (#523)

Noticed that Spec Decode went incorrect after rebase to 0.6.4

Identified root cause and fixed in the PR
1. incorrect return value position in batch_expansion.py
2. ContinuousPA generates faulty result in spec decode

CI added: #524

---------

Signed-off-by: Chendi.Xue <[email protected]>
  • Loading branch information
xuechendi authored Nov 28, 2024
1 parent 0c62b0b commit 756485f
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 25 deletions.
7 changes: 6 additions & 1 deletion .jenkins/test_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,9 @@ stages:
command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-mss.txt -t 2
- name: gsm8k_small_g2_tp2_mss
flavor: g2.s
command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-mss.txt -t 2
command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-mss.txt -t 2
- name: test_gsm8k_spec_decode
steps:
- name: gsm8k_small_g2_tp1_spec_decode
flavor: g2
command: cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-mss.txt -t 1
7 changes: 4 additions & 3 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,10 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
if model_config.quantization is None and loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError(
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")
warning_msg = f"Following weights were not initialized \
from checkpoint: {weights_not_loaded}"

logger.warning(warning_msg)

for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
Expand Down
95 changes: 78 additions & 17 deletions vllm/spec_decode/batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from vllm import SamplingParams
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
ExecuteModelRequest, SequenceData,
SequenceGroupMetadata, get_all_seq_ids)
Expand Down Expand Up @@ -159,11 +160,18 @@ def _contract_batch(
target_sampler_output will be contracted to.
"""
contracted_bs = len(contracted_seq_group_metadata_list)
(target_token_ids, target_probs, target_logprobs, target_hidden_states,
non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs,
non_spec_target_hidden_states) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens)
if current_platform.is_hpu():
(target_token_ids, target_probs, target_logprobs,
target_hidden_states, non_spec_target_token_ids,
non_spec_target_probs, non_spec_target_logprobs,
non_spec_target_hidden_states) = self._split_scoring_output_hpu(
target_sampler_output, num_scoring_tokens)
else:
(target_token_ids, target_probs, target_logprobs,
target_hidden_states, non_spec_target_token_ids,
non_spec_target_probs, non_spec_target_logprobs,
non_spec_target_hidden_states) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens)

# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
Expand Down Expand Up @@ -239,18 +247,30 @@ def _contract_batch_all_spec(
# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
contracted_bs, k = proposals.proposal_token_ids.shape

(
target_sampler_output.sampled_token_ids,
target_sampler_output.sampled_token_probs,
target_sampler_output.logprobs,
target_sampler_output.hidden_states,
_,
_,
_,
_,
) = self._split_scoring_output(target_sampler_output,
num_scoring_tokens)
if current_platform.is_hpu():
(
target_sampler_output.sampled_token_ids,
target_sampler_output.sampled_token_probs,
target_sampler_output.logprobs,
target_sampler_output.hidden_states,
_,
_,
_,
_,
) = self._split_scoring_output_hpu(target_sampler_output,
num_scoring_tokens)
else:
(
target_sampler_output.sampled_token_ids,
target_sampler_output.sampled_token_probs,
target_sampler_output.logprobs,
target_sampler_output.hidden_states,
_,
_,
_,
_,
) = self._split_scoring_output(target_sampler_output,
num_scoring_tokens)

# Reshape tensors to original batch size
target_token_ids = target_sampler_output.sampled_token_ids.reshape(
Expand Down Expand Up @@ -397,6 +417,47 @@ def _create_single_target_seq_group_metadata(
token_chunk_size=1,
)

@staticmethod
def _split_scoring_output_hpu(
sampler_output: SamplerOutput, num_scoring_tokens: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], torch.Tensor, torch.Tensor,
torch.Tensor, Optional[torch.Tensor]]:
"""Split the target model output into speculative and non-speculative
output.
"""

# vLLM currently only supports proposal lens equal to zero or the batch
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
#
# First samples are from speculative scoring, latter samples are non-
# speculative samples.
split_sizes = (num_scoring_tokens,
sampler_output.sampled_token_ids.numel() -
num_scoring_tokens)
(spec_probs, non_spec_probs
) = sampler_output.sampled_token_probs.split(split_sizes)
(spec_sampled_tokens, non_spec_sampled_tokens
) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
(
spec_logprobs,
non_spec_logprobs,
) = sampler_output.logprobs.split(split_sizes)

if sampler_output.hidden_states is not None:
(
spec_hidden_states,
non_spec_hidden_states,
) = sampler_output.hidden_states.split(split_sizes)
else:
spec_hidden_states, non_spec_hidden_states = None, None

return (spec_sampled_tokens, spec_probs, spec_logprobs,
spec_hidden_states, non_spec_sampled_tokens, non_spec_probs,
non_spec_logprobs, non_spec_hidden_states)

@staticmethod
def _split_scoring_output(
sampler_output: SamplerOutput, num_scoring_tokens: int
Expand Down
16 changes: 13 additions & 3 deletions vllm/spec_decode/hpu_draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,19 @@ def execute_model(
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if previous_hidden_states is not None:
_, block_size = model_input.input_tokens.shape
previous_hidden_states = previous_hidden_states.expand(
block_size, -1).unsqueeze(0)
batch_size, block_size = model_input.input_tokens.shape
previous_hidden_states = previous_hidden_states.unsqueeze(
dim=1).expand(-1, block_size, -1)
# because HPU will pad batch_size,
# we need to pad previous_hidden_states as well
batch_size_padding = batch_size - previous_hidden_states.shape[0]
if batch_size_padding > 0:
dummy_previous_hidden_states = torch.zeros_like(
previous_hidden_states[1:2]).expand(
batch_size_padding, -1, -1)
previous_hidden_states = torch.cat(
[previous_hidden_states, dummy_previous_hidden_states],
dim=0)
return super().execute_model(
model_input=model_input,
kv_caches=kv_caches,
Expand Down
20 changes: 19 additions & 1 deletion vllm/spec_decode/ngram_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,29 @@
import torch

from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer

if current_platform.is_cuda_alike():
DEVICE_TYPE = "cuda"
elif current_platform.is_neuron():
DEVICE_TYPE = "neuron"
elif current_platform.is_hpu():
DEVICE_TYPE = "hpu"
elif current_platform.is_openvino():
DEVICE_TYPE = "openvino"
elif current_platform.is_cpu():
DEVICE_TYPE = "cpu"
elif current_platform.is_tpu():
DEVICE_TYPE = "tpu"
elif current_platform.is_xpu():
DEVICE_TYPE = "xpu"
else:
raise ValueError(f"Unsupported platform: {current_platform}")


class NGramWorker(NonLLMProposerWorkerBase):
"""NGramWorker provides a light drafter without need for model.
Expand All @@ -34,7 +52,7 @@ def set_ngram_window_size(self, ngram_prompt_lookup_min: int,
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min

def init_device(self):
self.device = torch.device(f"cuda:{self.local_rank}")
self.device = torch.device(f"{DEVICE_TYPE}:{self.local_rank}")
self.load_model = lambda *args, **kwargs: None

# Current NGramWorker only supports Top1Proposer
Expand Down
5 changes: 5 additions & 0 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,11 @@ def __init__(
self._set_gc_threshold()
self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA',
'true').lower() == 'true'
if vllm_config.speculative_config is not None \
and self.use_contiguous_pa:
raise ValueError(
"Speculative decoding is not supported with "
"contiguous PA, please set VLLM_CONTIGUOUS_PA=false")
# For multi-step scheduling
self.cached_step_outputs: List[torch.Tensor] = []

Expand Down

0 comments on commit 756485f

Please sign in to comment.