From 59992243cb23a8286c71b62b05d0e5a770321fa8 Mon Sep 17 00:00:00 2001 From: Iryna Boiko Date: Wed, 18 Sep 2024 12:09:13 +0200 Subject: [PATCH] Fix blocks number calculation for Flat PA (#269) Fix blocks number calculation for Flat PA via adding empty table_block (https://github.com/HabanaAI/vllm-fork/issues/158) --- vllm/worker/habana_model_runner.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index afc656d7c21fb..2870c10078c2d 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -173,11 +173,16 @@ def generate_prompt_buckets(bs_bucket_config, def generate_decode_buckets(bs_bucket_config, blocks_bucket_config, max_blocks): buckets = [] - for bs in warmup_range(bs_bucket_config): - for blocks in warmup_range(blocks_bucket_config): + bs_buckets = warmup_range(bs_bucket_config) + block_buckets = warmup_range(blocks_bucket_config) + bmin, bstep, bmax = blocks_bucket_config + last_bucket = max_blocks if (max_blocks // bstep + == 0) else (max_blocks // bstep + 1) * bstep + for bs in bs_buckets: + for blocks in block_buckets: if blocks < bs: continue - if blocks > max_blocks: + if blocks > last_bucket: break buckets.append((bs, blocks)) return list(sorted(buckets, key=lambda b: (b[0] * b[1], b[1], b[0]))) @@ -968,10 +973,12 @@ def _prepare_decode( seq_lens.append(seq_len) block_table = seq_group_metadata.block_tables[seq_id] - block_number = block_table[position // self.block_size] - if block_number == _PAD_BLOCK_ID: + if len(block_table) == 0: + block_number = _PAD_BLOCK_ID + block_table = [] slot = next(dummy_slots) else: + block_number = block_table[position // self.block_size] block_offset = position % self.block_size slot = block_number * self.block_size + block_offset slot_mapping.append([slot]) @@ -996,7 +1003,7 @@ def _prepare_decode( num_decode_tokens = sum(seq_lens) - blocks_used = [len(bt) for bt in block_tables] + blocks_used = [len(bt) for bt in block_tables if bt] block_list = list(itertools.chain(*block_tables)) block_mapping_nested: List[List[int]] = [ [i] * b_u for i, b_u in enumerate(blocks_used) @@ -1084,8 +1091,9 @@ def prepare_input_tensors( batch_size_padded = find_bucket(real_batch_size, bucket_cfg) batch_size_padding = batch_size_padded - real_batch_size seq_group_metadata_list = seq_group_metadata_list.copy() - seq_group_metadata_list.extend(seq_group_metadata_list[0] - for _ in range(batch_size_padding)) + seq_group_metadata_list.extend( + self.create_dummy_seq_group_metadata(0, 0, is_prompt) + for _ in range(batch_size_padding)) prefill_reqs = [] decode_reqs = []