Skip to content

Commit

Permalink
Optimize for topk=1 case if we do not handle duplicates (#603)
Browse files Browse the repository at this point in the history
Original : #599

We have a case where topk=1, and topp=<1.

Adding special handling for the case topk=1 and handle_duplicate=0 (by
default handle_duplicate=0, to support num-scheduling-steps)
  • Loading branch information
ssarkar2 authored Jan 7, 2025
1 parent 27a22ab commit 9d6917f
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,14 @@ def __init__(self, increment: int):
self._increment = increment

def __call__(self, logits: torch.Tensor, p: float, k: int):
if k == 1 and not ApplyToppTopkScalar._handle_duplicates:
new_logits = torch.full(logits.shape,
-float("inf"),
device=logits.device)
vals, idx = torch.max(logits, keepdim=True, dim=1)
new_logits.scatter_(1, idx, vals.to(new_logits.dtype))
return new_logits

if k > ApplyToppTopkScalar._padded_k:
ApplyToppTopkScalar._padded_k = min(k + self._increment,
logits.shape[1])
Expand Down

0 comments on commit 9d6917f

Please sign in to comment.