Skip to content

Commit

Permalink
Optimzie ForwardBatch with a triton kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Nov 16, 2024
1 parent 522e631 commit 755bafd
Showing 1 changed file with 72 additions and 13 deletions.
85 changes: 72 additions & 13 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from typing import TYPE_CHECKING, List, Optional

import torch
import triton
import triton.language as tl

from sglang.srt.layers.rotary_embedding import MRotaryEmbedding

Expand Down Expand Up @@ -236,25 +238,16 @@ def init_new(

# Init position information
if not ret.forward_mode.is_decode():
ret.positions = torch.concat(
[
torch.arange(prefix_len, prefix_len + extend_len, device=device)
for prefix_len, extend_len in zip(
batch.extend_prefix_lens, batch.extend_seq_lens
)
],
axis=0,
)
ret.extend_num_tokens = batch.extend_num_tokens
ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32
).to(device, non_blocking=True)

ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, dtype=torch.int32
).to(device, non_blocking=True)
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
ret.extend_num_tokens = batch.extend_num_tokens
ret.positions, ret.extend_start_loc = compute_position_triton(
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
)
ret.extend_seq_lens_cpu = batch.extend_seq_lens
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens

Expand All @@ -271,3 +264,69 @@ def init_new(
model_runner.lora_manager.prepare_lora_batch(ret)

return ret


def compute_position_triton(
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
):
"""Compute positions. It is a fused version of `compute_position_torch`."""
batch_size = extend_seq_lens.shape[0]
positions = torch.empty(
extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device
)
extend_start_loc = torch.empty(
batch_size, dtype=torch.int64, device=extend_seq_lens.device
)

# Launch kernel
compute_position_kernel[(batch_size,)](
positions,
extend_start_loc,
extend_prefix_lens,
extend_seq_lens,
)

return positions, extend_start_loc


@triton.jit
def compute_position_kernel(
positions,
extend_start_loc,
extend_prefix_lens,
extend_seq_lens,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(0)

prefix_len = tl.load(extend_prefix_lens + pid)
seq_len = tl.load(extend_seq_lens + pid)

dst_start = 0
for i in range(pid):
dst_start += tl.load(extend_seq_lens + i)

num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
tl.store(
positions + dst_start + offset, prefix_len + offset, mask=offset < seq_len
)
tl.store(extend_start_loc + pid, dst_start)


def compute_position_torch(
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor
):
positions = torch.concat(
[
torch.arange(
prefix_len, prefix_len + extend_len, device=extend_prefix_lens.device
)
for prefix_len, extend_len in zip(extend_prefix_lens, extend_seq_lens)
],
axis=0,
)
extend_start_loc = torch.zeros_like(extend_seq_lens)
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
return positions.to(torch.int64), extend_start_loc

0 comments on commit 755bafd

Please sign in to comment.