diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 71d4ca5d888..19bd07b88da 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -694,6 +694,7 @@ def prepare_for_extend(self): # (req.req_pool_idx, slice(pre_len, seq_len)), # out_cache_loc[pt : pt + req.extend_input_len], # ) + # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start) if self.model_config.is_encoder_decoder: self.prepare_encoder_info_extend(input_ids, seq_lens) @@ -1146,6 +1147,7 @@ def write_req_to_token_pool_triton( pre_len = tl.load(pre_lens + pid) seq_len = tl.load(seq_lens + pid) + # TODO: optimize this? cumsum_start = 0 for i in range(pid): cumsum_start += tl.load(extend_lens + i) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index d839125448d..ea7c8d89a58 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -275,7 +275,7 @@ def compute_position_triton( extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device ) extend_start_loc = torch.empty( - batch_size, dtype=torch.int64, device=extend_seq_lens.device + batch_size, dtype=torch.int32, device=extend_seq_lens.device ) # Launch kernel @@ -302,6 +302,7 @@ def compute_position_kernel( prefix_len = tl.load(extend_prefix_lens + pid) seq_len = tl.load(extend_seq_lens + pid) + # TODO: optimize this? cumsum_start = 0 for i in range(pid): cumsum_start += tl.load(extend_seq_lens + i)