From de6833cc79a84ae9ef4cc5a9397ce11fb0d27f1c Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Wed, 25 Sep 2024 13:26:09 +0300 Subject: [PATCH] i messed up and now i've fixed it --- vllm/worker/habana_model_runner.py | 36 +++++++----------------------- 1 file changed, 8 insertions(+), 28 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 9c88d6652a4cd..c038dfe42bf5f 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -1334,8 +1334,8 @@ def warmup_scenario(self, seq_len, is_prompt, kv_caches, - is_profile_run=False, - override_n_runs=None) -> None: + is_pt_profiler_run=False, + is_lora_profile_run=False) -> None: use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) scenario_name = ("warmup_" f"{'prompt' if is_prompt else 'decode'}_" @@ -1367,10 +1367,8 @@ def warmup_scenario(self, for idx in range(max_num_seqs) ] self.profiler.start('internal', scenario_name) - times = 3 if use_graphs or is_profile_run else 1 - if override_n_runs is not None: - times = override_n_runs - if self.lora_config and not is_profile_run: + times = 3 if use_graphs or is_pt_profiler_run else 1 + if self.lora_config and not is_lora_profile_run: lora_mapping = LoRAMapping( [0] * batch_size * seq_len, [0] * batch_size * seq_len, @@ -1401,27 +1399,19 @@ def warmup_scenario(self, ] torch.hpu.synchronize() profiler = None - fwd_times = [] - if is_profile_run and self.is_driver_worker: + if is_pt_profiler_run and self.is_driver_worker: profiler = setup_profiler() profiler.start() for _ in range(times): - torch.hpu.synchronize() - start = time.perf_counter() inputs = self.prepare_model_input(seqs) - self.execute_model(inputs, kv_caches, warmup_mode=False) + self.execute_model(inputs, kv_caches, warmup_mode=True) torch.hpu.synchronize() - end = time.perf_counter() - elapsed = end - start - fwd_times.append(elapsed) - print(f'[{batch_size}x{seq_len}x{use_graphs}] tput: {batch_size/elapsed:.3f} tps, time: {elapsed*1000:.3f} ms') if profiler: profiler.step() if profiler: profiler.stop() self.profiler.end() gc.collect() - return fwd_times, use_graphs def remove_all_loras(self): if not self.lora_manager: @@ -1466,13 +1456,11 @@ def log_warmup(self, phase, i, max_i, batch_size, seq_len): f"free_mem:{free_mem}") logger.info(msg) - def warmup_all_buckets(self, buckets, is_prompt, kv_caches, override_n_runs=None): - bucket_times = {} + def warmup_all_buckets(self, buckets, is_prompt, kv_caches): for i, (batch_size, seq_len) in enumerate(reversed(buckets)): self.log_warmup('Prompt' if is_prompt else 'Decode', i, len(buckets), batch_size, seq_len) - bucket_times[(batch_size, seq_len)] = self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches, override_n_runs=override_n_runs) - return bucket_times + self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) def warmup_graphs(self, strategy, @@ -1676,14 +1664,6 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: logger.info(msg) self.profiler.end() - if os.environ.get('VLLM_PROFILE_SERVER_CHARACTERISTICS', 'false').lower() == 'true': - from vllm.hpu.utils import process_run_characteristics - n_runs = int(os.environ.get('VLLM_PROFILE_SERVER_CHARACTERISTICS_N', '5')) - decode_times = self.warmup_all_buckets(self.decode_buckets, False, kv_caches, override_n_runs=n_runs) - process_run_characteristics(decode_times, block_size=self.cache_config.block_size, prefill=False) - prefill_times = self.warmup_all_buckets(self.prompt_buckets, True, kv_caches, override_n_runs=n_runs) - process_run_characteristics(prefill_times, block_size=self.cache_config.block_size, prefill=True) - @property def vocab_size(self) -> int: return self.model_config.get_vocab_size()