diff --git a/.jenkins/test_config.yaml b/.jenkins/test_config.yaml index 3d8b2416506c7..0b9a2231d59a8 100644 --- a/.jenkins/test_config.yaml +++ b/.jenkins/test_config.yaml @@ -57,4 +57,4 @@ stages: command: TORCH_COMPILE_DISABLE=true VLLM_CONTIGUOUS_PA=false VLLM_SKIP_WARMUP=True pytest -v tests/spec_decode/e2e/test_medusa_correctness.py::test_medusa_e2e_greedy_correctness - name: gsm8k_small_g2_tp1_eagle_spec_decode flavor: g2 - command: TORCH_COMPILE_DISABLE=true VLLM_CONTIGUOUS_PA=false VLLM_SKIP_WARMUP=True pytest -v tests/spec_decode/e2e/test_eagle_correctness.py::test_eagle_e2e_greedy_correctness + command: VLLM_COS_SIN_RECOMPUTE=true TORCH_COMPILE_DISABLE=true VLLM_CONTIGUOUS_PA=false VLLM_SKIP_WARMUP=True pytest -v tests/spec_decode/e2e/test_eagle_correctness.py::test_eagle_e2e_greedy_correctness diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index a601189788441..90a5f80cf5755 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -102,7 +102,9 @@ def __init__( def prepare_cos_sin(self, positions: torch.Tensor, - offsets: Optional[torch.Tensor] = None): + offsets: Optional[torch.Tensor] = None, + recompute_cos_sin: bool = False): + self.recompute_cos_sin = recompute_cos_sin if offsets is not None: offsets = offsets.view(positions.shape[0], -1) positions = positions + offsets @@ -237,6 +239,8 @@ def forward_hpu( # forward, since the offset information wasn't available previously if hasattr(self, "scaling_factors") or self.sin is None: self.prepare_cos_sin(positions, offsets) + if self.recompute_cos_sin: + self.prepare_cos_sin(positions, offsets, recompute_cos_sin=True) num_tokens = positions.shape[0] * positions.shape[1] # HPU RoPE kernel requires hidden dimension for cos and sin to be equal # to query hidden dimension, so the original tensors need to be diff --git a/vllm/sequence.py b/vllm/sequence.py index 669124319c4f4..53c8a4b73b4e3 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1182,7 +1182,8 @@ def update(self, second_last_token_hidden_states: Optional[torch.Tensor] = None): """Update hidden states from target model invocation. Only used for decode steps""" - assert len(seq_group_metadata_list) == len(hidden_states) + if len(seq_group_metadata_list) < len(hidden_states): + hidden_states = hidden_states[:len(seq_group_metadata_list)] self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) self.hidden_states = torch.cat([self.hidden_states, hidden_states]) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index d3090d313d155..b80463195ced0 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -209,6 +209,8 @@ def __init__(self, model, block_size, dtype, enforce_eager, layer_names): self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', '1').lower() in ['1', 'true'] \ and not is_fake_hpu() + self.recompute_cos_sin = os.getenv('VLLM_COS_SIN_RECOMPUTE', + 'false').lower() in ['1', 'true'] self.block_size = block_size self.dtype = dtype self.layer_names = layer_names @@ -370,7 +372,8 @@ def _prepare_cos_sin(self, positions): # At the end, we should be at the RotaryEmbedding layer. if hasattr(current_module, 'prepare_cos_sin'): - current_module.prepare_cos_sin(positions) + current_module.prepare_cos_sin( + positions, recompute_cos_sin=self.recompute_cos_sin) else: raise AttributeError( "The module at the end of the path does not have \ diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index cc88070fff56e..b73b8d5190b30 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -78,9 +78,7 @@ def __init__( is_encoder_decoder_model = self._is_encoder_decoder_model() ModelRunnerClass: Type[HPUModelRunnerBase] = HPUModelRunner - if model_runner_cls is not None: - ModelRunnerClass = model_runner_cls - elif is_encoder_decoder_model: + if is_encoder_decoder_model: ModelRunnerClass = HPUEncoderDecoderModelRunner self.model_runner: HPUModelRunnerBase = ModelRunnerClass( vllm_config=vllm_config,