Skip to content

Commit

Permalink
[bucketing overhaul 2/n] Delegate bucket management to HPUBucketingCo…
Browse files Browse the repository at this point in the history
…ntext (#530)

Co-authored-by: Konrad Zawora <[email protected]>
  • Loading branch information
kdamaszk and kzawora-intel authored Nov 21, 2024
1 parent efe0268 commit f481707
Showing 1 changed file with 162 additions and 119 deletions.
281 changes: 162 additions & 119 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,23 +87,114 @@ class HPUBucketingGlobalState(metaclass=Singleton):
decode_buckets: List[Tuple[int, int]] = field(init=False)


def subtuple(obj: object,
typename: str,
to_copy: List[str],
to_override: Optional[Dict[str, object]] = None):
if obj is None:
return None
if to_override is None:
to_override = {}
fields = set(to_copy) | set(to_override.keys())
if type(obj) is dict:
values = {key: obj[key] for key in fields if key in obj}
else:
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
if typename not in _TYPE_CACHE:
_TYPE_CACHE[typename] = collections.namedtuple(typename,
' '.join(fields))
return _TYPE_CACHE[typename](**values)
class HPUBucketingContext(metaclass=Singleton):
global_state = HPUBucketingGlobalState()

def __init__(self, max_num_seqs, max_num_prefill_seqs, block_size,
max_num_batched_tokens):
self.max_num_seqs = max_num_seqs
self.max_num_prefill_seqs = max_num_prefill_seqs
self.block_size = block_size
self.max_num_batched_tokens = max_num_batched_tokens
self._setup_buckets()

def _setup_buckets(self) -> None:
align_bs = lambda x: min(self.max_num_seqs, x)
#FIXME: The default values should be max_model_len
max_prompt_seq = 1024
max_decode_seq = 2048
self.global_state.prompt_bs_bucket_cfg = read_bucket_settings(
'prompt',
'bs',
min=1,
step=align_bs(32),
max=self.max_num_prefill_seqs)
self.global_state.decode_bs_bucket_cfg = read_bucket_settings(
'decode', 'bs', min=1, step=align_bs(32), max=self.max_num_seqs)
self.global_state.prompt_seq_bucket_cfg = \
read_bucket_settings(
'prompt',
'seq',
min=self.block_size,
step=self.block_size,
max=max_prompt_seq)
self.global_state.decode_block_bucket_cfg = \
read_bucket_settings(
'decode',
'block',
min=self.block_size,
step=self.block_size,
max=max(self.block_size,
self.max_num_seqs * max_decode_seq // self.block_size))

msg = ("Prompt bucket config (min, step, max_warmup) "
f"bs:{self.global_state.prompt_bs_bucket_cfg}, "
f"seq:{self.global_state.prompt_seq_bucket_cfg}")
logger.info(msg)

msg = ("Decode bucket config (min, step, max_warmup) "
f"bs:{self.global_state.decode_bs_bucket_cfg}, "
f"block:{self.global_state.decode_block_bucket_cfg}")
logger.info(msg)

def generate_prompt_buckets(self):
self.global_state.prompt_buckets, prompt_omitted_buckets = \
generate_prompt_buckets(
self.global_state.prompt_bs_bucket_cfg,
self.global_state.prompt_seq_bucket_cfg,
self.max_num_batched_tokens)

msg = (f"Generated {len(self.global_state.prompt_buckets)} "
f"prompt buckets [bs, seq]: \
{list(sorted(self.global_state.prompt_buckets))}")
logger.info(msg)

msg = (f"Omitted {len(prompt_omitted_buckets)} "
"prompt buckets due to exceeded token budget "
f"(max_num_batched_tokens={self.max_num_batched_tokens})")
logger.info(msg)

msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}"
logger.debug(msg)

def generate_decode_buckets(self, max_blocks):
self.global_state.decode_buckets = generate_decode_buckets(
self.global_state.decode_bs_bucket_cfg,
self.global_state.decode_block_bucket_cfg, max_blocks)
logger.info("Generated %d decode buckets [bs, total_blocks]: %s",
len(self.global_state.decode_buckets),
list(sorted(self.global_state.decode_buckets)))

def get_padded_prompt_batch_size(self, batch_size):
return find_bucket(batch_size, self.global_state.prompt_bs_bucket_cfg)

def get_padded_decode_batch_size(self, batch_size):
return find_bucket(batch_size, self.global_state.decode_bs_bucket_cfg)

def get_padded_prompt_seq_len(self, seq_len):
return find_bucket(seq_len, self.global_state.prompt_seq_bucket_cfg)

def get_padded_decode_num_blocks(self, num_blocks):
return find_bucket(num_blocks,
self.global_state.decode_block_bucket_cfg)

def get_padded_batch_size(self, batch_size, is_prompt):
if is_prompt:
return self.get_padded_prompt_batch_size(batch_size)
return self.get_padded_decode_batch_size(batch_size)

def get_padded_seq_or_block(self, seq_or_block, is_prompt):
if is_prompt:
return self.get_padded_prompt_seq_len(seq_or_block)
return self.get_padded_decode_num_blocks(seq_or_block)

@property
def prompt_buckets(self):
return self.global_state.prompt_buckets

@property
def decode_buckets(self):
return self.global_state.decode_buckets


def read_bucket_settings(phase: str, dim: str, **defaults):
Expand Down Expand Up @@ -233,6 +324,25 @@ def find_bucket(value: int, config: Tuple[int, int, int]):
return max(bmin, min(next_step, next_pow))


def subtuple(obj: object,
typename: str,
to_copy: List[str],
to_override: Optional[Dict[str, object]] = None):
if obj is None:
return None
if to_override is None:
to_override = {}
fields = set(to_copy) | set(to_override.keys())
if type(obj) is dict:
values = {key: obj[key] for key in fields if key in obj}
else:
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
if typename not in _TYPE_CACHE:
_TYPE_CACHE[typename] = collections.namedtuple(typename,
' '.join(fields))
return _TYPE_CACHE[typename](**values)


def align_workers(value, op):
group = get_world_group().cpu_group
world_size = torch.distributed.get_world_size()
Expand Down Expand Up @@ -655,8 +765,12 @@ def __init__(
self.profiler_counter_helper = HabanaProfilerCounterHelper()
self.seen_configs: set = set()
self._mem_margin: Optional[int] = None
self.bucketing_global_state = HPUBucketingGlobalState()
self._setup_buckets()
self.bucketing_ctx = HPUBucketingContext(self.max_num_seqs,
self.max_num_prefill_seqs,
self.block_size,
self.max_num_batched_tokens)
self.graphed_buckets: Set[Any] = set()

self._set_gc_threshold()
self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA',
'true').lower() == 'true'
Expand Down Expand Up @@ -782,46 +896,6 @@ def _use_graphs(self, batch_size, seq_len, is_prompt):
def _is_valid_bucket(self, bucket):
return bucket[0] * bucket[1] <= self.max_num_batched_tokens

def _setup_buckets(self) -> None:
align_bs = lambda x: min(self.max_num_seqs, x)
#FIXME: The default values should be max_model_len
max_prompt_seq = 1024
max_decode_seq = 2048
self.bucketing_global_state.prompt_bs_bucket_cfg = read_bucket_settings(
'prompt',
'bs',
min=1,
step=align_bs(32),
max=self.max_num_prefill_seqs)
self.bucketing_global_state.decode_bs_bucket_cfg = read_bucket_settings(
'decode', 'bs', min=1, step=align_bs(32), max=self.max_num_seqs)
self.bucketing_global_state.prompt_seq_bucket_cfg = \
read_bucket_settings(
'prompt',
'seq',
min=self.block_size,
step=self.block_size,
max=max_prompt_seq)
self.bucketing_global_state.decode_block_bucket_cfg = \
read_bucket_settings(
'decode',
'block',
min=self.block_size,
step=self.block_size,
max=max(self.block_size,
self.max_num_seqs * max_decode_seq // self.block_size))
self.graphed_buckets: Set[Any] = set()

msg = ("Prompt bucket config (min, step, max_warmup) "
f"bs:{self.bucketing_global_state.prompt_bs_bucket_cfg}, "
f"seq:{self.bucketing_global_state.prompt_seq_bucket_cfg}")
logger.info(msg)

msg = ("Decode bucket config (min, step, max_warmup) "
f"bs:{self.bucketing_global_state.decode_bs_bucket_cfg}, "
f"block:{self.bucketing_global_state.decode_block_bucket_cfg}")
logger.info(msg)

def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
Expand Down Expand Up @@ -937,8 +1011,7 @@ def _prepare_prompt(
assert max_query_len > 0

max_prompt_len = max(
find_bucket(max_query_len,
self.bucketing_global_state.prompt_seq_bucket_cfg),
self.bucketing_ctx.get_padded_prompt_seq_len(max(seq_lens)),
self.block_size)

lora_ids: List[int] = []
Expand Down Expand Up @@ -1151,19 +1224,17 @@ def _prepare_decode(
padding_fn = None
if self.use_contiguous_pa:
block_bucket_size = max(max(block_list) + 1, len(block_list))
block_bucket_size = find_bucket(
block_bucket_size,
self.bucketing_global_state.decode_block_bucket_cfg)
block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks(
block_bucket_size)
indices: List[Any]
indices = [None] * block_bucket_size
for i, bid in enumerate(block_list):
indices[bid] = i
padding_fn = lambda tensor, pad_value: gather_list(
tensor, indices, pad_value)
else:
block_bucket_size = find_bucket(
len(block_list),
self.bucketing_global_state.decode_block_bucket_cfg)
block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks(
len(block_list))
padding_fn = lambda tensor, pad_value: pad_list(
tensor, block_bucket_size, pad_value)

Expand Down Expand Up @@ -1245,9 +1316,8 @@ def prepare_input_tensors(
self.profiler.start('internal', base_event_name)

real_batch_size = len(seq_group_metadata_list)
bucket_cfg = self.bucketing_global_state.prompt_bs_bucket_cfg \
if is_prompt else self.bucketing_global_state.decode_bs_bucket_cfg
batch_size_padded = find_bucket(real_batch_size, bucket_cfg)
batch_size_padded = self.bucketing_ctx.get_padded_batch_size(
real_batch_size, is_prompt)
batch_size_padding = batch_size_padded - real_batch_size
seq_group_metadata_list = seq_group_metadata_list.copy()
if batch_size_padding > 0:
Expand Down Expand Up @@ -1453,9 +1523,11 @@ def create_dummy_seq_group_metadata(self,
def profile_run(self) -> None:
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1]
max_batch_size = min(self.max_num_batched_tokens // max_seq_len,
self.scheduler_config.max_num_seqs)
max_batch_size = self.bucketing_ctx.global_state.prompt_bs_bucket_cfg[
-1]
max_seq_len = min(
self.bucketing_ctx.global_state.prompt_seq_bucket_cfg[-1],
self.max_num_batched_tokens // max_batch_size)

self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches,
False, True)
Expand Down Expand Up @@ -1594,7 +1666,7 @@ def log_warmup(self, phase, i, max_i, batch_size, seq_len):
free_mem = format_bytes(
HabanaMemoryProfiler.current_free_device_memory())
dim = "num_blocks"
if phase == "Prompt":
if "Prompt" in phase:
dim = "seq_len"
msg = (f"[Warmup][{phase}][{i+1}/{max_i}] "
f"batch_size:{batch_size} "
Expand Down Expand Up @@ -1689,37 +1761,12 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
return
self.profiler.start('internal', 'warmup')
max_blocks = kv_caches[0][0].size(0)

self.bucketing_global_state.prompt_buckets, prompt_omitted_buckets = \
generate_prompt_buckets(
self.bucketing_global_state.prompt_bs_bucket_cfg,
self.bucketing_global_state.prompt_seq_bucket_cfg,
self.max_num_batched_tokens)

msg = (f"Generated {len(self.bucketing_global_state.prompt_buckets)} "
f"prompt buckets [bs, seq]: \
{list(sorted(self.bucketing_global_state.prompt_buckets))}")
logger.info(msg)

msg = (f"Omitted {len(prompt_omitted_buckets)} "
"prompt buckets due to exceeded token budget "
f"(max_num_batched_tokens={self.max_num_batched_tokens})")
logger.info(msg)

msg = f"Omitted prompt buckets: {list(sorted(prompt_omitted_buckets))}"
logger.debug(msg)

self.bucketing_global_state.decode_buckets = generate_decode_buckets(
self.bucketing_global_state.decode_bs_bucket_cfg,
self.bucketing_global_state.decode_block_bucket_cfg, max_blocks)
logger.info("Generated %d decode buckets [bs, total_blocks]: %s",
len(self.bucketing_global_state.decode_buckets),
list(sorted(self.bucketing_global_state.decode_buckets)))
self.bucketing_ctx.generate_prompt_buckets()
self.bucketing_ctx.generate_decode_buckets(max_blocks)

if not htorch.utils.internal.is_lazy() and not self.enforce_eager:
cache_size_limit = len(
self.bucketing_global_state.prompt_buckets) + len(
self.bucketing_global_state.decode_buckets) + 1
cache_size_limit = len(self.bucketing_ctx.prompt_buckets) + len(
self.bucketing_ctx.decode_buckets) + 1
torch._dynamo.config.cache_size_limit = max(
cache_size_limit, torch._dynamo.config.cache_size_limit)
# Multiply by 8 to follow the original default ratio between
Expand All @@ -1746,10 +1793,10 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
'Please update Gaudi Software Suite.')
with compile_only_mode_context(
) if can_use_compile_only_mode else contextlib.nullcontext():
self.warmup_all_buckets(self.bucketing_global_state.prompt_buckets,
True, kv_caches)
self.warmup_all_buckets(self.bucketing_global_state.decode_buckets,
False, kv_caches)
self.warmup_all_buckets(self.bucketing_ctx.prompt_buckets, True,
kv_caches)
self.warmup_all_buckets(self.bucketing_ctx.decode_buckets, False,
kv_caches)

if not self.enforce_eager and htorch.utils.internal.is_lazy():
assert self.mem_margin is not None, \
Expand Down Expand Up @@ -1779,11 +1826,11 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
'max_bs')
mem_post_prompt, prompt_batch_seq, prompt_captured_all = \
self.warmup_graphs(
prompt_strategy, self.bucketing_global_state.prompt_buckets,
prompt_strategy, self.bucketing_ctx.prompt_buckets,
True, kv_caches, prompt_available_memory)
mem_post_decode, decode_batch_seq, decode_captured_all = \
self.warmup_graphs(
decode_strategy, self.bucketing_global_state.decode_buckets,
decode_strategy, self.bucketing_ctx.decode_buckets,
False, kv_caches, decode_available_memory)

# Not all prompt buckets were captured, but all decode buckets
Expand All @@ -1793,9 +1840,8 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
and not prompt_captured_all and decode_captured_all):
mem_post_prompt, _, prompt_captured_all = (
self.warmup_graphs(
prompt_strategy,
self.bucketing_global_state.prompt_buckets, True,
kv_caches,
prompt_strategy, self.bucketing_ctx.prompt_buckets,
True, kv_caches,
graph_free_mem - mem_post_prompt - mem_post_decode,
mem_post_prompt, prompt_batch_seq))

Expand All @@ -1806,18 +1852,15 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
and not decode_captured_all \
and prompt_captured_all:
mem_post_decode, _, _ = self.warmup_graphs(
decode_strategy,
self.bucketing_global_state.decode_buckets, False,
kv_caches,
decode_strategy, self.bucketing_ctx.decode_buckets,
False, kv_caches,
graph_free_mem - mem_post_prompt - mem_post_decode,
mem_post_decode, decode_batch_seq)

self.log_graph_warmup_summary(
self.bucketing_global_state.prompt_buckets, True,
mem_post_prompt)
self.bucketing_ctx.prompt_buckets, True, mem_post_prompt)
self.log_graph_warmup_summary(
self.bucketing_global_state.decode_buckets, False,
mem_post_decode)
self.bucketing_ctx.decode_buckets, False, mem_post_decode)

end_time = time.perf_counter()
end_mem = HabanaMemoryProfiler.current_device_memory_usage()
Expand Down

0 comments on commit f481707

Please sign in to comment.