From e128b35bba33924a97df666490485547dd83f1e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Sat, 1 Jun 2024 20:15:35 +0800 Subject: [PATCH] fix llama --- src/models/llama.cpp | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/models/llama.cpp b/src/models/llama.cpp index 33684196..57cce25a 100644 --- a/src/models/llama.cpp +++ b/src/models/llama.cpp @@ -522,10 +522,18 @@ namespace fastllm { all1 &= (seqLens[i] == 1); } - for (int b = 0; b < batch; b++) { - contexts[b] = positionIds[b]; + if (all1) { + for (int b = 0; b < batch; b++) { + contexts[b] = positionIds[b]; + } + CatBatch(contexts, 1, allPositionIds); + } else { + allPositionIds.CopyFrom(*(Data*)positionIds[0]); + allPositionIds.Expansion({1, seqLen}); + for (int i = 1; i < batch; i++) { + CatDirect(allPositionIds, *(Data*)positionIds[i], 1); + } } - CatBatch(contexts, 1, allPositionIds); Embedding(inputIds, this->weight["model.embed_tokens.weight"], hiddenStates); ToDataType(hiddenStates, this->dataType); @@ -710,7 +718,11 @@ namespace fastllm { // 1.2 Attention // 1.2.0 q * k^T if (alibiData.dims.size() == 0) { - Attention(q, pastKey, pastValue, attentionMask[b] == nullptr ? Data() : *attentionMask[b], curAttenOutput, q.dims[0] / pastKey.dims[0], 1.0 / sqrt(head_dim), 1); + if (attentionMask[b] == nullptr) { + Attention(q, pastKey, pastValue, Data(), curAttenOutput, q.dims[0] / pastKey.dims[0], 1.0 / sqrt(head_dim), 1); + } else { + Attention(q, pastKey, pastValue, *attentionMask[b], curAttenOutput, q.dims[0] / pastKey.dims[0], 1.0 / sqrt(head_dim), 1); + } } else { MatMulTransB(q, pastKey, attenWeights, 1.0 / sqrt(head_dim), q.dims[0] / pastKey.dims[0]); attenWeights.Reshape({1, attenWeights.dims[0], attenWeights.dims[1], attenWeights.dims[2]}); @@ -732,6 +744,7 @@ namespace fastllm { std::vector dims = curAttenOutput.dims; dims[1] = total; attenOutput.Expansion(dims); + attenOutput.ToDevice(q.dataDevice); } CatDirect(attenOutput, curAttenOutput, 1); }