diff --git a/vllm/config.py b/vllm/config.py index eef1c2bfb9df9..e732c84c54520 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -372,9 +372,10 @@ def verify_async_output_proc(self, parallel_config, speculative_config, self.use_async_output_proc = False return - if device_config.device_type not in ("cuda", "tpu"): + if device_config.device_type not in ("cuda", "tpu", "hpu"): logger.warning( - "Async output processing is only supported for CUDA or TPU. " + "Async output processing is only supported for CUDA, TPU " + "and HPU. " "Disabling it for other platforms.") self.use_async_output_proc = False return diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index bfbe4085ddd3f..f3f679dbd1878 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -428,6 +428,7 @@ class ModelInputForHPU(ModelRunnerInputBase): virtual_engine: int = 0 lora_mask: Optional[torch.Tensor] = None lora_logits_mask: Optional[torch.Tensor] = None + async_callback: Optional[Callable] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -1934,6 +1935,9 @@ def execute_model( if not self.is_driver_worker: return [] + if model_input.async_callback is not None: + model_input.async_callback() + # Sample the next token. with self.profiler.record_event( 'internal', ('sample_'