Skip to content

Commit

Permalink
Remove CPU sync before Sampler (HabanaAI#414)
Browse files Browse the repository at this point in the history
Currently before each Sampler call we have a CPU sync, which causes a
host gap:
<img width="226" alt="image"
src="https://github.com/user-attachments/assets/4509e69b-0f16-4ac9-812e-a2a9bc43a6ad">

This PR is removing that sync, so the host gap is no longer visible:
<img width="133" alt="image"
src="https://github.com/user-attachments/assets/66c19e4b-d832-4955-848d-8ae4acd8d264">

NOTE: class `ApplyToppTopkScalar` still has some CPU syncs inside. It
means that the biggest gain will be observed in the scenario without
`top_p` or `top_k` parameters. I think it is worth to investigate if we
can remove the syncs from this function too.
  • Loading branch information
kdamaszk authored Oct 22, 2024
1 parent aecd667 commit 0cf5261
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,13 @@ def _init_sampling_tensors(
self._do_penalties = do_penalties
self._do_top_p_top_k = do_top_p_top_k
self._do_min_p = do_min_p
self._top_p_scalar = sampling_tensors.top_ps[0].item()
self._top_k_scalar = sampling_tensors.top_ks[0].item()
self._top_p_scalar = sampling_tensors.top_ps[0]
self._top_k_scalar = sampling_tensors.top_ks[0]
scalar_p = torch.all(sampling_tensors.top_ps == self._top_p_scalar)
scalar_k = torch.all(sampling_tensors.top_ks == self._top_k_scalar)
self._scalar_p_and_k = (scalar_p and scalar_k).item()
if self._scalar_p_and_k and self._do_top_p_top_k:
self._apply_top_k_top_p_opt = ApplyToppTopkScalar(5)
self._scalar_p_and_k = torch.logical_and(scalar_p, scalar_k)

self._apply_top_k_top_p_opt = ApplyToppTopkScalar(5)

def forward(
self,
Expand Down Expand Up @@ -266,13 +266,13 @@ def forward(
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))

if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
if self._scalar_p_and_k:
logits = self._apply_top_k_top_p_opt(logits,
self._top_p_scalar,
self._top_k_scalar)
else:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)
# If we have a scalar p and k, we can use the optimized version.
logits = torch.where(
self._scalar_p_and_k,
self._apply_top_k_top_p_opt(logits, self._top_p_scalar,
self._top_k_scalar),
_apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks))

if do_min_p:
logits = _apply_min_p(logits, sampling_tensors.min_ps)
Expand Down

0 comments on commit 0cf5261

Please sign in to comment.