Skip to content

Commit

Permalink
[perf]fix current stream (vllm-project#11870)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Jan 9, 2025
1 parent a732900 commit 310aca8
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 15 deletions.
15 changes: 8 additions & 7 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ncclRedOpTypeEnum, ncclUniqueId)
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.utils import current_stream

logger = init_logger(__name__)

Expand Down Expand Up @@ -96,7 +97,7 @@ def __init__(
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
self.world_size, self.unique_id, self.rank)

stream = torch.cuda.current_stream()
stream = current_stream()
# A small all_reduce for warmup.
data = torch.zeros(1, device=device)
self.all_reduce(data)
Expand All @@ -119,7 +120,7 @@ def all_reduce(self,
out_tensor = torch.empty_like(in_tensor)

if stream is None:
stream = torch.cuda.current_stream()
stream = current_stream()
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
buffer_type(out_tensor.data_ptr()),
in_tensor.numel(),
Expand All @@ -141,7 +142,7 @@ def all_gather(self,
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = current_stream()
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
Expand All @@ -162,7 +163,7 @@ def reduce_scatter(self,
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = current_stream()
self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
Expand All @@ -177,7 +178,7 @@ def send(self, tensor: torch.Tensor, dst: int, stream=None):
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = current_stream()
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
self.comm, cudaStream_t(stream.cuda_stream))
Expand All @@ -189,7 +190,7 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None):
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = current_stream()
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))
Expand All @@ -201,7 +202,7 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = torch.cuda.current_stream()
stream = current_stream()
if src == self.rank:
sendbuff = buffer_type(tensor.data_ptr())
# NCCL requires the sender also to have a receive buffer
Expand Down
5 changes: 1 addition & 4 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,7 @@ def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
# TODO: pynccl should not use `stream=`
# it can just always use the current stream.
out = pynccl_comm.all_reduce(input_,
stream=torch.cuda.current_stream())
out = pynccl_comm.all_reduce(input_)
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
Expand Down
33 changes: 33 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,39 @@ def find_nccl_library() -> str:
return so_file


prev_set_stream = torch.cuda.set_stream

_current_stream = None


def _patched_set_stream(stream: torch.cuda.Stream) -> None:
global _current_stream
_current_stream = stream
prev_set_stream(stream)


torch.cuda.set_stream = _patched_set_stream


def current_stream() -> torch.cuda.Stream:
"""
replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`.
it turns out that `torch.cuda.current_stream()` is quite expensive,
as it will construct a new stream object at each call.
here we patch `torch.cuda.set_stream` to keep track of the current stream
directly, so that we can avoid calling `torch.cuda.current_stream()`.
the underlying hypothesis is that we do not call `torch._C._cuda_setStream`
from C/C++ code.
"""
global _current_stream
if _current_stream is None:
# when this function is called before any stream is set,
# we return the default stream.
_current_stream = torch.cuda.current_stream()
return _current_stream


def enable_trace_function_call_for_thread(vllm_config: "VllmConfig") -> None:
"""Set up function tracing for the current thread,
if enabled via the VLLM_TRACE_FUNCTION environment variable
Expand Down
8 changes: 4 additions & 4 deletions vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
get_pythonized_sample_results)
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SequenceGroupMetadata, SequenceOutput)
from vllm.utils import PyObjectCache, async_tensor_h2d
from vllm.utils import PyObjectCache, async_tensor_h2d, current_stream
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
Expand Down Expand Up @@ -498,7 +498,7 @@ def execute_model(
# appended sampler output from last iteration
# - also maybe pythonize if CPU is ahead of GPU

current_stream = torch.cuda.current_stream()
stream = current_stream()
if not model_input.is_first_multi_step:
# Explicitly block on the previous step's forward to make sure we
# don't clobber any GPU tensors still in use.
Expand Down Expand Up @@ -541,7 +541,7 @@ def execute_model(
num_steps=1)

# record the event for the current step so that the next step can sync
model_input.record_step_event(current_stream)
model_input.record_step_event(stream)

if get_pp_group().is_last_rank and self.is_driver_worker:
assert isinstance(output, list)
Expand All @@ -552,7 +552,7 @@ def execute_model(
# event for the pythonization so that we only pythonize if the
# tensors are ready. May be able to be combined with the step event
output_ready_event = torch.cuda.Event()
output_ready_event.record(current_stream)
output_ready_event.record(stream)
if self.parallel_config.pipeline_parallel_size > 1:
output[0].sampled_token_ids_cpu = output[
0].sampled_token_ids.cpu()
Expand Down

0 comments on commit 310aca8

Please sign in to comment.