Skip to content

Commit

Permalink
improve expand (#3)
Browse files Browse the repository at this point in the history
Signed-off-by: Abatom <[email protected]>
  • Loading branch information
Abatom authored Dec 19, 2024
1 parent 5c88ec4 commit 3460308
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions vllm/lora/ops/sgmv_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,20 @@ def _sgmv_expand_kernel(
"""
pid = tl.program_id(axis=0)
cur_batch = tl.program_id(axis=1)

slice_id = tl.program_id(axis=2)
cta_n_num = tl.cdiv(N, BLOCK_N)
pid_m = pid // cta_n_num
pid_n = pid % cta_n_num

M = tl.load(seq_lens + cur_batch)

num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
GROUP_SIZE_M: tl.constexpr = 1
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

if pid_m * BLOCK_M > M:
return
lora_index = tl.load(lora_indices + cur_batch)
Expand Down

0 comments on commit 3460308

Please sign in to comment.