From 8d27109d5d5e09a0c37a18210f9e0000d47c9968 Mon Sep 17 00:00:00 2001 From: Agoniii <815244047@qq.com> Date: Fri, 18 Oct 2024 03:33:39 +0000 Subject: [PATCH] attention_mask fill with -inf for UnfusedDotProductAttention Signed-off-by: Agoniii <815244047@qq.com> --- transformer_engine/pytorch/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 947c642c2c..3a9e3f9585 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -48,7 +48,7 @@ def attention_mask_func( attention_scores: torch.Tensor, attention_mask: torch.Tensor ) -> torch.Tensor: """Get attention mask""" - attention_scores.masked_fill_(attention_mask, -10000.0) + attention_scores.masked_fill_(attention_mask, float("-inf")) return attention_scores