diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 28962eb9ff8..4f76e57aec6 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -564,4 +564,4 @@ def main(server_args, bench_args): except Exception as e: raise e finally: - kill_child_process() + kill_child_process(include_self=True) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index f4c76b95853..71d4ca5d888 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1045,6 +1045,9 @@ def get_model_worker_batch(self): ) def copy(self): + # We need a stream synchronization here. Otherwise, there will be cuda illegal memory access errors. + _ = self.seq_lens[0].item() + # Only contain fields that will be used by process_batch_result return ScheduleBatch( reqs=self.reqs, diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 21264f1a975..8c924c442eb 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -56,6 +56,7 @@ def __init__( self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port) self.max_running_requests = self.worker.max_running_requests self.device = self.worker.device + self.gpu_id = gpu_id # Init future mappings self.future_token_ids_ct = 0 @@ -73,12 +74,6 @@ def __init__( ) self.forward_thread.start() - self.copy_queue = Queue() - self.copy_thread = threading.Thread( - target=self.copy_thread_func, - ) - self.copy_thread.start() - def get_worker_info(self): return self.worker.get_worker_info() @@ -104,12 +99,11 @@ def forward_thread_func(self): @torch.inference_mode() def forward_thread_func_(self): while True: - self.has_inflight_batch = False model_worker_batch, future_token_ids_ct = self.input_queue.get() if not model_worker_batch: break - self.has_inflight_batch = True self.launch_event = threading.Event() + copy_event = torch.cuda.Event() # Resolve future tokens in the input input_ids = model_worker_batch.input_ids @@ -142,39 +136,29 @@ def forward_thread_func_(self): ) ) next_token_ids = next_token_ids.to("cpu", non_blocking=True) - copy_event = torch.cuda.Event(blocking=True) copy_event.record() self.launch_event.set() - self.copy_queue.put((copy_event, logits_output, next_token_ids)) - - def copy_thread_func(self): - while True: - copy_event, logits_output, next_token_ids = self.copy_queue.get() - if not copy_event: - break - while not copy_event.query(): - time.sleep(1e-5) - - if logits_output.next_token_logprobs is not None: - logits_output.next_token_logprobs = ( - logits_output.next_token_logprobs.tolist() - ) - if logits_output.input_token_logprobs is not None: - logits_output.input_token_logprobs = ( - logits_output.input_token_logprobs.tolist() - ) - logits_output.normalized_prompt_logprobs = ( - logits_output.normalized_prompt_logprobs.tolist() - ) - - self.output_queue.put((logits_output, next_token_ids.tolist())) + self.output_queue.put((copy_event, logits_output, next_token_ids)) def resulve_batch_result(self, bid: int): - logits_output, next_token_ids = self.output_queue.get() - if self.has_inflight_batch: - # Wait until the batch is launched - self.launch_event.wait() + copy_event, logits_output, next_token_ids = self.output_queue.get() + while not copy_event.query(): + time.sleep(1e-5) + self.launch_event.wait() + + if logits_output.next_token_logprobs is not None: + logits_output.next_token_logprobs = ( + logits_output.next_token_logprobs.tolist() + ) + if logits_output.input_token_logprobs is not None: + logits_output.input_token_logprobs = ( + logits_output.input_token_logprobs.tolist() + ) + logits_output.normalized_prompt_logprobs = ( + logits_output.normalized_prompt_logprobs.tolist() + ) + next_token_ids = next_token_ids.tolist() return logits_output, next_token_ids def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):