Skip to content

Commit

Permalink
对话模式下支持存储之前的pastKeyCache
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed May 18, 2024
1 parent 5a2cd63 commit 6ea5e3c
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 19 deletions.
6 changes: 1 addition & 5 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@
*.pyc
token
/cmake-build-debug/
/build-tfacc/
/build-android/
/build-x86/
/build-py/
/build/
/build*
/pyfastllm/build/
/pyfastllm/dist/
/.idea/
Expand Down
7 changes: 7 additions & 0 deletions include/models/basellm.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ namespace fastllm {

virtual void DisableAdapter();

virtual bool SetSaveHistoryChat(bool save);

std::string model_type;

std::string pre_prompt; // 最初对话的提示语
Expand Down Expand Up @@ -159,6 +161,11 @@ namespace fastllm {
std::string adapterName;

int tokensLimit = -1;

std::string lastPrompt = "";
std::vector<std::pair<Data, Data> > *lastKeyValues = nullptr;
int lastPromptTokens = 0;
bool saveHistoryChat = false;
};
}

Expand Down
3 changes: 2 additions & 1 deletion main.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "model.h"
#include "model.h"

struct RunConfig {
std::string path = "chatglm-6b-int4.bin"; // 模型文件路径
Expand Down Expand Up @@ -62,6 +62,7 @@ int main(int argc, char **argv) {
fastllm::SetThreads(config.threads);
fastllm::SetLowMemMode(config.lowMemMode);
auto model = fastllm::CreateLLMModelFromFile(config.path);
model->SetSaveHistoryChat(true);

static std::string modelType = model->model_type;
printf("欢迎使用 %s 模型. 输入内容对话,reset清空历史记录,stop退出程序.\n", model->model_type.c_str());
Expand Down
67 changes: 54 additions & 13 deletions src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,30 @@ namespace fastllm {
preTokens = 0;
}

std::string basellm::Response(const std::string &input, RuntimeResult retCb,
std::string basellm::Response(const std::string &oriInput, RuntimeResult retCb,
const fastllm::GenerationConfig &generationConfig) {
std::string input = oriInput;
if (this->saveHistoryChat) {
if (lastKeyValues != nullptr) {
if (input.size() < lastPrompt.size() || (input.substr(0, lastPrompt.size()) != lastPrompt)) {
lastPrompt = "";
lastPromptTokens = 0;
delete lastKeyValues;
lastKeyValues = nullptr;
} else {
input = input.substr(lastPrompt.size());
}
}
} else {
lastPrompt = "";
lastPromptTokens = 0;
delete lastKeyValues;
lastKeyValues = nullptr;
}

//printf("lastPrompt = %s\n", lastPrompt.c_str());
//printf("input = %s\n", input.c_str());

#ifdef USE_CUDA
FastllmCudaClearBigBuffer();
#endif
Expand All @@ -77,18 +99,21 @@ namespace fastllm {
}

DataType testDataType = DataType::FLOAT32;
std::vector<std::pair<Data, Data> > pastKeyValues;
for (int i = 0; i < block_cnt; i++) {
pastKeyValues.push_back(std::make_pair(Data(testDataType),
Data(testDataType)));
pastKeyValues.back().first.SetKVCache();
pastKeyValues.back().second.SetKVCache();
if (lastKeyValues == nullptr) {
lastKeyValues = new std::vector<std::pair<Data, Data> >();
for (int i = 0; i < block_cnt; i++) {
lastKeyValues->push_back(std::make_pair(Data(testDataType),
Data(testDataType)));
lastKeyValues->back().first.SetKVCache();
lastKeyValues->back().second.SetKVCache();
}
}

std::vector<std::pair<Data, Data> > &pastKeyValues = (*lastKeyValues);
std::string retString = "";
std::vector<float> results;
LastTokensManager tokens(1, generationConfig.last_n);
int promptLen = inputTokens[0].size(), index = 0;
int promptLen = lastPromptTokens + inputTokens[0].size(), index = 0;
FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}}, inputIds, attentionMask, positionIds);
while (true) {
auto st = std::chrono::system_clock::now();
Expand Down Expand Up @@ -141,6 +166,9 @@ namespace fastllm {
#else
retCb(-1, retString.c_str());
#endif

lastPrompt += (input + retString);
lastPromptTokens = promptLen + index;
return retString;
}

Expand Down Expand Up @@ -784,19 +812,19 @@ printf("tot = %d\n", tot);
int index = params.find("index")->second;
int promptLen = params.find("promptLen")->second;

if (index == 0) {
if (inputTokens[0].size() > 1) {
int seqLen = inputTokens[0].size();

std::vector <float> vmask = std::vector <float> (seqLen * seqLen, 0);
std::vector <float> vmask = std::vector <float> (seqLen * promptLen, 0);
std::vector <float> vpids = std::vector <float> (seqLen, 0);
for (int i = 0; i < seqLen; i++) {
vpids[i] = i;
vpids[i] = promptLen - seqLen + i;
for (int j = i + 1; j < seqLen; j++) {
vmask[i * seqLen + j] = 1;
vmask[i * promptLen + (promptLen - seqLen + j)] = 1;
}
}
inputIds.CopyFrom(Data(DataType::FLOAT32, {1, seqLen}, inputTokens[0]));
attentionMask.CopyFrom(Data(DataType::FLOAT32, {seqLen, seqLen}, vmask));
attentionMask.CopyFrom(Data(DataType::FLOAT32, {seqLen, promptLen}, vmask));
positionIds.CopyFrom(Data(DataType::FLOAT32, {1, seqLen}, vpids));
} else {
inputIds.CopyFrom(Data(DataType::FLOAT32, {1, 1}, inputTokens[0]));
Expand All @@ -821,4 +849,17 @@ printf("tot = %d\n", tot);
void basellm::DisableAdapter() {
adapterName = "";
}

bool basellm::SetSaveHistoryChat(bool save) {
if (this->model_type == "llama" ||
this->model_type == "moe" ||
this->model_type == "internlm" ||
this->model_type == "qwen2_moe" ||
this->model_type == "deepseek_v2" ||
this->model_type == "qwen") {
this->saveHistoryChat = save;
return true;
}
return false;
}
}

0 comments on commit 6ea5e3c

Please sign in to comment.