diff --git a/src/models/xlmroberta.cpp b/src/models/xlmroberta.cpp index 09cec7b..77d2b71 100644 --- a/src/models/xlmroberta.cpp +++ b/src/models/xlmroberta.cpp @@ -112,7 +112,9 @@ namespace fastllm { PermuteSelf(k, {0, 2, 1, 3}); PermuteSelf(v, {0, 2, 1, 3}); - if (bsz == 1) { + if (false) { + // TODO: 这里使用的AttentionMask不是因果Mask,无法直接调用Attention函数 + // 后续需要修改AttentionMask使得可以直接调用Attention函数 q.Reshape({-1, q.dims[2], q.dims[3]}); k.Reshape({-1, k.dims[2], k.dims[3]}); v.Reshape({-1, v.dims[2], v.dims[3]});