Skip to content

Commit

Permalink
Merge pull request #474 from TylunasLi/develop
Browse files Browse the repository at this point in the history
支持将直接读取safetrensors得到的模型存为flm格式,并加载推理
  • Loading branch information
ztxz16 authored Jul 8, 2024
2 parents 677bae9 + e787aa6 commit c738c59
Show file tree
Hide file tree
Showing 11 changed files with 85 additions and 15 deletions.
4 changes: 4 additions & 0 deletions example/Win32Demo/fastllm-gpu.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,15 @@
<ClInclude Include="..\..\include\devices\cuda\cudadevice.h" />
<ClInclude Include="..\..\include\executor.h" />
<ClInclude Include="..\..\include\fastllm.h" />
<ClInclude Include="..\..\include\graph.h" />
<ClInclude Include="..\..\include\model.h" />
<ClInclude Include="..\..\include\models\basellm.h" />
<ClInclude Include="..\..\include\models\bert.h" />
<ClInclude Include="..\..\include\models\chatglm.h" />
<ClInclude Include="..\..\include\models\deepseekv2.h" />
<ClInclude Include="..\..\include\models\factoryllm.h" />
<ClInclude Include="..\..\include\models\glm.h" />
<ClInclude Include="..\..\include\models\graphllm.h" />
<ClInclude Include="..\..\include\models\internlm2.h" />
<ClInclude Include="..\..\include\models\llama.h" />
<ClInclude Include="..\..\include\models\minicpm.h" />
Expand All @@ -227,12 +229,14 @@
<ClCompile Include="..\..\src\devices\cuda\cudadevicebatch.cpp" />
<ClCompile Include="..\..\src\executor.cpp" />
<ClCompile Include="..\..\src\fastllm.cpp" />
<ClCompile Include="..\..\src\graph.cpp" />
<ClCompile Include="..\..\src\model.cpp" />
<ClCompile Include="..\..\src\models\basellm.cpp" />
<ClCompile Include="..\..\src\models\bert.cpp" />
<ClCompile Include="..\..\src\models\chatglm.cpp" />
<ClCompile Include="..\..\src\models\deepseekv2.cpp" />
<ClCompile Include="..\..\src\models\glm.cpp" />
<ClCompile Include="..\..\src\models\graphllm.cpp" />
<ClCompile Include="..\..\src\models\internlm2.cpp" />
<ClCompile Include="..\..\src\models\llama.cpp" />
<ClCompile Include="..\..\src\models\minicpm.cpp" />
Expand Down
12 changes: 12 additions & 0 deletions example/Win32Demo/fastllm-gpu.vcxproj.filters
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@
<ClInclude Include="..\..\include\fastllm.h">
<Filter>头文件</Filter>
</ClInclude>
<ClInclude Include="..\..\include\graph.h">
<Filter>头文件</Filter>
</ClInclude>
<ClInclude Include="..\..\include\model.h">
<Filter>头文件</Filter>
</ClInclude>
Expand All @@ -81,6 +84,9 @@
<ClInclude Include="..\..\include\models\glm.h">
<Filter>头文件\models</Filter>
</ClInclude>
<ClInclude Include="..\..\include\models\graphllm.h">
<Filter>头文件\models</Filter>
</ClInclude>
<ClInclude Include="..\..\include\models\internlm2.h">
<Filter>头文件\models</Filter>
</ClInclude>
Expand Down Expand Up @@ -134,6 +140,9 @@
<ClCompile Include="..\..\src\fastllm.cpp">
<Filter>源文件</Filter>
</ClCompile>
<ClCompile Include="..\..\src\graph.cpp">
<Filter>源文件</Filter>
</ClCompile>
<ClCompile Include="..\..\src\model.cpp">
<Filter>源文件</Filter>
</ClCompile>
Expand All @@ -155,6 +164,9 @@
<ClCompile Include="..\..\src\models\glm.cpp">
<Filter>源文件\models</Filter>
</ClCompile>
<ClCompile Include="..\..\src\models\graphllm.cpp">
<Filter>源文件\models</Filter>
</ClCompile>
<ClCompile Include="..\..\src\models\internlm2.cpp">
<Filter>源文件\models</Filter>
</ClCompile>
Expand Down
4 changes: 4 additions & 0 deletions example/Win32Demo/fastllm.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,15 @@
<ClInclude Include="..\..\include\devices\cpu\cputhreadpool.h" />
<ClInclude Include="..\..\include\executor.h" />
<ClInclude Include="..\..\include\fastllm.h" />
<ClInclude Include="..\..\include\graph.h" />
<ClInclude Include="..\..\include\model.h" />
<ClInclude Include="..\..\include\models\basellm.h" />
<ClInclude Include="..\..\include\models\bert.h" />
<ClInclude Include="..\..\include\models\chatglm.h" />
<ClInclude Include="..\..\include\models\deepseekv2.h" />
<ClInclude Include="..\..\include\models\factoryllm.h" />
<ClInclude Include="..\..\include\models\glm.h" />
<ClInclude Include="..\..\include\models\graphllm.h" />
<ClInclude Include="..\..\include\models\internlm2.h" />
<ClInclude Include="..\..\include\models\llama.h" />
<ClInclude Include="..\..\include\models\minicpm.h" />
Expand All @@ -201,12 +203,14 @@
<ClCompile Include="..\..\src\devices\cpu\cpudevicebatch.cpp" />
<ClCompile Include="..\..\src\executor.cpp" />
<ClCompile Include="..\..\src\fastllm.cpp" />
<ClCompile Include="..\..\src\graph.cpp" />
<ClCompile Include="..\..\src\model.cpp" />
<ClCompile Include="..\..\src\models\basellm.cpp" />
<ClCompile Include="..\..\src\models\bert.cpp" />
<ClCompile Include="..\..\src\models\chatglm.cpp" />
<ClCompile Include="..\..\src\models\deepseekv2.cpp" />
<ClCompile Include="..\..\src\models\glm.cpp" />
<ClCompile Include="..\..\src\models\graphllm.cpp" />
<ClCompile Include="..\..\src\models\internlm2.cpp" />
<ClCompile Include="..\..\src\models\llama.cpp" />
<ClCompile Include="..\..\src\models\minicpm.cpp" />
Expand Down
12 changes: 12 additions & 0 deletions example/Win32Demo/fastllm.vcxproj.filters
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@
<ClInclude Include="..\..\include\fastllm.h">
<Filter>头文件</Filter>
</ClInclude>
<ClInclude Include="..\..\include\graph.h">
<Filter>头文件</Filter>
</ClInclude>
<ClInclude Include="..\..\include\model.h">
<Filter>头文件</Filter>
</ClInclude>
Expand All @@ -81,6 +84,9 @@
<ClInclude Include="..\..\include\models\glm.h">
<Filter>头文件\models</Filter>
</ClInclude>
<ClInclude Include="..\..\include\models\graphllm.h">
<Filter>头文件\models</Filter>
</ClInclude>
<ClInclude Include="..\..\include\models\internlm2.h">
<Filter>头文件\models</Filter>
</ClInclude>
Expand Down Expand Up @@ -128,6 +134,9 @@
<ClCompile Include="..\..\src\fastllm.cpp">
<Filter>源文件</Filter>
</ClCompile>
<ClCompile Include="..\..\src\graph.cpp">
<Filter>源文件</Filter>
</ClCompile>
<ClCompile Include="..\..\src\model.cpp">
<Filter>源文件</Filter>
</ClCompile>
Expand All @@ -149,6 +158,9 @@
<ClCompile Include="..\..\src\models\glm.cpp">
<Filter>源文件\models</Filter>
</ClCompile>
<ClCompile Include="..\..\src\models\graphllm.cpp">
<Filter>源文件\models</Filter>
</ClCompile>
<ClCompile Include="..\..\src\models\internlm2.cpp">
<Filter>源文件\models</Filter>
</ClCompile>
Expand Down
2 changes: 1 addition & 1 deletion include/devices/cpu/alivethreadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ namespace fastllm {
auto duration = std::chrono::duration_cast<std::chrono::microseconds> (std::chrono::system_clock::now() - lastRunTime);
double gap = double(duration.count()) * std::chrono::microseconds::period::num / std::chrono::microseconds::period::den;
if (gap > 3) {
std::this_thread::sleep_for(std::chrono::seconds(0));
std::this_thread::sleep_for(std::chrono::microseconds(2));
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/fastllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1976,6 +1976,8 @@ namespace fastllm {
}
tokenizer.SetSpecialTokens(specialTokens);
}
if (this->dicts.find("chat_template") != this->dicts.end())
tokenizer.chatTemplate = this->dicts["chat_template"];

int len = buffer.ReadInt();
for (int i = 0; i < len; i++) {
Expand Down
42 changes: 33 additions & 9 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,21 @@ namespace fastllm {

void basellm::InitParams() {
if (this->weight.dicts.find("bos_token_id") != this->weight.dicts.end()) {
if(this->weight.dicts["bos_token_id"]!="None"){
if (this->weight.dicts["bos_token_id"]!="None") {
this->bos_token_id = atoi(this->weight.dicts["bos_token_id"].c_str());
}
if(this->weight.dicts["eos_token_id"]!="None"){
this->eos_token_id = atoi(this->weight.dicts["eos_token_id"].c_str());
}
if (this->weight.dicts.find("eos_token_id") != this->weight.dicts.end()) {
if (this->weight.dicts["eos_token_id"]!="None") {
if (this->weight.dicts["eos_token_id"][0] == '[' && this->eos_token_ids.empty()) {
std::string error;
json11::Json ids = json11::Json::parse(this->weight.dicts["eos_token_id"], error);
for (auto &it : ids.array_items()) {
this->eos_token_ids.insert(it.int_value());
}
} else {
this->eos_token_id = atoi(this->weight.dicts["eos_token_id"].c_str());
}
}
}
if (this->weight.dicts.find("im_start_id") != this->weight.dicts.end()) {
Expand Down Expand Up @@ -127,6 +137,16 @@ namespace fastllm {
}

void basellm::SaveModel(const std::string &fileName) {
if (this->weight.tokenizer.chatTemplate.empty()) {
if (this->weight.dicts.find("pre_prompt") == this->weight.dicts.end())
this->weight.dicts["pre_prompt"] = pre_prompt;
if (this->weight.dicts.find("user_role") == this->weight.dicts.end())
this->weight.dicts["user_role"] = user_role;
if (this->weight.dicts.find("bot_role") == this->weight.dicts.end())
this->weight.dicts["bot_role"] = bot_role;
if (this->weight.dicts.find("history_sep") == this->weight.dicts.end())
this->weight.dicts["history_sep"] = history_sep;
}
this->weight.SaveLowBitModel(fileName, 0);
}

Expand Down Expand Up @@ -262,9 +282,9 @@ namespace fastllm {
ClearBuffer();
buffer = new uint8_t[len * unitSize];

FILE *fi = fopen(this->fileName.c_str(), "r");
FILE *fi = fopen(this->fileName.c_str(), "rb");
int ret;
#if defined(_WIN32) or defined(_WIN64)
#if defined(_WIN32) || defined(_WIN64)
_fseeki64(fi, this->data_offsets[0], 0);
#else
fseek(fi, this->data_offsets[0], 0);
Expand Down Expand Up @@ -424,6 +444,8 @@ namespace fastllm {
std::string tokenizerConfigFile = path + "tokenizer_config.json";
auto tokenizerConfig = json11::Json::parse(ReadAllFile(tokenizerConfigFile), error);
model->weight.tokenizer.SetTokenizerConfig(tokenizerConfig);
if (!model->weight.tokenizer.chatTemplate.empty() && model->weight.dicts.find("chat_template") == model->weight.dicts.end())
model->weight.AddDict("chat_template", model->weight.tokenizer.chatTemplate);
std::string tokenizerClass = tokenizerConfig["tokenizer_class"].string_value();
if (tokenizerClass == "PreTrainedTokenizerFast"
|| tokenizerClass == "Qwen2Tokenizer"
Expand All @@ -439,14 +461,16 @@ namespace fastllm {
spTokens[it["content"].string_value()] = it["id"].int_value();
}
model->weight.tokenizer.SetSpecialTokens(spTokens);
if (!spTokens.empty())
model->weight.AddDict("tokenizer_has_special_tokens", "1");

if (!tokenizer["decoder"].is_null() && !tokenizer["decoder"]["type"].is_null() &&
tokenizer["decoder"]["type"].string_value() == "ByteLevel") {
model->weight.tokenizer.byteAsChar = true;
model->weight.AddDict("tokenizer_byte_as_char", "True");
}
} else if (tokenizerClass == "ChatGLM4Tokenizer") {
// GLM4御用的分词
model->bot_role = " ";
std::vector <std::string> lines, line;
SplitString(ReadAllFile(path + "tokenizer.model"), {'\r', '\n'}, lines);
for (int i = 0; i < lines.size(); i++) {
Expand All @@ -458,8 +482,8 @@ namespace fastllm {
spTokens[it.second["content"].string_value()] = atoi(it.first.c_str());
}
model->weight.tokenizer.SetSpecialTokens(spTokens);
((ChatGLMModel*)model)->gmask_token_id = model->weight.tokenizer.GetTokenId("[gMASK]");
((ChatGLMModel*)model)->bos_token_id = model->weight.tokenizer.GetTokenId("<sop>");
model->weight.AddDict("tokenizer_has_special_tokens", "1");
model->weight.AddDict("tokenizer_class", tokenizerClass);
((ChatGLMModel*)model)->tokenizerClass = tokenizerClass;

// ChatGLM采用拼接token的方法,需要强行指定分割词的TokenID
Expand Down Expand Up @@ -515,7 +539,7 @@ namespace fastllm {
auto config = json11::Json::parse(ReadAllFile(configFile), error);
basellm *model = CreateModelWithType(config["model_type"].string_value());
for (auto &it : config.object_items()) {
model->weight.AddDict(it.first, it.second.dump().c_str());
model->weight.AddDict(it.first, it.second.is_string() ? it.second.string_value() : it.second.dump());
}
// 设置eos_token_id
if (config["eos_token_id"].is_array()) {
Expand Down
8 changes: 8 additions & 0 deletions src/models/chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ namespace fastllm {

void ChatGLMModel::InitParams() {
basellm::InitParams();
if (this->weight.dicts.find("tokenizer_class") != this->weight.dicts.end()) {
this->tokenizerClass = this->weight.dicts["tokenizer_class"];
}
if (GetVersion() == 1) {
if (this->weight.dicts.find("gmask_token_id") != this->weight.dicts.end()) {
this->gmask_token_id = atoi(this->weight.dicts["gmask_token_id"].c_str());
Expand All @@ -97,6 +100,11 @@ namespace fastllm {
if (this->weight.dicts.find("rope_ratio") != this->weight.dicts.end()) {
UpdateRotaryPosEmb(atof(this->weight.dicts["rope_ratio"].c_str()));
}
if (this->tokenizerClass == "ChatGLM4Tokenizer") {
this->gmask_token_id = this->weight.tokenizer.GetTokenId("[gMASK]");
this->bos_token_id = this->weight.tokenizer.GetTokenId("<sop>");
this->weight.tokenizer.type = Tokenizer::TokenizerType::QWEN;
}
}

int ChatGLMModel::Forward(const fastllm::Data &inputIds, const fastllm::Data &attentionMask,
Expand Down
6 changes: 5 additions & 1 deletion src/models/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ namespace fastllm {
std::string mergeQkvWeightName = "model.layers." + std::to_string(i) + ".self_attn.mergeqkv.weight";
std::string mergeQkvBiasName = "model.layers." + std::to_string(i) + ".self_attn.mergeqkv.bias";

if (weight.weight.find(qkvWeightName) != weight.weight.end()) {
if (weight.weight.find(qkvWeightName) != weight.weight.end() || weight.weight.find(mergeQkvWeightName) != weight.weight.end()) {
mergeQKV = true;
break;
} else {
Expand Down Expand Up @@ -214,6 +214,10 @@ namespace fastllm {
std::string w3WeightName = "model.layers." + std::to_string(i) + ".mlp.up_proj.weight";
std::string swigluWeightName = "model.layers." + std::to_string(i) + ".mlp.gateup_proj.weight";

if (weight.weight.find(swigluWeightName) != weight.weight.end()) {
mergeQKV = true;
break;
}
Data &w1 = weight.weight[w1WeightName], &w3 = weight.weight[w3WeightName];
if ((w1.dataType == DataType::INT4_GROUP && w1.dims[1] % w1.groupCnt != 0) ||
(w3.dataType == DataType::INT4_GROUP && w3.dims[1] % w3.groupCnt != 0)) {
Expand Down
4 changes: 2 additions & 2 deletions tools/fastllm_pytools/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ def create(model,
modelInfo["history_sep"] = "";
if (modelInfo["model_type"] == "chatglm" and hasattr(tokenizer, "name") and tokenizer.name == "GLM4Tokenizer"):
# glm-4-chat
modelInfo["pre_prompt"] = "[gMASK]<sop>";
modelInfo["pre_prompt"] = "";
modelInfo["user_role"] = ("<FLM_FIX_TOKEN_" + str(tokenizer.convert_tokens_to_ids("<|user|>")) + ">\n");
modelInfo["bot_role"] = ("<FLM_FIX_TOKEN_" + str(tokenizer.convert_tokens_to_ids("<|assistant|>")) + ">");
modelInfo["history_sep"] = "";
modelInfo["eos_token_id"] = "151336"
modelInfo["tokenizer_class"] = tokenizer.name;
if "rope_scaling" in modelInfo and isinstance(modelInfo["rope_scaling"], builtins.dict):
rope_scaling = modelInfo.pop("rope_scaling")
modelInfo["rope_scaling.type"] = rope_scaling["type"]
Expand Down
4 changes: 2 additions & 2 deletions tools/fastllm_pytools/torch2flm.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,11 @@ def tofile(exportPath,
modelInfo["history_sep"] = "";
if (modelInfo["model_type"] == "chatglm" and hasattr(tokenizer, "name") and tokenizer.name == "GLM4Tokenizer"):
# glm-4-chat
modelInfo["pre_prompt"] = "[gMASK]<sop>";
modelInfo["pre_prompt"] = "";
modelInfo["user_role"] = ("<FLM_FIX_TOKEN_" + str(tokenizer.convert_tokens_to_ids("<|user|>")) + ">\n");
modelInfo["bot_role"] = ("<FLM_FIX_TOKEN_" + str(tokenizer.convert_tokens_to_ids("<|assistant|>")) + ">");
modelInfo["history_sep"] = "";
modelInfo["eos_token_id"] = "151336"
modelInfo["tokenizer_class"] = tokenizer.name;
if "rope_scaling" in modelInfo and isinstance(modelInfo["rope_scaling"], builtins.dict):
rope_scaling = modelInfo.pop("rope_scaling")
modelInfo["rope_scaling.type"] = rope_scaling["type"]
Expand Down

0 comments on commit c738c59

Please sign in to comment.