Skip to content

Commit

Permalink
fix llama
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jun 1, 2024
1 parent ee76be1 commit e128b35
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions src/models/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -522,10 +522,18 @@ namespace fastllm {
all1 &= (seqLens[i] == 1);
}

for (int b = 0; b < batch; b++) {
contexts[b] = positionIds[b];
if (all1) {
for (int b = 0; b < batch; b++) {
contexts[b] = positionIds[b];
}
CatBatch(contexts, 1, allPositionIds);
} else {
allPositionIds.CopyFrom(*(Data*)positionIds[0]);
allPositionIds.Expansion({1, seqLen});
for (int i = 1; i < batch; i++) {
CatDirect(allPositionIds, *(Data*)positionIds[i], 1);
}
}
CatBatch(contexts, 1, allPositionIds);

Embedding(inputIds, this->weight["model.embed_tokens.weight"], hiddenStates);
ToDataType(hiddenStates, this->dataType);
Expand Down Expand Up @@ -710,7 +718,11 @@ namespace fastllm {
// 1.2 Attention
// 1.2.0 q * k^T
if (alibiData.dims.size() == 0) {
Attention(q, pastKey, pastValue, attentionMask[b] == nullptr ? Data() : *attentionMask[b], curAttenOutput, q.dims[0] / pastKey.dims[0], 1.0 / sqrt(head_dim), 1);
if (attentionMask[b] == nullptr) {
Attention(q, pastKey, pastValue, Data(), curAttenOutput, q.dims[0] / pastKey.dims[0], 1.0 / sqrt(head_dim), 1);
} else {
Attention(q, pastKey, pastValue, *attentionMask[b], curAttenOutput, q.dims[0] / pastKey.dims[0], 1.0 / sqrt(head_dim), 1);
}
} else {
MatMulTransB(q, pastKey, attenWeights, 1.0 / sqrt(head_dim), q.dims[0] / pastKey.dims[0]);
attenWeights.Reshape({1, attenWeights.dims[0], attenWeights.dims[1], attenWeights.dims[2]});
Expand All @@ -732,6 +744,7 @@ namespace fastllm {
std::vector <int> dims = curAttenOutput.dims;
dims[1] = total;
attenOutput.Expansion(dims);
attenOutput.ToDevice(q.dataDevice);
}
CatDirect(attenOutput, curAttenOutput, 1);
}
Expand Down

0 comments on commit e128b35

Please sign in to comment.