diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index cbfeb0ec6de02..2aebdfa964619 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -55,19 +55,10 @@ def _sgmv_expand_kernel( cur_batch = tl.program_id(axis=1) slice_id = tl.program_id(axis=2) + cta_n_num = tl.cdiv(N, BLOCK_N) + pid_m = pid // cta_n_num + pid_n = pid % cta_n_num M = tl.load(seq_lens + cur_batch) - - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - GROUP_M: tl.constexpr = 1 - width = GROUP_M * grid_n - group_id = pid // width - first_pid_m = group_id * GROUP_M - group_idx = pid % width - group_size_m = min(grid_m - first_pid_m, GROUP_M) - pid_m = first_pid_m + (group_idx % group_size_m) - pid_n = group_idx // group_size_m - if pid_m * BLOCK_M > M: return lora_index = tl.load(lora_indices + cur_batch)