From 3e135aea80d463d85416c08c9e0bf12d08f3ae3b Mon Sep 17 00:00:00 2001 From: yuwenzho Date: Tue, 29 Oct 2024 14:41:07 +0800 Subject: [PATCH] Fix one_hot bug in torch compile mode (#427) Fix one_hot bug in torch compile mode ``` > block_mapping = torch.nn.functional.one_hot(metadata.block_mapping, num_classes=batch_size) E RuntimeError: Class values must be non-negative. ../../vllm/worker/hpu_model_runner.py:311: RuntimeError ``` --- vllm/worker/hpu_model_runner.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index b5100491c4135..78e8620d7c43c 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -318,18 +318,19 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype): mask = mask >= metadata.block_usage.unsqueeze(-1) attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( mask, -math.inf)) - if is_fake_hpu(): - # Unfortunately one_hot on CPU doesn't handle - # out of bounds classes. We need to mask those - # values manually - oob_values = metadata.block_mapping.lt(0) - block_mapping = metadata.block_mapping.masked_fill(oob_values, 0) - block_mapping = torch.nn.functional.one_hot(block_mapping, + + if not is_fake_hpu() and htorch.utils.internal.is_lazy(): + block_mapping = torch.nn.functional.one_hot(metadata.block_mapping, num_classes=batch_size) - block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0) else: - block_mapping = torch.nn.functional.one_hot(metadata.block_mapping, + # Unfortunately one_hot on CPU/torch.compile mode/eager mode + # doesn't handle out of bounds classes, + # so we convert all negative values to 0. + block_mapping = torch.nn.functional.relu(metadata.block_mapping) + block_mapping = torch.nn.functional.one_hot(block_mapping, num_classes=batch_size) + oob_values = metadata.block_mapping.lt(0) + block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0) block_mapping = block_mapping.to(dtype) metadata = metadata._replace(block_mapping=block_mapping, attn_bias=attn_bias)