diff --git a/include/fastllm.h b/include/fastllm.h index f730171a..16820335 100644 --- a/include/fastllm.h +++ b/include/fastllm.h @@ -46,7 +46,6 @@ namespace fastllm { float temperature = 1.0; // 温度参数,一般在0.1 ~ 1.0之间,设大这个参数可以带来结果的多样性 bool output_logits = false; // 是否返回logits bool enable_hash_id = false; // 给会话添加hash id - bool add_special_tokens = true; // prompt添加special tokens(chatglm模型生效) std::multiset stop_token_ids; bool IsSimpleGreedy() const { diff --git a/src/models/basellm.cpp b/src/models/basellm.cpp index e0dfbb5f..9c6e83d8 100644 --- a/src/models/basellm.cpp +++ b/src/models/basellm.cpp @@ -89,9 +89,7 @@ namespace fastllm { std::vector results; LastTokensManager tokens(1, generationConfig.last_n); int promptLen = inputTokens[0].size(), index = 0; - int add_special_tokens = generationConfig.add_special_tokens? 1: 0; - FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}, {"add_special_tokens", add_special_tokens}}, - inputIds, attentionMask, positionIds); + FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}}, inputIds, attentionMask, positionIds); while (true) { auto st = std::chrono::system_clock::now(); int ret = Forward(inputIds, attentionMask, positionIds, pastKeyValues, generationConfig, tokens); @@ -123,8 +121,7 @@ namespace fastllm { results.clear(); inputTokens[0] = std::vector {(float)ret}; - FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}}, {"add_special_tokens", add_special_tokens} - inputIds, attentionMask, positionIds); + FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}}, inputIds, attentionMask, positionIds); if (index == generationConfig.output_token_limit) { break; } @@ -198,7 +195,6 @@ namespace fastllm { } params[0]["index"] = 0; int index = 0; - params[0]["add_special_tokens"] = generationConfig.add_special_tokens? 1: 0; LastTokensManager tokensManager (batch, generationConfig.last_n); std::vector isEnding = std::vector (batch, false); diff --git a/src/models/chatglm.cpp b/src/models/chatglm.cpp index 2e31c55f..5e26644d 100644 --- a/src/models/chatglm.cpp +++ b/src/models/chatglm.cpp @@ -751,19 +751,16 @@ namespace fastllm { int index = params.find("index")->second; int promptLen = params.find("promptLen")->second; - bool add_special_tokens = params.find("add_special_tokens")->second == 0? false: true; if (index == 0) { - if (add_special_tokens) { - for (auto &ids: inputTokens) { - if (GetVersion() == 1) { - ids.push_back(gmask_token_id); - ids.push_back(bos_token_id); - } else if (GetVersion() == 2) { - if (ids.size() < 2 || ids[0] != this->gmask_token_id || ids[1] != this->bos_token_id) { - ids.insert(ids.begin(), this->bos_token_id); - ids.insert(ids.begin(), this->gmask_token_id); - } + for (auto &ids: inputTokens) { + if (GetVersion() == 1) { + ids.push_back(gmask_token_id); + ids.push_back(bos_token_id); + } else if (GetVersion() == 2) { + if (ids.size() < 2 || ids[0] != this->gmask_token_id || ids[1] != this->bos_token_id) { + ids.insert(ids.begin(), this->bos_token_id); + ids.insert(ids.begin(), this->gmask_token_id); } } } @@ -812,17 +809,12 @@ namespace fastllm { int batch = inputTokens.size(); int index = params[0].find("index")->second; - bool add_special_tokens = params[0].find("add_special_tokens")->second == 0? false: true; - int special_tokens_offset = 0; - if (add_special_tokens) { - special_tokens_offset = 2; - } if (index == 0) { std::vector seqLens; seqLens.resize(batch); int maxLen = 0; for (int i = 0; i < batch; i++) { - maxLen = std::max(maxLen, (int) inputTokens[i].size() + special_tokens_offset); + maxLen = std::max(maxLen, (int) inputTokens[i].size() + 2); seqLens[i] = (int) inputTokens[i].size(); } @@ -832,15 +824,13 @@ namespace fastllm { for (int i = 0; i < batch; i++) { if (GetVersion() == 1) { auto &tokens = inputTokens[i]; - int len = tokens.size(), base = maxLen - special_tokens_offset - len; + int len = tokens.size(), base = maxLen - 2 - len; for (int j = 0; j < len; j++) { ids[i * maxLen + base + j] = tokens[j]; } - if (add_special_tokens) { - ids[i * maxLen + base + len] = gmask_token_id; - ids[i * maxLen + base + len + 1] = bos_token_id; - } - len += special_tokens_offset; + ids[i * maxLen + base + len] = gmask_token_id; + ids[i * maxLen + base + len + 1] = bos_token_id; + len += 2; for (int j = 0; j < len - 1; j++) { vpids[i * 2 * maxLen + base + j] = j; } @@ -857,15 +847,13 @@ namespace fastllm { } } else { auto &tokens = inputTokens[i]; - int len = tokens.size(), base = maxLen - special_tokens_offset - len; - if (add_special_tokens) { - ids[i * maxLen + base] = gmask_token_id; - ids[i * maxLen + base + 1] = bos_token_id; - } + int len = tokens.size(), base = maxLen - 2 - len; + ids[i * maxLen + base] = gmask_token_id; + ids[i * maxLen + base + 1] = bos_token_id; for (int j = 0; j < len; j++) { - ids[i * maxLen + base + special_tokens_offset + j] = tokens[j]; + ids[i * maxLen + base + 2 + j] = tokens[j]; } - len += special_tokens_offset; + len += 2; for (int j = 0; j < len; j++) { vpids[i * 2 * maxLen + base + j] = j; }