diff --git a/CMakeLists.txt b/CMakeLists.txt index de3a5d43..f15373c6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,7 +43,7 @@ message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") set(FASTLLM_CXX_SOURCES src/fastllm.cpp src/device.cpp src/model.cpp src/executor.cpp src/devices/cpu/cpudevice.cpp src/devices/cpu/cpudevicebatch.cpp src/models/chatglm.cpp src/models/moss.cpp src/models/llama.cpp src/models/qwen.cpp src/models/basellm.cpp - src/models/glm.cpp src/models/minicpm.cpp src/models/internlm2.cpp + src/models/glm.cpp src/models/minicpm.cpp src/models/internlm2.cpp src/models/bert.cpp third_party/json11/json11.cpp) include_directories(include) diff --git a/include/devices/cpu/cpudevice.h b/include/devices/cpu/cpudevice.h index 91c13813..e35fe688 100644 --- a/include/devices/cpu/cpudevice.h +++ b/include/devices/cpu/cpudevice.h @@ -89,6 +89,14 @@ namespace fastllm { void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams); }; + class CpuTanHOp : BaseOperator { + void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams); + }; + + class CpuGeluOp : BaseOperator { + void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams); + }; + class CpuGeluNewOp : BaseOperator { void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams); }; @@ -114,6 +122,10 @@ namespace fastllm { void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams); }; + class CpuAttentionExtendedMaskOp : BaseOperator { + void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams); + }; + class CpuAlibiMaskOp : BaseOperator { void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams); }; diff --git a/include/fastllm.h b/include/fastllm.h index fafbb850..3640bedb 100644 --- a/include/fastllm.h +++ b/include/fastllm.h @@ -329,7 +329,8 @@ namespace fastllm { BPE = 0, NORMAL = 1, QWEN = 2, - GLM = 3 + GLM = 3, + BERT = 4 }; struct TrieNode { @@ -505,6 +506,10 @@ namespace fastllm { void Silu(const fastllm::Data &input, fastllm::Data &output); + void TanH(const Data &input, Data &output); + + void Gelu(const Data &input, Data &output); + void GeluNew(const Data &input, Data &output); void Swiglu(const fastllm::Data &input, fastllm::Data &output); @@ -517,6 +522,8 @@ namespace fastllm { void AttentionMask(Data &input, const Data &mask, float maskValue); // 把input里对应位置mask中为1的部分变成maskValue + void AttentionExtendedMask(Data &input, const Data &mask); // bert中的extended mask + void AlibiMask(Data &input, const Data &mask, float maskValue); // alibi mask void Permute(const Data &input, const std::vector &axis, Data &output); // 转置 diff --git a/include/model.h b/include/model.h index 1d896665..e80c0578 100644 --- a/include/model.h +++ b/include/model.h @@ -6,8 +6,11 @@ #define FASTLLM_MODEL_H #include "basellm.h" +#include "bert.h" namespace fastllm { + std::unique_ptr CreateEmbeddingModelFromFile(const std::string &fileName); + std::unique_ptr CreateLLMModelFromFile(const std::string &fileName); std::unique_ptr CreateEmptyLLMModel(const std::string &modelType); diff --git a/include/models/bert.h b/include/models/bert.h new file mode 100644 index 00000000..e5febfa7 --- /dev/null +++ b/include/models/bert.h @@ -0,0 +1,53 @@ + +#ifndef FASTLLM_BERT_H +#define FASTLLM_BERT_H + +#include "basellm.h" +#include "fastllm.h" + +namespace fastllm { + class BertModel { + public: + BertModel() {}; + + ~BertModel() { + this->weight.ReleaseWeight(); + }; + + void InitParams(); // 初始化参数信息 + + // 推理 + std::vector > Forward( + const Data &inputIds, + const Data &attentionMask, + const Data &tokenTypeIds, + const Data &positionIds); + + std::vector EmbeddingSentence(const std::string &context); + + std::vector > EmbeddingSentenceBatch(const std::vector &contexts); + + void LoadFromFile(const std::string &fileName); // 从文件读取 + + void SaveLowBitModel(const std::string &fileName, int bit); // 存储成量化模型 + + void SaveModel(const std::string &fileName); // 直接导出 + + void WarmUp() {}; // 预热 + + std::string model_type; + + float layer_norm_eps = 1e-12; + + int embed_dim = 512; + int num_attention_heads = 64; + int head_dim = embed_dim / num_attention_heads; + int max_positions = 32768; + int block_cnt = 12; + + WeightMap weight; // 权重 + std::map deviceMap; + }; +} + +#endif //FASTLLM_BERT_H \ No newline at end of file diff --git a/src/devices/cpu/cpudevice.cpp b/src/devices/cpu/cpudevice.cpp index 195d8735..77803417 100644 --- a/src/devices/cpu/cpudevice.cpp +++ b/src/devices/cpu/cpudevice.cpp @@ -35,12 +35,15 @@ namespace fastllm { this->ops["MatMulTransB"] = (BaseOperator*)(new CpuMatMulTransBOp()); this->ops["SoftMax"] = (BaseOperator*)(new CpuSoftMaxOp()); this->ops["Silu"] = (BaseOperator*)(new CpuSiluOp()); + this->ops["TanH"] = (BaseOperator*)(new CpuTanHOp()); + this->ops["Gelu"] = (BaseOperator*)(new CpuGeluOp()); this->ops["GeluNew"] = (BaseOperator*)(new CpuGeluNewOp()); this->ops["Swiglu"] = (BaseOperator*)(new CpuSwigluOp()); this->ops["Mul"] = (BaseOperator*)(new CpuMulOp()); this->ops["MulTo"] = (BaseOperator*)(new CpuMulToOp()); this->ops["AddTo"] = (BaseOperator*)(new CpuAddToOp()); this->ops["AttentionMask"] = (BaseOperator*)(new CpuAttentionMaskOp()); + this->ops["AttentionExtendedMask"] = (BaseOperator*)(new CpuAttentionExtendedMaskOp()); this->ops["AlibiMask"] = (BaseOperator*)(new CpuAlibiMaskOp()); this->ops["TopK"] = (BaseOperator*)(new CpuTopKOp()); this->ops["Permute"] = (BaseOperator*)(new CpuPermuteOp()); @@ -2505,6 +2508,74 @@ namespace fastllm { } } + void CpuTanHOp::Run(const std::string &opType, const fastllm::DataDict &datas, + const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) { + Data &input = *(datas.find("input")->second); + Data &output = *(datas.find("output")->second); + output.Allocate(); + AssertInFastLLM(input.dataType == DataType::FLOAT32, "GeluNew error: Data's type should be float32.\n"); + + float temp = sqrt(2.0f / M_PI), factor = 0.044715; + float *inputData = (float*)input.cpuData; + float *outputData = (float*)output.cpuData; + int len = input.Count(0); + int i = 0; + for (; i < len; i++) { + outputData[i] = tanhf(inputData[i]); + } + } + + float erf(float a) + { + float r, s, t, u; + + t = fabsf(a); + s = a * a; + if (t > 0.927734375f) + { // 475/512 + // maximum error 0.99527 ulp + r = fmaf(-1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12 + u = fmaf(-3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6 + r = fmaf(r, s, u); + r = fmaf(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4 + r = fmaf(r, t, -6.34846687e-1f); // -0x1.450aa0p-1 + r = fmaf(r, t, -1.28717512e-1f); // -0x1.079d0cp-3 + r = fmaf(r, t, -t); + r = 1.0f - expf(r); + r = copysignf(r, a); + } + else + { + // maximum error 0.98929 ulp + r = -5.96761703e-4f; // -0x1.38e000p-11 + r = fmaf(r, s, 4.99119423e-3f); // 0x1.471a58p-8 + r = fmaf(r, s, -2.67681349e-2f); // -0x1.b691b2p-6 + r = fmaf(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4 + r = fmaf(r, s, -3.76125336e-1f); // -0x1.812700p-2 + r = fmaf(r, s, 1.28379166e-1f); // 0x1.06eba8p-3 + r = fmaf(r, a, a); + } + return r; + } + + void CpuGeluOp::Run(const std::string &opType, const fastllm::DataDict &datas, + const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) { + Data &input = *(datas.find("input")->second); + Data &output = *(datas.find("output")->second); + output.Allocate(); + AssertInFastLLM(input.dataType == DataType::FLOAT32, "GeluNew error: Data's type should be float32.\n"); + + float temp = sqrt(2.0f / M_PI), factor = 0.044715; + float *inputData = (float*)input.cpuData; + float *outputData = (float*)output.cpuData; + int len = input.Count(0); + int i = 0; + for (; i < len; i++) { + float x = inputData[i]; + outputData[i] = x * 0.5f * (1.0f + erf(x / sqrt(2.0))); + } + } + void CpuGeluNewOp::Run(const std::string &opType, const fastllm::DataDict &datas, const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) { Data &input = *(datas.find("input")->second); @@ -2769,6 +2840,29 @@ namespace fastllm { } } + void CpuAttentionExtendedMaskOp::Run(const std::string &opType, const fastllm::DataDict &datas, + const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) { + Data &input = *(datas.find("input")->second); + Data &mask = *(datas.find("mask")->second); + int spatial = input.dims[3], n = input.dims[0], m = input.dims[1] * input.dims[2]; + + AssertInFastLLM(mask.dataType == DataType::FLOAT32, "AttentionExtendedMask: mask's datatype should be float32."); + if (input.dataType == DataType::FLOAT32) { + float *maskData = (float *) mask.cpuData; + float *attnData = (float *) input.cpuData; + for (int on = 0; on < n; on++) { + for (int om = 0; om < m; om++) { + int o = on * m + om; + for (int i = 0; i < spatial; i++) { + attnData[o * spatial + i] += maskData[on * spatial + i]; + } + } + } + } else { + ErrorInFastLLM("AttentionExtendedMask error: unsupport input's dataType.\n"); + } + } + void CpuAlibiMaskOp::Run(const std::string &opType, const fastllm::DataDict &datas, const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) { Data &input = *(datas.find("input")->second); diff --git a/src/fastllm.cpp b/src/fastllm.cpp index bdac3aa7..04e46b07 100644 --- a/src/fastllm.cpp +++ b/src/fastllm.cpp @@ -1012,6 +1012,10 @@ namespace fastllm { return s; } + bool isDigitOrChar(char c) { + return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'); + } + Data Tokenizer::Encode(const std::string &ori) { if (this->type == TokenizerType::BPE) { std::string s = Normalize(ori); @@ -1329,6 +1333,33 @@ namespace fastllm { } } + return Data (DataType::FLOAT32, {1, (int)v.size()}, v); + } else if (this->type == TokenizerType::BERT) { + std::vector v; + for (int i = 0; i < ori.size(); i++) { + int tokenId = -999999, pos = i - 1; + TrieNode *now = this->root; + + if (i > 0 && isDigitOrChar(ori[i - 1]) && isDigitOrChar(ori[i])) { + now = now->next['#']->next['#']; + } + for (int j = i; j < ori.size(); j++) { + if (now->next.find(ori[j]) != now->next.end()) { + now = now->next[ori[j]]; + if (now->tokenId != -999999) { + tokenId = now->tokenId; + pos = j; + } + } else { + break; + } + } + if (pos >= i) { + i = pos; + v.push_back(tokenId); + } + } + return Data (DataType::FLOAT32, {1, (int)v.size()}, v); } else { std::vector v; @@ -2177,6 +2208,18 @@ namespace fastllm { }, {}, {}); } + void TanH(const Data &input, Data &output) { + curExecutor->Run("TanH", { + {"input", (Data*)&input}, {"output", &output} + }, {}, {}); + } + + void Gelu(const fastllm::Data &input, fastllm::Data &output) { + curExecutor->Run("Gelu", { + {"input", (Data*)&input}, {"output", &output} + }, {}, {}); + } + void GeluNew(const fastllm::Data &input, fastllm::Data &output) { curExecutor->Run("GeluNew", { {"input", (Data*)&input}, {"output", &output} @@ -2213,6 +2256,12 @@ namespace fastllm { }, {{"maskValue", maskValue}}, {}); } + void AttentionExtendedMask(Data &input, const Data &mask) { + curExecutor->Run("AttentionExtendedMask", { + {"input", &input}, {"mask", (Data*)&mask} + }, {}, {}); + } + void AlibiMask(Data &input, const Data &mask, float maskValue) { curExecutor->Run("AlibiMask", { {"input", &input}, {"mask", (Data*)&mask} diff --git a/src/model.cpp b/src/model.cpp index 8721b6fe..e7d0e84d 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -11,6 +11,7 @@ #include "glm.h" #include "minicpm.h" #include "internlm2.h" +#include "bert.h" namespace fastllm { void basellm::LoadFromFile(const std::string &fileName) { @@ -118,12 +119,22 @@ namespace fastllm { model->weight.tokenizer.type = Tokenizer::TokenizerType::QWEN; } else if (modelType == "glm") { model = (basellm*)(new GLMModel()); + } else if (modelType == "bert") { + model = (basellm*)(new BertModel()); } else { ErrorInFastLLM("Unkown model type: " + modelType); } return model; } + std::unique_ptr CreateEmbeddingModelFromFile(const std::string &fileName) { + BertModel *model = new BertModel(); + model->weight.tokenizer.type = Tokenizer::BERT; + model->LoadFromFile(fileName); + model->WarmUp(); + return std::unique_ptr (model); + } + std::unique_ptr CreateLLMModelFromFile(const std::string &fileName) { std::string modelType = GetModelTypeFromFile(fileName); basellm *model = CreateModelWithType(modelType); diff --git a/src/models/bert.cpp b/src/models/bert.cpp new file mode 100644 index 00000000..88e4ace9 --- /dev/null +++ b/src/models/bert.cpp @@ -0,0 +1,159 @@ +// +// Created by huangyuyang on 4/25/24. +// + +#include "bert.h" +#include "utils.h" +#include +#include + +namespace fastllm { + void BertModel::LoadFromFile(const std::string &fileName) { + this->weight.LoadFromFile(fileName); + InitParams(); + } + + void BertModel::InitParams() { + if (this->weight.dicts.find("layer_norm_eps") != this->weight.dicts.end()) { + this->layer_norm_eps = atof(this->weight.dicts["layer_norm_eps"].c_str()); + } + if (this->weight.dicts.find("num_hidden_layers") != this->weight.dicts.end()) { + block_cnt = atoi(this->weight.dicts["num_hidden_layers"].c_str()); + } else if (this->weight.dicts.find("num_layers") != this->weight.dicts.end()) { + block_cnt = atoi(this->weight.dicts["num_layers"].c_str()); + } + if (this->weight.dicts.find("hidden_size") != this->weight.dicts.end()) { + embed_dim = atoi(this->weight.dicts["hidden_size"].c_str()); + } + if (this->weight.dicts.find("num_attention_heads") != this->weight.dicts.end()) { + num_attention_heads = atoi(this->weight.dicts["num_attention_heads"].c_str()); + } + this->head_dim = embed_dim / num_attention_heads; + } + + std::vector > BertModel::Forward( + const Data &inputIds, + const Data &attentionMask, + const Data &tokenTypeIds, + const Data &positionIds) { + // embedding + Data inputEmbeddings, tokenTypeEmbeddings, positionIdEmbeddings; + Embedding(inputIds, this->weight["embeddings.word_embeddings.weight"], inputEmbeddings); + Embedding(tokenTypeIds, this->weight["embeddings.token_type_embeddings.weight"], tokenTypeEmbeddings); + Embedding(positionIds, this->weight["embeddings.position_embeddings.weight"], positionIdEmbeddings); + AddTo(inputEmbeddings, tokenTypeEmbeddings); + AddTo(inputEmbeddings, positionIdEmbeddings); + + Data hiddenStates, firstStates; + LayerNorm(inputEmbeddings, this->weight["embeddings.LayerNorm.weight"], this->weight["embeddings.LayerNorm.bias"], -1, hiddenStates); + + Data q, k, v, qk, qkv, attnOutput, inter, pooler; + + for (int i = 0; i < this->block_cnt; i++) { + std::string queryWeightName = "encoder.layer." + std::to_string(i) + ".attention.self.query.weight"; + std::string queryBiasName = "encoder.layer." + std::to_string(i) + ".attention.self.query.bias"; + std::string keyWeightName = "encoder.layer." + std::to_string(i) + ".attention.self.key.weight"; + std::string keyBiasName = "encoder.layer." + std::to_string(i) + ".attention.self.key.bias"; + std::string valueWeightName = "encoder.layer." + std::to_string(i) + ".attention.self.value.weight"; + std::string valueBiasName = "encoder.layer." + std::to_string(i) + ".attention.self.value.bias"; + std::string attnOutputWeightName = "encoder.layer." + std::to_string(i) + ".attention.output.dense.weight"; + std::string attnOutputbiasName = "encoder.layer." + std::to_string(i) + ".attention.output.dense.bias"; + std::string attnLNWeightName = "encoder.layer." + std::to_string(i) + ".attention.output.LayerNorm.weight"; + std::string attnLNbiasName = "encoder.layer." + std::to_string(i) + ".attention.output.LayerNorm.bias"; + std::string interDenseWeightName = "encoder.layer." + std::to_string(i) + ".intermediate.dense.weight"; + std::string interDenseBiasName = "encoder.layer." + std::to_string(i) + ".intermediate.dense.bias"; + std::string outputWeightName = "encoder.layer." + std::to_string(i) + ".output.dense.weight"; + std::string outputbiasName = "encoder.layer." + std::to_string(i) + ".output.dense.bias"; + std::string outputLNWeightName = "encoder.layer." + std::to_string(i) + ".output.LayerNorm.weight"; + std::string outputLNbiasName = "encoder.layer." + std::to_string(i) + ".output.LayerNorm.bias"; + + Linear(hiddenStates, this->weight[queryWeightName], this->weight[queryBiasName], q); + Linear(hiddenStates, this->weight[keyWeightName], this->weight[keyBiasName], k); + Linear(hiddenStates, this->weight[valueWeightName], this->weight[valueBiasName], v); + + std::vector qdims = {q.dims[0], q.dims[1], this->num_attention_heads, this->head_dim}; + q.Reshape(qdims); + k.Reshape(qdims); + v.Reshape(qdims); + PermuteSelf(q, {0, 2, 1, 3}); + PermuteSelf(k, {0, 2, 1, 3}); + PermuteSelf(v, {0, 2, 1, 3}); + MatMulTransB(q, k, qk, 1.0 / sqrt(this->head_dim), 1); + AttentionExtendedMask(qk, attentionMask); + + Softmax(qk, qk, -1); + MatMul(qk, v, qkv, 1.0, 1); + + PermuteSelf(qkv, {0, 2, 1, 3}); + qkv.Reshape({qkv.dims[0], qkv.dims[1], -1}); + + Linear(qkv, this->weight[attnOutputWeightName], this->weight[attnOutputbiasName], attnOutput); + AddTo(hiddenStates, attnOutput); + LayerNorm(hiddenStates, this->weight[attnLNWeightName], this->weight[attnLNbiasName], -1, hiddenStates); + + Linear(hiddenStates, this->weight[interDenseWeightName], this->weight[interDenseBiasName], inter); + Gelu(inter, inter); + + Linear(inter, this->weight[outputWeightName], this->weight[outputbiasName], attnOutput); + AddTo(hiddenStates, attnOutput); + LayerNorm(hiddenStates, this->weight[outputLNWeightName], this->weight[outputLNbiasName], -1, hiddenStates); + } + + Split(hiddenStates, 1, 0, 1, firstStates); + firstStates.Reshape({firstStates.dims[0], -1}); + Linear(firstStates, this->weight["pooler.dense.weight"], this->weight["pooler.dense.bias"], pooler); + TanH(pooler, pooler); + + firstStates.ToDevice(DataDevice::CPU); + float *fret = (float*)firstStates.cpuData; + int batch = firstStates.dims[0], outputDim = firstStates.dims[1]; + std::vector > ret; + ret.resize(batch, std::vector (outputDim, 0.0f)); + for (int i = 0; i < batch; i++) { + memcpy(ret[i].data(), fret + i * outputDim, outputDim * sizeof(float)); + } + + return ret; + } + + std::vector BertModel::EmbeddingSentence(const std::string &context) { + std::vector contexts; + contexts.push_back(context); + return EmbeddingSentenceBatch(contexts)[0]; + } + + std::vector > BertModel::EmbeddingSentenceBatch(const std::vector &contexts) { + int batch = contexts.size(), len = 0; + std::vector > tokens; + tokens.resize(batch); + for (int i = 0; i < batch; i++) { + Data ids = this->weight.tokenizer.Encode("[CLS]" + contexts[i] + "[SEP]"); + for (int j = 0; j < ids.Count(0); j++) { + tokens[i].push_back((int)(((float*)ids.cpuData)[j])); + } + len = std::max(len, (int)tokens[i].size()); + } + + std::vector ids = std::vector (batch * len, 0.0f); + std::vector seqLens = std::vector (batch, 0.0f); + std::vector token_type_ids = std::vector (batch * len, 0.0f); + std::vector attention_mask = std::vector (batch * len, -1e10f); + std::vector position_ids = std::vector (batch * len, 0.0f); + for (int i = 0; i < batch; i++) { + seqLens[i] = tokens[i].size(); + for (int j = 0; j < tokens[i].size(); j++) { + ids[i * len + j] = tokens[i][j]; + attention_mask[i * len + j] = 0; + position_ids[i * len + j] = j; + } + } + + fastllm::Data inputIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, ids); + fastllm::Data attentionMask = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, attention_mask); + fastllm::Data tokenTypeIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, token_type_ids); + fastllm::Data positionIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, position_ids); + +// ClearProfiler(); Forward(inputIds, attentionMask, tokenTypeIds, positionIds); PrintProfiler(); + return Forward(inputIds, attentionMask, tokenTypeIds, positionIds); + } +} \ No newline at end of file diff --git a/tools/scripts/bert2flm.py b/tools/scripts/bert2flm.py new file mode 100644 index 00000000..55b53094 --- /dev/null +++ b/tools/scripts/bert2flm.py @@ -0,0 +1,13 @@ +import sys +from transformers import AutoTokenizer, AutoModel +from fastllm_pytools import torch2flm + +if __name__ == "__main__": + modelpath = sys.argv[3] if len(sys.argv) >= 4 else 'BAAI/bge-small-zh-v1.5' + tokenizer = AutoTokenizer.from_pretrained(modelpath) + model = AutoModel.from_pretrained(modelpath).cpu().float() + model = model.eval() + + dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16" + exportPath = sys.argv[1] if len(sys.argv) >= 2 else "bert-" + dtype + ".flm" + torch2flm.tofile(exportPath, model, tokenizer, dtype = dtype) \ No newline at end of file