From 5912fd341efae2125f8c86fcc0f22efe2adc3226 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Wed, 25 Sep 2024 16:54:20 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9A=82=E6=97=B6=E6=B3=A8=E9=87=8Axlmroberta?= =?UTF-8?q?=E4=B8=AD=E7=9A=84attention=E5=8A=A0=E9=80=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/models/xlmroberta.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/models/xlmroberta.cpp b/src/models/xlmroberta.cpp index 09cec7b..77d2b71 100644 --- a/src/models/xlmroberta.cpp +++ b/src/models/xlmroberta.cpp @@ -112,7 +112,9 @@ namespace fastllm { PermuteSelf(k, {0, 2, 1, 3}); PermuteSelf(v, {0, 2, 1, 3}); - if (bsz == 1) { + if (false) { + // TODO: 这里使用的AttentionMask不是因果Mask,无法直接调用Attention函数 + // 后续需要修改AttentionMask使得可以直接调用Attention函数 q.Reshape({-1, q.dims[2], q.dims[3]}); k.Reshape({-1, k.dims[2], k.dims[3]}); v.Reshape({-1, v.dims[2], v.dims[3]});