Skip to content

Commit

Permalink
i messed up and now i've fixed it
Browse files Browse the repository at this point in the history
  • Loading branch information
kzawora-intel committed Sep 25, 2024
1 parent 0bd8366 commit de6833c
Showing 1 changed file with 8 additions and 28 deletions.
36 changes: 8 additions & 28 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,8 +1334,8 @@ def warmup_scenario(self,
seq_len,
is_prompt,
kv_caches,
is_profile_run=False,
override_n_runs=None) -> None:
is_pt_profiler_run=False,
is_lora_profile_run=False) -> None:
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
scenario_name = ("warmup_"
f"{'prompt' if is_prompt else 'decode'}_"
Expand Down Expand Up @@ -1367,10 +1367,8 @@ def warmup_scenario(self,
for idx in range(max_num_seqs)
]
self.profiler.start('internal', scenario_name)
times = 3 if use_graphs or is_profile_run else 1
if override_n_runs is not None:
times = override_n_runs
if self.lora_config and not is_profile_run:
times = 3 if use_graphs or is_pt_profiler_run else 1
if self.lora_config and not is_lora_profile_run:
lora_mapping = LoRAMapping(
[0] * batch_size * seq_len,
[0] * batch_size * seq_len,
Expand Down Expand Up @@ -1401,27 +1399,19 @@ def warmup_scenario(self,
]
torch.hpu.synchronize()
profiler = None
fwd_times = []
if is_profile_run and self.is_driver_worker:
if is_pt_profiler_run and self.is_driver_worker:
profiler = setup_profiler()
profiler.start()
for _ in range(times):
torch.hpu.synchronize()
start = time.perf_counter()
inputs = self.prepare_model_input(seqs)
self.execute_model(inputs, kv_caches, warmup_mode=False)
self.execute_model(inputs, kv_caches, warmup_mode=True)
torch.hpu.synchronize()
end = time.perf_counter()
elapsed = end - start
fwd_times.append(elapsed)
print(f'[{batch_size}x{seq_len}x{use_graphs}] tput: {batch_size/elapsed:.3f} tps, time: {elapsed*1000:.3f} ms')
if profiler:
profiler.step()
if profiler:
profiler.stop()
self.profiler.end()
gc.collect()
return fwd_times, use_graphs

def remove_all_loras(self):
if not self.lora_manager:
Expand Down Expand Up @@ -1466,13 +1456,11 @@ def log_warmup(self, phase, i, max_i, batch_size, seq_len):
f"free_mem:{free_mem}")
logger.info(msg)

def warmup_all_buckets(self, buckets, is_prompt, kv_caches, override_n_runs=None):
bucket_times = {}
def warmup_all_buckets(self, buckets, is_prompt, kv_caches):
for i, (batch_size, seq_len) in enumerate(reversed(buckets)):
self.log_warmup('Prompt' if is_prompt else 'Decode', i,
len(buckets), batch_size, seq_len)
bucket_times[(batch_size, seq_len)] = self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches, override_n_runs=override_n_runs)
return bucket_times
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)

def warmup_graphs(self,
strategy,
Expand Down Expand Up @@ -1676,14 +1664,6 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
logger.info(msg)
self.profiler.end()

if os.environ.get('VLLM_PROFILE_SERVER_CHARACTERISTICS', 'false').lower() == 'true':
from vllm.hpu.utils import process_run_characteristics
n_runs = int(os.environ.get('VLLM_PROFILE_SERVER_CHARACTERISTICS_N', '5'))
decode_times = self.warmup_all_buckets(self.decode_buckets, False, kv_caches, override_n_runs=n_runs)
process_run_characteristics(decode_times, block_size=self.cache_config.block_size, prefill=False)
prefill_times = self.warmup_all_buckets(self.prompt_buckets, True, kv_caches, override_n_runs=n_runs)
process_run_characteristics(prefill_times, block_size=self.cache_config.block_size, prefill=True)

@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()
Expand Down

0 comments on commit de6833c

Please sign in to comment.