Skip to content

Commit

Permalink
优化会话cache,main程序更新(支持直接读huggingface模型,支持设置system prompt,eos_token等)
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed May 31, 2024
1 parent 41c45db commit 1d373ba
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 41 deletions.
38 changes: 35 additions & 3 deletions include/models/basellm.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,40 @@ namespace fastllm {
void RemoveHandle(int handleId);
};

struct PastKVCacheMemory {
std::string prompt;
int tokens;
int recordTimes = 0;
long long flushTime;
std::vector<std::pair<Data, Data> > kv;

PastKVCacheMemory () {}

PastKVCacheMemory (const std::string &prompt, int tokens, long long flushTime, std::vector<std::pair<Data, Data> > *kv);
};

struct PastKVCacheManager {
std::mutex locker;
int maxRecordNum = 5;
long long flushTime = 0;
std::map <std::string, PastKVCacheMemory*> memorys;

// 设置最多保存的记录条数
void SetMaxRecordNum(int maxRecordNum);

// 插入一条记录,若已存在则增加引用计数
void Record(const std::string &prompt, int tokens, std::vector<std::pair<Data, Data> > *kv);

// 尝试删除一条记录,若引用计数非0不会真的删除
void Remove(std::string prompt);

// 获取最长匹配的Memory,并加锁
PastKVCacheMemory *Get(const std::string &prompt);

// 解锁
void Unlock();
};

class basellm {
public:
basellm() {};
Expand Down Expand Up @@ -176,9 +210,7 @@ namespace fastllm {

int tokensLimit = -1;

std::string lastPrompt = "";
std::vector<std::pair<Data, Data> > *lastKeyValues = nullptr;
int lastPromptTokens = 0;
PastKVCacheManager pastKVCacheManager;
bool saveHistoryChat = false;

DataType dataType = DataType::FLOAT32;
Expand Down
53 changes: 44 additions & 9 deletions main.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,24 @@
#include "model.h"

std::map <std::string, fastllm::DataType> dataTypeDict = {
{"float32", fastllm::DataType::FLOAT32},
{"half", fastllm::DataType::FLOAT16},
{"float16", fastllm::DataType::FLOAT16},
{"int8", fastllm::DataType::INT8},
{"int4", fastllm::DataType::INT4_NOZERO},
{"int4z", fastllm::DataType::INT4},
{"int4g", fastllm::DataType::INT4_GROUP}
};

struct RunConfig {
std::string path = "chatglm-6b-int4.bin"; // 模型文件路径
std::string systemPrompt = "";
std::set <std::string> eosToken;
int threads = 4; // 使用的线程数
bool lowMemMode = false; // 是否使用低内存模式

fastllm::DataType dtype = fastllm::DataType::FLOAT16;
int groupCnt = -1;
};

void Usage() {
Expand All @@ -12,6 +27,9 @@ void Usage() {
std::cout << "<-p|--path> <args>: 模型文件的路径" << std::endl;
std::cout << "<-t|--threads> <args>: 使用的线程数量" << std::endl;
std::cout << "<-l|--low>: 使用低内存模式" << std::endl;
std::cout << "<--system> <args>: 设置系统提示词(system prompt)" << std::endl;
std::cout << "<--eos_token> <args>: 设置eos token" << std::endl;
std::cout << "<--dtype> <args>: 设置权重类型(读取hf文件时生效)" << std::endl;
std::cout << "<--top_p> <args>: 采样参数top_p" << std::endl;
std::cout << "<--top_k> <args>: 采样参数top_k" << std::endl;
std::cout << "<--temperature> <args>: 采样参数温度,越高结果越不固定" << std::endl;
Expand Down Expand Up @@ -43,6 +61,19 @@ void ParseArgs(int argc, char **argv, RunConfig &config, fastllm::GenerationConf
generationConfig.temperature = atof(sargv[++i].c_str());
} else if (sargv[i] == "--repeat_penalty") {
generationConfig.repeat_penalty = atof(sargv[++i].c_str());
} else if (sargv[i] == "--system") {
config.systemPrompt = sargv[++i];
} else if (sargv[i] == "--eos_token") {
config.eosToken.insert(sargv[++i]);
} else if (sargv[i] == "--dtype") {
std::string dtypeStr = sargv[++i];
if (dtypeStr.size() > 5 && dtypeStr.substr(0, 5) == "int4g") {
config.groupCnt = atoi(dtypeStr.substr(5).c_str());
dtypeStr = dtypeStr.substr(0, 5);
}
fastllm::AssertInFastLLM(dataTypeDict.find(dtypeStr) != dataTypeDict.end(),
"Unsupport data type: " + dtypeStr);
config.dtype = dataTypeDict[dtypeStr];
} else {
Usage();
exit(-1);
Expand All @@ -51,34 +82,39 @@ void ParseArgs(int argc, char **argv, RunConfig &config, fastllm::GenerationConf
}

int main(int argc, char **argv) {
int round = 0;
std::string history = "";

RunConfig config;
fastllm::GenerationConfig generationConfig;
ParseArgs(argc, argv, config, generationConfig);

fastllm::PrintInstructionInfo();
fastllm::SetThreads(config.threads);
fastllm::SetLowMemMode(config.lowMemMode);
auto model = fastllm::CreateLLMModelFromFile(config.path);
bool isHFDir = access((config.path + "/config.json").c_str(), R_OK) == 0 || access((config.path + "config.json").c_str(), R_OK) == 0;
auto model = !isHFDir ? fastllm::CreateLLMModelFromFile(config.path) : fastllm::CreateLLMModelFromHF(config.path, config.dtype, config.groupCnt);
model->SetSaveHistoryChat(true);

for (auto &it : config.eosToken) {
generationConfig.stop_token_ids.insert(model->weight.tokenizer.GetTokenId(it));
}
std::string systemConfig = config.systemPrompt;
fastllm::ChatMessages messages = {{"system", systemConfig}};

static std::string modelType = model->model_type;
printf("欢迎使用 %s 模型. 输入内容对话,reset清空历史记录,stop退出程序.\n", model->model_type.c_str());

while (true) {
printf("用户: ");
std::string input;
std::getline(std::cin, input);
if (input == "reset") {
history = "";
round = 0;
fastllm::ChatMessages messages = {{"system", config.systemPrompt}};
continue;
}
if (input == "stop") {
break;
}
std::string ret = model->Response(model->MakeInput(history, round, input), [](int index, const char* content) {
messages.push_back(std::make_pair("user", input));
std::string ret = model->Response(model->ApplyChatTemplate(messages), [](int index, const char* content) {
if (index == 0) {
printf("%s:%s", modelType.c_str(), content);
fflush(stdout);
Expand All @@ -91,8 +127,7 @@ int main(int argc, char **argv) {
printf("\n");
}
}, generationConfig);
history = model->MakeHistory(history, round, input, ret);
round++;
messages.push_back(std::make_pair("assistant", ret));
}

return 0;
Expand Down
12 changes: 9 additions & 3 deletions src/fastllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ namespace fastllm {
this->cacheUid = ori.cacheUid;

// std::cout<<"调用拷贝构造"<<std::endl;
if (ori.dims != this->dims || this->cpuData == nullptr || ori.dataType != this->dataType) {
if (ori.expansionDims != this->expansionDims || ori.dims != this->dims || this->cpuData == nullptr || ori.dataType != this->dataType) {
if (ori.dims.size() == 0) {
delete[] this->cpuData;
this->dataType = ori.dataType;
Expand All @@ -309,8 +309,14 @@ namespace fastllm {
return;
}
this->dataType = ori.dataType;
this->Resize(ori.dims);
this->Allocate();
if (ori.expansionDims.size() > 0 && ori.expansionDims != ori.dims) {
this->Expansion(ori.expansionDims);
this->Resize(ori.dims);
this->Allocate();
} else {
this->Resize(ori.dims);
this->Allocate();
}
}
std::memcpy(this->cpuData, ori.cpuData, this->GetBytes());
}
Expand Down
138 changes: 112 additions & 26 deletions src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,26 +55,94 @@ namespace fastllm {
isEnding = false;
preTokens = 0;
}

PastKVCacheMemory::PastKVCacheMemory(const std::string &prompt, int tokens, long long flushTime, std::vector<std::pair<Data, Data> > *kv) {
this->prompt = prompt;
this->tokens = tokens;
this->flushTime = flushTime;
this->recordTimes = 1;
auto dataType = (*kv)[0].first.dataType;
for (int i = 0; i < kv->size(); i++) {
this->kv.push_back(std::make_pair(Data(dataType), Data(dataType)));
}
for (int i = 0; i < kv->size(); i++) {
this->kv[i].first.CopyFrom((*kv)[i].first);
this->kv[i].second.CopyFrom((*kv)[i].second);
}
}

void PastKVCacheManager::SetMaxRecordNum(int maxRecordNum) {
std::lock_guard <std::mutex> lock(this->locker);
this->maxRecordNum = maxRecordNum;
}

void PastKVCacheManager::Record(const std::string &prompt, int tokens, std::vector<std::pair<Data, Data> > *kv) {
std::lock_guard <std::mutex> lock(this->locker);
if (this->memorys.find(prompt) != this->memorys.end()) {
this->memorys[prompt]->recordTimes++;
this->memorys[prompt]->flushTime = ++flushTime;
return;
}

if (this->memorys.size() >= this->maxRecordNum) {
std::string prompt = "";
long long minFlushTime = (1LL << 60);
for (auto &it : this->memorys) {
if (it.second->flushTime < minFlushTime) {
minFlushTime = it.second->flushTime;
prompt = it.first;
}
}
delete this->memorys[prompt];
this->memorys.erase(this->memorys.find(prompt));
}

this->memorys[prompt] = new PastKVCacheMemory(prompt, tokens, ++flushTime, kv);
}

void PastKVCacheManager::Remove(std::string prompt) {
std::lock_guard <std::mutex> lock(this->locker);
if (this->memorys.find(prompt) != this->memorys.end()) {
if ((--this->memorys[prompt]->recordTimes) <= 0) {
delete this->memorys[prompt];
this->memorys.erase(this->memorys.find(prompt));
}
}
}

PastKVCacheMemory *PastKVCacheManager::Get(const std::string &prompt) {
locker.lock();
std::string maxPrompt = "";
for (auto &it : this->memorys) {
const std::string &cur = it.first;
if (cur.size() > maxPrompt.size() && cur.size() <= prompt.size() && prompt.substr(0, cur.size()) == cur) {
maxPrompt = cur;
}
}
if (maxPrompt == "") {
return nullptr;
}
this->memorys[maxPrompt]->flushTime = ++this->flushTime;
return this->memorys[maxPrompt];
}

void PastKVCacheManager::Unlock() {
locker.unlock();
}

std::string basellm::Response(const std::string &oriInput, RuntimeResult retCb,
const fastllm::GenerationConfig &generationConfig) {
std::string input = oriInput;
PastKVCacheMemory *memory;
std::string oldPrompt;
int oldTokens = 0;
if (this->saveHistoryChat) {
if (lastKeyValues != nullptr) {
if (input.size() < lastPrompt.size() || (input.substr(0, lastPrompt.size()) != lastPrompt)) {
lastPrompt = "";
lastPromptTokens = 0;
delete lastKeyValues;
lastKeyValues = nullptr;
} else {
input = input.substr(lastPrompt.size());
}
memory = pastKVCacheManager.Get(input);
if (memory != nullptr) {
oldPrompt = memory->prompt;
oldTokens = memory->tokens;
input = input.substr(memory->prompt.size());
}
} else {
lastPrompt = "";
lastPromptTokens = 0;
delete lastKeyValues;
lastKeyValues = nullptr;
}

//printf("lastPrompt = %s\n", lastPrompt.c_str());
Expand All @@ -97,21 +165,32 @@ namespace fastllm {
for (int i = 0; i < inputTokenData.Count(0); i++) {
inputTokens[0].push_back(((float *) inputTokenData.cpuData)[i]);
}

if (lastKeyValues == nullptr) {
lastKeyValues = new std::vector<std::pair<Data, Data> >();
for (int i = 0; i < block_cnt; i++) {
lastKeyValues->push_back(std::make_pair(Data(this->dataType), Data(this->dataType)));
lastKeyValues->back().first.SetKVCache();
lastKeyValues->back().second.SetKVCache();

std::vector <std::pair <Data, Data> > pastKeyValues;
for (int i = 0; i < block_cnt; i++) {
pastKeyValues.push_back(std::make_pair(Data(this->dataType),
Data(this->dataType)));
}

if (this->saveHistoryChat) {
if (memory != nullptr) {
for (int i = 0; i < block_cnt; i++) {
pastKeyValues[i].first.CopyFrom(memory->kv[i].first);
pastKeyValues[i].second.CopyFrom(memory->kv[i].second);
}
}
pastKVCacheManager.Unlock();
}

for (int i = 0; i < block_cnt; i++) {
pastKeyValues.back().first.SetKVCache();
pastKeyValues.back().second.SetKVCache();
}

std::vector<std::pair<Data, Data> > &pastKeyValues = (*lastKeyValues);
std::string retString = "";
std::vector<float> results;
LastTokensManager tokens(1, generationConfig.last_n);
int promptLen = lastPromptTokens + inputTokens[0].size(), index = 0;
int promptLen = oldTokens + inputTokens[0].size(), index = 0;
int add_special_tokens = generationConfig.add_special_tokens? 1: 0;
FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}, {"add_special_tokens", add_special_tokens}},
inputIds, attentionMask, positionIds);
Expand All @@ -121,7 +200,7 @@ namespace fastllm {
int ret = Forward(inputIds, attentionMask, positionIds, pastKeyValues, generationConfig, tokens);
tokens.units[0].Push(ret);
if (ret == eos_token_id
|| generationConfig.stop_token_ids.find(index) != generationConfig.stop_token_ids.end()) {
|| generationConfig.stop_token_ids.find(ret) != generationConfig.stop_token_ids.end()) {
break;
}

Expand Down Expand Up @@ -171,8 +250,15 @@ namespace fastllm {
retCb(-1, retString.c_str());
#endif

lastPrompt += (input + retString);
lastPromptTokens = promptLen + index;
if (this->saveHistoryChat) {
std::string currentPrompt;
int currentTokens;
if (oldPrompt != "") {
pastKVCacheManager.Remove(oldPrompt);
}
pastKVCacheManager.Record(oriInput + retString, promptLen + index, &pastKeyValues);
}

return retString;
}

Expand Down

0 comments on commit 1d373ba

Please sign in to comment.