diff --git a/include/models/bert.h b/include/models/bert.h index 9c21e0f4..00b66aaf 100644 --- a/include/models/bert.h +++ b/include/models/bert.h @@ -6,7 +6,7 @@ #include "fastllm.h" namespace fastllm { - class BertModel { + class BertModel: public basellm { public: BertModel() {}; @@ -17,23 +17,34 @@ namespace fastllm { void InitParams(); // 初始化参数信息 // 推理 - std::vector > Forward( + std::vector > ForwardAll( const Data &inputIds, const Data &attentionMask, const Data &tokenTypeIds, - const Data &positionIds); + const Data &positionIds, + bool normalize); + + // 推理 + virtual int Forward( + const Data &inputIds, + const Data &attentionMask, + const Data &positionIds, + std::vector > &pastKeyValues, + const GenerationConfig &generationConfig = GenerationConfig(), + const LastTokensManager &lastTokens = LastTokensManager(), + std::vector *logits = nullptr); - std::vector EmbeddingSentence(const std::string &context); + std::vector EmbeddingSentence(const std::string &context, bool normalize); - std::vector > EmbeddingSentenceBatch(const std::vector &contexts); + std::vector > EmbeddingSentenceBatch(const std::vector &contexts, bool normalize); void LoadFromFile(const std::string &fileName); // 从文件读取 - void SaveLowBitModel(const std::string &fileName, int bit); // 存储成量化模型 + void WarmUp(); // 预热 - void SaveModel(const std::string &fileName); // 直接导出 + virtual std::string MakeInput(const std::string &history, int round, const std::string &input); - void WarmUp(); // 预热 + virtual std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output); std::string model_type; diff --git a/src/model.cpp b/src/model.cpp index dda6d7c5..c06ff58d 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -213,8 +213,15 @@ namespace fastllm { std::unique_ptr CreateLLMModelFromFile(const std::string &fileName) { std::string modelType = GetModelTypeFromFile(fileName); basellm *model = CreateModelWithType(modelType); - model->LoadFromFile(fileName); - model->WarmUp(); + if(modelType == "bert"){ + BertModel *bertModel = (BertModel*)model; + bertModel->weight.tokenizer.type = Tokenizer::BERT; + bertModel->LoadFromFile(fileName); + bertModel->WarmUp(); + }else{ + model->LoadFromFile(fileName); + model->WarmUp(); + } return std::unique_ptr (model); } diff --git a/src/models/bert.cpp b/src/models/bert.cpp index 849a4eb1..9573cac9 100644 --- a/src/models/bert.cpp +++ b/src/models/bert.cpp @@ -31,11 +31,25 @@ namespace fastllm { this->head_dim = embed_dim / num_attention_heads; } - std::vector > BertModel::Forward( + void Normalize(float *data, int dataLen) + { + float sum = 0.0; + for(int i = 0; i < dataLen; i++) + sum += data[i] * data[i]; + + if (sum < 1e-6) sum = 1e-6; + else sum = sqrt(sum); + + for(int i = 0; i < dataLen; i++) + data[i] = data[i] / sum; + } + + std::vector > BertModel::ForwardAll( const Data &inputIds, const Data &attentionMask, const Data &tokenTypeIds, - const Data &positionIds) { + const Data &positionIds, + bool normalize) { // embedding Data inputEmbeddings, tokenTypeEmbeddings, positionIdEmbeddings; Embedding(inputIds, this->weight["embeddings.word_embeddings.weight"], inputEmbeddings); @@ -114,19 +128,20 @@ namespace fastllm { std::vector > ret; ret.resize(batch, std::vector (outputDim, 0.0f)); for (int i = 0; i < batch; i++) { + if(normalize) Normalize(fret + i * outputDim, outputDim); memcpy(ret[i].data(), fret + i * outputDim, outputDim * sizeof(float)); } return ret; } - std::vector BertModel::EmbeddingSentence(const std::string &context) { + std::vector BertModel::EmbeddingSentence(const std::string &context, bool normalize) { std::vector contexts; contexts.push_back(context); - return EmbeddingSentenceBatch(contexts)[0]; + return EmbeddingSentenceBatch(contexts, normalize)[0]; } - std::vector > BertModel::EmbeddingSentenceBatch(const std::vector &contexts) { + std::vector > BertModel::EmbeddingSentenceBatch(const std::vector &contexts, bool normalize) { int batch = contexts.size(), len = 0; std::vector > tokens; tokens.resize(batch); @@ -158,12 +173,29 @@ namespace fastllm { fastllm::Data positionIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, position_ids); // printf("bs = %d, len = %d\n", batch, len); ClearProfiler(); Forward(inputIds, attentionMask, tokenTypeIds, positionIds); PrintProfiler(); - return Forward(inputIds, attentionMask, tokenTypeIds, positionIds); + return ForwardAll(inputIds, attentionMask, tokenTypeIds, positionIds, normalize); } void BertModel::WarmUp() { printf("Warmup...\n"); - EmbeddingSentence({"1"}); + EmbeddingSentence({"1"}, true); printf("finish.\n"); } + + int BertModel::Forward(const fastllm::Data &inputIds, const fastllm::Data &attentionMask, + const fastllm::Data &positionIds, std::vector> &pastKeyValues, + const GenerationConfig &generationConfig, const LastTokensManager &lastTokens, + std::vector *retLogits) { + return -1; + } + + std::string BertModel::MakeInput(const std::string &history, int round, const std::string &input) + { + return ""; + } + + std::string BertModel::MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output) + { + return ""; + } } \ No newline at end of file diff --git a/tools/fastllm_pytools/llm.py b/tools/fastllm_pytools/llm.py index 9f8bd5b9..d1b8a68f 100644 --- a/tools/fastllm_pytools/llm.py +++ b/tools/fastllm_pytools/llm.py @@ -93,6 +93,9 @@ fastllm_lib.get_max_input_len_llm_model.argtypes = [ctypes.c_int] fastllm_lib.get_max_input_len_llm_model.restype = ctypes.c_int +fastllm_lib.embedding_sentence.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_bool, ctypes.POINTER(ctypes.c_int)] +fastllm_lib.embedding_sentence.restype = ctypes.POINTER(ctypes.c_float) + def softmax(a): max_value = a[0] for i in a: @@ -1081,6 +1084,15 @@ def set_verbose(self, verbose: int): def get_max_input_len(self): return fastllm_lib.get_max_input_len_llm_model(self.model) + def embedding_sentence(self, input: str, normalize = True): + embedding_len = ctypes.c_int(0) + embedding_c_float = fastllm_lib.embedding_sentence(self.model, input.encode(), normalize, embedding_len) + embedding = [] + for i in range(embedding_len.value): + embedding.append(embedding_c_float[i]) + #print("{:.7f}".format(embedding[i]), end=" ") + return embedding + def GraphNode(name: str, type: str = "data", value = None): diff --git a/tools/fastllm_pytools/openai_server/fastllm_embed.py b/tools/fastllm_pytools/openai_server/fastllm_embed.py new file mode 100644 index 00000000..ba959ccb --- /dev/null +++ b/tools/fastllm_pytools/openai_server/fastllm_embed.py @@ -0,0 +1,18 @@ +import asyncio +import logging +import json +import traceback +from fastapi import Request + +from .protocal.openai_protocol import * +from ftllm import llm + +class FastLLmEmbed: + def __init__(self, + model_name, + model): + self.model_name = model_name + self.model = model + + def embedding_sentence(self, request: EmbedRequest, raw_request: Request): + return self.model.embedding_sentence(request.inputs, request.normalize) diff --git a/tools/fastllm_pytools/openai_server/protocal/openai_protocol.py b/tools/fastllm_pytools/openai_server/protocal/openai_protocol.py index 5631a258..36cbe989 100644 --- a/tools/fastllm_pytools/openai_server/protocal/openai_protocol.py +++ b/tools/fastllm_pytools/openai_server/protocal/openai_protocol.py @@ -199,4 +199,11 @@ class CompletionStreamResponse(BaseModel): object: str = "text_completion" created: int = Field(default_factory=lambda: int(time.time())) model: str - choices: List[CompletionResponseStreamChoice] \ No newline at end of file + choices: List[CompletionResponseStreamChoice] + +class EmbedRequest(BaseModel): + inputs: str + normalize: Optional[bool] + prompt_name: Optional[str] + truncate: Optional[bool] + truncation_direction: Optional[str] diff --git a/tools/fastllm_pytools/server.py b/tools/fastllm_pytools/server.py index 2c6ae3a4..7ab1e4f6 100644 --- a/tools/fastllm_pytools/server.py +++ b/tools/fastllm_pytools/server.py @@ -9,6 +9,7 @@ from .openai_server.protocal.openai_protocol import * from .openai_server.fastllm_completion import FastLLmCompletion +from .openai_server.fastllm_embed import FastLLmEmbed from .util import make_normal_parser from .util import make_normal_llm_model @@ -30,6 +31,7 @@ def parse_args(): ) fastllm_completion:FastLLmCompletion +fastllm_embed:FastLLmEmbed @app.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest, @@ -47,6 +49,12 @@ async def create_chat_completion(request: ChatCompletionRequest, assert isinstance(generator, ChatCompletionResponse) return JSONResponse(content = generator.model_dump()) +@app.post("/v1/embed") +async def create_embed(request: EmbedRequest, + raw_request: Request): + embedding = fastllm_embed.embedding_sentence(request, raw_request) + return JSONResponse(embedding) + def init_logging(log_level = logging.INFO, log_file:str = None): logging_format = '%(asctime)s %(process)d %(filename)s[line:%(lineno)d] %(levelname)s: %(message)s' root = logging.getLogger() @@ -64,6 +72,6 @@ def init_logging(log_level = logging.INFO, log_file:str = None): logging.info(args) model = make_normal_llm_model(args) model.set_verbose(True) - fastllm_completion = FastLLmCompletion(model_name = args.model_name, - model = model) + fastllm_completion = FastLLmCompletion(model_name = args.model_name, model = model) + fastllm_embed = FastLLmEmbed(model_name = args.model_name, model = model) uvicorn.run(app, host = args.host, port = args.port) diff --git a/tools/src/pytools.cpp b/tools/src/pytools.cpp index 4f1a8671..8bd133ad 100644 --- a/tools/src/pytools.cpp +++ b/tools/src/pytools.cpp @@ -425,4 +425,14 @@ extern "C" { auto model = models.GetModel(modelId); return model->max_positions; } + + DLL_EXPORT float* embedding_sentence(int modelId, char *input, bool normalize, int *embeddingLen) { + fastllm::BertModel *model = (fastllm::BertModel*)models.GetModel(modelId); + std::string str(input); + std::vector result = model->EmbeddingSentence(str, normalize); + float *fvalue = new float[result.size()]; + memcpy(fvalue, result.data(), result.size() * sizeof(float)); + *embeddingLen = result.size(); + return fvalue; + } };