Skip to content

Commit

Permalink
增加embed python api接口,bert模型增加normalize
Browse files Browse the repository at this point in the history
  • Loading branch information
jiewlmrh committed Sep 20, 2024
1 parent 4aafb94 commit 04aa027
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 20 deletions.
27 changes: 19 additions & 8 deletions include/models/bert.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "fastllm.h"

namespace fastllm {
class BertModel {
class BertModel: public basellm {
public:
BertModel() {};

Expand All @@ -17,23 +17,34 @@ namespace fastllm {
void InitParams(); // 初始化参数信息

// 推理
std::vector <std::vector <float> > Forward(
std::vector <std::vector <float> > ForwardAll(
const Data &inputIds,
const Data &attentionMask,
const Data &tokenTypeIds,
const Data &positionIds);
const Data &positionIds,
bool normalize);

// 推理
virtual int Forward(
const Data &inputIds,
const Data &attentionMask,
const Data &positionIds,
std::vector <std::pair <Data, Data> > &pastKeyValues,
const GenerationConfig &generationConfig = GenerationConfig(),
const LastTokensManager &lastTokens = LastTokensManager(),
std::vector <float> *logits = nullptr);

std::vector <float> EmbeddingSentence(const std::string &context);
std::vector <float> EmbeddingSentence(const std::string &context, bool normalize);

std::vector <std::vector <float> > EmbeddingSentenceBatch(const std::vector <std::string> &contexts);
std::vector <std::vector <float> > EmbeddingSentenceBatch(const std::vector <std::string> &contexts, bool normalize);

void LoadFromFile(const std::string &fileName); // 从文件读取

void SaveLowBitModel(const std::string &fileName, int bit); // 存储成量化模型
void WarmUp(); // 预热

void SaveModel(const std::string &fileName); // 直接导出
virtual std::string MakeInput(const std::string &history, int round, const std::string &input);

void WarmUp(); // 预热
virtual std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output);

std::string model_type;

Expand Down
11 changes: 9 additions & 2 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,15 @@ namespace fastllm {
std::unique_ptr<fastllm::basellm> CreateLLMModelFromFile(const std::string &fileName) {
std::string modelType = GetModelTypeFromFile(fileName);
basellm *model = CreateModelWithType(modelType);
model->LoadFromFile(fileName);
model->WarmUp();
if(modelType == "bert"){
BertModel *bertModel = (BertModel*)model;
bertModel->weight.tokenizer.type = Tokenizer::BERT;
bertModel->LoadFromFile(fileName);
bertModel->WarmUp();
}else{
model->LoadFromFile(fileName);
model->WarmUp();
}
return std::unique_ptr<fastllm::basellm> (model);
}

Expand Down
46 changes: 39 additions & 7 deletions src/models/bert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,25 @@ namespace fastllm {
this->head_dim = embed_dim / num_attention_heads;
}

std::vector <std::vector <float> > BertModel::Forward(
void Normalize(float *data, int dataLen)
{
float sum = 0.0;
for(int i = 0; i < dataLen; i++)
sum += data[i] * data[i];

if (sum < 1e-6) sum = 1e-6;
else sum = sqrt(sum);

for(int i = 0; i < dataLen; i++)
data[i] = data[i] / sum;
}

std::vector <std::vector <float> > BertModel::ForwardAll(
const Data &inputIds,
const Data &attentionMask,
const Data &tokenTypeIds,
const Data &positionIds) {
const Data &positionIds,
bool normalize) {
// embedding
Data inputEmbeddings, tokenTypeEmbeddings, positionIdEmbeddings;
Embedding(inputIds, this->weight["embeddings.word_embeddings.weight"], inputEmbeddings);
Expand Down Expand Up @@ -114,19 +128,20 @@ namespace fastllm {
std::vector <std::vector <float> > ret;
ret.resize(batch, std::vector <float> (outputDim, 0.0f));
for (int i = 0; i < batch; i++) {
if(normalize) Normalize(fret + i * outputDim, outputDim);
memcpy(ret[i].data(), fret + i * outputDim, outputDim * sizeof(float));
}

return ret;
}

std::vector <float> BertModel::EmbeddingSentence(const std::string &context) {
std::vector <float> BertModel::EmbeddingSentence(const std::string &context, bool normalize) {
std::vector <std::string> contexts;
contexts.push_back(context);
return EmbeddingSentenceBatch(contexts)[0];
return EmbeddingSentenceBatch(contexts, normalize)[0];
}

std::vector <std::vector <float> > BertModel::EmbeddingSentenceBatch(const std::vector <std::string> &contexts) {
std::vector <std::vector <float> > BertModel::EmbeddingSentenceBatch(const std::vector <std::string> &contexts, bool normalize) {
int batch = contexts.size(), len = 0;
std::vector <std::vector <int> > tokens;
tokens.resize(batch);
Expand Down Expand Up @@ -158,12 +173,29 @@ namespace fastllm {
fastllm::Data positionIds = fastllm::Data(fastllm::DataType::FLOAT32, {batch, len}, position_ids);

// printf("bs = %d, len = %d\n", batch, len); ClearProfiler(); Forward(inputIds, attentionMask, tokenTypeIds, positionIds); PrintProfiler();
return Forward(inputIds, attentionMask, tokenTypeIds, positionIds);
return ForwardAll(inputIds, attentionMask, tokenTypeIds, positionIds, normalize);
}

void BertModel::WarmUp() {
printf("Warmup...\n");
EmbeddingSentence({"1"});
EmbeddingSentence({"1"}, true);
printf("finish.\n");
}

int BertModel::Forward(const fastllm::Data &inputIds, const fastllm::Data &attentionMask,
const fastllm::Data &positionIds, std::vector<std::pair<Data, Data>> &pastKeyValues,
const GenerationConfig &generationConfig, const LastTokensManager &lastTokens,
std::vector <float> *retLogits) {
return -1;
}

std::string BertModel::MakeInput(const std::string &history, int round, const std::string &input)
{
return "";
}

std::string BertModel::MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output)
{
return "";
}
}
12 changes: 12 additions & 0 deletions tools/fastllm_pytools/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@
fastllm_lib.get_max_input_len_llm_model.argtypes = [ctypes.c_int]
fastllm_lib.get_max_input_len_llm_model.restype = ctypes.c_int

fastllm_lib.embedding_sentence.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_bool, ctypes.POINTER(ctypes.c_int)]
fastllm_lib.embedding_sentence.restype = ctypes.POINTER(ctypes.c_float)

def softmax(a):
max_value = a[0]
for i in a:
Expand Down Expand Up @@ -1081,6 +1084,15 @@ def set_verbose(self, verbose: int):
def get_max_input_len(self):
return fastllm_lib.get_max_input_len_llm_model(self.model)

def embedding_sentence(self, input: str, normalize = True):
embedding_len = ctypes.c_int(0)
embedding_c_float = fastllm_lib.embedding_sentence(self.model, input.encode(), normalize, embedding_len)
embedding = []
for i in range(embedding_len.value):
embedding.append(embedding_c_float[i])
#print("{:.7f}".format(embedding[i]), end=" ")
return embedding

def GraphNode(name: str,
type: str = "data",
value = None):
Expand Down
18 changes: 18 additions & 0 deletions tools/fastllm_pytools/openai_server/fastllm_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import asyncio
import logging
import json
import traceback
from fastapi import Request

from .protocal.openai_protocol import *
from ftllm import llm

class FastLLmEmbed:
def __init__(self,
model_name,
model):
self.model_name = model_name
self.model = model

def embedding_sentence(self, request: EmbedRequest, raw_request: Request):
return self.model.embedding_sentence(request.inputs, request.normalize)
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,11 @@ class CompletionStreamResponse(BaseModel):
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseStreamChoice]
choices: List[CompletionResponseStreamChoice]

class EmbedRequest(BaseModel):
inputs: str
normalize: Optional[bool]
prompt_name: Optional[str]
truncate: Optional[bool]
truncation_direction: Optional[str]
12 changes: 10 additions & 2 deletions tools/fastllm_pytools/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .openai_server.protocal.openai_protocol import *
from .openai_server.fastllm_completion import FastLLmCompletion
from .openai_server.fastllm_embed import FastLLmEmbed
from .util import make_normal_parser
from .util import make_normal_llm_model

Expand All @@ -30,6 +31,7 @@ def parse_args():
)

fastllm_completion:FastLLmCompletion
fastllm_embed:FastLLmEmbed

@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
Expand All @@ -47,6 +49,12 @@ async def create_chat_completion(request: ChatCompletionRequest,
assert isinstance(generator, ChatCompletionResponse)
return JSONResponse(content = generator.model_dump())

@app.post("/v1/embed")
async def create_embed(request: EmbedRequest,
raw_request: Request):
embedding = fastllm_embed.embedding_sentence(request, raw_request)
return JSONResponse(embedding)

def init_logging(log_level = logging.INFO, log_file:str = None):
logging_format = '%(asctime)s %(process)d %(filename)s[line:%(lineno)d] %(levelname)s: %(message)s'
root = logging.getLogger()
Expand All @@ -64,6 +72,6 @@ def init_logging(log_level = logging.INFO, log_file:str = None):
logging.info(args)
model = make_normal_llm_model(args)
model.set_verbose(True)
fastllm_completion = FastLLmCompletion(model_name = args.model_name,
model = model)
fastllm_completion = FastLLmCompletion(model_name = args.model_name, model = model)
fastllm_embed = FastLLmEmbed(model_name = args.model_name, model = model)
uvicorn.run(app, host = args.host, port = args.port)
10 changes: 10 additions & 0 deletions tools/src/pytools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,4 +425,14 @@ extern "C" {
auto model = models.GetModel(modelId);
return model->max_positions;
}

DLL_EXPORT float* embedding_sentence(int modelId, char *input, bool normalize, int *embeddingLen) {
fastllm::BertModel *model = (fastllm::BertModel*)models.GetModel(modelId);
std::string str(input);
std::vector <float> result = model->EmbeddingSentence(str, normalize);
float *fvalue = new float[result.size()];
memcpy(fvalue, result.data(), result.size() * sizeof(float));
*embeddingLen = result.size();
return fvalue;
}
};

0 comments on commit 04aa027

Please sign in to comment.