From 9d6917f29be4f9b0193c423c69db98f2b3ac8795 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Tue, 7 Jan 2025 01:37:26 -0800 Subject: [PATCH] Optimize for topk=1 case if we do not handle duplicates (#603) Original : https://github.com/HabanaAI/vllm-fork/pull/599 We have a case where topk=1, and topp=<1. Adding special handling for the case topk=1 and handle_duplicate=0 (by default handle_duplicate=0, to support num-scheduling-steps) --- vllm/model_executor/layers/sampler.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index b61fb0e47d07a..ce6ec1a89ff87 100755 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -405,6 +405,14 @@ def __init__(self, increment: int): self._increment = increment def __call__(self, logits: torch.Tensor, p: float, k: int): + if k == 1 and not ApplyToppTopkScalar._handle_duplicates: + new_logits = torch.full(logits.shape, + -float("inf"), + device=logits.device) + vals, idx = torch.max(logits, keepdim=True, dim=1) + new_logits.scatter_(1, idx, vals.to(new_logits.dtype)) + return new_logits + if k > ApplyToppTopkScalar._padded_k: ApplyToppTopkScalar._padded_k = min(k + self._increment, logits.shape[1])