From 6ea5e3cba919cfdc38ac50320ddf189306f332b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Sat, 18 May 2024 17:42:42 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AF=B9=E8=AF=9D=E6=A8=A1=E5=BC=8F=E4=B8=8B?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=AD=98=E5=82=A8=E4=B9=8B=E5=89=8D=E7=9A=84?= =?UTF-8?q?pastKeyCache?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 6 +--- include/models/basellm.h | 7 +++++ main.cpp | 3 +- src/models/basellm.cpp | 67 ++++++++++++++++++++++++++++++++-------- 4 files changed, 64 insertions(+), 19 deletions(-) diff --git a/.gitignore b/.gitignore index 2a2949a5..68833c74 100644 --- a/.gitignore +++ b/.gitignore @@ -2,11 +2,7 @@ *.pyc token /cmake-build-debug/ -/build-tfacc/ -/build-android/ -/build-x86/ -/build-py/ -/build/ +/build* /pyfastllm/build/ /pyfastllm/dist/ /.idea/ diff --git a/include/models/basellm.h b/include/models/basellm.h index 61302542..e010188a 100644 --- a/include/models/basellm.h +++ b/include/models/basellm.h @@ -128,6 +128,8 @@ namespace fastllm { virtual void DisableAdapter(); + virtual bool SetSaveHistoryChat(bool save); + std::string model_type; std::string pre_prompt; // 最初对话的提示语 @@ -159,6 +161,11 @@ namespace fastllm { std::string adapterName; int tokensLimit = -1; + + std::string lastPrompt = ""; + std::vector > *lastKeyValues = nullptr; + int lastPromptTokens = 0; + bool saveHistoryChat = false; }; } diff --git a/main.cpp b/main.cpp index e0bf06d7..4d4dee73 100644 --- a/main.cpp +++ b/main.cpp @@ -1,4 +1,4 @@ -#include "model.h" +#include "model.h" struct RunConfig { std::string path = "chatglm-6b-int4.bin"; // 模型文件路径 @@ -62,6 +62,7 @@ int main(int argc, char **argv) { fastllm::SetThreads(config.threads); fastllm::SetLowMemMode(config.lowMemMode); auto model = fastllm::CreateLLMModelFromFile(config.path); + model->SetSaveHistoryChat(true); static std::string modelType = model->model_type; printf("欢迎使用 %s 模型. 输入内容对话,reset清空历史记录,stop退出程序.\n", model->model_type.c_str()); diff --git a/src/models/basellm.cpp b/src/models/basellm.cpp index 040c8dd3..3cc7e5b9 100644 --- a/src/models/basellm.cpp +++ b/src/models/basellm.cpp @@ -56,8 +56,30 @@ namespace fastllm { preTokens = 0; } - std::string basellm::Response(const std::string &input, RuntimeResult retCb, + std::string basellm::Response(const std::string &oriInput, RuntimeResult retCb, const fastllm::GenerationConfig &generationConfig) { + std::string input = oriInput; + if (this->saveHistoryChat) { + if (lastKeyValues != nullptr) { + if (input.size() < lastPrompt.size() || (input.substr(0, lastPrompt.size()) != lastPrompt)) { + lastPrompt = ""; + lastPromptTokens = 0; + delete lastKeyValues; + lastKeyValues = nullptr; + } else { + input = input.substr(lastPrompt.size()); + } + } + } else { + lastPrompt = ""; + lastPromptTokens = 0; + delete lastKeyValues; + lastKeyValues = nullptr; + } + + //printf("lastPrompt = %s\n", lastPrompt.c_str()); + //printf("input = %s\n", input.c_str()); + #ifdef USE_CUDA FastllmCudaClearBigBuffer(); #endif @@ -77,18 +99,21 @@ namespace fastllm { } DataType testDataType = DataType::FLOAT32; - std::vector > pastKeyValues; - for (int i = 0; i < block_cnt; i++) { - pastKeyValues.push_back(std::make_pair(Data(testDataType), - Data(testDataType))); - pastKeyValues.back().first.SetKVCache(); - pastKeyValues.back().second.SetKVCache(); + if (lastKeyValues == nullptr) { + lastKeyValues = new std::vector >(); + for (int i = 0; i < block_cnt; i++) { + lastKeyValues->push_back(std::make_pair(Data(testDataType), + Data(testDataType))); + lastKeyValues->back().first.SetKVCache(); + lastKeyValues->back().second.SetKVCache(); + } } + std::vector > &pastKeyValues = (*lastKeyValues); std::string retString = ""; std::vector results; LastTokensManager tokens(1, generationConfig.last_n); - int promptLen = inputTokens[0].size(), index = 0; + int promptLen = lastPromptTokens + inputTokens[0].size(), index = 0; FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}}, inputIds, attentionMask, positionIds); while (true) { auto st = std::chrono::system_clock::now(); @@ -141,6 +166,9 @@ namespace fastllm { #else retCb(-1, retString.c_str()); #endif + + lastPrompt += (input + retString); + lastPromptTokens = promptLen + index; return retString; } @@ -784,19 +812,19 @@ printf("tot = %d\n", tot); int index = params.find("index")->second; int promptLen = params.find("promptLen")->second; - if (index == 0) { + if (inputTokens[0].size() > 1) { int seqLen = inputTokens[0].size(); - std::vector vmask = std::vector (seqLen * seqLen, 0); + std::vector vmask = std::vector (seqLen * promptLen, 0); std::vector vpids = std::vector (seqLen, 0); for (int i = 0; i < seqLen; i++) { - vpids[i] = i; + vpids[i] = promptLen - seqLen + i; for (int j = i + 1; j < seqLen; j++) { - vmask[i * seqLen + j] = 1; + vmask[i * promptLen + (promptLen - seqLen + j)] = 1; } } inputIds.CopyFrom(Data(DataType::FLOAT32, {1, seqLen}, inputTokens[0])); - attentionMask.CopyFrom(Data(DataType::FLOAT32, {seqLen, seqLen}, vmask)); + attentionMask.CopyFrom(Data(DataType::FLOAT32, {seqLen, promptLen}, vmask)); positionIds.CopyFrom(Data(DataType::FLOAT32, {1, seqLen}, vpids)); } else { inputIds.CopyFrom(Data(DataType::FLOAT32, {1, 1}, inputTokens[0])); @@ -821,4 +849,17 @@ printf("tot = %d\n", tot); void basellm::DisableAdapter() { adapterName = ""; } + + bool basellm::SetSaveHistoryChat(bool save) { + if (this->model_type == "llama" || + this->model_type == "moe" || + this->model_type == "internlm" || + this->model_type == "qwen2_moe" || + this->model_type == "deepseek_v2" || + this->model_type == "qwen") { + this->saveHistoryChat = save; + return true; + } + return false; + } }