Skip to content

Commit

Permalink
Fix blocks number calculation for Flat PA (HabanaAI#269)
Browse files Browse the repository at this point in the history
Fix blocks number calculation for Flat PA via adding empty table_block
(HabanaAI#158)
  • Loading branch information
iboiko-habana authored and zhouyu5 committed Sep 20, 2024
1 parent 574a796 commit 5999224
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,16 @@ def generate_prompt_buckets(bs_bucket_config,
def generate_decode_buckets(bs_bucket_config, blocks_bucket_config,
max_blocks):
buckets = []
for bs in warmup_range(bs_bucket_config):
for blocks in warmup_range(blocks_bucket_config):
bs_buckets = warmup_range(bs_bucket_config)
block_buckets = warmup_range(blocks_bucket_config)
bmin, bstep, bmax = blocks_bucket_config
last_bucket = max_blocks if (max_blocks // bstep
== 0) else (max_blocks // bstep + 1) * bstep
for bs in bs_buckets:
for blocks in block_buckets:
if blocks < bs:
continue
if blocks > max_blocks:
if blocks > last_bucket:
break
buckets.append((bs, blocks))
return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0])))
Expand Down Expand Up @@ -968,10 +973,12 @@ def _prepare_decode(
seq_lens.append(seq_len)

block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size]
if block_number == _PAD_BLOCK_ID:
if len(block_table) == 0:
block_number = _PAD_BLOCK_ID
block_table = []
slot = next(dummy_slots)
else:
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append([slot])
Expand All @@ -996,7 +1003,7 @@ def _prepare_decode(

num_decode_tokens = sum(seq_lens)

blocks_used = [len(bt) for bt in block_tables]
blocks_used = [len(bt) for bt in block_tables if bt]
block_list = list(itertools.chain(*block_tables))
block_mapping_nested: List[List[int]] = [
[i] * b_u for i, b_u in enumerate(blocks_used)
Expand Down Expand Up @@ -1084,8 +1091,9 @@ def prepare_input_tensors(
batch_size_padded = find_bucket(real_batch_size, bucket_cfg)
batch_size_padding = batch_size_padded - real_batch_size
seq_group_metadata_list = seq_group_metadata_list.copy()
seq_group_metadata_list.extend(seq_group_metadata_list[0]
for _ in range(batch_size_padding))
seq_group_metadata_list.extend(
self.create_dummy_seq_group_metadata(0, 0, is_prompt)
for _ in range(batch_size_padding))

prefill_reqs = []
decode_reqs = []
Expand Down

0 comments on commit 5999224

Please sign in to comment.