Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Nov 16, 2024
1 parent bee18d0 commit 6eccceb
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
2 changes: 2 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,7 @@ def prepare_for_extend(self):
# (req.req_pool_idx, slice(pre_len, seq_len)),
# out_cache_loc[pt : pt + req.extend_input_len],
# )
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)

if self.model_config.is_encoder_decoder:
self.prepare_encoder_info_extend(input_ids, seq_lens)
Expand Down Expand Up @@ -1146,6 +1147,7 @@ def write_req_to_token_pool_triton(
pre_len = tl.load(pre_lens + pid)
seq_len = tl.load(seq_lens + pid)

# TODO: optimize this?
cumsum_start = 0
for i in range(pid):
cumsum_start += tl.load(extend_lens + i)
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def compute_position_triton(
extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device
)
extend_start_loc = torch.empty(
batch_size, dtype=torch.int64, device=extend_seq_lens.device
batch_size, dtype=torch.int32, device=extend_seq_lens.device
)

# Launch kernel
Expand All @@ -302,6 +302,7 @@ def compute_position_kernel(
prefix_len = tl.load(extend_prefix_lens + pid)
seq_len = tl.load(extend_seq_lens + pid)

# TODO: optimize this?
cumsum_start = 0
for i in range(pid):
cumsum_start += tl.load(extend_seq_lens + i)
Expand Down

0 comments on commit 6eccceb

Please sign in to comment.