Skip to content

Commit

Permalink
fix destructors flow and remove finish_measurements
Browse files Browse the repository at this point in the history
  • Loading branch information
nirda7 committed Dec 8, 2024
1 parent e0e47ed commit 029f9f9
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 41 deletions.
3 changes: 0 additions & 3 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,9 +1329,6 @@ def _advance_to_next_step(
else:
seq.append_token_id(sample.output_token, sample.logprobs)

def finish_measurements(self):
self.model_executor.finish_measurements()

def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
Expand Down
4 changes: 0 additions & 4 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,6 @@ def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
else:
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)

def finish_measurements(self):
assert not envs.VLLM_USE_V1, "INC does not support vLLM V1"
self.llm_engine.finish_measurements() # type: ignore[attr-defined]

@overload # LEGACY: single (prompt + optional token ids)
def generate(
self,
Expand Down
9 changes: 2 additions & 7 deletions vllm/executor/hpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,6 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
msg = f"init_cache_engine took {cache_init_m.get_summary_string()}"
logger.info(msg)

def finish_measurements(self):
self.driver_worker.finish_measurements()

def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
Expand Down Expand Up @@ -200,10 +197,8 @@ def start_profile(self) -> None:
def stop_profile(self) -> None:
self.driver_worker.stop_profile()

def shutdown(self) -> None:
if hasattr(self, "driver_worker") and hasattr(self.driver_worker,
'shutdown_inc'):
self.driver_worker.shutdown_inc()
def shutdown_inc(self) -> None:
self.driver_worker.shutdown_inc()


class HPUExecutorAsync(HPUExecutor, ExecutorAsyncBase):
Expand Down
18 changes: 8 additions & 10 deletions vllm/executor/ray_hpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,22 @@ def _init_executor(self) -> None:
self.output_decoder = msgspec.msgpack.Decoder(
Optional[List[SamplerOutput]])

self.terminate_ray = True

def shutdown(self) -> None:
for worker in self.workers:
worker.__ray_terminate__.remote()
if getattr(self, 'terminate_ray', False):
for worker in self.workers:
worker.__ray_terminate__.remote()
self.terminate_ray = False
if hasattr(self, "forward_dag") and self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
self.forward_dag = None

def finish_measurements(self):
self._run_workers("finish_measurements")
def shutdown_inc(self):
self._run_workers("shutdown_inc")

def _get_worker_module_and_class(
self
Expand Down Expand Up @@ -480,9 +484,6 @@ def _compiled_ray_dag(self, enable_asyncio: bool):

return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)

def __del__(self):
self.shutdown()


class RayHPUExecutorAsync(RayHPUExecutor, DistributedGPUExecutorAsync):

Expand Down Expand Up @@ -553,6 +554,3 @@ async def _start_worker_execution_loop(self):
for worker in self.non_driver_workers
]
return await asyncio.gather(*coros)

def __del__(self):
self.shutdown()
18 changes: 4 additions & 14 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1894,10 +1894,6 @@ def prepare_model_input(
is_prompt=is_prompt,
virtual_engine=virtual_engine)

def finish_measurements(self):
from neural_compressor.torch.quantization import finalize_calibration
finalize_calibration(self.model.model)

def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode):
cfg = (batch_size, seq_len, is_prompt)
seen = cfg in self.seen_configs
Expand Down Expand Up @@ -2265,18 +2261,12 @@ def _make_decode_output(
return SamplerOutput(sampler_outputs)

def shutdown_inc(self):
can_finalize_inc = False
from contextlib import suppress
with suppress(AttributeError):
can_finalize_inc = (self.model_config.quantization == 'inc') and \
(self.model.model is not None) and \
self.inc_initialized_successfully and \
not getattr(self, "_is_inc_finalized", False)
can_finalize_inc = (self.model_config.quantization == 'inc') and \
(self.model.model is not None) and \
self.inc_initialized_successfully and \
not getattr(self, "_is_inc_finalized", False)
if can_finalize_inc:
from neural_compressor.torch.quantization import (
finalize_calibration)
finalize_calibration(self.model.model)
self._is_inc_finalized = True

def __del__(self):
self.shutdown_inc()
3 changes: 0 additions & 3 deletions vllm/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,6 @@ def _warm_up_model(self) -> None:
# the model initialization and profiling.
set_random_seed(self.model_config.seed)

def finish_measurements(self):
self.model_runner.finish_measurements()

@property
def do_metadata_broadcast(self) -> bool:
return self.parallel_config.tensor_parallel_size > 1
Expand Down

0 comments on commit 029f9f9

Please sign in to comment.