From 3460308fba44b7449a5f04798215095c38dc5034 Mon Sep 17 00:00:00 2001 From: Zhonghua Deng Date: Thu, 19 Dec 2024 21:17:25 +0800 Subject: [PATCH] improve expand (#3) Signed-off-by: Abatom --- vllm/lora/ops/sgmv_expand.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index 6a5e1d697c236..c1f100c541e38 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -53,12 +53,20 @@ def _sgmv_expand_kernel( """ pid = tl.program_id(axis=0) cur_batch = tl.program_id(axis=1) - slice_id = tl.program_id(axis=2) - cta_n_num = tl.cdiv(N, BLOCK_N) - pid_m = pid // cta_n_num - pid_n = pid % cta_n_num + M = tl.load(seq_lens + cur_batch) + + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + GROUP_SIZE_M: tl.constexpr = 1 + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + if pid_m * BLOCK_M > M: return lora_index = tl.load(lora_indices + cur_batch)