Skip to content

Commit

Permalink
保存直接读取的glm4类模型为flm格式(#465)
Browse files Browse the repository at this point in the history
  • Loading branch information
cgli authored and TylunasLi committed Jul 7, 2024
1 parent 2793c40 commit e787aa6
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 10 deletions.
31 changes: 25 additions & 6 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,21 @@ namespace fastllm {

void basellm::InitParams() {
if (this->weight.dicts.find("bos_token_id") != this->weight.dicts.end()) {
if(this->weight.dicts["bos_token_id"]!="None"){
if (this->weight.dicts["bos_token_id"]!="None") {
this->bos_token_id = atoi(this->weight.dicts["bos_token_id"].c_str());
}
if(this->weight.dicts["eos_token_id"]!="None"){
this->eos_token_id = atoi(this->weight.dicts["eos_token_id"].c_str());
}
if (this->weight.dicts.find("eos_token_id") != this->weight.dicts.end()) {
if (this->weight.dicts["eos_token_id"]!="None") {
if (this->weight.dicts["eos_token_id"][0] == '[' && this->eos_token_ids.empty()) {
std::string error;
json11::Json ids = json11::Json::parse(this->weight.dicts["eos_token_id"], error);
for (auto &it : ids.array_items()) {
this->eos_token_ids.insert(it.int_value());
}
} else {
this->eos_token_id = atoi(this->weight.dicts["eos_token_id"].c_str());
}
}
}
if (this->weight.dicts.find("im_start_id") != this->weight.dicts.end()) {
Expand Down Expand Up @@ -127,6 +137,16 @@ namespace fastllm {
}

void basellm::SaveModel(const std::string &fileName) {
if (this->weight.tokenizer.chatTemplate.empty()) {
if (this->weight.dicts.find("pre_prompt") == this->weight.dicts.end())
this->weight.dicts["pre_prompt"] = pre_prompt;
if (this->weight.dicts.find("user_role") == this->weight.dicts.end())
this->weight.dicts["user_role"] = user_role;
if (this->weight.dicts.find("bot_role") == this->weight.dicts.end())
this->weight.dicts["bot_role"] = bot_role;
if (this->weight.dicts.find("history_sep") == this->weight.dicts.end())
this->weight.dicts["history_sep"] = history_sep;
}
this->weight.SaveLowBitModel(fileName, 0);
}

Expand Down Expand Up @@ -451,7 +471,6 @@ namespace fastllm {
}
} else if (tokenizerClass == "ChatGLM4Tokenizer") {
// GLM4御用的分词
model->bot_role = " ";
std::vector <std::string> lines, line;
SplitString(ReadAllFile(path + "tokenizer.model"), {'\r', '\n'}, lines);
for (int i = 0; i < lines.size(); i++) {
Expand All @@ -463,8 +482,8 @@ namespace fastllm {
spTokens[it.second["content"].string_value()] = atoi(it.first.c_str());
}
model->weight.tokenizer.SetSpecialTokens(spTokens);
((ChatGLMModel*)model)->gmask_token_id = model->weight.tokenizer.GetTokenId("[gMASK]");
((ChatGLMModel*)model)->bos_token_id = model->weight.tokenizer.GetTokenId("<sop>");
model->weight.AddDict("tokenizer_has_special_tokens", "1");
model->weight.AddDict("tokenizer_class", tokenizerClass);
((ChatGLMModel*)model)->tokenizerClass = tokenizerClass;

// ChatGLM采用拼接token的方法,需要强行指定分割词的TokenID
Expand Down
8 changes: 8 additions & 0 deletions src/models/chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ namespace fastllm {

void ChatGLMModel::InitParams() {
basellm::InitParams();
if (this->weight.dicts.find("tokenizer_class") != this->weight.dicts.end()) {
this->tokenizerClass = this->weight.dicts["tokenizer_class"];
}
if (GetVersion() == 1) {
if (this->weight.dicts.find("gmask_token_id") != this->weight.dicts.end()) {
this->gmask_token_id = atoi(this->weight.dicts["gmask_token_id"].c_str());
Expand All @@ -97,6 +100,11 @@ namespace fastllm {
if (this->weight.dicts.find("rope_ratio") != this->weight.dicts.end()) {
UpdateRotaryPosEmb(atof(this->weight.dicts["rope_ratio"].c_str()));
}
if (this->tokenizerClass == "ChatGLM4Tokenizer") {
this->gmask_token_id = this->weight.tokenizer.GetTokenId("[gMASK]");
this->bos_token_id = this->weight.tokenizer.GetTokenId("<sop>");
this->weight.tokenizer.type = Tokenizer::TokenizerType::QWEN;
}
}

int ChatGLMModel::Forward(const fastllm::Data &inputIds, const fastllm::Data &attentionMask,
Expand Down
4 changes: 2 additions & 2 deletions tools/fastllm_pytools/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ def create(model,
modelInfo["history_sep"] = "";
if (modelInfo["model_type"] == "chatglm" and hasattr(tokenizer, "name") and tokenizer.name == "GLM4Tokenizer"):
# glm-4-chat
modelInfo["pre_prompt"] = "[gMASK]<sop>";
modelInfo["pre_prompt"] = "";
modelInfo["user_role"] = ("<FLM_FIX_TOKEN_" + str(tokenizer.convert_tokens_to_ids("<|user|>")) + ">\n");
modelInfo["bot_role"] = ("<FLM_FIX_TOKEN_" + str(tokenizer.convert_tokens_to_ids("<|assistant|>")) + ">");
modelInfo["history_sep"] = "";
modelInfo["eos_token_id"] = "151336"
modelInfo["tokenizer_class"] = tokenizer.name;
if "rope_scaling" in modelInfo and isinstance(modelInfo["rope_scaling"], builtins.dict):
rope_scaling = modelInfo.pop("rope_scaling")
modelInfo["rope_scaling.type"] = rope_scaling["type"]
Expand Down
4 changes: 2 additions & 2 deletions tools/fastllm_pytools/torch2flm.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,11 @@ def tofile(exportPath,
modelInfo["history_sep"] = "";
if (modelInfo["model_type"] == "chatglm" and hasattr(tokenizer, "name") and tokenizer.name == "GLM4Tokenizer"):
# glm-4-chat
modelInfo["pre_prompt"] = "[gMASK]<sop>";
modelInfo["pre_prompt"] = "";
modelInfo["user_role"] = ("<FLM_FIX_TOKEN_" + str(tokenizer.convert_tokens_to_ids("<|user|>")) + ">\n");
modelInfo["bot_role"] = ("<FLM_FIX_TOKEN_" + str(tokenizer.convert_tokens_to_ids("<|assistant|>")) + ">");
modelInfo["history_sep"] = "";
modelInfo["eos_token_id"] = "151336"
modelInfo["tokenizer_class"] = tokenizer.name;
if "rope_scaling" in modelInfo and isinstance(modelInfo["rope_scaling"], builtins.dict):
rope_scaling = modelInfo.pop("rope_scaling")
modelInfo["rope_scaling.type"] = rope_scaling["type"]
Expand Down

0 comments on commit e787aa6

Please sign in to comment.