diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py index 991ba0da36c35..5fad3c23684ce 100644 --- a/vllm/lora/ops/sgmv_expand_slice.py +++ b/vllm/lora/ops/sgmv_expand_slice.py @@ -33,6 +33,7 @@ def _sgmv_expand_slice_kernel( ls_d2_ptr, cm_stride, cn_stride, + group_size, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, @@ -49,77 +50,74 @@ def _sgmv_expand_slice_kernel( times. """ pid = tl.program_id(axis=0) - cur_batch = tl.program_id(axis=1) - slice_id = tl.program_id(axis=2) + slice_id = tl.program_id(axis=1) cta_n_num = tl.cdiv(N, BLOCK_N) pid_m = pid // cta_n_num pid_n = pid % cta_n_num - M = tl.load(seq_lens + cur_batch) - if pid_m * BLOCK_M > M: - return - lora_index = tl.load(lora_indices + cur_batch) - if lora_index == -1: - return - cur_seq_start = tl.load(b_seq_start_loc + cur_batch) - offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M - offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N - offset_k = tl.arange(0, BLOCK_K) - ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + for cur_batch in range(group_size): + M = tl.load(seq_lens + cur_batch) + lora_index = tl.load(lora_indices + cur_batch) + if pid_m * BLOCK_M <= M and lora_index != -1: + cur_seq_start = tl.load(b_seq_start_loc + cur_batch) + offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_k = tl.arange(0, BLOCK_K) + ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) - # input - cur_input_ptr = input_ptr + slice_id * input_d0_stride - a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride + - ram[:, None] * input_d1_stride + - offset_k[None, :] * input_d2_stride, ) - # lora - cur_lora_ptr = tl.load(lora_ptr + slice_id).to( - tl.pointer_type(out_ptr.dtype.element_ty)) - cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) - cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) - cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) + # input + cur_input_ptr = input_ptr + slice_id * input_d0_stride + a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride + + ram[:, None] * input_d1_stride + + offset_k[None, :] * input_d2_stride, ) + # lora + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( + tl.pointer_type(out_ptr.dtype.element_ty)) + cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) + cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) + cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) - b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + - offset_k[:, None] * cur_lora_d2_stride + - rbn[None, :] * cur_lora_d1_stride) - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k in range(tl.cdiv(K, BLOCK_K)): - if EVEN_K: - tiled_a = tl.load(a_ptr) - tiled_b = tl.load(b_ptr) - else: - tiled_a = tl.load(a_ptr, - mask=offset_k[None, :] < K - k * BLOCK_K, - other=0) - tiled_b = tl.load(b_ptr, - mask=offset_k[:, None] < K - k * BLOCK_K, - other=0) - if CAST_TYPE: - tiled_a = tiled_a.to(cur_lora_ptr.dtype.element_ty) - accumulator += tl.dot( - tiled_a, - tiled_b, - ) - a_ptr += BLOCK_K * input_d2_stride - b_ptr += BLOCK_K * cur_lora_d2_stride + b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + + offset_k[:, None] * cur_lora_d2_stride + + rbn[None, :] * cur_lora_d1_stride) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] < K - k * BLOCK_K, + other=0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] < K - k * BLOCK_K, + other=0) + if CAST_TYPE: + tiled_a = tiled_a.to(cur_lora_ptr.dtype.element_ty) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * input_d2_stride + b_ptr += BLOCK_K * cur_lora_d2_stride - tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty) + tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty) - cur_slice_start = tl.load(slice_start_loc + slice_id) + cur_slice_start = tl.load(slice_start_loc + slice_id) - offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M - offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start - c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + - offset_cn[None, :] * cn_stride) - M = tl.load(seq_lens + cur_batch) - c_mask = (offset_cm[:, None] < - (cur_seq_start + M)) & (offset_cn[None, :] < - (cur_slice_start + N)) - if ADD_INPUTS: - tiled_out = tl.load(c_ptr, mask=c_mask) - tiled_c += tiled_out - tl.store(c_ptr, tiled_c, mask=c_mask) + offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start + c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + + offset_cn[None, :] * cn_stride) + M = tl.load(seq_lens + cur_batch) + c_mask = (offset_cm[:, None] < + (cur_seq_start + M)) & (offset_cn[None, :] < + (cur_slice_start + N)) + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) _LORA_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} @@ -218,7 +216,6 @@ def _sgmv_expand_slice( CAST_TYPE = True grid = ( triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), - batches, len(lora_ptr_tensor), ) _sgmv_expand_slice_kernel[grid]( @@ -239,6 +236,7 @@ def _sgmv_expand_slice( lora_strides_d2_tensor, output_tensor.stride(0), output_tensor.stride(1), + batches, BLOCK_M, BLOCK_N, BLOCK_K,