From f481707e55756c53bd87a7b1ea8f90b052652efc Mon Sep 17 00:00:00 2001 From: Karol Damaszke Date: Thu, 21 Nov 2024 11:33:36 +0100 Subject: [PATCH] [bucketing overhaul 2/n] Delegate bucket management to HPUBucketingContext (#530) Co-authored-by: Konrad Zawora --- vllm/worker/hpu_model_runner.py | 281 ++++++++++++++++++-------------- 1 file changed, 162 insertions(+), 119 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 03c5e62c8f11e..b064da65db67d 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -87,23 +87,114 @@ class HPUBucketingGlobalState(metaclass=Singleton): decode_buckets: List[Tuple[int, int]] = field(init=False) -def subtuple(obj: object, - typename: str, - to_copy: List[str], - to_override: Optional[Dict[str, object]] = None): - if obj is None: - return None - if to_override is None: - to_override = {} - fields = set(to_copy) | set(to_override.keys()) - if type(obj) is dict: - values = {key: obj[key] for key in fields if key in obj} - else: - values = {f: to_override.get(f, getattr(obj, f)) for f in fields} - if typename not in _TYPE_CACHE: - _TYPE_CACHE[typename] = collections.namedtuple(typename, - ' '.join(fields)) - return _TYPE_CACHE[typename](**values) +class HPUBucketingContext(metaclass=Singleton): + global_state = HPUBucketingGlobalState() + + def __init__(self, max_num_seqs, max_num_prefill_seqs, block_size, + max_num_batched_tokens): + self.max_num_seqs = max_num_seqs + self.max_num_prefill_seqs = max_num_prefill_seqs + self.block_size = block_size + self.max_num_batched_tokens = max_num_batched_tokens + self._setup_buckets() + + def _setup_buckets(self) -> None: + align_bs = lambda x: min(self.max_num_seqs, x) + #FIXME: The default values should be max_model_len + max_prompt_seq = 1024 + max_decode_seq = 2048 + self.global_state.prompt_bs_bucket_cfg = read_bucket_settings( + 'prompt', + 'bs', + min=1, + step=align_bs(32), + max=self.max_num_prefill_seqs) + self.global_state.decode_bs_bucket_cfg = read_bucket_settings( + 'decode', 'bs', min=1, step=align_bs(32), max=self.max_num_seqs) + self.global_state.prompt_seq_bucket_cfg = \ + read_bucket_settings( + 'prompt', + 'seq', + min=self.block_size, + step=self.block_size, + max=max_prompt_seq) + self.global_state.decode_block_bucket_cfg = \ + read_bucket_settings( + 'decode', + 'block', + min=self.block_size, + step=self.block_size, + max=max(self.block_size, + self.max_num_seqs * max_decode_seq // self.block_size)) + + msg = ("Prompt bucket config (min, step, max_warmup) " + f"bs:{self.global_state.prompt_bs_bucket_cfg}, " + f"seq:{self.global_state.prompt_seq_bucket_cfg}") + logger.info(msg) + + msg = ("Decode bucket config (min, step, max_warmup) " + f"bs:{self.global_state.decode_bs_bucket_cfg}, " + f"block:{self.global_state.decode_block_bucket_cfg}") + logger.info(msg) + + def generate_prompt_buckets(self): + self.global_state.prompt_buckets, prompt_omitted_buckets = \ + generate_prompt_buckets( + self.global_state.prompt_bs_bucket_cfg, + self.global_state.prompt_seq_bucket_cfg, + self.max_num_batched_tokens) + + msg = (f"Generated {len(self.global_state.prompt_buckets)} " + f"prompt buckets [bs, seq]: \ + {list(sorted(self.global_state.prompt_buckets))}") + logger.info(msg) + + msg = (f"Omitted {len(prompt_omitted_buckets)} " + "prompt buckets due to exceeded token budget " + f"(max_num_batched_tokens={self.max_num_batched_tokens})") + logger.info(msg) + + msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" + logger.debug(msg) + + def generate_decode_buckets(self, max_blocks): + self.global_state.decode_buckets = generate_decode_buckets( + self.global_state.decode_bs_bucket_cfg, + self.global_state.decode_block_bucket_cfg, max_blocks) + logger.info("Generated %d decode buckets [bs, total_blocks]: %s", + len(self.global_state.decode_buckets), + list(sorted(self.global_state.decode_buckets))) + + def get_padded_prompt_batch_size(self, batch_size): + return find_bucket(batch_size, self.global_state.prompt_bs_bucket_cfg) + + def get_padded_decode_batch_size(self, batch_size): + return find_bucket(batch_size, self.global_state.decode_bs_bucket_cfg) + + def get_padded_prompt_seq_len(self, seq_len): + return find_bucket(seq_len, self.global_state.prompt_seq_bucket_cfg) + + def get_padded_decode_num_blocks(self, num_blocks): + return find_bucket(num_blocks, + self.global_state.decode_block_bucket_cfg) + + def get_padded_batch_size(self, batch_size, is_prompt): + if is_prompt: + return self.get_padded_prompt_batch_size(batch_size) + return self.get_padded_decode_batch_size(batch_size) + + def get_padded_seq_or_block(self, seq_or_block, is_prompt): + if is_prompt: + return self.get_padded_prompt_seq_len(seq_or_block) + return self.get_padded_decode_num_blocks(seq_or_block) + + @property + def prompt_buckets(self): + return self.global_state.prompt_buckets + + @property + def decode_buckets(self): + return self.global_state.decode_buckets def read_bucket_settings(phase: str, dim: str, **defaults): @@ -233,6 +324,25 @@ def find_bucket(value: int, config: Tuple[int, int, int]): return max(bmin, min(next_step, next_pow)) +def subtuple(obj: object, + typename: str, + to_copy: List[str], + to_override: Optional[Dict[str, object]] = None): + if obj is None: + return None + if to_override is None: + to_override = {} + fields = set(to_copy) | set(to_override.keys()) + if type(obj) is dict: + values = {key: obj[key] for key in fields if key in obj} + else: + values = {f: to_override.get(f, getattr(obj, f)) for f in fields} + if typename not in _TYPE_CACHE: + _TYPE_CACHE[typename] = collections.namedtuple(typename, + ' '.join(fields)) + return _TYPE_CACHE[typename](**values) + + def align_workers(value, op): group = get_world_group().cpu_group world_size = torch.distributed.get_world_size() @@ -655,8 +765,12 @@ def __init__( self.profiler_counter_helper = HabanaProfilerCounterHelper() self.seen_configs: set = set() self._mem_margin: Optional[int] = None - self.bucketing_global_state = HPUBucketingGlobalState() - self._setup_buckets() + self.bucketing_ctx = HPUBucketingContext(self.max_num_seqs, + self.max_num_prefill_seqs, + self.block_size, + self.max_num_batched_tokens) + self.graphed_buckets: Set[Any] = set() + self._set_gc_threshold() self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA', 'true').lower() == 'true' @@ -782,46 +896,6 @@ def _use_graphs(self, batch_size, seq_len, is_prompt): def _is_valid_bucket(self, bucket): return bucket[0] * bucket[1] <= self.max_num_batched_tokens - def _setup_buckets(self) -> None: - align_bs = lambda x: min(self.max_num_seqs, x) - #FIXME: The default values should be max_model_len - max_prompt_seq = 1024 - max_decode_seq = 2048 - self.bucketing_global_state.prompt_bs_bucket_cfg = read_bucket_settings( - 'prompt', - 'bs', - min=1, - step=align_bs(32), - max=self.max_num_prefill_seqs) - self.bucketing_global_state.decode_bs_bucket_cfg = read_bucket_settings( - 'decode', 'bs', min=1, step=align_bs(32), max=self.max_num_seqs) - self.bucketing_global_state.prompt_seq_bucket_cfg = \ - read_bucket_settings( - 'prompt', - 'seq', - min=self.block_size, - step=self.block_size, - max=max_prompt_seq) - self.bucketing_global_state.decode_block_bucket_cfg = \ - read_bucket_settings( - 'decode', - 'block', - min=self.block_size, - step=self.block_size, - max=max(self.block_size, - self.max_num_seqs * max_decode_seq // self.block_size)) - self.graphed_buckets: Set[Any] = set() - - msg = ("Prompt bucket config (min, step, max_warmup) " - f"bs:{self.bucketing_global_state.prompt_bs_bucket_cfg}, " - f"seq:{self.bucketing_global_state.prompt_seq_bucket_cfg}") - logger.info(msg) - - msg = ("Decode bucket config (min, step, max_warmup) " - f"bs:{self.bucketing_global_state.decode_bs_bucket_cfg}, " - f"block:{self.bucketing_global_state.decode_block_bucket_cfg}") - logger.info(msg) - def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -937,8 +1011,7 @@ def _prepare_prompt( assert max_query_len > 0 max_prompt_len = max( - find_bucket(max_query_len, - self.bucketing_global_state.prompt_seq_bucket_cfg), + self.bucketing_ctx.get_padded_prompt_seq_len(max(seq_lens)), self.block_size) lora_ids: List[int] = [] @@ -1151,9 +1224,8 @@ def _prepare_decode( padding_fn = None if self.use_contiguous_pa: block_bucket_size = max(max(block_list) + 1, len(block_list)) - block_bucket_size = find_bucket( - block_bucket_size, - self.bucketing_global_state.decode_block_bucket_cfg) + block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks( + block_bucket_size) indices: List[Any] indices = [None] * block_bucket_size for i, bid in enumerate(block_list): @@ -1161,9 +1233,8 @@ def _prepare_decode( padding_fn = lambda tensor, pad_value: gather_list( tensor, indices, pad_value) else: - block_bucket_size = find_bucket( - len(block_list), - self.bucketing_global_state.decode_block_bucket_cfg) + block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks( + len(block_list)) padding_fn = lambda tensor, pad_value: pad_list( tensor, block_bucket_size, pad_value) @@ -1245,9 +1316,8 @@ def prepare_input_tensors( self.profiler.start('internal', base_event_name) real_batch_size = len(seq_group_metadata_list) - bucket_cfg = self.bucketing_global_state.prompt_bs_bucket_cfg \ - if is_prompt else self.bucketing_global_state.decode_bs_bucket_cfg - batch_size_padded = find_bucket(real_batch_size, bucket_cfg) + batch_size_padded = self.bucketing_ctx.get_padded_batch_size( + real_batch_size, is_prompt) batch_size_padding = batch_size_padded - real_batch_size seq_group_metadata_list = seq_group_metadata_list.copy() if batch_size_padding > 0: @@ -1453,9 +1523,11 @@ def create_dummy_seq_group_metadata(self, def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1] - max_batch_size = min(self.max_num_batched_tokens // max_seq_len, - self.scheduler_config.max_num_seqs) + max_batch_size = self.bucketing_ctx.global_state.prompt_bs_bucket_cfg[ + -1] + max_seq_len = min( + self.bucketing_ctx.global_state.prompt_seq_bucket_cfg[-1], + self.max_num_batched_tokens // max_batch_size) self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, False, True) @@ -1594,7 +1666,7 @@ def log_warmup(self, phase, i, max_i, batch_size, seq_len): free_mem = format_bytes( HabanaMemoryProfiler.current_free_device_memory()) dim = "num_blocks" - if phase == "Prompt": + if "Prompt" in phase: dim = "seq_len" msg = (f"[Warmup][{phase}][{i+1}/{max_i}] " f"batch_size:{batch_size} " @@ -1689,37 +1761,12 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: return self.profiler.start('internal', 'warmup') max_blocks = kv_caches[0][0].size(0) - - self.bucketing_global_state.prompt_buckets, prompt_omitted_buckets = \ - generate_prompt_buckets( - self.bucketing_global_state.prompt_bs_bucket_cfg, - self.bucketing_global_state.prompt_seq_bucket_cfg, - self.max_num_batched_tokens) - - msg = (f"Generated {len(self.bucketing_global_state.prompt_buckets)} " - f"prompt buckets [bs, seq]: \ - {list(sorted(self.bucketing_global_state.prompt_buckets))}") - logger.info(msg) - - msg = (f"Omitted {len(prompt_omitted_buckets)} " - "prompt buckets due to exceeded token budget " - f"(max_num_batched_tokens={self.max_num_batched_tokens})") - logger.info(msg) - - msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}" - logger.debug(msg) - - self.bucketing_global_state.decode_buckets = generate_decode_buckets( - self.bucketing_global_state.decode_bs_bucket_cfg, - self.bucketing_global_state.decode_block_bucket_cfg, max_blocks) - logger.info("Generated %d decode buckets [bs, total_blocks]: %s", - len(self.bucketing_global_state.decode_buckets), - list(sorted(self.bucketing_global_state.decode_buckets))) + self.bucketing_ctx.generate_prompt_buckets() + self.bucketing_ctx.generate_decode_buckets(max_blocks) if not htorch.utils.internal.is_lazy() and not self.enforce_eager: - cache_size_limit = len( - self.bucketing_global_state.prompt_buckets) + len( - self.bucketing_global_state.decode_buckets) + 1 + cache_size_limit = len(self.bucketing_ctx.prompt_buckets) + len( + self.bucketing_ctx.decode_buckets) + 1 torch._dynamo.config.cache_size_limit = max( cache_size_limit, torch._dynamo.config.cache_size_limit) # Multiply by 8 to follow the original default ratio between @@ -1746,10 +1793,10 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: 'Please update Gaudi Software Suite.') with compile_only_mode_context( ) if can_use_compile_only_mode else contextlib.nullcontext(): - self.warmup_all_buckets(self.bucketing_global_state.prompt_buckets, - True, kv_caches) - self.warmup_all_buckets(self.bucketing_global_state.decode_buckets, - False, kv_caches) + self.warmup_all_buckets(self.bucketing_ctx.prompt_buckets, True, + kv_caches) + self.warmup_all_buckets(self.bucketing_ctx.decode_buckets, False, + kv_caches) if not self.enforce_eager and htorch.utils.internal.is_lazy(): assert self.mem_margin is not None, \ @@ -1779,11 +1826,11 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: 'max_bs') mem_post_prompt, prompt_batch_seq, prompt_captured_all = \ self.warmup_graphs( - prompt_strategy, self.bucketing_global_state.prompt_buckets, + prompt_strategy, self.bucketing_ctx.prompt_buckets, True, kv_caches, prompt_available_memory) mem_post_decode, decode_batch_seq, decode_captured_all = \ self.warmup_graphs( - decode_strategy, self.bucketing_global_state.decode_buckets, + decode_strategy, self.bucketing_ctx.decode_buckets, False, kv_caches, decode_available_memory) # Not all prompt buckets were captured, but all decode buckets @@ -1793,9 +1840,8 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: and not prompt_captured_all and decode_captured_all): mem_post_prompt, _, prompt_captured_all = ( self.warmup_graphs( - prompt_strategy, - self.bucketing_global_state.prompt_buckets, True, - kv_caches, + prompt_strategy, self.bucketing_ctx.prompt_buckets, + True, kv_caches, graph_free_mem - mem_post_prompt - mem_post_decode, mem_post_prompt, prompt_batch_seq)) @@ -1806,18 +1852,15 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: and not decode_captured_all \ and prompt_captured_all: mem_post_decode, _, _ = self.warmup_graphs( - decode_strategy, - self.bucketing_global_state.decode_buckets, False, - kv_caches, + decode_strategy, self.bucketing_ctx.decode_buckets, + False, kv_caches, graph_free_mem - mem_post_prompt - mem_post_decode, mem_post_decode, decode_batch_seq) self.log_graph_warmup_summary( - self.bucketing_global_state.prompt_buckets, True, - mem_post_prompt) + self.bucketing_ctx.prompt_buckets, True, mem_post_prompt) self.log_graph_warmup_summary( - self.bucketing_global_state.decode_buckets, False, - mem_post_decode) + self.bucketing_ctx.decode_buckets, False, mem_post_decode) end_time = time.perf_counter() end_mem = HabanaMemoryProfiler.current_device_memory_usage()