Skip to content

Commit

Permalink
Fix illegal memory access
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Nov 16, 2024
1 parent 724bc11 commit bee18d0
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 37 deletions.
2 changes: 1 addition & 1 deletion python/sglang/bench_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,4 +564,4 @@ def main(server_args, bench_args):
except Exception as e:
raise e
finally:
kill_child_process()
kill_child_process(include_self=True)
3 changes: 3 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,9 @@ def get_model_worker_batch(self):
)

def copy(self):
# We need a stream synchronization here. Otherwise, there will be cuda illegal memory access errors.
_ = self.seq_lens[0].item()

# Only contain fields that will be used by process_batch_result
return ScheduleBatch(
reqs=self.reqs,
Expand Down
56 changes: 20 additions & 36 deletions python/sglang/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
self.max_running_requests = self.worker.max_running_requests
self.device = self.worker.device
self.gpu_id = gpu_id

# Init future mappings
self.future_token_ids_ct = 0
Expand All @@ -73,12 +74,6 @@ def __init__(
)
self.forward_thread.start()

self.copy_queue = Queue()
self.copy_thread = threading.Thread(
target=self.copy_thread_func,
)
self.copy_thread.start()

def get_worker_info(self):
return self.worker.get_worker_info()

Expand All @@ -104,12 +99,11 @@ def forward_thread_func(self):
@torch.inference_mode()
def forward_thread_func_(self):
while True:
self.has_inflight_batch = False
model_worker_batch, future_token_ids_ct = self.input_queue.get()
if not model_worker_batch:
break
self.has_inflight_batch = True
self.launch_event = threading.Event()
copy_event = torch.cuda.Event()

# Resolve future tokens in the input
input_ids = model_worker_batch.input_ids
Expand Down Expand Up @@ -142,39 +136,29 @@ def forward_thread_func_(self):
)
)
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
copy_event = torch.cuda.Event(blocking=True)
copy_event.record()

self.launch_event.set()
self.copy_queue.put((copy_event, logits_output, next_token_ids))

def copy_thread_func(self):
while True:
copy_event, logits_output, next_token_ids = self.copy_queue.get()
if not copy_event:
break
while not copy_event.query():
time.sleep(1e-5)

if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.tolist()
)
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.tolist()
)

self.output_queue.put((logits_output, next_token_ids.tolist()))
self.output_queue.put((copy_event, logits_output, next_token_ids))

def resulve_batch_result(self, bid: int):
logits_output, next_token_ids = self.output_queue.get()
if self.has_inflight_batch:
# Wait until the batch is launched
self.launch_event.wait()
copy_event, logits_output, next_token_ids = self.output_queue.get()
while not copy_event.query():
time.sleep(1e-5)
self.launch_event.wait()

if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.tolist()
)
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.tolist()
)
next_token_ids = next_token_ids.tolist()
return logits_output, next_token_ids

def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
Expand Down

0 comments on commit bee18d0

Please sign in to comment.